From 73ba3b5eec62583475baaf85fa6c461a3aa03e5c Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 10 Nov 2018 15:17:58 +0800 Subject: [PATCH 1/2] bug fix for pipeline --- fastNLP/api/cws.py | 32 +++++++++++++++++++ fastNLP/api/pipeline.py | 2 +- .../chinese_word_segment/train_context.py | 13 ++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) create mode 100644 fastNLP/api/cws.py diff --git a/fastNLP/api/cws.py b/fastNLP/api/cws.py new file mode 100644 index 00000000..ea6f96e6 --- /dev/null +++ b/fastNLP/api/cws.py @@ -0,0 +1,32 @@ + + +from fastNLP.api.api import API +from fastNLP.core.dataset import DataSet + +class CWS(API): + def __init__(self, model_path='xxx'): + super(CWS, self).__init__() + self.load(model_path) + + def predict(self, sentence, pretrain=False): + + if hasattr(self, 'model') and hasattr(self, 'pipeline'): + raise ValueError("You have to load model first. Or specify pretrain=True.") + + sentence_list = [] + # 1. 检查sentence的类型 + if isinstance(sentence, str): + sentence_list.append(sentence) + elif isinstance(sentence, list): + sentence_list = sentence + + # 2. 组建dataset + dataset = DataSet() + dataset.add_field('raw_sentence', sentence_list) + + # 3. 使用pipeline + self.pipeline(dataset) + + # 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 + + # 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 diff --git a/fastNLP/api/pipeline.py b/fastNLP/api/pipeline.py index 745c8874..0edceb19 100644 --- a/fastNLP/api/pipeline.py +++ b/fastNLP/api/pipeline.py @@ -13,7 +13,7 @@ class Pipeline: def process(self, dataset): assert len(self.pipeline)!=0, "You need to add some processor first." - for proc_name, proc in self.pipeline: + for proc in self.pipeline: dataset = proc(dataset) return dataset diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index e43f8a24..184380e0 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -223,8 +223,21 @@ pp = Pipeline() pp.add_processor(fs2hs_proc) pp.add_processor(sp_proc) pp.add_processor(char_proc) +pp.add_processor(tag_proc) pp.add_processor(bigram_proc) pp.add_processor(char_index_proc) pp.add_processor(bigram_index_proc) pp.add_processor(seq_len_proc) +te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_test.txt' +te_dataset = reader.load(te_filename) +pp(te_dataset) + +batch_size = 64 +te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) +pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher) +print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, + pre * 100, + rec * 100)) + + From ea1c8c1100d523605013ef5c53901202fa6d65cf Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sat, 10 Nov 2018 19:59:32 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E5=BD=93=E5=89=8D=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E5=88=86=E8=AF=8D=E5=87=86=E7=A1=AE=E7=8E=87=E5=B7=B2=E8=BE=BE?= =?UTF-8?q?=E6=AD=A3=E5=B8=B8=E5=88=86=E8=AF=8D=E5=88=86=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/sampler.py | 3 +- .../process/cws_processor.py | 4 +- .../chinese_word_segment/train_context.py | 37 +++++++++++++------ 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index d2d1b301..652bc97e 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -78,7 +78,8 @@ class BucketSampler(BaseSampler): for i in range(num_batch_per_bucket): batchs.append(left_init_indexes[i*self.batch_size:(i+1)*self.batch_size]) left_init_indexes = left_init_indexes[num_batch_per_bucket*self.batch_size:] - + if (left_init_indexes)!=0: + batchs.append(left_init_indexes) np.random.shuffle(batchs) return list(chain(*batchs)) diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index e93431ff..8363ca75 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -182,10 +182,10 @@ class Pre2Post2BigramProcessor(BigramProcessor): # Processor了 class VocabProcessor(Processor): - def __init__(self, field_name): + def __init__(self, field_name, min_count=1, max_vocab_size=None): super(VocabProcessor, self).__init__(field_name, None) - self.vocab = Vocabulary() + self.vocab = Vocabulary(min_freq=min_count, max_size=max_vocab_size) def process(self, *datasets): for dataset in datasets: diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index 184380e0..21b7ab89 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -11,11 +11,15 @@ from reproduction.chinese_word_segment.process.cws_processor import SeqLenProces from reproduction.chinese_word_segment.process.span_converter import AlphaSpanConverter from reproduction.chinese_word_segment.process.span_converter import DigitSpanConverter +from reproduction.chinese_word_segment.process.span_converter import TimeConverter +from reproduction.chinese_word_segment.process.span_converter import MixNumAlphaConverter +from reproduction.chinese_word_segment.process.span_converter import EmailConverter from reproduction.chinese_word_segment.cws_io.cws_reader import NaiveCWSReader from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp -tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_train.txt' -dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_dev.txt' +ds_name = 'pku' +tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) +dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) reader = NaiveCWSReader() @@ -27,8 +31,12 @@ dev_dataset = reader.load(dev_filename) fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') +sp_proc.add_span_converter(EmailConverter()) +sp_proc.add_span_converter(MixNumAlphaConverter()) sp_proc.add_span_converter(AlphaSpanConverter()) sp_proc.add_span_converter(DigitSpanConverter()) +sp_proc.add_span_converter(TimeConverter()) + char_proc = CWSCharSegProcessor('sentence', 'chars_list') @@ -37,7 +45,7 @@ tag_proc = CWSSegAppTagProcessor('sentence', 'tags') bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') char_vocab_proc = VocabProcessor('chars_list') -bigram_vocab_proc = VocabProcessor('bigrams_list') +bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4) # 2. 使用processor fs2hs_proc(tr_dataset) @@ -74,6 +82,8 @@ bigram_index_proc(dev_dataset) seq_len_proc(dev_dataset) print("Finish preparing data.") +print("Vocab size:{}, bigram size:{}.".format(char_vocab_proc.get_vocab_size(), bigram_vocab_proc.get_vocab_size())) + # 3. 得到数据集可以用于训练了 from itertools import chain @@ -89,11 +99,10 @@ def flat_nested_list(nested_list): return list(chain(*nested_list)) def calculate_pre_rec_f1(model, batcher): - true_ys, pred_ys, seq_lens = decode_iterator(model, batcher) - refined_true_ys = refine_ys_on_seq_len(true_ys, seq_lens) - refined_pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) - true_ys = flat_nested_list(refined_true_ys) - pred_ys = flat_nested_list(refined_pred_ys) + true_ys, pred_ys = decode_iterator(model, batcher) + + true_ys = flat_nested_list(true_ys) + pred_ys = flat_nested_list(pred_ys) cor_num = 0 yp_wordnum = pred_ys.count(1) @@ -134,7 +143,10 @@ def decode_iterator(model, batcher): seq_lens.extend(list(seq_len)) model.train() - return true_ys, pred_ys, seq_lens + true_ys = refine_ys_on_seq_len(true_ys, seq_lens) + pred_ys = refine_ys_on_seq_len(pred_ys, seq_lens) + + return true_ys, pred_ys # TODO pretrain的embedding是怎么解决的? @@ -161,7 +173,7 @@ cws_model.cuda() num_epochs = 3 loss_fn = FocalLoss(class_num=tag_size) -optimizer = optim.Adagrad(cws_model.parameters(), lr=0.01) +optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) print_every = 50 @@ -179,6 +191,8 @@ for num_epoch in range(num_epochs): pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) cws_model.train() for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): + optimizer.zero_grad() + pred_dict = cws_model(batch_x) # B x L x tag_size seq_lens = pred_dict['seq_lens'] @@ -217,6 +231,7 @@ for num_epoch in range(num_epochs): } best_epoch = num_epoch +cws_model.load_state_dict(best_state_dict) # 4. 组装需要存下的内容 pp = Pipeline() @@ -229,7 +244,7 @@ pp.add_processor(char_index_proc) pp.add_processor(bigram_index_proc) pp.add_processor(seq_len_proc) -te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/pku/middle_files/pku_test.txt' +te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) te_dataset = reader.load(te_filename) pp(te_dataset)