插件窝 干货文章 如何将PyTorch ResNet50模型导出为支持动态batch size的ONNX格式?

如何将PyTorch ResNet50模型导出为支持动态batch size的ONNX格式?

batch ONNX 模型 输入 209    来源:    2025-03-11

要将PyTorch的ResNet50模型导出为支持动态batch size的ONNX格式,需通过dynamic_axes参数指定动态维度。以下是详细步骤和代码示例:


步骤说明

  1. 加载预训练模型并设置为评估模式
    确保模型处于推理状态(禁用Dropout和BatchNorm的随机性)。

  2. 创建虚拟输入(Dummy Input)
    生成一个符合输入形状的张量,batch size可为任意值(如1),因为后续会动态化。

  3. 导出ONNX模型
    使用torch.onnx.export函数,通过dynamic_axes参数标记输入和输出的动态维度(第0维为batch)。

  4. 验证ONNX模型
    使用ONNX Runtime测试不同batch size的输入,确保模型正常工作。


完整代码示例

import torch
import torchvision.models as models

# 1. 加载ResNet50模型并设为评估模式
model = models.resnet50(pretrained=True)
model.eval()

# 2. 创建虚拟输入(batch_size=1,但导出时会设为动态)
dummy_input = torch.randn(1, 3, 224, 224)

# 3. 导出ONNX模型(关键:设置dynamic_axes)
torch.onnx.export(
    model,
    dummy_input,
    "resnet50_dynamic.onnx",
    input_names=["input"],   # 输入节点名称
    output_names=["output"], # 输出节点名称
    dynamic_axes={
        "input": {0: "batch_size"},  # 第0维(batch)动态
        "output": {0: "batch_size"}  # 输出同样支持动态batch
    },
    opset_version=13  # 推荐使用较高版本确保兼容性
)

print("导出成功!")

验证导出的ONNX模型

使用ONNX Runtime测试不同batch size的输入是否兼容:

import onnx
import onnxruntime as ort
import numpy as np

# 检查模型格式是否正确
onnx_model = onnx.load("resnet50_dynamic.onnx")
onnx.checker.check_model(onnx_model)

# 测试不同batch size(如1, 4, 8)
for batch_size in [1, 4, 8]:
    # 生成随机输入数据
    dummy_input = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)

    # 创建ONNX Runtime会话
    ort_session = ort.InferenceSession("resnet50_dynamic.onnx")

    # 执行推理
    outputs = ort_session.run(None, {"input": dummy_input})
    print(f"Batch size {batch_size} 推理成功!输出形状:{outputs[0].shape}")

关键参数解释

  • dynamic_axes
    指定输入和输出的动态维度,格式为字典:

    • "input": {0: "batch_size"}:输入的第0维(batch)为动态,命名为batch_size
    • "output": {0: "batch_size"}:输出的第0维同步为动态。
  • opset_version
    建议设置为11或更高版本(如13),以支持更多ONNX算子。


常见问题排查

  1. 导出后模型仍是静态batch
    检查dynamic_axes是否正确定义了输入和输出的第0维。

  2. 推理时报维度错误
    确保输入数据形状为(batch_size, 3, 224, 224),且数据类型为float32

  3. 自定义模型兼容性
    如果模型包含非标准操作(如自定义层),需确保这些操作在ONNX中有对应实现,或通过扩展ONNX支持。


通过以上步骤,即可成功导出支持动态batch size的ResNet50 ONNX模型,并验证其灵活性。