From 9667c524a403504e68fbc9a95d3f880e723cc6a3 Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 11 Nov 2018 15:53:33 +0800 Subject: [PATCH] =?UTF-8?q?=E5=9F=BA=E6=9C=AC=E5=AE=8C=E5=96=84=E4=BA=86cw?= =?UTF-8?q?s=E7=9A=84predict?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/api/api.py | 37 ++++++++++++++++--- fastNLP/api/cws.py | 32 ---------------- fastNLP/api/processor.py | 23 ++++++------ .../process/cws_processor.py | 4 +- .../chinese_word_segment/train_context.py | 28 +++++++------- 5 files changed, 61 insertions(+), 63 deletions(-) delete mode 100644 fastNLP/api/cws.py diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index c7d48326..823e0ee0 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -17,12 +17,7 @@ class API: def load(self, name): _dict = torch.load(name) self.pipeline = _dict['pipeline'] - self.model = _dict['model'] - def save(self, path): - _dict = {'pipeline': self.pipeline, - 'model': self.model} - torch.save(_dict, path) class POS_tagger(API): @@ -64,6 +59,38 @@ class POS_tagger(API): self.tag_vocab = _dict["tag_vocab"] + +class CWS(API): + def __init__(self, model_path='xxx'): + super(CWS, self).__init__() + self.load(model_path) + + def predict(self, sentence, pretrain=False): + + if hasattr(self, 'pipeline'): + raise ValueError("You have to load model first. Or specify pretrain=True.") + + sentence_list = [] + # 1. 检查sentence的类型 + if isinstance(sentence, str): + sentence_list.append(sentence) + elif isinstance(sentence, list): + sentence_list = sentence + + # 2. 组建dataset + dataset = DataSet() + dataset.add_field('raw_sentence', sentence_list) + + # 3. 使用pipeline + self.pipeline(dataset) + + output = dataset['output'] + if isinstance(sentence, str): + return output[0] + elif isinstance(sentence, list): + return output + + if __name__ == "__main__": tagger = POS_tagger() print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]])) diff --git a/fastNLP/api/cws.py b/fastNLP/api/cws.py deleted file mode 100644 index ea6f96e6..00000000 --- a/fastNLP/api/cws.py +++ /dev/null @@ -1,32 +0,0 @@ - - -from fastNLP.api.api import API -from fastNLP.core.dataset import DataSet - -class CWS(API): - def __init__(self, model_path='xxx'): - super(CWS, self).__init__() - self.load(model_path) - - def predict(self, sentence, pretrain=False): - - if hasattr(self, 'model') and hasattr(self, 'pipeline'): - raise ValueError("You have to load model first. Or specify pretrain=True.") - - sentence_list = [] - # 1. 检查sentence的类型 - if isinstance(sentence, str): - sentence_list.append(sentence) - elif isinstance(sentence, list): - sentence_list = sentence - - # 2. 组建dataset - dataset = DataSet() - dataset.add_field('raw_sentence', sentence_list) - - # 3. 使用pipeline - self.pipeline(dataset) - - # 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 - - # 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 diff --git a/fastNLP/api/processor.py b/fastNLP/api/processor.py index a7223b38..d809b7cc 100644 --- a/fastNLP/api/processor.py +++ b/fastNLP/api/processor.py @@ -1,9 +1,13 @@ +import torch +from collections import defaultdict +import re + from fastNLP.core.dataset import DataSet from fastNLP.core.vocabulary import Vocabulary +from fastNLP.core.batch import Batch +from fastNLP.core.sampler import SequentialSampler -import re - class Processor: def __init__(self, field_name, new_added_field_name): self.field_name = field_name @@ -172,12 +176,6 @@ class SeqLenProcessor(Processor): dataset.set_need_tensor(**{self.new_added_field_name: True}) return dataset - -from fastNLP.core.batch import Batch -from fastNLP.core.sampler import SequentialSampler -import torch -from collections import defaultdict - class ModelProcessor(Processor): def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): """ @@ -205,9 +203,12 @@ class ModelProcessor(Processor): for key, value in prediction.items(): tmp_batch = [] value = value.cpu().numpy() - for idx, seq_len in enumerate(seq_lens): - tmp_batch.append(value[idx, :seq_len]) - batch_output[key].extend(tmp_batch) + if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): + for idx, seq_len in enumerate(seq_lens): + tmp_batch.append(value[idx, :seq_len]) + batch_output[key].extend(tmp_batch) + else: + batch_output[key].extend(value.tolist()) batch_output[self.seq_len_field_name].extend(seq_lens) diff --git a/reproduction/chinese_word_segment/process/cws_processor.py b/reproduction/chinese_word_segment/process/cws_processor.py index 2aa05bef..4aaff5af 100644 --- a/reproduction/chinese_word_segment/process/cws_processor.py +++ b/reproduction/chinese_word_segment/process/cws_processor.py @@ -216,7 +216,7 @@ class SeqLenProcessor(Processor): return dataset class SegApp2OutputProcessor(Processor): - def __init__(self, chars_field_name='chars', tag_field_name='pred_tags', new_added_field_name='output'): + def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): super(SegApp2OutputProcessor, self).__init__(None, None) self.chars_field_name = chars_field_name @@ -235,6 +235,6 @@ class SegApp2OutputProcessor(Processor): if tag==1: # 当前没有考虑将原文替换回去 words.append(''.join(chars[start_idx:idx+1])) - start_idx = idx + start_idx = idx + 1 ins[self.new_added_field_name] = ' '.join(words) diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py index ac0b8471..18e59989 100644 --- a/reproduction/chinese_word_segment/train_context.py +++ b/reproduction/chinese_word_segment/train_context.py @@ -20,8 +20,10 @@ from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 ds_name = 'msr' -tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) -dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) +tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, + ds_name) +dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, + ds_name) reader = NaiveCWSReader() @@ -32,17 +34,17 @@ dev_dataset = reader.load(dev_filename) # 1. 准备processor fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') -sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') +# sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') # sp_proc.add_span_converter(EmailConverter()) # sp_proc.add_span_converter(MixNumAlphaConverter()) -sp_proc.add_span_converter(AlphaSpanConverter()) -sp_proc.add_span_converter(DigitSpanConverter()) +# sp_proc.add_span_converter(AlphaSpanConverter()) +# sp_proc.add_span_converter(DigitSpanConverter()) # sp_proc.add_span_converter(TimeConverter()) -char_proc = CWSCharSegProcessor('sentence', 'chars_list') +char_proc = CWSCharSegProcessor('raw_sentence', 'chars_list') -tag_proc = CWSSegAppTagProcessor('sentence', 'tags') +tag_proc = CWSSegAppTagProcessor('raw_sentence', 'tags') bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') @@ -52,7 +54,7 @@ bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4) # 2. 使用processor fs2hs_proc(tr_dataset) -sp_proc(tr_dataset) +# sp_proc(tr_dataset) char_proc(tr_dataset) tag_proc(tr_dataset) @@ -73,7 +75,7 @@ seq_len_proc(tr_dataset) # 2.1 处理dev_dataset fs2hs_proc(dev_dataset) -sp_proc(dev_dataset) +# sp_proc(dev_dataset) char_proc(dev_dataset) tag_proc(dev_dataset) @@ -133,7 +135,7 @@ for num_epoch in range(num_epochs): for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): optimizer.zero_grad() - pred_dict = cws_model(batch_x) # B x L x tag_size + pred_dict = cws_model(**batch_x) # B x L x tag_size seq_lens = pred_dict['seq_lens'] masks = seq_lens_to_mask(seq_lens).float() @@ -176,7 +178,7 @@ cws_model.load_state_dict(best_state_dict) # 4. 组装需要存下的内容 pp = Pipeline() pp.add_processor(fs2hs_proc) -pp.add_processor(sp_proc) +# pp.add_processor(sp_proc) pp.add_processor(char_proc) pp.add_processor(tag_proc) pp.add_processor(bigram_proc) @@ -187,7 +189,7 @@ pp.add_processor(seq_len_proc) -te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) +te_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) te_dataset = reader.load(te_filename) pp(te_dataset) @@ -216,7 +218,7 @@ output_proc = SegApp2OutputProcessor() pp = Pipeline() pp.add_processor(fs2hs_proc) -pp.add_processor(sp_proc) +# pp.add_processor(sp_proc) pp.add_processor(char_proc) pp.add_processor(bigram_proc) pp.add_processor(char_index_proc)