插件窝 干货文章 如何解决PyTorch ResNet50模型导出ONNX时动态Batch Size的问题?

如何解决PyTorch ResNet50模型导出ONNX时动态Batch Size的问题?

导出 ONNX 模型 输入 499    来源:    2025-03-12

在将PyTorch的ResNet50模型导出为ONNX格式时,如果你希望支持动态的Batch Size,可以通过在导出过程中指定动态的输入维度来实现。以下是具体的步骤和代码示例:

1. 安装必要的库

确保你已经安装了PyTorch和ONNX库。如果没有安装,可以使用以下命令进行安装:

pip install torch onnx

2. 加载预训练的ResNet50模型

首先,加载预训练的ResNet50模型:

import torch
import torchvision.models as models

# 加载预训练的ResNet50模型
model = models.resnet50(pretrained=True)
model.eval()  # 将模型设置为评估模式

3. 定义动态输入维度

在导出ONNX模型时,你可以通过torch.onnx.export函数的dynamic_axes参数来指定动态的输入维度。对于Batch Size,你可以将输入张量的第0维设置为动态。

# 定义一个示例输入张量,Batch Size为1
dummy_input = torch.randn(1, 3, 224, 224)

# 指定动态的Batch Size
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}

# 导出ONNX模型
torch.onnx.export(model,               # 要导出的模型
                  dummy_input,         # 示例输入
                  "resnet50_dynamic.onnx",  # 导出的ONNX文件名
                  export_params=True,  # 导出模型参数
                  opset_version=11,    # ONNX算子集版本
                  do_constant_folding=True,  # 是否进行常量折叠优化
                  input_names=['input'],  # 输入名称
                  output_names=['output'],  # 输出名称
                  dynamic_axes=dynamic_axes)  # 动态维度

4. 验证导出的ONNX模型

你可以使用ONNX Runtime或其他工具来验证导出的ONNX模型是否支持动态Batch Size。

import onnx
import onnxruntime as ort

# 加载导出的ONNX模型
onnx_model = onnx.load("resnet50_dynamic.onnx")
onnx.checker.check_model(onnx_model)

# 使用ONNX Runtime进行推理
ort_session = ort.InferenceSession("resnet50_dynamic.onnx")

# 准备输入数据,Batch Size为2
input_data = torch.randn(2, 3, 224, 224).numpy()

# 进行推理
outputs = ort_session.run(None, {'input': input_data})

print(outputs)

5. 总结

通过上述步骤,你可以成功地将PyTorch的ResNet50模型导出为支持动态Batch Size的ONNX格式。在导出时,使用dynamic_axes参数指定输入和输出的动态维度,特别是Batch Size所在的维度(通常是第0维)。这样,导出的ONNX模型可以在推理时接受不同大小的Batch Size输入。