插件窝 干货文章 如何用ONNX Runtime正确加载和推理PyTorch导出的ONNX模型?

如何用ONNX Runtime正确加载和推理PyTorch导出的ONNX模型?

ONNX 模型 输入 导出 467    来源:    2025-03-13

要使用ONNX Runtime正确加载和推理PyTorch导出的ONNX模型,你需要按照以下步骤进行操作:

1. 导出PyTorch模型为ONNX格式

首先,你需要将PyTorch模型导出为ONNX格式。假设你已经有一个训练好的PyTorch模型 model,并且有一个示例输入 dummy_input,你可以使用以下代码将其导出为ONNX格式:

import torch
import torch.onnx

# 假设你已经有一个训练好的PyTorch模型
model = ...  # 你的PyTorch模型

# 创建一个示例输入
dummy_input = torch.randn(1, 3, 224, 224)  # 假设输入是1张3通道的224x224图像

# 导出模型为ONNX格式
onnx_model_path = "model.onnx"
torch.onnx.export(model, dummy_input, onnx_model_path, verbose=True)

2. 使用ONNX Runtime加载ONNX模型

接下来,你可以使用ONNX Runtime加载导出的ONNX模型。

import onnxruntime as ort

# 加载ONNX模型
onnx_model_path = "model.onnx"
session = ort.InferenceSession(onnx_model_path)

# 获取输入和输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

3. 准备输入数据

确保你的输入数据格式与模型期望的格式一致。通常,输入数据需要是NumPy数组。

import numpy as np

# 假设输入是一个1张3通道的224x224图像
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

4. 进行推理

使用ONNX Runtime进行推理。

# 进行推理
outputs = session.run([output_name], {input_name: input_data})

# 输出结果
print(outputs)

5. 处理输出结果

根据你的模型输出,你可能需要对输出结果进行进一步处理。例如,如果模型输出是分类概率,你可能需要找到最大概率对应的类别。

# 假设输出是一个分类概率向量
output = outputs[0]
predicted_class = np.argmax(output)
print(f"Predicted class: {predicted_class}")

6. 完整代码示例

以下是一个完整的代码示例:

import torch
import torch.onnx
import onnxruntime as ort
import numpy as np

# 假设你已经有一个训练好的PyTorch模型
model = ...  # 你的PyTorch模型

# 创建一个示例输入
dummy_input = torch.randn(1, 3, 224, 224)  # 假设输入是1张3通道的224x224图像

# 导出模型为ONNX格式
onnx_model_path = "model.onnx"
torch.onnx.export(model, dummy_input, onnx_model_path, verbose=True)

# 加载ONNX模型
session = ort.InferenceSession(onnx_model_path)

# 获取输入和输出名称
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

# 准备输入数据
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 进行推理
outputs = session.run([output_name], {input_name: input_data})

# 输出结果
output = outputs[0]
predicted_class = np.argmax(output)
print(f"Predicted class: {predicted_class}")

注意事项

  • 输入数据格式:确保输入数据的形状和类型与模型期望的格式一致。
  • ONNX模型版本:确保ONNX Runtime的版本与导出的ONNX模型版本兼容。
  • 硬件加速:如果你有GPU,可以配置ONNX Runtime使用GPU进行推理,以提高性能。

通过以上步骤,你可以成功加载和推理PyTorch导出的ONNX模型。