将PyTorch模型导出为ONNX格式并使用ONNX Runtime进行推理的过程可以分为以下几个步骤:
首先,确保你已经安装了PyTorch、ONNX和ONNX Runtime库。你可以使用以下命令安装这些库:
pip install torch onnx onnxruntime
假设你有一个训练好的PyTorch模型,你可以使用torch.onnx.export
函数将其导出为ONNX格式。
import torch
import torchvision
# 假设我们有一个预训练的ResNet18模型
model = torchvision.models.resnet18(pretrained=True)
model.eval() # 将模型设置为评估模式
# 创建一个示例输入张量
dummy_input = torch.randn(1, 3, 224, 224)
# 导出模型为ONNX格式
onnx_path = "resnet18.onnx"
torch.onnx.export(model, # 要导出的模型
dummy_input, # 示例输入
onnx_path, # 导出的ONNX文件路径
export_params=True, # 导出模型参数
opset_version=11, # ONNX算子集版本
do_constant_folding=True, # 是否进行常量折叠优化
input_names=['input'], # 输入名称
output_names=['output'], # 输出名称
dynamic_axes={'input': {0: 'batch_size'}, # 动态轴
'output': {0: 'batch_size'}})
导出ONNX模型后,你可以使用ONNX Runtime加载模型并进行推理。
import onnxruntime as ort
import numpy as np
# 加载ONNX模型
onnx_path = "resnet18.onnx"
session = ort.InferenceSession(onnx_path)
# 准备输入数据
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)
# 进行推理
outputs = session.run([output_name], {input_name: input_data})
# 输出结果
print(outputs)
在导出ONNX模型后,建议验证导出的模型是否与原始PyTorch模型的行为一致。
# 使用PyTorch模型进行推理
with torch.no_grad():
torch_output = model(torch.from_numpy(input_data)).numpy()
# 比较ONNX Runtime和PyTorch的输出
np.testing.assert_allclose(torch_output, outputs[0], rtol=1e-03, atol=1e-05)
print("Exported model has been tested with ONNX Runtime, and the result looks good!")
如果你的模型需要处理不同大小的输入(例如不同的批量大小),你可以在导出ONNX模型时指定动态轴。
torch.onnx.export(model,
dummy_input,
onnx_path,
export_params=True,
opset_version=11,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}})
ONNX Runtime提供了多种优化选项,例如使用不同的执行提供者(如CPU、CUDA、TensorRT等)来加速推理。
# 使用CUDA执行提供者
session = ort.InferenceSession(onnx_path, providers=['CUDAExecutionProvider'])
# 使用TensorRT执行提供者
session = ort.InferenceSession(onnx_path, providers=['TensorrtExecutionProvider'])
通过以上步骤,你可以将PyTorch模型导出为ONNX格式,并使用ONNX Runtime进行推理。ONNX Runtime提供了跨平台、高性能的推理能力,适用于多种硬件和操作系统环境。