以第一个需要的中间层为例,这个层叫/relu/Relu
,输出叫/relu/Relu_output_0
,我只要遍历整个模型,把对应层的输出加到定义中就行
import onnx
import onnxruntime
import torch.onnx
layer_names = ['/relu/Relu',
'/layer1/layer1.2/relu_2/Relu',
'/layer2/layer2.3/relu_2/Relu',
]
onnx_model = onnx.load("yundikan/unet_model_wj.onnx")
for node in onnx_model.graph.node:
if node.name not in layer_names:
continue
for output in node.output:
onnx_model.graph.output.extend([onnx.ValueInfoProto(name=output)])
print('node is: ', onnx.ValueInfoProto(name=output))
providers = [
("CUDAExecutionProvider", {"cudnn_conv_algo_search": "DEFAULT"}),
'CPUExecutionProvider',
]
session = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=providers)
ort_inputs = {session.get_inputs()[0].name: torch.FloatTensor(1, 3, 1024, 1024).numpy()}
ort_out = session.run(None, ort_inputs)
评论区