wanggui.hwg wenmeng.zwm 2 years ago
parent
commit
3ce1866224
2 changed files with 12 additions and 12 deletions
  1. +4
    -4
      modelscope/pipelines/nlp/translation_evaluation_pipeline.py
  2. +8
    -8
      tests/pipelines/test_translation_evaluation.py

+ 4
- 4
modelscope/pipelines/nlp/translation_evaluation_pipeline.py View File

@@ -77,14 +77,14 @@ class TranslationEvaluationPipeline(Pipeline):
self.preprocessor.eval_mode = eval_mode
return

def __call__(self, input_dict: Dict[str, Union[str, List[str]]], **kwargs):
def __call__(self, input: Dict[str, Union[str, List[str]]], **kwargs):
r"""Implementation of __call__ function.

Args:
input_dict: The formatted dict containing the inputted sentences.
input: The formatted dict containing the inputted sentences.
An example of the formatted dict:
```
input_dict = {
input = {
'hyp': [
'This is a sentence.',
'This is another sentence.',
@@ -100,7 +100,7 @@ class TranslationEvaluationPipeline(Pipeline):
}
```
"""
return super().__call__(input=input_dict, **kwargs)
return super().__call__(input=input, **kwargs)

def forward(self,
input_ids: List[torch.Tensor]) -> Dict[str, torch.Tensor]:


+ 8
- 8
tests/pipelines/test_translation_evaluation.py View File

@@ -18,7 +18,7 @@ class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck):

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_for_unite_large(self):
input_dict = {
input = {
'hyp': [
'This is a sentence.',
'This is another sentence.',
@@ -34,17 +34,17 @@ class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck):
}

pipeline_ins = pipeline(self.task, model=self.model_id_large)
print(pipeline_ins(input_dict))
print(pipeline_ins(input=input))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC)
print(pipeline_ins(input_dict))
print(pipeline_ins(input=input))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF)
print(pipeline_ins(input_dict))
print(pipeline_ins(input=input))

@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_name_for_unite_base(self):
input_dict = {
input = {
'hyp': [
'This is a sentence.',
'This is another sentence.',
@@ -60,13 +60,13 @@ class TranslationEvaluationTest(unittest.TestCase, DemoCompatibilityCheck):
}

pipeline_ins = pipeline(self.task, model=self.model_id_base)
print(pipeline_ins(input_dict))
print(pipeline_ins(input=input))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.SRC)
print(pipeline_ins(input_dict))
print(pipeline_ins(input=input))

pipeline_ins.change_eval_mode(eval_mode=EvaluationMode.REF)
print(pipeline_ins(input_dict))
print(pipeline_ins(input=input))


if __name__ == '__main__':


Loading…
Cancel
Save