Browse Source

1.修复BertEncoder中的flip时报错的bug; 2.删除Metric中fast_param_map防止误触发导致不好debug的问题

tags/v0.5.5
yh_cc 4 years ago
parent
commit
218f3aaded
2 changed files with 1 additions and 20 deletions
  1. +0
    -19
      fastNLP/core/metrics.py
  2. +1
    -1
      fastNLP/embeddings/bert_embedding.py

+ 0
- 19
fastNLP/core/metrics.py View File

@@ -201,21 +201,6 @@ class MetricBase(object):
f"Parameter `{func_param}` is not in {_get_func_signature(self.evaluate)}. Please check the "
f"initialization parameters, or change its signature.")

def _fast_param_map(self, pred_dict, target_dict):
"""Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map.
such as pred_dict has one element, target_dict has one element

:param pred_dict:
:param target_dict:
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping.
"""
fast_param = {}
if len(self._param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1:
fast_param['pred'] = list(pred_dict.values())[0]
fast_param['target'] = list(target_dict.values())[0]
return fast_param
return fast_param

def __call__(self, pred_dict, target_dict):
"""
这个方法会调用self.evaluate 方法.
@@ -231,10 +216,6 @@ class MetricBase(object):
:return:
"""

fast_param = self._fast_param_map(pred_dict, target_dict)
if fast_param:
self.evaluate(**fast_param)
return

if not self._checked:
if not callable(self.evaluate):


+ 1
- 1
fastNLP/embeddings/bert_embedding.py View File

@@ -193,7 +193,7 @@ class BertWordPieceEncoder(nn.Module):
with torch.no_grad():
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
if token_type_ids is None:
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
sep_mask_cumsum = sep_mask.long().flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long()


Loading…
Cancel
Save