将PyTorch模型导出为ONNX格式并使用ONNX Runtime进行推理的过程可以分为以下几个步骤:
首先,你需要将训练好的PyTorch模型导出为ONNX格式。以下是一个简单的示例代码:
import torch
import torch.onnx
# 假设你已经有一个训练好的模型
class SimpleModel(torch.nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = torch.nn.Linear(10, 1)
def forward(self, x):
return self.fc(x)
# 实例化模型
model = SimpleModel()
# 设置模型为评估模式
model.eval()
# 创建一个示例输入
dummy_input = torch.randn(1, 10)
# 导出模型为ONNX格式
onnx_model_path = "model.onnx"
torch.onnx.export(model, # 要导出的模型
dummy_input, # 示例输入
onnx_model_path, # 导出的ONNX文件路径
export_params=True, # 导出模型参数
opset_version=11, # ONNX算子集版本
do_constant_folding=True, # 是否进行常量折叠优化
input_names=['input'], # 输入名称
output_names=['output'], # 输出名称
dynamic_axes={'input': {0: 'batch_size'}, # 动态轴
'output': {0: 'batch_size'}})
导出ONNX模型后,你可以使用ONNX Runtime进行推理。以下是一个简单的示例代码:
import onnxruntime as ort
import numpy as np
# 加载ONNX模型
onnx_model_path = "model.onnx"
session = ort.InferenceSession(onnx_model_path)
# 准备输入数据
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
# 创建一个示例输入
input_data = np.random.randn(1, 10).astype(np.float32)
# 进行推理
result = session.run([output_name], {input_name: input_data})
# 输出推理结果
print("Inference result:", result)
在导出ONNX模型后,建议验证导出的模型是否与原始PyTorch模型一致。你可以通过比较两者的输出来验证:
# 使用PyTorch模型进行推理
with torch.no_grad():
torch_output = model(torch.tensor(input_data))
# 使用ONNX Runtime进行推理
onnx_output = session.run([output_name], {input_name: input_data})[0]
# 比较两者的输出
print("PyTorch output:", torch_output.numpy())
print("ONNX output:", onnx_output)
print("Are outputs close?", np.allclose(torch_output.numpy(), onnx_output, atol=1e-6))
如果你的模型需要处理动态输入形状(例如不同的批量大小),你可以在导出ONNX模型时指定动态轴(如步骤1所示)。在推理时,ONNX Runtime会自动处理不同形状的输入。
ONNX Runtime提供了一些工具来优化ONNX模型,例如使用onnxruntime.tools.optimize_model
函数:
from onnxruntime.tools import optimize_model
# 优化模型
optimized_model = optimize_model(onnx_model_path)
# 保存优化后的模型
optimized_model.save_model("optimized_model.onnx")
通过以上步骤,你可以将PyTorch模型导出为ONNX格式,并使用ONNX Runtime进行推理。ONNX Runtime是一个高效的推理引擎,支持多种硬件加速器(如CPU、GPU、TPU等),并且可以在多种平台上运行。