插件窝 干货文章 如何用ONNX Runtime正确加载和推理PyTorch导出的ONNX模型?

如何用ONNX Runtime正确加载和推理PyTorch导出的ONNX模型?

模型 ONNX 输入 导出 812    来源:    2025-03-14

要使用ONNX Runtime正确加载和推理PyTorch导出的ONNX模型,你需要按照以下步骤操作:

1. 导出PyTorch模型为ONNX格式

首先,确保你已经将PyTorch模型导出为ONNX格式。以下是一个简单的示例:

import torch
import torch.onnx

# 假设你有一个简单的PyTorch模型
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(10, 1)

    def forward(self, x):
        return self.linear(x)

# 实例化模型
model = SimpleModel()

# 创建一个示例输入
dummy_input = torch.randn(1, 10)

# 导出模型为ONNX格式
torch.onnx.export(model, dummy_input, "simple_model.onnx", verbose=True)

2. 安装ONNX Runtime

确保你已经安装了ONNX Runtime。你可以通过以下命令安装:

pip install onnxruntime

3. 使用ONNX Runtime加载和推理ONNX模型

接下来,你可以使用ONNX Runtime加载导出的ONNX模型并进行推理。

import onnxruntime
import numpy as np

# 加载ONNX模型
session = onnxruntime.InferenceSession("simple_model.onnx")

# 获取输入名称
input_name = session.get_inputs()[0].name

# 创建一个示例输入(与导出时的输入形状一致)
input_data = np.random.randn(1, 10).astype(np.float32)

# 进行推理
outputs = session.run(None, {input_name: input_data})

# 输出结果
print(outputs)

4. 解释代码

  • onnxruntime.InferenceSession: 用于加载ONNX模型并创建一个推理会话。
  • session.get_inputs(): 获取模型的输入信息,包括输入的名称和形状。
  • session.run(): 执行推理,传入输入数据并获取输出结果。

5. 注意事项

  • 输入数据格式: 确保输入数据的形状和数据类型与模型期望的输入一致。
  • 输出结果: session.run() 返回一个列表,包含所有输出张量。根据模型的输出结构,你可能需要处理这些输出。

6. 进一步优化

  • GPU加速: 如果你有GPU,可以使用ONNX Runtime的GPU版本进行加速。安装时使用 pip install onnxruntime-gpu,并在创建 InferenceSession 时指定 providers=['CUDAExecutionProvider']
  • 动态输入形状: 如果模型支持动态输入形状,可以在导出ONNX模型时指定动态轴。
torch.onnx.export(model, dummy_input, "simple_model.onnx", verbose=True, dynamic_axes={'input': {0: 'batch_size'}})

通过这些步骤,你可以成功加载和推理PyTorch导出的ONNX模型。