要将PyTorch的ResNet50模型导出为支持动态batch size的ONNX格式,需通过dynamic_axes
参数指定动态维度。以下是详细步骤和代码示例:
加载预训练模型并设置为评估模式
确保模型处于推理状态(禁用Dropout和BatchNorm的随机性)。
创建虚拟输入(Dummy Input)
生成一个符合输入形状的张量,batch size可为任意值(如1),因为后续会动态化。
导出ONNX模型
使用torch.onnx.export
函数,通过dynamic_axes
参数标记输入和输出的动态维度(第0维为batch)。
验证ONNX模型
使用ONNX Runtime测试不同batch size的输入,确保模型正常工作。
import torch
import torchvision.models as models
# 1. 加载ResNet50模型并设为评估模式
model = models.resnet50(pretrained=True)
model.eval()
# 2. 创建虚拟输入(batch_size=1,但导出时会设为动态)
dummy_input = torch.randn(1, 3, 224, 224)
# 3. 导出ONNX模型(关键:设置dynamic_axes)
torch.onnx.export(
model,
dummy_input,
"resnet50_dynamic.onnx",
input_names=["input"], # 输入节点名称
output_names=["output"], # 输出节点名称
dynamic_axes={
"input": {0: "batch_size"}, # 第0维(batch)动态
"output": {0: "batch_size"} # 输出同样支持动态batch
},
opset_version=13 # 推荐使用较高版本确保兼容性
)
print("导出成功!")
使用ONNX Runtime测试不同batch size的输入是否兼容:
import onnx
import onnxruntime as ort
import numpy as np
# 检查模型格式是否正确
onnx_model = onnx.load("resnet50_dynamic.onnx")
onnx.checker.check_model(onnx_model)
# 测试不同batch size(如1, 4, 8)
for batch_size in [1, 4, 8]:
# 生成随机输入数据
dummy_input = np.random.randn(batch_size, 3, 224, 224).astype(np.float32)
# 创建ONNX Runtime会话
ort_session = ort.InferenceSession("resnet50_dynamic.onnx")
# 执行推理
outputs = ort_session.run(None, {"input": dummy_input})
print(f"Batch size {batch_size} 推理成功!输出形状:{outputs[0].shape}")
dynamic_axes
指定输入和输出的动态维度,格式为字典:
"input": {0: "batch_size"}
:输入的第0维(batch)为动态,命名为batch_size
。"output": {0: "batch_size"}
:输出的第0维同步为动态。opset_version
建议设置为11或更高版本(如13),以支持更多ONNX算子。
导出后模型仍是静态batch
检查dynamic_axes
是否正确定义了输入和输出的第0维。
推理时报维度错误
确保输入数据形状为(batch_size, 3, 224, 224)
,且数据类型为float32
。
自定义模型兼容性
如果模型包含非标准操作(如自定义层),需确保这些操作在ONNX中有对应实现,或通过扩展ONNX支持。
通过以上步骤,即可成功导出支持动态batch size的ResNet50 ONNX模型,并验证其灵活性。