打印

训练机器推理之三 模型导出

[复制链接]
997|0
手机看帖
扫描二维码
随时随地手机跟帖
跳转到指定楼层
楼主
丙丁先生|  楼主 | 2024-7-5 07:37 | 只看该作者 回帖奖励 |倒序浏览 |阅读模式
导图方法使用 linger 导出 onnx
如果调用linger.init(...)接口后,使用torch.onnx.export会被自动替换为linger.onnx.export进行调用,即torch.onnx.export = linger.onnx.export
import linger.....linger.init(...)torch.onnx.export(...) # 实际上调用的是 linger.onnx.exportCopy


  • 导出支持动态输入大小的图
torch.onnx.export(torch_model,               # model being run                  x,                         # model input (or a tuple for multiple inputs)                  "super_resolution.onnx",   # where to save the model (can be a file or file-like object)                  export_params=True,        # store the trained parameter weights inside the model file                  opset_version=12,          # the ONNX version to export the model to                  do_constant_folding=True,  # whether to execute constant folding for optimization                  input_names = ['input',   # the model's input names                  output_names = ['output', # the model's output names                  dynamic_axes={'input' : {0 : 'batch_size'},    # variable lenght axes                                'output' : {0 : 'batch_size'}})Copy


其中 dynamic_axes使用有几种形式:
  • 仅提供索引信息
    例如下例子表示 把input_1的0,2,3维作为动态输入,第1仍然保持固定输入,'input_2'第0维作为动态输入,output的0,1维作为动态输入,对于动态输入的维度,PyTorch会自动给该维度生成一个名字以替换维度信息
dynamic_axes = {'input_1':[0, 2, 3,                  'input_2':[0,                  'output':[0, 1}Copy


  • 对于给定的索引信息,指定名字
    对于input_1,指定动态维0、1、2的名字分别为batch、width、height,其他输入同理
dynamic_axes = {'input_1':{0:'batch',                             1:'width',                             2:'height'},                  'input_2':{0:'batch'},                  'output':{0:'batch',                            1:'detections'}Copy


  • 将上面两者进行混用
dynamic_axes = {'input_1':[0, 2, 3,                  'input_2':{0:'batch'},                  'output':[0,1}Copy


  • 带有可选参数的导出
    例如想命名输入输出tensor名字或者比较超前的op可以加上torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
import torchimport torch.onnxtorch_model = ...# set the model to inference modetorch_model.eval()dummy_input = torch.randn(1,3,244,244)torch.onnx.export(torch_model,dummy_input,"test.onnx",                    opset_version=11,input_names=["input",output_names=["output",operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)Copy


  • torch.no_grad()报错
    torch 1.6 版本后,需要with torch.no_grad(),即
import torchimport torch.onnxtorch_model = ...# set the model to inference modetorch_model.eval()dummy_input = torch.randn(1,3,244,244)with torch.no_grad():    torch.onnx.export(torch_model,dummy_input,"test.onnx",                        opset_version=11,input_names=["input",output_names=["output",operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)Copy


警告:如果不使用with torch.no_grad(),则会报以下错误
RuntimeError: isDifferentiableType(variable.scalar_type()) INTERNAL ASSERT FAILED at "/pytorch/torch/csrc/autograd/functions/utils.h":59, please report a bug to PyTorch.






使用特权

评论回复

相关帖子

发新帖 我要提问
您需要登录后才可以回帖 登录 | 注册

本版积分规则

616

主题

2137

帖子

5

粉丝