|
- # Copyright (c) OpenMMLab. All rights reserved.
- import os
- import os.path as osp
- import warnings
-
- import numpy as np
- import onnx
- import onnxruntime as ort
- import torch
- import torch.nn as nn
-
- ort_custom_op_path = ''
- try:
- from mmcv.ops import get_onnxruntime_op_path
- ort_custom_op_path = get_onnxruntime_op_path()
- except (ImportError, ModuleNotFoundError):
- warnings.warn('If input model has custom op from mmcv, \
- you may have to build mmcv with ONNXRuntime from source.')
-
-
- class WrapFunction(nn.Module):
- """Wrap the function to be tested for torch.onnx.export tracking."""
-
- def __init__(self, wrapped_function):
- super(WrapFunction, self).__init__()
- self.wrapped_function = wrapped_function
-
- def forward(self, *args, **kwargs):
- return self.wrapped_function(*args, **kwargs)
-
-
- def ort_validate(model, feats, onnx_io='tmp.onnx'):
- """Validate the output of the onnxruntime backend is the same as the output
- generated by torch.
-
- Args:
- model (nn.Module | function): the function of model or model
- to be verified.
- feats (tuple(list(torch.Tensor)) | list(torch.Tensor) | torch.Tensor):
- the input of model.
- onnx_io (str): the name of onnx output file.
- """
- # if model is not an instance of nn.Module, then it is a normal
- # function and it should be wrapped.
- if isinstance(model, nn.Module):
- wrap_model = model
- else:
- wrap_model = WrapFunction(model)
- wrap_model.cpu().eval()
- with torch.no_grad():
- torch.onnx.export(
- wrap_model,
- feats,
- onnx_io,
- export_params=True,
- keep_initializers_as_inputs=True,
- do_constant_folding=True,
- verbose=False,
- opset_version=11)
-
- if isinstance(feats, tuple):
- ort_feats = []
- for feat in feats:
- ort_feats += feat
- else:
- ort_feats = feats
- # default model name: tmp.onnx
- onnx_outputs = get_ort_model_output(ort_feats)
-
- # remove temp file
- if osp.exists(onnx_io):
- os.remove(onnx_io)
-
- if isinstance(feats, tuple):
- torch_outputs = convert_result_list(wrap_model.forward(*feats))
- else:
- torch_outputs = convert_result_list(wrap_model.forward(feats))
- torch_outputs = [
- torch_output.detach().numpy() for torch_output in torch_outputs
- ]
-
- # match torch_outputs and onnx_outputs
- for i in range(len(onnx_outputs)):
- np.testing.assert_allclose(
- torch_outputs[i], onnx_outputs[i], rtol=1e-03, atol=1e-05)
-
-
- def get_ort_model_output(feat, onnx_io='tmp.onnx'):
- """Run the model in onnxruntime env.
-
- Args:
- feat (list[Tensor]): A list of tensors from torch.rand,
- each is a 4D-tensor.
-
- Returns:
- list[np.array]: onnxruntime infer result, each is a np.array
- """
-
- onnx_model = onnx.load(onnx_io)
- onnx.checker.check_model(onnx_model)
-
- session_options = ort.SessionOptions()
- # register custom op for onnxruntime
- if osp.exists(ort_custom_op_path):
- session_options.register_custom_ops_library(ort_custom_op_path)
- sess = ort.InferenceSession(onnx_io, session_options)
- if isinstance(feat, torch.Tensor):
- onnx_outputs = sess.run(None,
- {sess.get_inputs()[0].name: feat.numpy()})
- else:
- onnx_outputs = sess.run(None, {
- sess.get_inputs()[i].name: feat[i].numpy()
- for i in range(len(feat))
- })
- return onnx_outputs
-
-
- def convert_result_list(outputs):
- """Convert the torch forward outputs containing tuple or list to a list
- only containing torch.Tensor.
-
- Args:
- output (list(Tensor) | tuple(list(Tensor) | ...): the outputs
- in torch env, maybe containing nested structures such as list
- or tuple.
-
- Returns:
- list(Tensor): a list only containing torch.Tensor
- """
- # recursive end condition
- if isinstance(outputs, torch.Tensor):
- return [outputs]
-
- ret = []
- for sub in outputs:
- ret += convert_result_list(sub)
- return ret
|