|
|
@@ -1,169 +0,0 @@ |
|
|
|
|
|
|
|
from fastNLP.api.pipeline import Pipeline |
|
|
|
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor |
|
|
|
from fastNLP.api.processor import SeqLenProcessor |
|
|
|
from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor |
|
|
|
from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor |
|
|
|
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor |
|
|
|
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor |
|
|
|
from reproduction.chinese_word_segment.process.cws_processor import InputTargetProcessor |
|
|
|
|
|
|
|
from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader |
|
|
|
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF |
|
|
|
|
|
|
|
|
|
|
|
ds_name = 'msr' |
|
|
|
|
|
|
|
tr_filename = '/home/hyan/ctb3/train.conllx' |
|
|
|
dev_filename = '/home/hyan/ctb3/dev.conllx' |
|
|
|
|
|
|
|
|
|
|
|
reader = ConllCWSReader() |
|
|
|
|
|
|
|
tr_dataset = reader.load(tr_filename, cut_long_sent=True) |
|
|
|
dev_dataset = reader.load(dev_filename) |
|
|
|
|
|
|
|
print("Train {}. Dev: {}".format(len(tr_dataset), len(dev_dataset))) |
|
|
|
|
|
|
|
# 1. 准备processor |
|
|
|
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') |
|
|
|
|
|
|
|
char_proc = CWSCharSegProcessor('raw_sentence', 'chars_lst') |
|
|
|
tag_proc = CWSBMESTagProcessor('raw_sentence', 'target') |
|
|
|
|
|
|
|
bigram_proc = Pre2Post2BigramProcessor('chars_lst', 'bigrams_lst') |
|
|
|
|
|
|
|
char_vocab_proc = VocabIndexerProcessor('chars_lst', new_added_filed_name='chars') |
|
|
|
bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='bigrams', min_freq=4) |
|
|
|
|
|
|
|
seq_len_proc = SeqLenProcessor('chars') |
|
|
|
|
|
|
|
# 2. 使用processor |
|
|
|
fs2hs_proc(tr_dataset) |
|
|
|
|
|
|
|
char_proc(tr_dataset) |
|
|
|
tag_proc(tr_dataset) |
|
|
|
bigram_proc(tr_dataset) |
|
|
|
|
|
|
|
char_vocab_proc(tr_dataset) |
|
|
|
bigram_vocab_proc(tr_dataset) |
|
|
|
seq_len_proc(tr_dataset) |
|
|
|
|
|
|
|
# 2.1 处理dev_dataset |
|
|
|
fs2hs_proc(dev_dataset) |
|
|
|
|
|
|
|
char_proc(dev_dataset) |
|
|
|
tag_proc(dev_dataset) |
|
|
|
bigram_proc(dev_dataset) |
|
|
|
|
|
|
|
char_vocab_proc(dev_dataset) |
|
|
|
bigram_vocab_proc(dev_dataset) |
|
|
|
seq_len_proc(dev_dataset) |
|
|
|
|
|
|
|
dev_dataset.set_input('target') |
|
|
|
tr_dataset.set_input('target') |
|
|
|
|
|
|
|
|
|
|
|
print("Finish preparing data.") |
|
|
|
|
|
|
|
# 3. 得到数据集可以用于训练了 |
|
|
|
# TODO pretrain的embedding是怎么解决的? |
|
|
|
|
|
|
|
from torch import optim |
|
|
|
|
|
|
|
|
|
|
|
tag_size = tag_proc.tag_size |
|
|
|
|
|
|
|
cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, |
|
|
|
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), |
|
|
|
bigram_embed_dim=30, num_bigram_per_char=8, |
|
|
|
hidden_size=200, bidirectional=True, embed_drop_p=0.3, |
|
|
|
num_layers=1, tag_size=tag_size) |
|
|
|
cws_model.cuda() |
|
|
|
|
|
|
|
num_epochs = 5 |
|
|
|
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.005) |
|
|
|
|
|
|
|
from fastNLP.core.trainer import Trainer |
|
|
|
from fastNLP.core.sampler import BucketSampler |
|
|
|
from fastNLP.core.metrics import BMESF1PreRecMetric |
|
|
|
|
|
|
|
metric = BMESF1PreRecMetric(target='tags') |
|
|
|
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=num_epochs, |
|
|
|
batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, |
|
|
|
optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) |
|
|
|
|
|
|
|
trainer.train() |
|
|
|
|
|
|
|
# 4. 组装需要存下的内容 |
|
|
|
pp = Pipeline() |
|
|
|
pp.add_processor(fs2hs_proc) |
|
|
|
# pp.add_processor(sp_proc) |
|
|
|
pp.add_processor(char_proc) |
|
|
|
pp.add_processor(tag_proc) |
|
|
|
pp.add_processor(bigram_proc) |
|
|
|
pp.add_processor(char_vocab_proc) |
|
|
|
pp.add_processor(bigram_vocab_proc) |
|
|
|
pp.add_processor(seq_len_proc) |
|
|
|
# pp.add_processor(input_target_proc) |
|
|
|
|
|
|
|
# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) |
|
|
|
te_filename = '/home/hyan/ctb3/test.conllx' |
|
|
|
te_dataset = reader.load(te_filename) |
|
|
|
pp(te_dataset) |
|
|
|
|
|
|
|
from fastNLP.core.tester import Tester |
|
|
|
|
|
|
|
tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False, |
|
|
|
verbose=1) |
|
|
|
tester.test() |
|
|
|
# |
|
|
|
# batch_size = 64 |
|
|
|
# te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) |
|
|
|
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes') |
|
|
|
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, |
|
|
|
# pre * 100, |
|
|
|
# rec * 100)) |
|
|
|
|
|
|
|
# TODO 这里貌似需要区分test pipeline与infer pipeline |
|
|
|
|
|
|
|
test_context_dict = {'pipeline': pp, |
|
|
|
'model': cws_model} |
|
|
|
# torch.save(test_context_dict, 'models/test_context_crf.pkl') |
|
|
|
|
|
|
|
|
|
|
|
# 5. dev的pp |
|
|
|
# 4. 组装需要存下的内容 |
|
|
|
|
|
|
|
from fastNLP.api.processor import ModelProcessor |
|
|
|
from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor |
|
|
|
|
|
|
|
model_proc = ModelProcessor(cws_model) |
|
|
|
output_proc = BMES2OutputProcessor(chars_field_name='chars_lst', tag_field_name='pred') |
|
|
|
|
|
|
|
pp = Pipeline() |
|
|
|
pp.add_processor(fs2hs_proc) |
|
|
|
# pp.add_processor(sp_proc) |
|
|
|
pp.add_processor(char_proc) |
|
|
|
pp.add_processor(bigram_proc) |
|
|
|
char_vocab_proc.set_verbose(0) |
|
|
|
pp.add_processor(char_vocab_proc) |
|
|
|
bigram_vocab_proc.set_verbose(0) |
|
|
|
pp.add_processor(bigram_vocab_proc) |
|
|
|
pp.add_processor(seq_len_proc) |
|
|
|
|
|
|
|
pp.add_processor(model_proc) |
|
|
|
pp.add_processor(output_proc) |
|
|
|
|
|
|
|
|
|
|
|
# TODO 这里貌似需要区分test pipeline与infer pipeline |
|
|
|
import torch |
|
|
|
import datetime |
|
|
|
now = datetime.datetime.now() |
|
|
|
infer_context_dict = {'pipeline': pp, 'tag_proc': tag_proc} |
|
|
|
torch.save(infer_context_dict, 'models/cws_crf_{}_{}.pkl'.format(now.month, now.day)) |
|
|
|
|
|
|
|
|
|
|
|
# TODO 还需要考虑如何替换回原文的问题? |
|
|
|
# 1. 不需要将特殊tag替换 |
|
|
|
# 2. 需要将特殊tag替换回去 |