当前位置:首页>编程日记>正文

mxnet 查看中间层结果


import mxnet as mx
from mxnet import nd
from mxnet.gluon import nnmx.cpu(), mx.gpu(), mx.gpu(0)

查看mxnet网络所有节点

    import jsonwith open('./model-symbol.json', 'r', encoding='utf8') as fp:conf = json.load(fp)# conf = json.loads(symbol.tojson())nodes = conf["nodes"]heads = set(conf["heads"][0])symbols = []for i, node in enumerate(nodes):op = node["op"]if op == "null" and i > 0:continueif op != "null" or i in heads:print(node['name'])

cpu模式下,只能返回一层

gpu(0)模式下,能返回多层结果。

查看权重
在训练过程中,有时候我们为了debug而需要查看中间某一步的权重信息,在mxnet中,我们可以很方便的调用get_params()方法来得到权重信息。

'''
查看权重示例代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents

'''
查看权重示例代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
'''
import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
mod = mx.mod.Module(symbol=sym,context=mx.gpu()) #创建Module
mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
mod.set_params(arg_params,aux_params)
import numpy as np
import cv2
def get_image(filename):img = cv2.imread(filename)img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img = cv2.resize(img,(224,224))img = np.swapaxes(img,0,2)img = np.swapaxes(img,1,2)img = img[np.newaxis,:]return img
from collections import namedtuple
Batch = namedtuple('Batch',['data'])
img = get_image('val_1000/0.jpg') #获取图片
mod.forward(Batch([mx.nd.array(img)])) #预测结果
################################################
#debug模式下,获取权重信息
keys = mod.get_params()[0].keys() # 列出所有权重名称
conv_w = mod.get_params()[0]['conv0_weight'] #获取想要查看的权重信息,如conv_weight
print conv_w.asnumpy() #查看具体数值
################################################
prob = mod.get_outputs()[0].asnumpy()
y = np.argsort(np.squeeze(prob))[::-1]
print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))



查看中间输出结果
由于mxnet的网络由symbol组成,而symbol又属于符号式编程,所以我们不能像上面查看权重一样直接查看,我们需要把我们想看的输出结果保存下来。

'''
方法一
查看中间结果代码
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
'''
import mxnet as mx
net = mx.symbol.Variable('data')
fc1 = mx.symbol.FullyConnected(data=net, name='fc1', num_hidden=128)
net = mx.symbol.Activation(data=fc1, name='relu1', act_type="relu")
net = mx.symbol.FullyConnected(data=net, name='fc2', num_hidden=64)
out = mx.symbol.SoftmaxOutput(data=net, name='softmax')
# 通过把两个输出组成一个group来得到自己需要查看的中间层输出结果
group = mx.symbol.Group([fc1, out]) 
print group.list_outputs()


方法二
有时候我们使用别人的模型,所以无法像方法一一样在定义模型的时候就确定需要查看的中间层输出结果,
这时候我们使用get_internals()方法来查找自己需要查看的中间层
转载时注明地址:http://blog.csdn.net/u010414386?viewmode=contents
 

这个出来是list,存放的不同的层的结果。

prob = mod.get_outputs()


import mxnet as mx
sym, arg_params, aux_params = mx.model.load_checkpoint('resnet-50',0)#载入模型
########################################################################
args = sym.get_internals().list_outputs() #获得所有中间输出
internals = model.symbol.get_internals()
fc1 = internals['fc1_output']
conv = internals['stage4_unit3_conv1_output']
group = mx.symbol.Group([fc1, sym, conv])  #把需要输出的结果按group方式组合起来,这样就可以得到中间层的输出
#########################################################################
mod = mx.mod.Module(symbol=group,context=mx.gpu()) #创建Module
mod.bind(for_training=False,data_shapes=[('data',(1,3,224,224))]) #绑定,此代码为预测代码,所以training参数设为False
mod.set_params(arg_params,aux_params)
import numpy as np
import cv2
def get_image(filename):img = cv2.imread(filename)img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)img = cv2.resize(img,(224,224))img = np.swapaxes(img,0,2)img = np.swapaxes(img,1,2)img = img[np.newaxis,:]return img
from collections import namedtuple
Batch = namedtuple('Batch',['data'])
img = get_image('val_1000/0.jpg') #获取图片
mod.forward(Batch([mx.nd.array(img)])) #预测结果
prob = mod.get_outputs()[0].asnumpy()
y = np.argsort(np.squeeze(prob))[::-1]
print('truth label %d; top-1 predict label %d' % (val_label[0], y[0]))


原文链接:https://blog.csdn.net/u010414386/article/details/55668880

打印所有层输出:

import mxnet as mxdef get_output_symbol(symbol):"""Parameters----------symbol: SymbolSymbol to be visualized."""import jsonfrom mxnet.symbol.symbol import Symbolif not isinstance(symbol, Symbol):raise TypeError("symbol must be Symbol")conf = json.loads(symbol.tojson())nodes = conf["nodes"]heads = set(conf["heads"][0])symbols = []for i, node in enumerate(nodes):op = node["op"]if op == "null" and i > 0:continueif op != "null" or i in heads:symbols.append(node['name'])return symbolsdef debug_model(model):# prepare data 准备输入数据input_blob=mx.nd.zeros(shape=(1,3,112,112),ctx=mx.cpu())db = mx.io.DataBatch(data=(input_blob,))# get output symbol 找到特征层,获取输出节点symbols = get_output_symbol(model.symbol)symbols = [x for x in symbols if x != 'data']arg_params, aux_params = model.get_params()internals = model.symbol.get_internals()outputs = internals.list_outputs()symbols_output_name = [x + '_output' for x in symbols]symbols_output = [internals[x] for x in symbols_output_name]# 重建符号与模型group = mx.symbol.Group(symbols_output)mod = mx.mod.Module(symbol=group, context=mx.cpu())mod.bind(data_shapes=[('data', (1, 3, 112, 112))])  # 绑定输入shapemod.set_params(arg_params, aux_params)mod.forward(db, is_train=False)output = mod.get_outputs()output_dict = {k: v.asnumpy() for k, v in zip(symbols, output)}# 保存结果import osfrom collections import Iterableif not os.path.exists('output'):os.mkdir('output')for k, v in output_dict.items():with open('output/{}.txt'.format(k), 'w') as f:print('Shape is {}, data type is {}'.format(v.shape, v.dtype), file=f)for i, batch in enumerate(v):print('Batch {}:'.format(i), file=f)for j, channel in enumerate(batch):print('{}Channel {}:'.format(' ' * 4, j), file=f)if isinstance(channel, Iterable):for k, width in enumerate(channel):print(' ' * 8, file=f, end='')for m, height in enumerate(width):print(height, end='  ', file=f)print(file=f)else:print(' ' * 8 + str(channel), file=f)# 加载与训练模型
def get_model(ctx, image_size, model_str, layer):_vec = model_str.split(',')assert len(_vec)==2prefix = _vec[0]epoch = int(_vec[1])print('loading',prefix, epoch)sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)all_layers = sym.get_internals()sym = all_layers[layer+'_output']model = mx.mod.Module(symbol=sym, context=ctx, label_names = None)model.bind(data_shapes=[('data', (1, 3, image_size[0], image_size[1]))])model.set_params(arg_params, aux_params)# 打印输出shapearg_shape, out_shape, _ = sym.infer_shape(data=(1, 3, image_size[0], image_size[1]))mx.viz.print_summary(sym, {'data': (1, 3, image_size[0], image_size[1])})return model
if __name__=='__main__':model = get_model(mx.cpu(), (112, 112), 'model-y1-test2/model,0', 'fc1')	debug_model(model)

http://www.coolblog.cn/news/ab3ec0df04aadae0.html

相关文章:

  • asp多表查询并显示_SpringBoot系列(五):SpringBoot整合Mybatis实现多表关联查询
  • s7day2学习记录
  • 【求锤得锤的故事】Redis锁从面试连环炮聊到神仙打架。
  • 矿Spring入门Demo
  • 拼音怎么写_老师:不会写的字用圈代替,看到孩子试卷,网友:人才
  • Linux 实时流量监测(iptraf中文图解)
  • Win10 + Python + GPU版MXNet + VS2015 + RTools + R配置
  • 美颜
  • shell访问php文件夹,Shell获取某目录下所有文件夹的名称
  • 如何优雅的实现 Spring Boot 接口参数加密解密?
  • LeCun亲授的深度学习入门课:从飞行器的发明到卷积神经网络
  • Mac原生Terminal快速登录ssh
  • java受保护的数据与_Javascript类定义语法,私有成员、受保护成员、静态成员等介绍...
  • mysql commit 机制_1024MySQL事物提交机制
  • 支撑微博千亿调用的轻量级RPC框架:Motan
  • jquery 使用小技巧
  • 2019-9
  • 法拉利虚拟学院2010 服务器,法拉利虚拟学院2010
  • vscode pylint 错误_将实际未错误的py库添加到pylint白名单
  • 科学计算工具NumPy(3):ndarray的元素处理
  • 工程师在工作电脑存 64G 不雅文件,被公司开除后索赔 41 万,结果…
  • linux批量创建用户和密码
  • newinsets用法java_Java XYPlot.setInsets方法代碼示例
  • js常用阻止冒泡事件
  • 气泡图在开源监控工具中的应用效果
  • 各类型土地利用图例_划重点!国土空间总体规划——土地利用
  • php 启动服务器监听
  • dubbo简单示例
  • 【设计模式】 模式PK:策略模式VS状态模式
  • [iptables]Redhat 7.2下使用iptables实现NAT
  • Ubuntu13.10:[3]如何开启SSH SERVER服务
  • CSS小技巧——CSS滚动条美化
  • JS实现-页面数据无限加载
  • 阿里巴巴分布式服务框架 Dubbo
  • 最新DOS大全
  • Django View(视图系统)
  • 阿里大鱼.net core 发送短信
  • 程序员入错行怎么办?
  • 两张超级大表join优化
  • 第九天函数
  • Linux软件安装-----apache安装
  • HDU 5988 最小费用流
  • Sorenson Capital:值得投资的 5 种 AI 技术
  • 《看透springmvc源码分析与实践》读书笔记一
  • 正式开课!如何学习相机模型与标定?(单目+双目+鱼眼+深度相机)
  • Arm芯片的新革命在缓缓上演
  • nagios自写插件—check_file
  • python3 错误 Max retries exceeded with url 解决方法
  • 行为模式之Template Method模式
  • 通过Spark进行ALS离线和Stream实时推荐