Browse Source

基本完善了cws的predict

tags/v0.2.0
yh 6 years ago
parent
commit
9667c524a4
5 changed files with 61 additions and 63 deletions
  1. +32
    -5
      fastNLP/api/api.py
  2. +0
    -32
      fastNLP/api/cws.py
  3. +12
    -11
      fastNLP/api/processor.py
  4. +2
    -2
      reproduction/chinese_word_segment/process/cws_processor.py
  5. +15
    -13
      reproduction/chinese_word_segment/train_context.py

+ 32
- 5
fastNLP/api/api.py View File

@@ -17,12 +17,7 @@ class API:
def load(self, name): def load(self, name):
_dict = torch.load(name) _dict = torch.load(name)
self.pipeline = _dict['pipeline'] self.pipeline = _dict['pipeline']
self.model = _dict['model']


def save(self, path):
_dict = {'pipeline': self.pipeline,
'model': self.model}
torch.save(_dict, path)




class POS_tagger(API): class POS_tagger(API):
@@ -64,6 +59,38 @@ class POS_tagger(API):
self.tag_vocab = _dict["tag_vocab"] self.tag_vocab = _dict["tag_vocab"]





class CWS(API):
def __init__(self, model_path='xxx'):
super(CWS, self).__init__()
self.load(model_path)

def predict(self, sentence, pretrain=False):

if hasattr(self, 'pipeline'):
raise ValueError("You have to load model first. Or specify pretrain=True.")

sentence_list = []
# 1. 检查sentence的类型
if isinstance(sentence, str):
sentence_list.append(sentence)
elif isinstance(sentence, list):
sentence_list = sentence

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('raw_sentence', sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)

output = dataset['output']
if isinstance(sentence, str):
return output[0]
elif isinstance(sentence, list):
return output


if __name__ == "__main__": if __name__ == "__main__":
tagger = POS_tagger() tagger = POS_tagger()
print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]])) print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]]))

+ 0
- 32
fastNLP/api/cws.py View File

@@ -1,32 +0,0 @@


from fastNLP.api.api import API
from fastNLP.core.dataset import DataSet

class CWS(API):
def __init__(self, model_path='xxx'):
super(CWS, self).__init__()
self.load(model_path)

def predict(self, sentence, pretrain=False):

if hasattr(self, 'model') and hasattr(self, 'pipeline'):
raise ValueError("You have to load model first. Or specify pretrain=True.")

sentence_list = []
# 1. 检查sentence的类型
if isinstance(sentence, str):
sentence_list.append(sentence)
elif isinstance(sentence, list):
sentence_list = sentence

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('raw_sentence', sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)

# 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果

# 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作

+ 12
- 11
fastNLP/api/processor.py View File

@@ -1,9 +1,13 @@
import torch
from collections import defaultdict
import re

from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler




import re

class Processor: class Processor:
def __init__(self, field_name, new_added_field_name): def __init__(self, field_name, new_added_field_name):
self.field_name = field_name self.field_name = field_name
@@ -172,12 +176,6 @@ class SeqLenProcessor(Processor):
dataset.set_need_tensor(**{self.new_added_field_name: True}) dataset.set_need_tensor(**{self.new_added_field_name: True})
return dataset return dataset



from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
import torch
from collections import defaultdict

class ModelProcessor(Processor): class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32): def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
""" """
@@ -205,9 +203,12 @@ class ModelProcessor(Processor):
for key, value in prediction.items(): for key, value in prediction.items():
tmp_batch = [] tmp_batch = []
value = value.cpu().numpy() value = value.cpu().numpy()
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1):
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
else:
batch_output[key].extend(value.tolist())


batch_output[self.seq_len_field_name].extend(seq_lens) batch_output[self.seq_len_field_name].extend(seq_lens)




+ 2
- 2
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -216,7 +216,7 @@ class SeqLenProcessor(Processor):
return dataset return dataset


class SegApp2OutputProcessor(Processor): class SegApp2OutputProcessor(Processor):
def __init__(self, chars_field_name='chars', tag_field_name='pred_tags', new_added_field_name='output'):
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'):
super(SegApp2OutputProcessor, self).__init__(None, None) super(SegApp2OutputProcessor, self).__init__(None, None)


self.chars_field_name = chars_field_name self.chars_field_name = chars_field_name
@@ -235,6 +235,6 @@ class SegApp2OutputProcessor(Processor):
if tag==1: if tag==1:
# 当前没有考虑将原文替换回去 # 当前没有考虑将原文替换回去
words.append(''.join(chars[start_idx:idx+1])) words.append(''.join(chars[start_idx:idx+1]))
start_idx = idx
start_idx = idx + 1
ins[self.new_added_field_name] = ' '.join(words) ins[self.new_added_field_name] = ' '.join(words)



+ 15
- 13
reproduction/chinese_word_segment/train_context.py View File

@@ -20,8 +20,10 @@ from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1


ds_name = 'msr' ds_name = 'msr'
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name)
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name)
tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name,
ds_name)
dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name,
ds_name)


reader = NaiveCWSReader() reader = NaiveCWSReader()


@@ -32,17 +34,17 @@ dev_dataset = reader.load(dev_filename)
# 1. 准备processor # 1. 准备processor
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')


sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence')
# sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence')
# sp_proc.add_span_converter(EmailConverter()) # sp_proc.add_span_converter(EmailConverter())
# sp_proc.add_span_converter(MixNumAlphaConverter()) # sp_proc.add_span_converter(MixNumAlphaConverter())
sp_proc.add_span_converter(AlphaSpanConverter())
sp_proc.add_span_converter(DigitSpanConverter())
# sp_proc.add_span_converter(AlphaSpanConverter())
# sp_proc.add_span_converter(DigitSpanConverter())
# sp_proc.add_span_converter(TimeConverter()) # sp_proc.add_span_converter(TimeConverter())




char_proc = CWSCharSegProcessor('sentence', 'chars_list')
char_proc = CWSCharSegProcessor('raw_sentence', 'chars_list')


tag_proc = CWSSegAppTagProcessor('sentence', 'tags')
tag_proc = CWSSegAppTagProcessor('raw_sentence', 'tags')


bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list') bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list')


@@ -52,7 +54,7 @@ bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4)
# 2. 使用processor # 2. 使用processor
fs2hs_proc(tr_dataset) fs2hs_proc(tr_dataset)


sp_proc(tr_dataset)
# sp_proc(tr_dataset)


char_proc(tr_dataset) char_proc(tr_dataset)
tag_proc(tr_dataset) tag_proc(tr_dataset)
@@ -73,7 +75,7 @@ seq_len_proc(tr_dataset)


# 2.1 处理dev_dataset # 2.1 处理dev_dataset
fs2hs_proc(dev_dataset) fs2hs_proc(dev_dataset)
sp_proc(dev_dataset)
# sp_proc(dev_dataset)


char_proc(dev_dataset) char_proc(dev_dataset)
tag_proc(dev_dataset) tag_proc(dev_dataset)
@@ -133,7 +135,7 @@ for num_epoch in range(num_epochs):
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
optimizer.zero_grad() optimizer.zero_grad()


pred_dict = cws_model(batch_x) # B x L x tag_size
pred_dict = cws_model(**batch_x) # B x L x tag_size


seq_lens = pred_dict['seq_lens'] seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float() masks = seq_lens_to_mask(seq_lens).float()
@@ -176,7 +178,7 @@ cws_model.load_state_dict(best_state_dict)
# 4. 组装需要存下的内容 # 4. 组装需要存下的内容
pp = Pipeline() pp = Pipeline()
pp.add_processor(fs2hs_proc) pp.add_processor(fs2hs_proc)
pp.add_processor(sp_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc) pp.add_processor(char_proc)
pp.add_processor(tag_proc) pp.add_processor(tag_proc)
pp.add_processor(bigram_proc) pp.add_processor(bigram_proc)
@@ -187,7 +189,7 @@ pp.add_processor(seq_len_proc)






te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_dataset = reader.load(te_filename) te_dataset = reader.load(te_filename)
pp(te_dataset) pp(te_dataset)


@@ -216,7 +218,7 @@ output_proc = SegApp2OutputProcessor()


pp = Pipeline() pp = Pipeline()
pp.add_processor(fs2hs_proc) pp.add_processor(fs2hs_proc)
pp.add_processor(sp_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc) pp.add_processor(char_proc)
pp.add_processor(bigram_proc) pp.add_processor(bigram_proc)
pp.add_processor(char_index_proc) pp.add_processor(char_index_proc)


Loading…
Cancel
Save