diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index 300dd8ac..3f8cc057 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -67,7 +67,7 @@ class FullSpaceToHalfSpaceProcessor(Processor): def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: - sentence = ins[self.field_name].text + sentence = ins[self.field_name] new_sentence = [None]*len(sentence) for idx, char in enumerate(sentence): if char in self.convert_map: @@ -78,12 +78,13 @@ class FullSpaceToHalfSpaceProcessor(Processor): class IndexerProcessor(Processor): - def __init__(self, vocab, field_name, new_added_field_name): + def __init__(self, vocab, field_name, new_added_field_name, delete_old_field=False): assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) super(IndexerProcessor, self).__init__(field_name, new_added_field_name) self.vocab = vocab + self.delete_old_field = delete_old_field def set_vocab(self, vocab): assert isinstance(vocab, Vocabulary), "Only Vocabulary class is allowed, not {}.".format(type(vocab)) @@ -97,6 +98,11 @@ class IndexerProcessor(Processor): index = [self.vocab.to_index(token) for token in tokens] ins[self.new_added_field_name] = index + dataset.set_need_tensor(**{self.new_added_field_name:True}) + + if self.delete_old_field: + dataset.delete_field(self.field_name) + return dataset diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 397a3ddb..856a6eac 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -55,14 +55,15 @@ class Batch(object): indices = self.idx_list[self.curidx:endidx] - for field_name, field in self.dataset.get_fields(): - batch = torch.from_numpy(field.get(indices)) - if not field.need_tensor: #TODO 修改 - pass - elif field.is_target: - batch_y[field_name] = batch - else: - batch_x[field_name] = batch + for field_name, field in self.dataset.get_fields().items(): + if field.need_tensor: + batch = torch.from_numpy(field.get(indices)) + if not field.need_tensor: + pass + elif field.is_target: + batch_y[field_name] = batch + else: + batch_x[field_name] = batch self.curidx = endidx diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 18da9bd7..cffe95a9 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -75,11 +75,13 @@ class DataSet(object): assert len(self) == len(fields) self.field_arrays[name] = FieldArray(name, fields) + def delete_field(self, name): + self.field_arrays.pop(name) + def get_fields(self): return self.field_arrays def __getitem__(self, name): - assert name in self.field_arrays return self.field_arrays[name] def __len__(self): diff --git a/reproduction/chinese_word_segment/io/__init__.py b/reproduction/chinese_word_segment/cws_io/__init__.py similarity index 100% rename from reproduction/chinese_word_segment/io/__init__.py rename to reproduction/chinese_word_segment/cws_io/__init__.py diff --git a/reproduction/chinese_word_segment/io/cws_reader.py b/reproduction/chinese_word_segment/cws_io/cws_reader.py similarity index 100% rename from reproduction/chinese_word_segment/io/cws_reader.py rename to reproduction/chinese_word_segment/cws_io/cws_reader.py diff --git a/reproduction/chinese_word_segment/models/cws_model.py b/reproduction/chinese_word_segment/models/cws_model.py index dfcfcafe..1fc1af26 100644 --- a/reproduction/chinese_word_segment/models/cws_model.py +++ b/reproduction/chinese_word_segment/models/cws_model.py @@ -35,13 +35,6 @@ class CWSBiLSTMEncoder(BaseModel): self.bigram_embedding = nn.Embedding(num_embeddings=bigram_vocab_num, embedding_dim=bigram_embed_dim) self.input_size += self.num_bigram_per_char*bigram_embed_dim - if self.num_criterion!=None: - if bidirectional: - self.backward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion, - embedding_dim=self.hidden_size) - self.forward_criterion_embedding = nn.Embedding(num_embeddings=self.num_criterion, - embedding_dim=self.hidden_size) - if not self.embed_drop_p is None: self.embedding_drop = nn.Dropout(p=self.embed_drop_p) @@ -102,13 +95,14 @@ class CWSBiLSTMSegApp(BaseModel): self.decoder_model = MLP(size_layer) - def forward(self, **kwargs): - chars = kwargs['chars'] - if 'bigram' in kwargs: - bigrams = kwargs['bigrams'] + def forward(self, batch_dict): + device = self.parameters().__next__().device + chars = batch_dict['indexed_chars_list'].to(device) + if 'bigram' in batch_dict: + bigrams = batch_dict['indexed_chars_list'].to(device) else: bigrams = None - seq_lens = kwargs['seq_lens'] + seq_lens = batch_dict['seq_lens'].to(device) feats = self.encoder_model(chars, bigrams, seq_lens) probs = self.decoder_model(feats) @@ -119,6 +113,10 @@ class CWSBiLSTMSegApp(BaseModel): return pred_dict + def predict(self, batch_dict): + pass + + def loss_fn(self, pred_dict, true_dict): seq_lens = pred_dict['seq_lens'] masks = seq_lens_to_mask(seq_lens).float() @@ -131,5 +129,4 @@ class CWSBiLSTMSegApp(BaseModel): true_y.view(-1), reduction='none')*masks.view(-1)/torch.sum(masks) - return loss - + return loss \ No newline at end of file diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index c025895f..27a6fb1d 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -21,7 +21,7 @@ class SpeicalSpanProcessor(Processor): def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: - sentence = ins[self.field_name].text + sentence = ins[self.field_name] for span_converter in self.span_converters: sentence = span_converter.find_certain_span_and_replace(sentence) ins[self.new_added_field_name] = sentence @@ -42,10 +42,9 @@ class CWSCharSegProcessor(Processor): def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: - sentence = ins[self.field_name].text + sentence = ins[self.field_name] chars = self._split_sent_into_chars(sentence) - new_token_field = TokenListFiled(chars, is_target=False) - ins[self.new_added_field_name] = new_token_field + ins[self.new_added_field_name] = chars return dataset @@ -109,10 +108,11 @@ class CWSTagProcessor(Processor): def process(self, dataset): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: - sentence = ins[self.field_name].text + sentence = ins[self.field_name] tag_list = self._generate_tag(sentence) new_tag_field = SeqLabelField(tag_list) ins[self.new_added_field_name] = new_tag_field + dataset.set_is_target(**{self.new_added_field_name:True}) return dataset def _tags_from_word_len(self, word_len): @@ -123,6 +123,8 @@ class CWSSegAppTagProcessor(CWSTagProcessor): def __init__(self, field_name, new_added_field_name=None): super(CWSSegAppTagProcessor, self).__init__(field_name, new_added_field_name) + self.tag_size = 2 + def _tags_from_word_len(self, word_len): tag_list = [] for _ in range(word_len-1): @@ -140,10 +142,9 @@ class BigramProcessor(Processor): assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: - characters = ins[self.field_name].content + characters = ins[self.field_name] bigrams = self._generate_bigram(characters) - new_token_field = TokenListFiled(bigrams) - ins[self.new_added_field_name] = new_token_field + ins[self.new_added_field_name] = bigrams return dataset @@ -190,9 +191,26 @@ class VocabProcessor(Processor): for dataset in datasets: assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) for ins in dataset: - tokens = ins[self.field_name].content + tokens = ins[self.field_name] self.vocab.update(tokens) def get_vocab(self): self.vocab.build_vocab() return self.vocab + + def get_vocab_size(self): + return len(self.vocab) + + +class SeqLenProcessor(Processor): + def __init__(self, field_name, new_added_field_name='seq_lens'): + + super(SeqLenProcessor, self).__init__(field_name, new_added_field_name) + + def process(self, dataset): + assert isinstance(dataset, DataSet), "Only Dataset class is allowed, not {}.".format(type(dataset)) + for ins in dataset: + length = len(ins[self.field_name]) + ins[self.new_added_field_name] = length + dataset.set_need_tensor(**{self.new_added_field_name:True}) + return dataset diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index c44294ee..c5e7b2a4 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -9,35 +9,22 @@ from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegPr from reproduction.chinese_word_segment.process.cws_processor import CWSSegAppTagProcessor from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor from reproduction.chinese_word_segment.process.cws_processor import VocabProcessor +from reproduction.chinese_word_segment.process.cws_processor import SeqLenProcessor from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter -from reproduction.chinese_word_segment.io.cws_reader import NaiveCWSReader +from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp -tr_filename = '' -dev_filename = '' +tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_train.txt' +dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_dev.txt' reader = NaiveCWSReader() -tr_sentences = reader.load(tr_filename, cut_long_sent=True) -dev_sentences = reader.load(dev_filename) +tr_dataset = reader.load(tr_filename, cut_long_sent=True) +dev_dataset = reader.load(dev_filename) -# TODO 如何组建成为一个Dataset -def construct_dataset(sentences): - dataset = DataSet() - for sentence in sentences: - instance = Instance() - instance['raw_sentence'] = sentence - dataset.append(instance) - - return dataset - - -tr_dataset = construct_dataset(tr_sentences) -dev_dataset = construct_dataset(dev_sentences) - # 1. 准备processor fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') @@ -45,14 +32,14 @@ sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') sp_proc.add_span_converter(AlphaSpanConverter()) sp_proc.add_span_converter(DigitSpanConverter()) -char_proc = CWSCharSegProcessor('sentence', 'char_list') +char_proc = CWSCharSegProcessor('sentence', 'chars_list') -tag_proc = CWSSegAppTagProcessor('sentence', 'tag') +tag_proc = CWSSegAppTagProcessor('sentence', 'tags') -bigram_proc = Pre2Post2BigramProcessor('char_list', 'bigram_list') +bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') -char_vocab_proc = VocabProcessor('char_list') -bigram_vocab_proc = VocabProcessor('bigram_list') +char_vocab_proc = VocabProcessor('chars_list') +bigram_vocab_proc = VocabProcessor('bigrams_list') # 2. 使用processor fs2hs_proc(tr_dataset) @@ -66,15 +53,18 @@ bigram_proc(tr_dataset) char_vocab_proc(tr_dataset) bigram_vocab_proc(tr_dataset) -char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'indexed_chars_list') -bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','indexed_bigrams_list') +char_index_proc = IndexerProcessor(char_vocab_proc.get_vocab(), 'chars_list', 'indexed_chars_list', + delete_old_field=True) +bigram_index_proc = IndexerProcessor(bigram_vocab_proc.get_vocab(), 'bigrams_list','indexed_bigrams_list', + delete_old_field=True) +seq_len_proc = SeqLenProcessor('indexed_chars_list') char_index_proc(tr_dataset) bigram_index_proc(tr_dataset) +seq_len_proc(tr_dataset) # 2.1 处理dev_dataset fs2hs_proc(dev_dataset) - sp_proc(dev_dataset) char_proc(dev_dataset) @@ -83,14 +73,148 @@ bigram_proc(dev_dataset) char_index_proc(dev_dataset) bigram_index_proc(dev_dataset) +seq_len_proc(dev_dataset) +print("Finish preparing data.") # 3. 得到数据集可以用于训练了 -# TODO pretrain的embedding是怎么解决的? -cws_model = CWSBiLSTMSegApp(vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None, - hidden_size=200, bidirectional=True, embed_drop_p=None, num_layers=1, tag_size=2) +from itertools import chain + +def refine_ys_on_seq_len(ys, seq_lens): + refined_ys = [] + for b_idx, length in enumerate(seq_lens): + refined_ys.append(list(ys[b_idx][:length])) + + return refined_ys + +def flat_nested_list(nested_list): + return list(chain(*nested_list)) + +def calculate_pre_rec_f1(model, batcher): + true_ys, pred_ys, seq_lens = decode_iterator(model, batcher) + refined_true_ys = refine_ys_on_seq_len(true_ys, seq_lens) + refined_pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) + true_ys = flat_nested_list(refined_true_ys) + pred_ys = flat_nested_list(refined_pred_ys) + + cor_num = 0 + yp_wordnum = pred_ys.count(1) + yt_wordnum = true_ys.count(1) + start = 0 + for i in range(len(true_ys)): + if true_ys[i] == 1: + flag = True + for j in range(start, i + 1): + if true_ys[j] != pred_ys[j]: + flag = False + break + if flag: + cor_num += 1 + start = i + 1 + P = cor_num / (float(yp_wordnum) + 1e-6) + R = cor_num / (float(yt_wordnum) + 1e-6) + F = 2 * P * R / (P + R + 1e-6) + return P, R, F + +def decode_iterator(model, batcher): + true_ys = [] + pred_ys = [] + seq_lens = [] + with torch.no_grad(): + model.eval() + for batch_x, batch_y in batcher: + pred_dict = model(batch_x) + seq_len = pred_dict['seq_lens'].cpu().numpy() + probs = pred_dict['pred_probs'] + _, pred_y = probs.max(dim=-1) + true_y = batch_y['tags'] + pred_y = pred_y.cpu().numpy() + true_y = true_y.cpu().numpy() + + true_ys.extend(list(true_y)) + pred_ys.extend(list(pred_y)) + seq_lens.extend(list(seq_len)) + model.train() + + return true_ys, pred_ys, seq_lens +# TODO pretrain的embedding是怎么解决的? +from reproduction.chinese_word_segment.utils import FocalLoss +from reproduction.chinese_word_segment.utils import seq_lens_to_mask +from fastNLP.core.batch import Batch +from fastNLP.core.sampler import RandomSampler +from fastNLP.core.sampler import SequentialSampler + +import torch +from torch import optim +import sys +from tqdm import tqdm + + +tag_size = tag_proc.tag_size + +cws_model = CWSBiLSTMSegApp(char_vocab_proc.get_vocab_size(), embed_dim=100, + bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), + bigram_embed_dim=100, num_bigram_per_char=8, + hidden_size=200, bidirectional=True, embed_drop_p=None, + num_layers=1, tag_size=tag_size) + +num_epochs = 3 +loss_fn = FocalLoss(class_num=tag_size) +optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01) + + +print_every = 50 +batch_size = 32 +tr_batcher = Batch(tr_dataset, batch_size, RandomSampler(), use_cuda=False) +dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) +num_batch_per_epoch = len(tr_dataset) // batch_size +best_f1 = 0 +best_epoch = 0 +for num_epoch in range(num_epochs): + print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) + sys.stdout.flush() + avg_loss = 0 + with tqdm(total=num_batch_per_epoch, leave=True) as pbar: + pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) + cws_model.train() + for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): + pred_dict = cws_model(batch_x) # B x L x tag_size + seq_lens = batch_x['seq_lens'] + masks = seq_lens_to_mask(seq_lens) + tags = batch_y['tags'] + loss = torch.sum(loss_fn(pred_dict['pred_prob'].view(-1, tag_size), + tags.view(-1)) * masks.view(-1)) / torch.sum(masks) + # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) + + avg_loss += loss.item() + + loss.backward() + for group in optimizer.param_groups: + for param in group['params']: + param.grad.clamp_(-5, 5) + + optimizer.step() + + if batch_idx % print_every == 0: + pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) + avg_loss = 0 + pbar.update(print_every) + + # 验证集 + pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher) + print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, + pre*100, + rec*100)) + if best_f1