导图方法使用 linger 导出 onnx 如果调用linger.init(...)接口后,使用torch.onnx.export会被自动替换为linger.onnx.export进行调用,即torch.onnx.export = linger.onnx.export import linger.....linger.init(...)torch.onnx.export(...) # 实际上调用的是 linger.onnx.exportCopy
torch.onnx.export(torch_model, # model being run x, # model input (or a tuple for multiple inputs) "super_resolution.onnx", # where to save the model (can be a file or file-like object) export_params=True, # store the trained parameter weights inside the model file opset_version=12, # the ONNX version to export the model to do_constant_folding=True, # whether to execute constant folding for optimization input_names = ['input', # the model's input names output_names = ['output', # the model's output names dynamic_axes={'input' : {0 : 'batch_size'}, # variable lenght axes 'output' : {0 : 'batch_size'}})Copy
其中 dynamic_axes使用有几种形式: - 仅提供索引信息
例如下例子表示 把input_1的0,2,3维作为动态输入,第1仍然保持固定输入,'input_2'第0维作为动态输入,output的0,1维作为动态输入,对于动态输入的维度,PyTorch会自动给该维度生成一个名字以替换维度信息
dynamic_axes = {'input_1':[0, 2, 3, 'input_2':[0, 'output':[0, 1}Copy
- 对于给定的索引信息,指定名字
对于input_1,指定动态维0、1、2的名字分别为batch、width、height,其他输入同理
dynamic_axes = {'input_1':{0:'batch', 1:'width', 2:'height'}, 'input_2':{0:'batch'}, 'output':{0:'batch', 1:'detections'}Copy
dynamic_axes = {'input_1':[0, 2, 3, 'input_2':{0:'batch'}, 'output':[0,1}Copy
- 带有可选参数的导出
例如想命名输入输出tensor名字或者比较超前的op可以加上torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
import torchimport torch.onnxtorch_model = ...# set the model to inference modetorch_model.eval()dummy_input = torch.randn(1,3,244,244)torch.onnx.export(torch_model,dummy_input,"test.onnx", opset_version=11,input_names=["input",output_names=["output",operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)Copy
- torch.no_grad()报错
torch 1.6 版本后,需要with torch.no_grad(),即
import torchimport torch.onnxtorch_model = ...# set the model to inference modetorch_model.eval()dummy_input = torch.randn(1,3,244,244)with torch.no_grad(): torch.onnx.export(torch_model,dummy_input,"test.onnx", opset_version=11,input_names=["input",output_names=["output",operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)Copy
警告:如果不使用with torch.no_grad(),则会报以下错误 RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/functions/utils.h":59, please report a bug to PyTorch.
|