@@ -17,12 +17,7 @@ class API: | |||
def load(self, name): | |||
_dict = torch.load(name) | |||
self.pipeline = _dict['pipeline'] | |||
self.model = _dict['model'] | |||
def save(self, path): | |||
_dict = {'pipeline': self.pipeline, | |||
'model': self.model} | |||
torch.save(_dict, path) | |||
class POS_tagger(API): | |||
@@ -64,6 +59,38 @@ class POS_tagger(API): | |||
self.tag_vocab = _dict["tag_vocab"] | |||
class CWS(API): | |||
def __init__(self, model_path='xxx'): | |||
super(CWS, self).__init__() | |||
self.load(model_path) | |||
def predict(self, sentence, pretrain=False): | |||
if hasattr(self, 'pipeline'): | |||
raise ValueError("You have to load model first. Or specify pretrain=True.") | |||
sentence_list = [] | |||
# 1. 检查sentence的类型 | |||
if isinstance(sentence, str): | |||
sentence_list.append(sentence) | |||
elif isinstance(sentence, list): | |||
sentence_list = sentence | |||
# 2. 组建dataset | |||
dataset = DataSet() | |||
dataset.add_field('raw_sentence', sentence_list) | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
output = dataset['output'] | |||
if isinstance(sentence, str): | |||
return output[0] | |||
elif isinstance(sentence, list): | |||
return output | |||
if __name__ == "__main__": | |||
tagger = POS_tagger() | |||
print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]])) |
@@ -1,32 +0,0 @@ | |||
from fastNLP.api.api import API | |||
from fastNLP.core.dataset import DataSet | |||
class CWS(API): | |||
def __init__(self, model_path='xxx'): | |||
super(CWS, self).__init__() | |||
self.load(model_path) | |||
def predict(self, sentence, pretrain=False): | |||
if hasattr(self, 'model') and hasattr(self, 'pipeline'): | |||
raise ValueError("You have to load model first. Or specify pretrain=True.") | |||
sentence_list = [] | |||
# 1. 检查sentence的类型 | |||
if isinstance(sentence, str): | |||
sentence_list.append(sentence) | |||
elif isinstance(sentence, list): | |||
sentence_list = sentence | |||
# 2. 组建dataset | |||
dataset = DataSet() | |||
dataset.add_field('raw_sentence', sentence_list) | |||
# 3. 使用pipeline | |||
self.pipeline(dataset) | |||
# 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 | |||
# 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 |
@@ -1,9 +1,13 @@ | |||
import torch | |||
from collections import defaultdict | |||
import re | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import SequentialSampler | |||
import re | |||
class Processor: | |||
def __init__(self, field_name, new_added_field_name): | |||
self.field_name = field_name | |||
@@ -172,12 +176,6 @@ class SeqLenProcessor(Processor): | |||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | |||
return dataset | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.sampler import SequentialSampler | |||
import torch | |||
from collections import defaultdict | |||
class ModelProcessor(Processor): | |||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | |||
""" | |||
@@ -205,9 +203,12 @@ class ModelProcessor(Processor): | |||
for key, value in prediction.items(): | |||
tmp_batch = [] | |||
value = value.cpu().numpy() | |||
for idx, seq_len in enumerate(seq_lens): | |||
tmp_batch.append(value[idx, :seq_len]) | |||
batch_output[key].extend(tmp_batch) | |||
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | |||
for idx, seq_len in enumerate(seq_lens): | |||
tmp_batch.append(value[idx, :seq_len]) | |||
batch_output[key].extend(tmp_batch) | |||
else: | |||
batch_output[key].extend(value.tolist()) | |||
batch_output[self.seq_len_field_name].extend(seq_lens) | |||
@@ -216,7 +216,7 @@ class SeqLenProcessor(Processor): | |||
return dataset | |||
class SegApp2OutputProcessor(Processor): | |||
def __init__(self, chars_field_name='chars', tag_field_name='pred_tags', new_added_field_name='output'): | |||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||
super(SegApp2OutputProcessor, self).__init__(None, None) | |||
self.chars_field_name = chars_field_name | |||
@@ -235,6 +235,6 @@ class SegApp2OutputProcessor(Processor): | |||
if tag==1: | |||
# 当前没有考虑将原文替换回去 | |||
words.append(''.join(chars[start_idx:idx+1])) | |||
start_idx = idx | |||
start_idx = idx + 1 | |||
ins[self.new_added_field_name] = ' '.join(words) | |||
@@ -20,8 +20,10 @@ from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp | |||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||
ds_name = 'msr' | |||
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) | |||
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) | |||
tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, | |||
ds_name) | |||
dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, | |||
ds_name) | |||
reader = NaiveCWSReader() | |||
@@ -32,17 +34,17 @@ dev_dataset = reader.load(dev_filename) | |||
# 1. 准备processor | |||
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | |||
sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') | |||
# sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') | |||
# sp_proc.add_span_converter(EmailConverter()) | |||
# sp_proc.add_span_converter(MixNumAlphaConverter()) | |||
sp_proc.add_span_converter(AlphaSpanConverter()) | |||
sp_proc.add_span_converter(DigitSpanConverter()) | |||
# sp_proc.add_span_converter(AlphaSpanConverter()) | |||
# sp_proc.add_span_converter(DigitSpanConverter()) | |||
# sp_proc.add_span_converter(TimeConverter()) | |||
char_proc = CWSCharSegProcessor('sentence', 'chars_list') | |||
char_proc = CWSCharSegProcessor('raw_sentence', 'chars_list') | |||
tag_proc = CWSSegAppTagProcessor('sentence', 'tags') | |||
tag_proc = CWSSegAppTagProcessor('raw_sentence', 'tags') | |||
bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') | |||
@@ -52,7 +54,7 @@ bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4) | |||
# 2. 使用processor | |||
fs2hs_proc(tr_dataset) | |||
sp_proc(tr_dataset) | |||
# sp_proc(tr_dataset) | |||
char_proc(tr_dataset) | |||
tag_proc(tr_dataset) | |||
@@ -73,7 +75,7 @@ seq_len_proc(tr_dataset) | |||
# 2.1 处理dev_dataset | |||
fs2hs_proc(dev_dataset) | |||
sp_proc(dev_dataset) | |||
# sp_proc(dev_dataset) | |||
char_proc(dev_dataset) | |||
tag_proc(dev_dataset) | |||
@@ -133,7 +135,7 @@ for num_epoch in range(num_epochs): | |||
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | |||
optimizer.zero_grad() | |||
pred_dict = cws_model(batch_x) # B x L x tag_size | |||
pred_dict = cws_model(**batch_x) # B x L x tag_size | |||
seq_lens = pred_dict['seq_lens'] | |||
masks = seq_lens_to_mask(seq_lens).float() | |||
@@ -176,7 +178,7 @@ cws_model.load_state_dict(best_state_dict) | |||
# 4. 组装需要存下的内容 | |||
pp = Pipeline() | |||
pp.add_processor(fs2hs_proc) | |||
pp.add_processor(sp_proc) | |||
# pp.add_processor(sp_proc) | |||
pp.add_processor(char_proc) | |||
pp.add_processor(tag_proc) | |||
pp.add_processor(bigram_proc) | |||
@@ -187,7 +189,7 @@ pp.add_processor(seq_len_proc) | |||
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||
te_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||
te_dataset = reader.load(te_filename) | |||
pp(te_dataset) | |||
@@ -216,7 +218,7 @@ output_proc = SegApp2OutputProcessor() | |||
pp = Pipeline() | |||
pp.add_processor(fs2hs_proc) | |||
pp.add_processor(sp_proc) | |||
# pp.add_processor(sp_proc) | |||
pp.add_processor(char_proc) | |||
pp.add_processor(bigram_proc) | |||
pp.add_processor(char_index_proc) | |||