diff --git a/modelscope/models/nlp/masked_language_model.py b/modelscope/models/nlp/masked_language_model.py index 928bda7b..6a8c6626 100644 --- a/modelscope/models/nlp/masked_language_model.py +++ b/modelscope/models/nlp/masked_language_model.py @@ -31,20 +31,20 @@ class MaskedLMModelBase(Model): return self.model.config return None - def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: + def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]: """return the result by the model Args: - inputs (Dict[str, Any]): the preprocessed data + input (Dict[str, Any]): the preprocessed data Returns: Dict[str, np.ndarray]: results """ rst = self.model( - input_ids=inputs['input_ids'], - attention_mask=inputs['attention_mask'], - token_type_ids=inputs['token_type_ids']) - return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} + input_ids=input['input_ids'], + attention_mask=input['attention_mask'], + token_type_ids=input['token_type_ids']) + return {'logits': rst['logits'], 'input_ids': input['input_ids']} @MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)