diff --git a/modelscope/exporters/torch_model_exporter.py b/modelscope/exporters/torch_model_exporter.py index 94ef277a..7bf6c0c0 100644 --- a/modelscope/exporters/torch_model_exporter.py +++ b/modelscope/exporters/torch_model_exporter.py @@ -7,9 +7,9 @@ from typing import Any, Dict, Mapping import torch from torch import nn from torch.onnx import export as onnx_export -from torch.onnx.utils import _decide_input_format from modelscope.models import TorchModel +from modelscope.outputs import ModelOutputBase from modelscope.pipelines.base import collate_fn from modelscope.utils.constant import ModelFile from modelscope.utils.logger import get_logger @@ -102,6 +102,53 @@ class TorchModelExporter(Exporter): """ return None + @staticmethod + def _decide_input_format(model, args): + import inspect + + def _signature(model) -> inspect.Signature: + should_be_callable = getattr(model, 'forward', model) + if callable(should_be_callable): + return inspect.signature(should_be_callable) + raise ValueError('model has no forward method and is not callable') + + try: + sig = _signature(model) + except ValueError as e: + logger.warn('%s, skipping _decide_input_format' % e) + return args + try: + ordered_list_keys = list(sig.parameters.keys()) + if ordered_list_keys[0] == 'self': + ordered_list_keys = ordered_list_keys[1:] + args_dict: Dict = {} + if isinstance(args, list): + args_list = args + elif isinstance(args, tuple): + args_list = list(args) + else: + args_list = [args] + if isinstance(args_list[-1], dict): + args_dict = args_list[-1] + args_list = args_list[:-1] + n_nonkeyword = len(args_list) + for optional_arg in ordered_list_keys[n_nonkeyword:]: + if optional_arg in args_dict: + args_list.append(args_dict[optional_arg]) + # Check if this arg has a default value + else: + param = sig.parameters[optional_arg] + if param.default != param.empty: + args_list.append(param.default) + args = args_list if isinstance(args, list) else tuple(args_list) + # Cases of models with no input args + except IndexError: + logger.warn('No input args, skipping _decide_input_format') + except Exception as e: + logger.warn('Skipping _decide_input_format\n {}'.format(e.args[0])) + + return args + def _torch_export_onnx(self, model: nn.Module, output: str, @@ -179,16 +226,21 @@ class TorchModelExporter(Exporter): with torch.no_grad(): model.eval() outputs_origin = model.forward( - *_decide_input_format(model, dummy_inputs)) - if isinstance(outputs_origin, Mapping): - outputs_origin = numpify_tensor_nested( - list(outputs_origin.values())) + *self._decide_input_format(model, dummy_inputs)) + if isinstance(outputs_origin, (Mapping, ModelOutputBase)): + outputs_origin = list( + numpify_tensor_nested(outputs_origin).values()) elif isinstance(outputs_origin, (tuple, list)): - outputs_origin = numpify_tensor_nested(outputs_origin) + outputs_origin = list(numpify_tensor_nested(outputs_origin)) outputs = ort_session.run( onnx_outputs, numpify_tensor_nested(dummy_inputs), ) + outputs = numpify_tensor_nested(outputs) + if isinstance(outputs, dict): + outputs = list(outputs.values()) + elif isinstance(outputs, tuple): + outputs = list(outputs) tols = {} if rtol is not None: @@ -232,12 +284,26 @@ class TorchModelExporter(Exporter): 'Model property dummy_inputs must be set.') dummy_inputs = collate_fn(dummy_inputs, device) if isinstance(dummy_inputs, Mapping): - dummy_inputs = tuple(dummy_inputs.values()) + dummy_inputs = self._decide_input_format(model, dummy_inputs) + dummy_inputs_filter = [] + for _input in dummy_inputs: + if _input is not None: + dummy_inputs_filter.append(_input) + else: + break + + if len(dummy_inputs) != len(dummy_inputs_filter): + logger.warn( + f'Dummy inputs is not continuous in the forward method, ' + f'origin length: {len(dummy_inputs)}, ' + f'the length after filtering: {len(dummy_inputs_filter)}') + dummy_inputs = dummy_inputs_filter + with torch.no_grad(): model.eval() with replace_call(): traced_model = torch.jit.trace( - model, dummy_inputs, strict=strict) + model, tuple(dummy_inputs), strict=strict) torch.jit.save(traced_model, output) if validation: @@ -249,6 +315,10 @@ class TorchModelExporter(Exporter): outputs = numpify_tensor_nested(outputs) outputs_origin = model.forward(*dummy_inputs) outputs_origin = numpify_tensor_nested(outputs_origin) + if isinstance(outputs, dict): + outputs = list(outputs.values()) + if isinstance(outputs_origin, dict): + outputs_origin = list(outputs_origin.values()) tols = {} if rtol is not None: tols['rtol'] = rtol diff --git a/modelscope/models/base/base_model.py b/modelscope/models/base/base_model.py index e01d1f05..1ca7e030 100644 --- a/modelscope/models/base/base_model.py +++ b/modelscope/models/base/base_model.py @@ -161,5 +161,12 @@ class Model(ABC): assert config is not None, 'Cannot save the model because the model config is empty.' if isinstance(config, Config): config = config.to_dict() + if 'preprocessor' in config and config['preprocessor'] is not None: + if 'mode' in config['preprocessor']: + config['preprocessor']['mode'] = 'inference' + elif 'val' in config['preprocessor'] and 'mode' in config[ + 'preprocessor']['val']: + config['preprocessor']['val']['mode'] = 'inference' + save_pretrained(self, target_folder, save_checkpoint_names, save_function, config, **kwargs) diff --git a/modelscope/models/nlp/bert/text_ranking.py b/modelscope/models/nlp/bert/text_ranking.py index d6bbf277..b5ac8d7e 100644 --- a/modelscope/models/nlp/bert/text_ranking.py +++ b/modelscope/models/nlp/bert/text_ranking.py @@ -36,6 +36,7 @@ class BertForTextRanking(BertForSequenceClassification): output_attentions=None, output_hidden_states=None, return_dict=None, + *args, **kwargs) -> AttentionTextClassificationModelOutput: outputs = self.base_model.forward( input_ids=input_ids, diff --git a/modelscope/models/nlp/structbert/text_classification.py b/modelscope/models/nlp/structbert/text_classification.py index 044cf8d0..8797beb3 100644 --- a/modelscope/models/nlp/structbert/text_classification.py +++ b/modelscope/models/nlp/structbert/text_classification.py @@ -109,6 +109,7 @@ class SbertForSequenceClassification(SbertPreTrainedModel): output_attentions=None, output_hidden_states=None, return_dict=None, + *args, **kwargs): r""" Args: diff --git a/modelscope/trainers/trainer.py b/modelscope/trainers/trainer.py index aaf24cfa..7478d8e4 100644 --- a/modelscope/trainers/trainer.py +++ b/modelscope/trainers/trainer.py @@ -672,7 +672,7 @@ class EpochBasedTrainer(BaseTrainer): self.model, cfg=cfg, default_args=default_args) except KeyError as e: self.logger.error( - f'Build optimizer error, the optimizer {cfg} is native torch optimizer, ' + f'Build optimizer error, the optimizer {cfg} is a torch native component, ' f'please check if your torch with version: {torch.__version__} matches the config.' ) raise e @@ -682,7 +682,7 @@ class EpochBasedTrainer(BaseTrainer): return build_lr_scheduler(cfg=cfg, default_args=default_args) except KeyError as e: self.logger.error( - f'Build lr_scheduler error, the lr_scheduler {cfg} is native torch lr_scheduler, ' + f'Build lr_scheduler error, the lr_scheduler {cfg} is a torch native component, ' f'please check if your torch with version: {torch.__version__} matches the config.' ) raise e diff --git a/tests/export/test_export_sbert_sequence_classification.py b/tests/export/test_export_sbert_sequence_classification.py index 0e4f8349..7533732d 100644 --- a/tests/export/test_export_sbert_sequence_classification.py +++ b/tests/export/test_export_sbert_sequence_classification.py @@ -23,7 +23,7 @@ class TestExportSbertSequenceClassification(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - @unittest.skip + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_export_sbert_sequence_classification(self): model = Model.from_pretrained(self.model_id) print( diff --git a/tests/trainers/test_finetune_sequence_classification.py b/tests/trainers/test_finetune_sequence_classification.py index 02dd9d2f..061d37d3 100644 --- a/tests/trainers/test_finetune_sequence_classification.py +++ b/tests/trainers/test_finetune_sequence_classification.py @@ -38,7 +38,7 @@ class TestFinetuneSequenceClassification(unittest.TestCase): shutil.rmtree(self.tmp_dir) super().tearDown() - @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + @unittest.skip def test_trainer_cfg_class(self): dataset = MsDataset.load('clue', subset_name='tnews') train_dataset = dataset['train'] diff --git a/tests/trainers/test_trainer_with_nlp.py b/tests/trainers/test_trainer_with_nlp.py index d9d56b60..f1d9e414 100644 --- a/tests/trainers/test_trainer_with_nlp.py +++ b/tests/trainers/test_trainer_with_nlp.py @@ -72,7 +72,7 @@ class TestTrainerWithNlp(unittest.TestCase): output_dir = os.path.join(self.tmp_dir, ModelFile.TRAIN_OUTPUT_DIR) pipeline_sentence_similarity(output_dir) - @unittest.skipUnless(test_level() >= 3, 'skip test in current test level') + @unittest.skip def test_trainer_with_backbone_head(self): model_id = 'damo/nlp_structbert_sentiment-classification_chinese-base' kwargs = dict(