| @@ -217,7 +217,7 @@ class CrossEntropyLoss(LossBase): | |||
| 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | |||
| 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | |||
| 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | |||
| :param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||
| :param ignore_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||
| 传入seq_len. | |||
| :param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
| @@ -227,10 +227,11 @@ class CrossEntropyLoss(LossBase): | |||
| """ | |||
| def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'): | |||
| def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, ignore_idx=-100, reduction='mean', **kwargs): | |||
| super(CrossEntropyLoss, self).__init__() | |||
| self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
| self.padding_idx = padding_idx | |||
| ignore_idx = kwargs.pop('padding_idx', ignore_idx) | |||
| self.ignore_idx = ignore_idx | |||
| assert reduction in ('mean', 'sum', 'none') | |||
| self.reduction = reduction | |||
| self.class_in_dim = class_in_dim | |||
| @@ -238,7 +239,7 @@ class CrossEntropyLoss(LossBase): | |||
| def get_loss(self, pred, target, seq_len=None): | |||
| if seq_len is not None and target.dim()>1: | |||
| mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) | |||
| target = target.masked_fill(mask, self.padding_idx) | |||
| target = target.masked_fill(mask, self.ignore_idx) | |||
| if pred.dim() > 2: | |||
| if self.class_in_dim == -1: | |||
| @@ -250,7 +251,7 @@ class CrossEntropyLoss(LossBase): | |||
| target = target.reshape(-1) | |||
| return F.cross_entropy(input=pred, target=target, | |||
| ignore_index=self.padding_idx, reduction=self.reduction) | |||
| ignore_index=self.ignore_idx, reduction=self.reduction) | |||
| class L1Loss(LossBase): | |||
| @@ -318,16 +319,30 @@ class BCEWithLogits(LossBase): | |||
| :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
| :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
| :param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes) | |||
| 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | |||
| 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | |||
| 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | |||
| :param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
| """ | |||
| def __init__(self, pred=None, target=None, reduction='mean'): | |||
| def __init__(self, pred=None, target=None, class_in_dim=-1, reduction='mean'): | |||
| super(BCEWithLogits, self).__init__() | |||
| self._init_param_map(pred=pred, target=target) | |||
| assert reduction in ('mean', 'sum', 'none') | |||
| self.reduction = reduction | |||
| self.class_in_dim = class_in_dim | |||
| def get_loss(self, pred, target): | |||
| if pred.dim() > 2: | |||
| if self.class_in_dim == -1: | |||
| if pred.size(1) != target.size(1): # 有可能顺序替换了 | |||
| pred = pred.transpose(1, 2) | |||
| else: | |||
| pred = pred.transpose(-1, self.class_in_dim) | |||
| pred = pred.reshape(-1, pred.size(-1)) | |||
| target = target.reshape(-1) | |||
| return F.binary_cross_entropy_with_logits(input=pred, target=target, reduction=self.reduction) | |||
| @@ -336,22 +351,41 @@ class NLLLoss(LossBase): | |||
| 负对数似然损失函数 | |||
| """ | |||
| def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||
| def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, ignore_idx=-100, reduction='mean'): | |||
| r""" | |||
| :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
| :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
| :param seq_len: 句子的长度, 长度之外的token不会计算loss。仅在输出为3d时需要 | |||
| :param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes) | |||
| 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | |||
| 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | |||
| 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | |||
| :param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||
| 传入seq_len. | |||
| :param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
| """ | |||
| super(NLLLoss, self).__init__() | |||
| self._init_param_map(pred=pred, target=target) | |||
| self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
| assert reduction in ('mean', 'sum', 'none') | |||
| self.reduction = reduction | |||
| self.ignore_idx = ignore_idx | |||
| self.class_in_dim = class_in_dim | |||
| def get_loss(self, pred, target): | |||
| def get_loss(self, pred, target, seq_len=None): | |||
| if seq_len is not None and target.dim()>1: | |||
| mask = seq_len_to_mask(seq_len, max_len=target.size(1)).eq(False) | |||
| target = target.masked_fill(mask, self.ignore_idx) | |||
| if pred.dim() > 2: | |||
| if self.class_in_dim == -1: | |||
| if pred.size(1) != target.size(1): # 有可能顺序替换了 | |||
| pred = pred.transpose(1, 2) | |||
| else: | |||
| pred = pred.transpose(-1, self.class_in_dim) | |||
| pred = pred.reshape(-1, pred.size(-1)) | |||
| target = target.reshape(-1) | |||
| return F.nll_loss(input=pred, target=target, ignore_index=self.ignore_idx, reduction=self.reduction) | |||
| @@ -322,7 +322,8 @@ class SortedSampler(Sampler): | |||
| def __init__(self, seq_len_field_name='seq_len', descending=True): | |||
| """ | |||
| :param str seq_len_field_name: 对应序列长度的 `field` 的名字 | |||
| :param str seq_len_field_name: 按哪个field进行排序。如果传入的field是数字,则直接按照该数字大小排序;如果传入的field不是 | |||
| 数字,则使用该field的长度进行排序 | |||
| :param bool descending: 是否降序排列 | |||
| """ | |||
| self.seq_len_field_name = seq_len_field_name | |||
| @@ -330,6 +331,11 @@ class SortedSampler(Sampler): | |||
| def __call__(self, data_set): | |||
| seq_lens = data_set.get_field(self.seq_len_field_name).content | |||
| try: | |||
| seq_lens = list(map(len, seq_lens)) | |||
| except: | |||
| pass | |||
| orders = np.argsort(seq_lens).tolist() # 从小到大的顺序 | |||
| if self.descending: | |||
| orders = orders[::-1] | |||
| @@ -523,6 +523,7 @@ class Trainer(object): | |||
| self._forward_func = self.model.forward | |||
| self.fp16 = fp16 | |||
| self.verbose = kwargs.get('verbose', 0) | |||
| # check fp16相关的设置 | |||
| self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||
| @@ -608,7 +609,7 @@ class Trainer(object): | |||
| self.callback_manager = CallbackManager(env={"trainer": self}, | |||
| callbacks=callbacks) | |||
| def train(self, load_best_model=True, on_exception='auto'): | |||
| def train(self, load_best_model=True, on_exception='auto', **kwargs): | |||
| r""" | |||
| 使用该函数使Trainer开始训练。 | |||
| @@ -617,6 +618,8 @@ class Trainer(object): | |||
| :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 | |||
| 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; | |||
| 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. | |||
| :param kwargs: | |||
| int verbose: 为1时在发生异常时会打印异常发生时batch中的数据在dataset中的index | |||
| :return dict: 返回一个字典类型的数据, | |||
| 内含以下内容:: | |||
| @@ -629,6 +632,7 @@ class Trainer(object): | |||
| """ | |||
| results = {} | |||
| verbose = kwargs.get('verbose', 0) | |||
| if self.n_epochs <= 0: | |||
| self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.") | |||
| results['seconds'] = 0. | |||
| @@ -650,6 +654,8 @@ class Trainer(object): | |||
| except BaseException as e: | |||
| self.callback_manager.on_exception(e) | |||
| if verbose>0: | |||
| self.logger.info(f"The data indices for current batch are: {self.data_iterator.cur_batch_indices}.") | |||
| if on_exception == 'auto': | |||
| if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
| raise e | |||
| @@ -393,7 +393,7 @@ class _BertWordModel(nn.Module): | |||
| else: | |||
| pos_num_output_layer = max(layer, pos_num_output_layer) | |||
| self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
| self.tokenizer = BertTokenizer.from_pretrained(model_dir_or_name) | |||
| self.encoder = BertModel.from_pretrained(model_dir_or_name, | |||
| neg_num_output_layer=neg_num_output_layer, | |||
| pos_num_output_layer=pos_num_output_layer, | |||
| @@ -432,14 +432,14 @@ class _BertWordModel(nn.Module): | |||
| word = '[UNK]' | |||
| elif vocab.word_count[word] < min_freq: | |||
| word = '[UNK]' | |||
| word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
| word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
| word_pieces = self.tokenizer.wordpiece_tokenizer.tokenize(word) | |||
| word_pieces = self.tokenizer.convert_tokens_to_ids(word_pieces) | |||
| word_to_wordpieces.append(word_pieces) | |||
| word_pieces_lengths.append(len(word_pieces)) | |||
| self._cls_index = self.tokenzier.vocab['[CLS]'] | |||
| self._sep_index = self.tokenzier.vocab['[SEP]'] | |||
| self._cls_index = self.tokenizer.vocab['[CLS]'] | |||
| self._sep_index = self.tokenizer.vocab['[SEP]'] | |||
| self._word_pad_index = vocab.padding_idx | |||
| self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
| self._wordpiece_pad_index = self.tokenizer.vocab['[PAD]'] # 需要用于生成word_piece | |||
| self.word_to_wordpieces = np.array(word_to_wordpieces, dtype=object) | |||
| self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | |||
| logger.debug("Successfully generate word pieces.") | |||
| @@ -566,7 +566,7 @@ class _BertWordModel(nn.Module): | |||
| :param str folder: | |||
| :return: | |||
| """ | |||
| self.tokenzier.save_pretrained(folder) | |||
| self.tokenizer.save_pretrained(folder) | |||
| self.encoder.save_pretrained(folder) | |||
| @@ -579,7 +579,7 @@ class _BertWordPieceModel(nn.Module): | |||
| def __init__(self, model_dir_or_name: str, layers: str = '-1', pooled_cls: bool=False): | |||
| super().__init__() | |||
| self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
| self.tokenizer = BertTokenizer.from_pretrained(model_dir_or_name) | |||
| self.encoder = BertModel.from_pretrained(model_dir_or_name) | |||
| # 检查encoder_layer_number是否合理 | |||
| encoder_layer_number = len(self.encoder.encoder.layer) | |||
| @@ -599,10 +599,10 @@ class _BertWordPieceModel(nn.Module): | |||
| assert layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
| f"a bert model with {encoder_layer_number} layers." | |||
| self._cls_index = self.tokenzier.cls_index | |||
| self._sep_index = self.tokenzier.sep_index | |||
| self._wordpiece_unknown_index = self.tokenzier.unk_index | |||
| self._wordpiece_pad_index = self.tokenzier.pad_index # 需要用于生成word_piece | |||
| self._cls_index = self.tokenizer.cls_index | |||
| self._sep_index = self.tokenizer.sep_index | |||
| self._wordpiece_unknown_index = self.tokenizer.unk_index | |||
| self._wordpiece_pad_index = self.tokenizer.pad_index # 需要用于生成word_piece | |||
| self.pooled_cls = pooled_cls | |||
| def index_datasets(self, *datasets, field_name, add_cls_sep=True): | |||
| @@ -615,7 +615,7 @@ class _BertWordPieceModel(nn.Module): | |||
| :return: | |||
| """ | |||
| encode_func = partial(self.tokenzier.encode, add_special_tokens=add_cls_sep) | |||
| encode_func = partial(self.tokenizer.encode, add_special_tokens=add_cls_sep) | |||
| for index, dataset in enumerate(datasets): | |||
| try: | |||
| @@ -654,5 +654,5 @@ class _BertWordPieceModel(nn.Module): | |||
| :param folder: | |||
| :return: | |||
| """ | |||
| self.tokenzier.save_pretrained(folder) | |||
| self.tokenizer.save_pretrained(folder) | |||
| self.encoder.save_pretrained(folder) | |||
| @@ -328,7 +328,7 @@ def _beam_search_generate(decoder: Seq2SeqDecoder, tokens=None, state=None, max_ | |||
| max_len_eos_mask = max_lengths.eq(cur_len+1) | |||
| eos_scores = scores[:, _eos_token_id] | |||
| # 如果已经达到最大长度,就把eos的分数加大 | |||
| scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e12, eos_scores) | |||
| scores[:, _eos_token_id] = torch.where(max_len_eos_mask, eos_scores+1e32, eos_scores) | |||
| if do_sample: | |||
| if temperature > 0 and temperature != 1: | |||