diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6aca0b20..fa46df24 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -201,21 +201,6 @@ class MetricBase(object): f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " f"initialization parameters, or change its signature.") - def _fast_param_map(self, pred_dict, target_dict): - """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. - such as pred_dict has one element, target_dict has one element - - :param pred_dict: - :param target_dict: - :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. - """ - fast_param = {} - if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: - fast_param['pred'] = list(pred_dict.values())[0] - fast_param['target'] = list(target_dict.values())[0] - return fast_param - return fast_param - def __call__(self, pred_dict, target_dict): """ 这个方法会调用self.evaluate 方法. @@ -231,10 +216,6 @@ class MetricBase(object): :return: """ - fast_param = self._fast_param_map(pred_dict, target_dict) - if fast_param: - self.evaluate(**fast_param) - return if not self._checked: if not callable(self.evaluate): diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index e615ba64..3e2b98be 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -193,7 +193,7 @@ class BertWordPieceEncoder(nn.Module): with torch.no_grad(): sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len if token_type_ids is None: - sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) token_type_ids = sep_mask_cumsum.fmod(2) if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 token_type_ids = token_type_ids.eq(0).long()