标签搜索

目 录CONTENT

文章目录

获取ONNX模型中间层的输出

陈铭
2024-06-07 / 0 评论 / 0 点赞 / 28 阅读 / 175 字 / 正在检测是否收录...

以第一个需要的中间层为例,这个层叫/relu/Relu,输出叫/relu/Relu_output_0,我只要遍历整个模型,把对应层的输出加到定义中就行
image

import onnx
import onnxruntime
import torch.onnx

layer_names = ['/relu/Relu',
               '/layer1/layer1.2/relu_2/Relu',
               '/layer2/layer2.3/relu_2/Relu',
               ]

onnx_model = onnx.load("yundikan/unet_model_wj.onnx")
for node in onnx_model.graph.node:
    if node.name not in layer_names:
        continue
    for output in node.output:
        onnx_model.graph.output.extend([onnx.ValueInfoProto(name=output)])
        print('node is: ', onnx.ValueInfoProto(name=output))

providers = [
    ("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),
    'CPUExecutionProvider',
]
session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=providers)
ort_inputs = {session.get_inputs()[0].name: torch.FloatTensor(1, 3, 1024, 1024).numpy()}
ort_out = session.run(None, ort_inputs)
0

评论区