From 218f3aaded3fd869f99a7c8a8b42702ca97dad6b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Tue, 24 Dec 2019 00:30:38 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8DBertEncoder=E4=B8=AD?= =?UTF-8?q?=E7=9A=84flip=E6=97=B6=E6=8A=A5=E9=94=99=E7=9A=84bug;=202.?= =?UTF-8?q?=E5=88=A0=E9=99=A4Metric=E4=B8=ADfast=5Fparam=5Fmap=E9=98=B2?= =?UTF-8?q?=E6=AD=A2=E8=AF=AF=E8=A7=A6=E5=8F=91=E5=AF=BC=E8=87=B4=E4=B8=8D?= =?UTF-8?q?=E5=A5=BDdebug=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/metrics.py | 19 ------------------- fastNLP/embeddings/bert_embedding.py | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 6aca0b20..fa46df24 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -201,21 +201,6 @@ class MetricBase(object): f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the " f"initialization parameters, or change its signature.") - def _fast_param_map(self, pred_dict, target_dict): - """Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. - such as pred_dict has one element, target_dict has one element - - :param pred_dict: - :param target_dict: - :return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. - """ - fast_param = {} - if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: - fast_param['pred'] = list(pred_dict.values())[0] - fast_param['target'] = list(target_dict.values())[0] - return fast_param - return fast_param - def __call__(self, pred_dict, target_dict): """ 这个方法会调用self.evaluate 方法. @@ -231,10 +216,6 @@ class MetricBase(object): :return: """ - fast_param = self._fast_param_map(pred_dict, target_dict) - if fast_param: - self.evaluate(**fast_param) - return if not self._checked: if not callable(self.evaluate): diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index e615ba64..3e2b98be 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -193,7 +193,7 @@ class BertWordPieceEncoder(nn.Module): with torch.no_grad(): sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len if token_type_ids is None: - sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) + sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) token_type_ids = sep_mask_cumsum.fmod(2) if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 token_type_ids = token_type_ids.eq(0).long()