Browse Source

基本完善了cws的predict

tags/v0.2.0
yh 6 years ago
parent
commit
9667c524a4
5 changed files with 61 additions and 63 deletions
  1. +32
    -5
      fastNLP/api/api.py
  2. +0
    -32
      fastNLP/api/cws.py
  3. +12
    -11
      fastNLP/api/processor.py
  4. +2
    -2
      reproduction/chinese_word_segment/process/cws_processor.py
  5. +15
    -13
      reproduction/chinese_word_segment/train_context.py

+ 32
- 5
fastNLP/api/api.py View File

@@ -17,12 +17,7 @@ class API:
def load(self, name):
_dict = torch.load(name)
self.pipeline = _dict['pipeline']
self.model = _dict['model']

def save(self, path):
_dict = {'pipeline': self.pipeline,
'model': self.model}
torch.save(_dict, path)


class POS_tagger(API):
@@ -64,6 +59,38 @@ class POS_tagger(API):
self.tag_vocab = _dict["tag_vocab"]



class CWS(API):
def __init__(self, model_path='xxx'):
super(CWS, self).__init__()
self.load(model_path)

def predict(self, sentence, pretrain=False):

if hasattr(self, 'pipeline'):
raise ValueError("You have to load model first. Or specify pretrain=True.")

sentence_list = []
# 1. 检查sentence的类型
if isinstance(sentence, str):
sentence_list.append(sentence)
elif isinstance(sentence, list):
sentence_list = sentence

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('raw_sentence', sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)

output = dataset['output']
if isinstance(sentence, str):
return output[0]
elif isinstance(sentence, list):
return output


if __name__ == "__main__":
tagger = POS_tagger()
print(tagger.predict([["我", "是", "学生", "。"], ["我", "是", "学生", "。"]]))

+ 0
- 32
fastNLP/api/cws.py View File

@@ -1,32 +0,0 @@


from fastNLP.api.api import API
from fastNLP.core.dataset import DataSet

class CWS(API):
def __init__(self, model_path='xxx'):
super(CWS, self).__init__()
self.load(model_path)

def predict(self, sentence, pretrain=False):

if hasattr(self, 'model') and hasattr(self, 'pipeline'):
raise ValueError("You have to load model first. Or specify pretrain=True.")

sentence_list = []
# 1. 检查sentence的类型
if isinstance(sentence, str):
sentence_list.append(sentence)
elif isinstance(sentence, list):
sentence_list = sentence

# 2. 组建dataset
dataset = DataSet()
dataset.add_field('raw_sentence', sentence_list)

# 3. 使用pipeline
self.pipeline(dataset)

# 4. TODO 这里应该要交给一个iterator一样的东西预测这个结果

# 5. TODO 得到结果,需要考虑是否需要反转回去, 及post_process的操作

+ 12
- 11
fastNLP/api/processor.py View File

@@ -1,9 +1,13 @@
import torch
from collections import defaultdict
import re

from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler


import re

class Processor:
def __init__(self, field_name, new_added_field_name):
self.field_name = field_name
@@ -172,12 +176,6 @@ class SeqLenProcessor(Processor):
dataset.set_need_tensor(**{self.new_added_field_name: True})
return dataset


from fastNLP.core.batch import Batch
from fastNLP.core.sampler import SequentialSampler
import torch
from collections import defaultdict

class ModelProcessor(Processor):
def __init__(self, model, seq_len_field_name='seq_lens', batch_size=32):
"""
@@ -205,9 +203,12 @@ class ModelProcessor(Processor):
for key, value in prediction.items():
tmp_batch = []
value = value.cpu().numpy()
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
if len(value.shape) == 1 or (len(value.shape)==2 and value.shape[1]==1):
for idx, seq_len in enumerate(seq_lens):
tmp_batch.append(value[idx, :seq_len])
batch_output[key].extend(tmp_batch)
else:
batch_output[key].extend(value.tolist())

batch_output[self.seq_len_field_name].extend(seq_lens)



+ 2
- 2
reproduction/chinese_word_segment/process/cws_processor.py View File

@@ -216,7 +216,7 @@ class SeqLenProcessor(Processor):
return dataset

class SegApp2OutputProcessor(Processor):
def __init__(self, chars_field_name='chars', tag_field_name='pred_tags', new_added_field_name='output'):
def __init__(self, chars_field_name='chars_list', tag_field_name='pred_tags', new_added_field_name='output'):
super(SegApp2OutputProcessor, self).__init__(None, None)

self.chars_field_name = chars_field_name
@@ -235,6 +235,6 @@ class SegApp2OutputProcessor(Processor):
if tag==1:
# 当前没有考虑将原文替换回去
words.append(''.join(chars[start_idx:idx+1]))
start_idx = idx
start_idx = idx + 1
ins[self.new_added_field_name] = ' '.join(words)


+ 15
- 13
reproduction/chinese_word_segment/train_context.py View File

@@ -20,8 +20,10 @@ from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMSegApp
from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1

ds_name = 'msr'
tr_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name, ds_name)
dev_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name, ds_name)
tr_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_train.txt'.format(ds_name,
ds_name)
dev_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_dev.txt'.format(ds_name,
ds_name)

reader = NaiveCWSReader()

@@ -32,17 +34,17 @@ dev_dataset = reader.load(dev_filename)
# 1. 准备processor
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')

sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence')
# sp_proc = SpeicalSpanProcessor('raw_sentence', 'sentence')
# sp_proc.add_span_converter(EmailConverter())
# sp_proc.add_span_converter(MixNumAlphaConverter())
sp_proc.add_span_converter(AlphaSpanConverter())
sp_proc.add_span_converter(DigitSpanConverter())
# sp_proc.add_span_converter(AlphaSpanConverter())
# sp_proc.add_span_converter(DigitSpanConverter())
# sp_proc.add_span_converter(TimeConverter())


char_proc = CWSCharSegProcessor('sentence', 'chars_list')
char_proc = CWSCharSegProcessor('raw_sentence', 'chars_list')

tag_proc = CWSSegAppTagProcessor('sentence', 'tags')
tag_proc = CWSSegAppTagProcessor('raw_sentence', 'tags')

bigram_proc = Pre2Post2BigramProcessor('chars_list', 'bigrams_list')

@@ -52,7 +54,7 @@ bigram_vocab_proc = VocabProcessor('bigrams_list', min_count=4)
# 2. 使用processor
fs2hs_proc(tr_dataset)

sp_proc(tr_dataset)
# sp_proc(tr_dataset)

char_proc(tr_dataset)
tag_proc(tr_dataset)
@@ -73,7 +75,7 @@ seq_len_proc(tr_dataset)

# 2.1 处理dev_dataset
fs2hs_proc(dev_dataset)
sp_proc(dev_dataset)
# sp_proc(dev_dataset)

char_proc(dev_dataset)
tag_proc(dev_dataset)
@@ -133,7 +135,7 @@ for num_epoch in range(num_epochs):
for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1):
optimizer.zero_grad()

pred_dict = cws_model(batch_x) # B x L x tag_size
pred_dict = cws_model(**batch_x) # B x L x tag_size

seq_lens = pred_dict['seq_lens']
masks = seq_lens_to_mask(seq_lens).float()
@@ -176,7 +178,7 @@ cws_model.load_state_dict(best_state_dict)
# 4. 组装需要存下的内容
pp = Pipeline()
pp.add_processor(fs2hs_proc)
pp.add_processor(sp_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(tag_proc)
pp.add_processor(bigram_proc)
@@ -187,7 +189,7 @@ pp.add_processor(seq_len_proc)



te_filename = '/hdd/fudanNLP/CWS/Multi_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_filename = '/home/hyan/CWS/Mutil_Criterion/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_dataset = reader.load(te_filename)
pp(te_dataset)

@@ -216,7 +218,7 @@ output_proc = SegApp2OutputProcessor()

pp = Pipeline()
pp.add_processor(fs2hs_proc)
pp.add_processor(sp_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(bigram_proc)
pp.add_processor(char_index_proc)


Loading…
Cancel
Save