Browse Source

revert a mis modification

master
雨泓 3 years ago
parent
commit
2eb633ec93
1 changed files with 6 additions and 6 deletions
  1. +6
    -6
      modelscope/models/nlp/masked_language_model.py

+ 6
- 6
modelscope/models/nlp/masked_language_model.py View File

@@ -31,20 +31,20 @@ class MaskedLMModelBase(Model):
return self.model.config
return None

def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]:
"""return the result by the model

Args:
inputs (Dict[str, Any]): the preprocessed data
input (Dict[str, Any]): the preprocessed data

Returns:
Dict[str, np.ndarray]: results
"""
rst = self.model(
input_ids=inputs['input_ids'],
attention_mask=inputs['attention_mask'],
token_type_ids=inputs['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']}
input_ids=input['input_ids'],
attention_mask=input['attention_mask'],
token_type_ids=input['token_type_ids'])
return {'logits': rst['logits'], 'input_ids': input['input_ids']}


@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert)


Loading…
Cancel
Save