@@ -17,12 +17,7 @@ class API: | |||||
def load(self, name): | def load(self, name): | ||||
_dict = torch.load(name) | _dict = torch.load(name) | ||||
self.pipeline = _dict['pipeline'] | self.pipeline = _dict['pipeline'] | ||||
self.model = _dict['model'] | |||||
def save(self, path): | |||||
_dict = {'pipeline': self.pipeline, | |||||
'model': self.model} | |||||
torch.save(_dict, path) | |||||
class POS_tagger(API): | class POS_tagger(API): | ||||
@@ -64,6 +59,38 @@ class POS_tagger(API): | |||||
self.tag_vocab = _dict["tag_vocab"] | self.tag_vocab = _dict["tag_vocab"] | ||||
class CWS(API): | |||||
def __init__(self, model_path='xxx'): | |||||
super(CWS, self).__init__() | |||||
self.load(model_path) | |||||
def predict(self, sentence, pretrain=False): | |||||
if hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first. Or specify pretrain=True.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(sentence, str): | |||||
sentence_list.append(sentence) | |||||
elif isinstance(sentence, list): | |||||
sentence_list = sentence | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('raw_sentence', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
output = dataset['output'] | |||||
if isinstance(sentence, str): | |||||
return output[0] | |||||
elif isinstance(sentence, list): | |||||
return output | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
tagger = POS_tagger() | tagger = POS_tagger() | ||||
print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]])) | print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]])) |
@@ -1,32 +0,0 @@ | |||||
from fastNLP.api.api import API | |||||
from fastNLP.core.dataset import DataSet | |||||
class CWS(API): | |||||
def __init__(self, model_path='xxx'): | |||||
super(CWS, self).__init__() | |||||
self.load(model_path) | |||||
def predict(self, sentence, pretrain=False): | |||||
if hasattr(self, 'model') and hasattr(self, 'pipeline'): | |||||
raise ValueError("You have to load model first. Or specify pretrain=True.") | |||||
sentence_list = [] | |||||
# 1. 检查sentence的类型 | |||||
if isinstance(sentence, str): | |||||
sentence_list.append(sentence) | |||||
elif isinstance(sentence, list): | |||||
sentence_list = sentence | |||||
# 2. 组建dataset | |||||
dataset = DataSet() | |||||
dataset.add_field('raw_sentence', sentence_list) | |||||
# 3. 使用pipeline | |||||
self.pipeline(dataset) | |||||
# 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果 | |||||
# 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作 |
@@ -1,9 +1,13 @@ | |||||
import torch | |||||
from collections import defaultdict | |||||
import re | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
import re | |||||
class Processor: | class Processor: | ||||
def __init__(self, field_name, new_added_field_name): | def __init__(self, field_name, new_added_field_name): | ||||
self.field_name = field_name | self.field_name = field_name | ||||
@@ -172,12 +176,6 @@ class SeqLenProcessor(Processor): | |||||
dataset.set_need_tensor(**{self.new_added_field_name: True}) | dataset.set_need_tensor(**{self.new_added_field_name: True}) | ||||
return dataset | return dataset | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
import torch | |||||
from collections import defaultdict | |||||
class ModelProcessor(Processor): | class ModelProcessor(Processor): | ||||
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): | ||||
""" | """ | ||||
@@ -205,9 +203,12 @@ class ModelProcessor(Processor): | |||||
for key, value in prediction.items(): | for key, value in prediction.items(): | ||||
tmp_batch = [] | tmp_batch = [] | ||||
value = value.cpu().numpy() | value = value.cpu().numpy() | ||||
for idx, seq_len in enumerate(seq_lens): | |||||
tmp_batch.append(value[idx, :seq_len]) | |||||
batch_output[key].extend(tmp_batch) | |||||
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1): | |||||
for idx, seq_len in enumerate(seq_lens): | |||||
tmp_batch.append(value[idx, :seq_len]) | |||||
batch_output[key].extend(tmp_batch) | |||||
else: | |||||
batch_output[key].extend(value.tolist()) | |||||
batch_output[self.seq_len_field_name].extend(seq_lens) | batch_output[self.seq_len_field_name].extend(seq_lens) | ||||
@@ -216,7 +216,7 @@ class SeqLenProcessor(Processor): | |||||
return dataset | return dataset | ||||
class SegApp2OutputProcessor(Processor): | class SegApp2OutputProcessor(Processor): | ||||
def __init__(self, chars_field_name='chars', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'): | |||||
super(SegApp2OutputProcessor, self).__init__(None, None) | super(SegApp2OutputProcessor, self).__init__(None, None) | ||||
self.chars_field_name = chars_field_name | self.chars_field_name = chars_field_name | ||||
@@ -235,6 +235,6 @@ class SegApp2OutputProcessor(Processor): | |||||
if tag==1: | if tag==1: | ||||
# 当前没有考虑将原文替换回去 | # 当前没有考虑将原文替换回去 | ||||
words.append(''.join(chars[start_idx:idx+1])) | words.append(''.join(chars[start_idx:idx+1])) | ||||
start_idx = idx | |||||
start_idx = idx + 1 | |||||
ins[self.new_added_field_name] = ' '.join(words) | ins[self.new_added_field_name] = ' '.join(words) | ||||
@@ -20,8 +20,10 @@ from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp | |||||
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | ||||
ds_name = 'msr' | ds_name = 'msr' | ||||
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name) | |||||
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name) | |||||
tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, | |||||
ds_name) | |||||
dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, | |||||
ds_name) | |||||
reader = NaiveCWSReader() | reader = NaiveCWSReader() | ||||
@@ -32,17 +34,17 @@ dev_dataset = reader.load(dev_filename) | |||||
# 1. 准备processor | # 1. 准备processor | ||||
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | ||||
sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') | |||||
# sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence') | |||||
# sp_proc.add_span_converter(EmailConverter()) | # sp_proc.add_span_converter(EmailConverter()) | ||||
# sp_proc.add_span_converter(MixNumAlphaConverter()) | # sp_proc.add_span_converter(MixNumAlphaConverter()) | ||||
sp_proc.add_span_converter(AlphaSpanConverter()) | |||||
sp_proc.add_span_converter(DigitSpanConverter()) | |||||
# sp_proc.add_span_converter(AlphaSpanConverter()) | |||||
# sp_proc.add_span_converter(DigitSpanConverter()) | |||||
# sp_proc.add_span_converter(TimeConverter()) | # sp_proc.add_span_converter(TimeConverter()) | ||||
char_proc = CWSCharSegProcessor('sentence', 'chars_list') | |||||
char_proc = CWSCharSegProcessor('raw_sentence', 'chars_list') | |||||
tag_proc = CWSSegAppTagProcessor('sentence', 'tags') | |||||
tag_proc = CWSSegAppTagProcessor('raw_sentence', 'tags') | |||||
bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') | bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') | ||||
@@ -52,7 +54,7 @@ bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4) | |||||
# 2. 使用processor | # 2. 使用processor | ||||
fs2hs_proc(tr_dataset) | fs2hs_proc(tr_dataset) | ||||
sp_proc(tr_dataset) | |||||
# sp_proc(tr_dataset) | |||||
char_proc(tr_dataset) | char_proc(tr_dataset) | ||||
tag_proc(tr_dataset) | tag_proc(tr_dataset) | ||||
@@ -73,7 +75,7 @@ seq_len_proc(tr_dataset) | |||||
# 2.1 处理dev_dataset | # 2.1 处理dev_dataset | ||||
fs2hs_proc(dev_dataset) | fs2hs_proc(dev_dataset) | ||||
sp_proc(dev_dataset) | |||||
# sp_proc(dev_dataset) | |||||
char_proc(dev_dataset) | char_proc(dev_dataset) | ||||
tag_proc(dev_dataset) | tag_proc(dev_dataset) | ||||
@@ -133,7 +135,7 @@ for num_epoch in range(num_epochs): | |||||
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | ||||
optimizer.zero_grad() | optimizer.zero_grad() | ||||
pred_dict = cws_model(batch_x) # B x L x tag_size | |||||
pred_dict = cws_model(**batch_x) # B x L x tag_size | |||||
seq_lens = pred_dict['seq_lens'] | seq_lens = pred_dict['seq_lens'] | ||||
masks = seq_lens_to_mask(seq_lens).float() | masks = seq_lens_to_mask(seq_lens).float() | ||||
@@ -176,7 +178,7 @@ cws_model.load_state_dict(best_state_dict) | |||||
# 4. 组装需要存下的内容 | # 4. 组装需要存下的内容 | ||||
pp = Pipeline() | pp = Pipeline() | ||||
pp.add_processor(fs2hs_proc) | pp.add_processor(fs2hs_proc) | ||||
pp.add_processor(sp_proc) | |||||
# pp.add_processor(sp_proc) | |||||
pp.add_processor(char_proc) | pp.add_processor(char_proc) | ||||
pp.add_processor(tag_proc) | pp.add_processor(tag_proc) | ||||
pp.add_processor(bigram_proc) | pp.add_processor(bigram_proc) | ||||
@@ -187,7 +189,7 @@ pp.add_processor(seq_len_proc) | |||||
te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||||
te_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||||
te_dataset = reader.load(te_filename) | te_dataset = reader.load(te_filename) | ||||
pp(te_dataset) | pp(te_dataset) | ||||
@@ -216,7 +218,7 @@ output_proc = SegApp2OutputProcessor() | |||||
pp = Pipeline() | pp = Pipeline() | ||||
pp.add_processor(fs2hs_proc) | pp.add_processor(fs2hs_proc) | ||||
pp.add_processor(sp_proc) | |||||
# pp.add_processor(sp_proc) | |||||
pp.add_processor(char_proc) | pp.add_processor(char_proc) | ||||
pp.add_processor(bigram_proc) | pp.add_processor(bigram_proc) | ||||
pp.add_processor(char_index_proc) | pp.add_processor(char_index_proc) | ||||