插件窝 干货文章 如何将PyTorch ResNet50模型导出为支持动态batch size的ONNX格式?

如何将PyTorch ResNet50模型导出为支持动态batch size的ONNX格式?

ONNX 模型 导出 动态 927    来源:    2025-03-13

要将PyTorch中的ResNet50模型导出为支持动态batch size的ONNX格式,你可以按照以下步骤进行操作:

1. 安装必要的库

确保你已经安装了torchonnx库。如果没有安装,可以使用以下命令进行安装:

pip install torch onnx

2. 加载预训练的ResNet50模型

首先,加载预训练的ResNet50模型:

import torch
import torchvision.models as models

# 加载预训练的ResNet50模型
model = models.resnet50(pretrained=True)

3. 设置模型为评估模式

在导出模型之前,确保将模型设置为评估模式:

model.eval()

4. 定义动态batch size的输入

为了支持动态batch size,你需要定义一个带有动态维度的输入张量。可以使用torch.onnx.exportdynamic_axes参数来指定哪些维度是动态的。

# 定义一个示例输入张量,batch size为动态维度
dummy_input = torch.randn(1, 3, 224, 224)  # 这里的1是占位符,实际batch size可以是任意的

# 定义动态维度
dynamic_axes = {'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}

5. 导出模型为ONNX格式

使用torch.onnx.export函数将模型导出为ONNX格式,并指定动态维度:

# 导出模型为ONNX格式
onnx_path = "resnet50_dynamic.onnx"
torch.onnx.export(
    model,                  # 要导出的模型
    dummy_input,            # 示例输入
    onnx_path,              # 导出的ONNX文件路径
    export_params=True,     # 导出模型参数
    opset_version=11,       # ONNX opset版本
    do_constant_folding=True,  # 是否进行常量折叠优化
    input_names=['input'],  # 输入名称
    output_names=['output'],# 输出名称
    dynamic_axes=dynamic_axes  # 动态维度
)

print(f"Model exported to {onnx_path}")

6. 验证导出的ONNX模型

你可以使用ONNX Runtime或其他工具来验证导出的ONNX模型是否正确。

import onnx

# 加载导出的ONNX模型
onnx_model = onnx.load(onnx_path)

# 检查模型是否有效
onnx.checker.check_model(onnx_model)

print("ONNX model is valid!")

7. 使用ONNX模型进行推理

你可以使用ONNX Runtime来加载并运行导出的ONNX模型:

import onnxruntime as ort

# 创建ONNX Runtime会话
ort_session = ort.InferenceSession(onnx_path)

# 准备输入数据
input_name = ort_session.get_inputs()[0].name
output_name = ort_session.get_outputs()[0].name

# 使用动态batch size的输入进行推理
dynamic_input = torch.randn(4, 3, 224, 224).numpy()  # batch size为4
outputs = ort_session.run([output_name], {input_name: dynamic_input})

print(outputs)

总结

通过以上步骤,你可以将PyTorch中的ResNet50模型导出为支持动态batch size的ONNX格式。导出的模型可以在不同的batch size下进行推理,适用于各种部署场景。