diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 21c024f0..05e5b440 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -28,6 +28,7 @@ from .utils import _check_arg_dict_list from .utils import _check_function_or_method from .utils import _get_func_signature from .utils import seq_len_to_mask +import warnings class LossBase(object): @@ -226,7 +227,8 @@ class CrossEntropyLoss(LossBase): def get_loss(self, pred, target, seq_len=None): if pred.dim() > 2: if pred.size(1) != target.size(1): # 有可能顺序替换了 - pred = pred.transpose(1, 2) + raise RuntimeError("It seems like that your prediction's shape is (batch_size, num_labels, max_len)." + " It should be (batch_size, max_len, num_labels).") pred = pred.reshape(-1, pred.size(-1)) target = target.reshape(-1) if seq_len is not None: diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index a85b7fee..6d18fd48 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -942,7 +942,7 @@ def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAUL if dev_data is not None: tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, - batch_size=batch_size, verbose=-1) + batch_size=batch_size, verbose=-1, use_tqdm=False) evaluate_results = tester.test() _check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py index 261007ae..ea5e84ac 100644 --- a/fastNLP/embeddings/bert_embedding.py +++ b/fastNLP/embeddings/bert_embedding.py @@ -11,7 +11,7 @@ from ..core.vocabulary import Vocabulary from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer from .contextual_embedding import ContextualEmbedding - +import warnings class BertEmbedding(ContextualEmbedding): """ @@ -229,6 +229,10 @@ class _WordBertModel(nn.Module): # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 found_count = 0 + self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids + if "[CLS]" in vocab: + warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin " + "and end of the sentence automatically.") for word, index in vocab: if index == vocab.padding_idx: # pad是个特殊的符号 word = '[PAD]' @@ -316,9 +320,18 @@ class _WordBertModel(nn.Module): word_pieces[:, 0].fill_(self._cls_index) batch_indexes = torch.arange(batch_size).to(words) word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index + if self._has_sep_in_vocab: #但[SEP]在vocab中出现应该才会需要token_ids + with torch.no_grad(): + sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len + sep_mask_cumsum = sep_mask.flip(dim=-1).cumsum(dim=-1).flip(dim=-1) + token_type_ids = sep_mask_cumsum.fmod(2) + if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 + token_type_ids = token_type_ids.eq(0).float() + else: + token_type_ids = torch.zeros_like(word_pieces) # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] - bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, + bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, output_all_encoded_layers=True) # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size diff --git a/reproduction/joint_cws_parse/models/CharParser.py b/reproduction/joint_cws_parse/models/CharParser.py index c07c070e..7d89cacb 100644 --- a/reproduction/joint_cws_parse/models/CharParser.py +++ b/reproduction/joint_cws_parse/models/CharParser.py @@ -224,11 +224,11 @@ class CharBiaffineParser(BiaffineParser): batch_size, seq_len, _ = arc_pred.shape flip_mask = (mask == 0) - _arc_pred = arc_pred.clone() - _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf')) + # _arc_pred = arc_pred.clone() + _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf')) - arc_true[:, 0].fill_(-1) - label_true[:, 0].fill_(-1) + arc_true.data[:, 0].fill_(-1) + label_true.data[:, 0].fill_(-1) arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1) label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1) diff --git a/reproduction/joint_cws_parse/train.py b/reproduction/joint_cws_parse/train.py index 0c34614b..ed4b07f0 100644 --- a/reproduction/joint_cws_parse/train.py +++ b/reproduction/joint_cws_parse/train.py @@ -14,6 +14,7 @@ from torch.optim.lr_scheduler import StepLR from fastNLP import Tester from fastNLP import GradientClipCallback, LRScheduler import os +from fastNLP import cache_results def set_random_seed(random_seed=666): import random, numpy, torch @@ -39,43 +40,42 @@ label_mlp_size = 100 batch_size = 32 update_every = 4 n_epochs = 100 -data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 -vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt +data_name = 'new_ctb7' #################################################### +data_folder = f'/remote-home/hyan01/exps/JointCwsPosParser/data/{data_name}/output' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件 +vector_folder = '/remote-home/hyan01/exps/CWS/pretrain/vectors' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt set_random_seed(1234) device = 0 -# @cache_results('caches/{}.pkl'.format(data_name)) -# def get_data(): -data = CTBxJointLoader().process(data_folder) - -char_labels_vocab = data.vocabs['char_labels'] - -pre_chars_vocab = data.vocabs['pre_chars'] -pre_bigrams_vocab = data.vocabs['pre_bigrams'] -pre_trigrams_vocab = data.vocabs['pre_trigrams'] - -chars_vocab = data.vocabs['chars'] -bigrams_vocab = data.vocabs['bigrams'] -trigrams_vocab = data.vocabs['trigrams'] - -pre_chars_embed = StaticEmbedding(pre_chars_vocab, - model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), - init_method=uniform_init, normalize=False) -pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std() -pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, - model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), - init_method=uniform_init, normalize=False) -pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std() -pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, - model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), - init_method=uniform_init, normalize=False) -pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std() - - # return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data - -# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() +@cache_results('caches/{}.pkl'.format(data_name)) +def get_data(): + data = CTBxJointLoader().process(data_folder) + char_labels_vocab = data.vocabs['char_labels'] + + pre_chars_vocab = data.vocabs['pre_chars'] + pre_bigrams_vocab = data.vocabs['pre_bigrams'] + pre_trigrams_vocab = data.vocabs['pre_trigrams'] + + chars_vocab = data.vocabs['chars'] + bigrams_vocab = data.vocabs['bigrams'] + trigrams_vocab = data.vocabs['trigrams'] + pre_chars_embed = StaticEmbedding(pre_chars_vocab, + model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) + pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std() + pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) + pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std() + pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab, + model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'), + init_method=uniform_init, normalize=False) + pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std() + + return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data + +chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data() print(data) model = CharParser(char_vocab_size=len(chars_vocab), @@ -104,11 +104,24 @@ optimizer = optim.Adam([param for param in model.parameters() if param.requires_ sampler = BucketSampler(seq_len_field_name='seq_lens') callbacks = [] + +from fastNLP.core.callback import Callback +from torch.optim.lr_scheduler import LambdaLR +class SchedulerCallback(Callback): + def __init__(self, scheduler): + super().__init__() + self.scheduler = scheduler + + def on_backward_end(self): + if self.step % self.update_every==0: + self.scheduler.step() + +scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) # scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000)) -scheduler = StepLR(optimizer, step_size=18, gamma=0.75) -# optim_callback = OptimizerCallback(optimizer, scheduler, update_every) +# scheduler = StepLR(optimizer, step_size=18, gamma=0.75) +scheduler_callback = SchedulerCallback(scheduler) # callbacks.append(optim_callback) -scheduler_callback = LRScheduler(scheduler) +# scheduler_callback = LRScheduler(scheduler) callbacks.append(scheduler_callback) callbacks.append(GradientClipCallback(clip_type='value', clip_value=5)) @@ -119,6 +132,6 @@ callbacks.append(dev_callback) trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3, validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer, - check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True, + check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True, device=device, callbacks=callbacks, update_every=update_every) trainer.train() \ No newline at end of file