diff --git a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py index 65ed7ab8..037d6081 100644 --- a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py +++ b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py @@ -63,8 +63,10 @@ class Conll2003DataLoader(DataSetLoader): data.datasets[name] = dataset # 对construct vocab - word_vocab = Vocabulary(min_freq=3) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) - word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT) + word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) + # word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT) + # TODO 这样感觉不规范呐 + word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT) data.vocabs[Const.INPUT] = word_vocab diff --git a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py index bf1ab71e..5abfe7c5 100644 --- a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py +++ b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py @@ -87,7 +87,8 @@ class OntoNoteNERDataLoader(DataSetLoader): # 对construct vocab word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt) - word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') + # word_vocab.from_dataset(data.datasets['train'], field_name='raw_words') + word_vocab.from_dataset(*data.datasets.values(), field_name=Const.INPUT) word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name=Const.INPUT) data.vocabs[Const.INPUT] = word_vocab diff --git a/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py b/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py index 79fa7a76..36d86651 100644 --- a/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py +++ b/reproduction/seqence_labelling/ner/model/lstm_cnn_crf.py @@ -4,7 +4,7 @@ from torch import nn from fastNLP import seq_len_to_mask from fastNLP.modules import Embedding from fastNLP.modules import LSTM -from fastNLP.modules import ConditionalRandomField, allowed_transitions, TimestepDropout +from fastNLP.modules import ConditionalRandomField, allowed_transitions import torch.nn.functional as F from fastNLP import Const @@ -17,13 +17,12 @@ class CNNBiLSTMCRF(nn.Module): self.lstm = LSTM(input_size=self.embedding.embedding_dim+self.char_embedding.embedding_dim, hidden_size=hidden_size//2, num_layers=num_layers, bidirectional=True, batch_first=True, dropout=dropout) - self.forward_fc = nn.Linear(hidden_size//2, len(tag_vocab)) - self.backward_fc = nn.Linear(hidden_size//2, len(tag_vocab)) + self.fc = nn.Linear(hidden_size, len(tag_vocab)) - transitions = allowed_transitions(tag_vocab.idx2word, encoding_type=encoding_type, include_start_end=False) - self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=False, allowed_transitions=transitions) + transitions = allowed_transitions(tag_vocab.idx2word, encoding_type=encoding_type, include_start_end=True) + self.crf = ConditionalRandomField(len(tag_vocab), include_start_end_trans=True, allowed_transitions=transitions) - self.dropout = TimestepDropout(dropout, inplace=True) + self.dropout = nn.Dropout(dropout, inplace=True) for name, param in self.named_parameters(): if 'ward_fc' in name: @@ -40,13 +39,8 @@ class CNNBiLSTMCRF(nn.Module): words = torch.cat([words, chars], dim=-1) outputs, _ = self.lstm(words, seq_len) self.dropout(outputs) - forwards, backwards = outputs.chunk(2, dim=-1) - # forward_logits = F.log_softmax(self.forward_fc(forwards), dim=-1) - # backward_logits = F.log_softmax(self.backward_fc(backwards), dim=-1) - - logits = self.forward_fc(forwards) + self.backward_fc(backwards) - self.dropout(logits) + logits = F.log_softmax(self.fc(outputs), dim=-1) if target is not None: loss = self.crf(logits, target, seq_len_to_mask(seq_len)) diff --git a/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py b/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py index 278ff42f..507be4f6 100644 --- a/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py +++ b/reproduction/seqence_labelling/ner/train_cnn_lstm_crf_conll2003.py @@ -10,7 +10,8 @@ from fastNLP import BucketSampler from fastNLP import Const from torch.optim import SGD, Adam from fastNLP import GradientClipCallback -from fastNLP.core.callback import FitlogCallback +from fastNLP.core.callback import FitlogCallback, LRScheduler +from torch.optim.lr_scheduler import LambdaLR import fitlog fitlog.debug() @@ -19,7 +20,7 @@ from reproduction.seqence_labelling.ner.data.Conll2003Loader import Conll2003Dat encoding_type = 'bioes' data = Conll2003DataLoader(encoding_type=encoding_type).process('/hdd/fudanNLP/fastNLP/others/data/conll2003', - word_vocab_opt=VocabularyOption(min_freq=3)) + word_vocab_opt=VocabularyOption(min_freq=2)) print(data) char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], kernel_sizes=[3]) @@ -28,15 +29,18 @@ word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], requires_grad=True) word_embed.embedding.weight.data = word_embed.embedding.weight.data/word_embed.embedding.weight.data.std() -model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=400, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], +model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type) -optimizer = Adam(model.parameters(), lr=0.001) +optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9) +scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) -callbacks = [GradientClipCallback(clip_type='value'), FitlogCallback({'test':data.datasets['test']}, verbose=1)] +callbacks = [GradientClipCallback(clip_type='value', clip_value=5), FitlogCallback({'test':data.datasets['test'], + 'train':data.datasets['train']}, verbose=1), + scheduler] trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(), - device=0, dev_data=data.datasets['dev'], batch_size=32, + device=0, dev_data=data.datasets['dev'], batch_size=10, metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type), callbacks=callbacks, num_workers=1, n_epochs=100) trainer.train() \ No newline at end of file diff --git a/reproduction/seqence_labelling/ner/train_ontonote.py b/reproduction/seqence_labelling/ner/train_ontonote.py index 6f443dfd..e2a4158a 100644 --- a/reproduction/seqence_labelling/ner/train_ontonote.py +++ b/reproduction/seqence_labelling/ner/train_ontonote.py @@ -25,10 +25,10 @@ word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT], model_dir_or_name='/hdd/fudanNLP/pretrain_vectors/glove.6B.100d.txt', requires_grad=True) -model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], +model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=400, num_layers=2, tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type) -optimizer = Adam(model.parameters(), lr=0.001) +optimizer = SGD(model.parameters(), lr=0.015, momentum=0.9) callbacks = [GradientClipCallback(), FitlogCallback(data.datasets['test'], verbose=1)]