|
|
@@ -31,20 +31,20 @@ class MaskedLMModelBase(Model): |
|
|
|
return self.model.config |
|
|
|
return None |
|
|
|
|
|
|
|
def forward(self, inputs: Dict[str, Tensor]) -> Dict[str, np.ndarray]: |
|
|
|
def forward(self, input: Dict[str, Tensor]) -> Dict[str, np.ndarray]: |
|
|
|
"""return the result by the model |
|
|
|
|
|
|
|
Args: |
|
|
|
inputs (Dict[str, Any]): the preprocessed data |
|
|
|
input (Dict[str, Any]): the preprocessed data |
|
|
|
|
|
|
|
Returns: |
|
|
|
Dict[str, np.ndarray]: results |
|
|
|
""" |
|
|
|
rst = self.model( |
|
|
|
input_ids=inputs['input_ids'], |
|
|
|
attention_mask=inputs['attention_mask'], |
|
|
|
token_type_ids=inputs['token_type_ids']) |
|
|
|
return {'logits': rst['logits'], 'input_ids': inputs['input_ids']} |
|
|
|
input_ids=input['input_ids'], |
|
|
|
attention_mask=input['attention_mask'], |
|
|
|
token_type_ids=input['token_type_ids']) |
|
|
|
return {'logits': rst['logits'], 'input_ids': input['input_ids']} |
|
|
|
|
|
|
|
|
|
|
|
@MODELS.register_module(Tasks.fill_mask, module_name=Models.structbert) |
|
|
|