当前位置:首页>编程日记>正文

mxnet 查看中间层结果

本站寻求有缘人接手,详细了解请联系站长QQ1493399855


import mxnet as mx
from mxnet import nd
from mxnet.gluon import nnmx.cpu(), mx.gpu(), mx.gpu(0)

查看mxnet网络所有节点

    import jsonwith open('./model-symbol.json', 'r', encoding='utf8') as fp:conf = json.load(fp)# conf = json.loads(symbol.tojson())nodes = conf["nodes"]heads = set(conf["heads"][0])symbols = []for i, node in enumerate(nodes):op = node["op"]if op == "null" and i > 0:continueif op != "null" or i in heads:print(node['name'])

cpu模式下,只能返回一层

gpu(0)模式下,能返回多层结果。

查看权重
在训练过程中,有时候我们为了debug而需要查看中间某一步的权重信息,在mxnet中,我们可以很方便的调用get_params()方法来得到权重信息。

'''
查看权重示例代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents

'''
查看权重示例代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
'''
import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #创建Module
mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
mod.set_params(arg_params,aux_params)
import numpy as np
import cv2
def get_image(filename):img = cv2.imread(filename)img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img = cv2.resize(img,(224,224))img = np.swapaxes(img,0,2)img = np.swapaxes(img,1,2)img = img[np.newaxis,:]return img
from collections import namedtuple
Batch = namedtuple('Batch',['data'])
img = get_image('val_1000/0.jpg') #获取图片
mod.forward(Batch([mx.nd.array(img)])) #预测结果
################################################
#debug模式下,获取权重信息
keys = mod.get_params()[0].keys() # 列出所有权重名称
conv_w = mod.get_params()[0]['conv0_weight'] #获取想要查看的权重信息,如conv_weight
print conv_w.asnumpy() #查看具体数值
################################################
prob = mod.get_outputs()[0].asnumpy()
y = np.argsort(np.squeeze(prob))[::-1]
print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))



查看中间输出结果
由于mxnet的网络由symbol组成,而symbol又属于符号式编程,所以我们不能像上面查看权重一样直接查看,我们需要把我们想看的输出结果保存下来。

'''
方法一
查看中间结果代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
'''
import mxnet as mx
net = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=128)
net = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=64)
out = mx.symbol.SoftmaxOutput(data=net, name='softmax')
# 通过把两个输出组成一个group来得到自己需要查看的中间层输出结果
group = mx.symbol.Group([fc1, out]) 
print group.list_outputs()


方法二
有时候我们使用别人的模型,所以无法像方法一一样在定义模型的时候就确定需要查看的中间层输出结果,
这时候我们使用get_internals()方法来查找自己需要查看的中间层
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
 

这个出来是list,存放的不同的层的结果。

prob = mod.get_outputs()


import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
########################################################################
args = sym.get_internals().list_outputs() #获得所有中间输出
internals = model.symbol.get_internals()
fc1 = internals['fc1_output']
conv = internals['stage4_unit3_conv1_output']
group = mx.symbol.Group([fc1, sym, conv])  #把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
#########################################################################
mod = mx.mod.Module(symbol=group,context=mx.gpu()) #创建Module
mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
mod.set_params(arg_params,aux_params)
import numpy as np
import cv2
def get_image(filename):img = cv2.imread(filename)img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img = cv2.resize(img,(224,224))img = np.swapaxes(img,0,2)img = np.swapaxes(img,1,2)img = img[np.newaxis,:]return img
from collections import namedtuple
Batch = namedtuple('Batch',['data'])
img = get_image('val_1000/0.jpg') #获取图片
mod.forward(Batch([mx.nd.array(img)])) #预测结果
prob = mod.get_outputs()[0].asnumpy()
y = np.argsort(np.squeeze(prob))[::-1]
print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))


原文链接:https://blog.csdn.net/u010414386/article/details/55668880

打印所有层输出:

import mxnet as mxdef get_output_symbol(symbol):"""Parameters----------symbol: SymbolSymbol to be visualized."""import jsonfrom mxnet.symbol.symbol import Symbolif not isinstance(symbol, Symbol):raise TypeError("symbol must be Symbol")conf = json.loads(symbol.tojson())nodes = conf["nodes"]heads = set(conf["heads"][0])symbols = []for i, node in enumerate(nodes):op = node["op"]if op == "null" and i > 0:continueif op != "null" or i in heads:symbols.append(node['name'])return symbolsdef debug_model(model):# prepare data 准备输入数据input_blob=mx.nd.zeros(shape=(1,3,112,112),ctx=mx.cpu())db = mx.io.DataBatch(data=(input_blob,))# get output symbol 找到特征层,获取输出节点symbols = get_output_symbol(model.symbol)symbols = [x for x in symbols if x != 'data']arg_params, aux_params = model.get_params()internals = model.symbol.get_internals()outputs = internals.list_outputs()symbols_output_name = [x + '_output' for x in symbols]symbols_output = [internals[x] for x in symbols_output_name]# 重建符号与模型group = mx.symbol.Group(symbols_output)mod = mx.mod.Module(symbol=group, context=mx.cpu())mod.bind(data_shapes=[('data', (1, 3, 112, 112))])  # 绑定输入shapemod.set_params(arg_params, aux_params)mod.forward(db, is_train=False)output = mod.get_outputs()output_dict = {k: v.asnumpy() for k, v in zip(symbols, output)}# 保存结果import osfrom collections import Iterableif not os.path.exists('output'):os.mkdir('output')for k, v in output_dict.items():with open('output/{}.txt'.format(k), 'w') as f:print('Shape is {}, data type is {}'.format(v.shape, v.dtype), file=f)for i, batch in enumerate(v):print('Batch {}:'.format(i), file=f)for j, channel in enumerate(batch):print('{}Channel {}:'.format(' ' * 4, j), file=f)if isinstance(channel, Iterable):for k, width in enumerate(channel):print(' ' * 8, file=f, end='')for m, height in enumerate(width):print(height, end='  ', file=f)print(file=f)else:print(' ' * 8 + str(channel), file=f)# 加载与训练模型
def get_model(ctx, image_size, model_str, layer):_vec = model_str.split(',')assert len(_vec)==2prefix = _vec[0]epoch = int(_vec[1])print('loading',prefix, epoch)sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)all_layers = sym.get_internals()sym = all_layers[layer+'_output']model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])model.set_params(arg_params, aux_params)# 打印输出shapearg_shape, out_shape, _ = sym.infer_shape(data=(1, 3, image_size[0], image_size[1]))mx.viz.print_summary(sym, {'data': (1, 3, image_size[0], image_size[1])})return model
if __name__=='__main__':model = get_model(mx.cpu(), (112, 112), 'model-y1-test2/model,0', 'fc1')	debug_model(model)

http://www.coolblog.cn/news/ab3ec0df04aadae0.html

相关文章:

  • asp多表查询并显示_SpringBoot系列(五):SpringBoot整合Mybatis实现多表关联查询
  • s7day2学习记录
  • 【求锤得锤的故事】Redis锁从面试连环炮聊到神仙打架。
  • 矿Spring入门Demo
  • 拼音怎么写_老师:不会写的字用圈代替,看到孩子试卷,网友:人才
  • Linux 实时流量监测(iptraf中文图解)
  • Win10 + Python + GPU版MXNet + VS2015 + RTools + R配置
  • 美颜
  • shell访问php文件夹,Shell获取某目录下所有文件夹的名称
  • 如何优雅的实现 Spring Boot 接口参数加密解密?
  • LeCun亲授的深度学习入门课:从飞行器的发明到卷积神经网络
  • Mac原生Terminal快速登录ssh
  • 法拉利虚拟学院2010 服务器,法拉利虚拟学院2010
  • 支撑微博千亿调用的轻量级RPC框架:Motan
  • mysql commit 机制_1024MySQL事物提交机制
  • java受保护的数据与_Javascript类定义语法,私有成员、受保护成员、静态成员等介绍...
  • 2019-9
  • jquery 使用小技巧
  • 科学计算工具NumPy(3):ndarray的元素处理
  • vscode pylint 错误_将实际未错误的py库添加到pylint白名单
  • 工程师在工作电脑存 64G 不雅文件,被公司开除后索赔 41 万,结果…
  • linux批量创建用户和密码
  • js常用阻止冒泡事件
  • 气泡图在开源监控工具中的应用效果
  • newinsets用法java_Java XYPlot.setInsets方法代碼示例
  • 各类型土地利用图例_划重点!国土空间总体规划——土地利用
  • php 启动服务器监听
  • dubbo简单示例
  • Ubuntu13.10:[3]如何开启SSH SERVER服务
  • [iptables]Redhat 7.2下使用iptables实现NAT
  • Django View(视图系统)
  • 【设计模式】 模式PK:策略模式VS状态模式
  • CSS小技巧——CSS滚动条美化
  • JS实现-页面数据无限加载
  • 最新DOS大全
  • 阿里巴巴分布式服务框架 Dubbo
  • 阿里大鱼.net core 发送短信
  • Sorenson Capital:值得投资的 5 种 AI 技术
  • 程序员入错行怎么办?
  • Arm芯片的新革命在缓缓上演
  • 两张超级大表join优化
  • 第九天函数
  • Linux软件安装-----apache安装
  • HDU 5988 最小费用流
  • 《看透springmvc源码分析与实践》读书笔记一
  • 通过Spark进行ALS离线和Stream实时推荐
  • nagios自写插件—check_file
  • python3 错误 Max retries exceeded with url 解决方法
  • 正式开课!如何学习相机模型与标定?(单目+双目+鱼眼+深度相机)
  • 行为模式之Template Method模式