插件窝 干货文章 如何将PyTorch模型导出为ONNX并使用ONNX Runtime进行推理?

如何将PyTorch模型导出为ONNX并使用ONNX Runtime进行推理?

ONNX 模型 导出 input 196    来源:    2025-03-14

将PyTorch模型导出为ONNX格式并使用ONNX Runtime进行推理的过程可以分为以下几个步骤:

1. 安装必要的库

首先,确保你已经安装了PyTorch、ONNX和ONNX Runtime库。你可以使用以下命令安装这些库:

pip install torch onnx onnxruntime

2. 导出PyTorch模型为ONNX格式

假设你有一个训练好的PyTorch模型,你可以使用torch.onnx.export函数将其导出为ONNX格式。

import torch
import torchvision

# 假设我们有一个预训练的ResNet18模型
model = torchvision.models.resnet18(pretrained=True)
model.eval()  # 将模型设置为评估模式

# 创建一个示例输入张量
dummy_input = torch.randn(1, 3, 224, 224)

# 导出模型为ONNX格式
onnx_path = "resnet18.onnx"
torch.onnx.export(model,               # 要导出的模型
                  dummy_input,         # 示例输入
                  onnx_path,          # 导出的ONNX文件路径
                  export_params=True,  # 导出模型参数
                  opset_version=11,    # ONNX算子集版本
                  do_constant_folding=True,  # 是否进行常量折叠优化
                  input_names=['input'],     # 输入名称
                  output_names=['output'],   # 输出名称
                  dynamic_axes={'input': {0: 'batch_size'},  # 动态轴
                                'output': {0: 'batch_size'}})

3. 使用ONNX Runtime进行推理

导出ONNX模型后,你可以使用ONNX Runtime加载模型并进行推理。

import onnxruntime as ort
import numpy as np

# 加载ONNX模型
onnx_path = "resnet18.onnx"
session = ort.InferenceSession(onnx_path)

# 准备输入数据
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 进行推理
outputs = session.run([output_name], {input_name: input_data})

# 输出结果
print(outputs)

4. 验证ONNX模型的正确性

在导出ONNX模型后,建议验证导出的模型是否与原始PyTorch模型的行为一致。

# 使用PyTorch模型进行推理
with torch.no_grad():
    torch_output = model(torch.from_numpy(input_data)).numpy()

# 比较ONNX Runtime和PyTorch的输出
np.testing.assert_allclose(torch_output, outputs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNX Runtime, and the result looks good!")

5. 处理动态输入形状(可选)

如果你的模型需要处理不同大小的输入(例如不同的批量大小),你可以在导出ONNX模型时指定动态轴。

torch.onnx.export(model, 
                  dummy_input, 
                  onnx_path, 
                  export_params=True, 
                  opset_version=11, 
                  do_constant_folding=True, 
                  input_names=['input'], 
                  output_names=['output'], 
                  dynamic_axes={'input': {0: 'batch_size'}, 
                                'output': {0: 'batch_size'}})

6. 使用ONNX Runtime进行推理时的优化

ONNX Runtime提供了多种优化选项,例如使用不同的执行提供者(如CPU、CUDA、TensorRT等)来加速推理。

# 使用CUDA执行提供者
session = ort.InferenceSession(onnx_path, providers=['CUDAExecutionProvider'])

# 使用TensorRT执行提供者
session = ort.InferenceSession(onnx_path, providers=['TensorrtExecutionProvider'])

总结

通过以上步骤,你可以将PyTorch模型导出为ONNX格式,并使用ONNX Runtime进行推理。ONNX Runtime提供了跨平台、高性能的推理能力,适用于多种硬件和操作系统环境。