From 3ce186622490e71792facd47df1785aeb3b44f63 Mon Sep 17 00:00:00 2001 From: "wanggui.hwg" Date: Wed, 7 Dec 2022 17:23:52 +0800 Subject: [PATCH] [to #42322933] Fix bugs for UniTE Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11011725 --- .../nlp/translation_evaluation_pipeline.py | 8 ++++---- tests/pipelines/test_translation_evaluation.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/modelscope/pipelines/nlp/translation_evaluation_pipeline.py b/modelscope/pipelines/nlp/translation_evaluation_pipeline.py index bc942342..3ec3ee7d 100644 --- a/modelscope/pipelines/nlp/translation_evaluation_pipeline.py +++ b/modelscope/pipelines/nlp/translation_evaluation_pipeline.py @@ -77,14 +77,14 @@ class TranslationEvaluationPipeline(Pipeline): self.preprocessor.eval_mode = eval_mode return - def __call__(self, input_dict: Dict[str, Union[str, List[str]]], **kwargs): + def __call__(self, input: Dict[str, Union[str, List[str]]], **kwargs): r"""Implementation of __call__ function. Args: - input_dict: The formatted dict containing the inputted sentences. + input: The formatted dict containing the inputted sentences. An example of the formatted dict: ``` - input_dict = { + input = { 'hyp': [ 'This is a sentence.', 'This is another sentence.', @@ -100,7 +100,7 @@ class TranslationEvaluationPipeline(Pipeline): } ``` """ - return super().__call__(input=input_dict, **kwargs) + return super().__call__(input=input, **kwargs) def forward(self, input_ids: List[torch.Tensor]) -> Dict[str, torch.Tensor]: diff --git a/tests/pipelines/test_translation_evaluation.py b/tests/pipelines/test_translation_evaluation.py index 0c73edca..76720ac0 100644 --- a/tests/pipelines/test_translation_evaluation.py +++ b/tests/pipelines/test_translation_evaluation.py @@ -18,7 +18,7 @@ class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck): @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name_for_unite_large(self): - input_dict = { + input = { 'hyp': [ 'This is a sentence.', 'This is another sentence.', @@ -34,17 +34,17 @@ class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck): } pipeline_ins = pipeline(self.task, model=self.model_id_large) - print(pipeline_ins(input_dict)) + print(pipeline_ins(input=input)) pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) - print(pipeline_ins(input_dict)) + print(pipeline_ins(input=input)) pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) - print(pipeline_ins(input_dict)) + print(pipeline_ins(input=input)) @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') def test_run_with_model_name_for_unite_base(self): - input_dict = { + input = { 'hyp': [ 'This is a sentence.', 'This is another sentence.', @@ -60,13 +60,13 @@ class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck): } pipeline_ins = pipeline(self.task, model=self.model_id_base) - print(pipeline_ins(input_dict)) + print(pipeline_ins(input=input)) pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC) - print(pipeline_ins(input_dict)) + print(pipeline_ins(input=input)) pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF) - print(pipeline_ins(input_dict)) + print(pipeline_ins(input=input)) if __name__ == '__main__':