|
|
@@ -7,9 +7,9 @@ from typing import Any, Dict, Mapping |
|
|
|
import torch |
|
|
|
from torch import nn |
|
|
|
from torch.onnx import export as onnx_export |
|
|
|
from torch.onnx.utils import _decide_input_format |
|
|
|
|
|
|
|
from modelscope.models import TorchModel |
|
|
|
from modelscope.outputs import ModelOutputBase |
|
|
|
from modelscope.pipelines.base import collate_fn |
|
|
|
from modelscope.utils.constant import ModelFile |
|
|
|
from modelscope.utils.logger import get_logger |
|
|
@@ -102,6 +102,53 @@ class TorchModelExporter(Exporter): |
|
|
|
""" |
|
|
|
return None |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def _decide_input_format(model, args): |
|
|
|
import inspect |
|
|
|
|
|
|
|
def _signature(model) -> inspect.Signature: |
|
|
|
should_be_callable = getattr(model, 'forward', model) |
|
|
|
if callable(should_be_callable): |
|
|
|
return inspect.signature(should_be_callable) |
|
|
|
raise ValueError('model has no forward method and is not callable') |
|
|
|
|
|
|
|
try: |
|
|
|
sig = _signature(model) |
|
|
|
except ValueError as e: |
|
|
|
logger.warn('%s, skipping _decide_input_format' % e) |
|
|
|
return args |
|
|
|
try: |
|
|
|
ordered_list_keys = list(sig.parameters.keys()) |
|
|
|
if ordered_list_keys[0] == 'self': |
|
|
|
ordered_list_keys = ordered_list_keys[1:] |
|
|
|
args_dict: Dict = {} |
|
|
|
if isinstance(args, list): |
|
|
|
args_list = args |
|
|
|
elif isinstance(args, tuple): |
|
|
|
args_list = list(args) |
|
|
|
else: |
|
|
|
args_list = [args] |
|
|
|
if isinstance(args_list[-1], dict): |
|
|
|
args_dict = args_list[-1] |
|
|
|
args_list = args_list[:-1] |
|
|
|
n_nonkeyword = len(args_list) |
|
|
|
for optional_arg in ordered_list_keys[n_nonkeyword:]: |
|
|
|
if optional_arg in args_dict: |
|
|
|
args_list.append(args_dict[optional_arg]) |
|
|
|
# Check if this arg has a default value |
|
|
|
else: |
|
|
|
param = sig.parameters[optional_arg] |
|
|
|
if param.default != param.empty: |
|
|
|
args_list.append(param.default) |
|
|
|
args = args_list if isinstance(args, list) else tuple(args_list) |
|
|
|
# Cases of models with no input args |
|
|
|
except IndexError: |
|
|
|
logger.warn('No input args, skipping _decide_input_format') |
|
|
|
except Exception as e: |
|
|
|
logger.warn('Skipping _decide_input_format\n {}'.format(e.args[0])) |
|
|
|
|
|
|
|
return args |
|
|
|
|
|
|
|
def _torch_export_onnx(self, |
|
|
|
model: nn.Module, |
|
|
|
output: str, |
|
|
@@ -179,16 +226,21 @@ class TorchModelExporter(Exporter): |
|
|
|
with torch.no_grad(): |
|
|
|
model.eval() |
|
|
|
outputs_origin = model.forward( |
|
|
|
*_decide_input_format(model, dummy_inputs)) |
|
|
|
if isinstance(outputs_origin, Mapping): |
|
|
|
outputs_origin = numpify_tensor_nested( |
|
|
|
list(outputs_origin.values())) |
|
|
|
*self._decide_input_format(model, dummy_inputs)) |
|
|
|
if isinstance(outputs_origin, (Mapping, ModelOutputBase)): |
|
|
|
outputs_origin = list( |
|
|
|
numpify_tensor_nested(outputs_origin).values()) |
|
|
|
elif isinstance(outputs_origin, (tuple, list)): |
|
|
|
outputs_origin = numpify_tensor_nested(outputs_origin) |
|
|
|
outputs_origin = list(numpify_tensor_nested(outputs_origin)) |
|
|
|
outputs = ort_session.run( |
|
|
|
onnx_outputs, |
|
|
|
numpify_tensor_nested(dummy_inputs), |
|
|
|
) |
|
|
|
outputs = numpify_tensor_nested(outputs) |
|
|
|
if isinstance(outputs, dict): |
|
|
|
outputs = list(outputs.values()) |
|
|
|
elif isinstance(outputs, tuple): |
|
|
|
outputs = list(outputs) |
|
|
|
|
|
|
|
tols = {} |
|
|
|
if rtol is not None: |
|
|
@@ -232,12 +284,26 @@ class TorchModelExporter(Exporter): |
|
|
|
'Model property dummy_inputs must be set.') |
|
|
|
dummy_inputs = collate_fn(dummy_inputs, device) |
|
|
|
if isinstance(dummy_inputs, Mapping): |
|
|
|
dummy_inputs = tuple(dummy_inputs.values()) |
|
|
|
dummy_inputs = self._decide_input_format(model, dummy_inputs) |
|
|
|
dummy_inputs_filter = [] |
|
|
|
for _input in dummy_inputs: |
|
|
|
if _input is not None: |
|
|
|
dummy_inputs_filter.append(_input) |
|
|
|
else: |
|
|
|
break |
|
|
|
|
|
|
|
if len(dummy_inputs) != len(dummy_inputs_filter): |
|
|
|
logger.warn( |
|
|
|
f'Dummy inputs is not continuous in the forward method, ' |
|
|
|
f'origin length: {len(dummy_inputs)}, ' |
|
|
|
f'the length after filtering: {len(dummy_inputs_filter)}') |
|
|
|
dummy_inputs = dummy_inputs_filter |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
model.eval() |
|
|
|
with replace_call(): |
|
|
|
traced_model = torch.jit.trace( |
|
|
|
model, dummy_inputs, strict=strict) |
|
|
|
model, tuple(dummy_inputs), strict=strict) |
|
|
|
torch.jit.save(traced_model, output) |
|
|
|
|
|
|
|
if validation: |
|
|
@@ -249,6 +315,10 @@ class TorchModelExporter(Exporter): |
|
|
|
outputs = numpify_tensor_nested(outputs) |
|
|
|
outputs_origin = model.forward(*dummy_inputs) |
|
|
|
outputs_origin = numpify_tensor_nested(outputs_origin) |
|
|
|
if isinstance(outputs, dict): |
|
|
|
outputs = list(outputs.values()) |
|
|
|
if isinstance(outputs_origin, dict): |
|
|
|
outputs_origin = list(outputs_origin.values()) |
|
|
|
tols = {} |
|
|
|
if rtol is not None: |
|
|
|
tols['rtol'] = rtol |
|
|
|