Browse Source

fastNLP V0.3

tags/v0.3.0^2
FengZiYjun 5 years ago
parent
commit
8df5bce938
2 changed files with 4 additions and 173 deletions
  1. +0
    -169
      reproduction/chinese_word_segment/train_context.py
  2. +4
    -4
      tutorials/README.md

+ 0
- 169
reproduction/chinese_word_segment/train_context.py View File

@@ -1,169 +0,0 @@

from fastNLP.api.pipeline import Pipeline
from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor
from fastNLP.api.processor import SeqLenProcessor
from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor
from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor
from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor
from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor
from reproduction.chinese_word_segment.process.cws_processor import InputTargetProcessor

from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader
from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF


ds_name = 'msr'

tr_filename = '/home/hyan/ctb3/train.conllx'
dev_filename = '/home/hyan/ctb3/dev.conllx'


reader = ConllCWSReader()

tr_dataset = reader.load(tr_filename, cut_long_sent=True)
dev_dataset = reader.load(dev_filename)

print("Train {}. Dev: {}".format(len(tr_dataset), len(dev_dataset)))

# 1. 准备processor
fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence')

char_proc = CWSCharSegProcessor('raw_sentence', 'chars_lst')
tag_proc = CWSBMESTagProcessor('raw_sentence', 'target')

bigram_proc = Pre2Post2BigramProcessor('chars_lst', 'bigrams_lst')

char_vocab_proc = VocabIndexerProcessor('chars_lst', new_added_filed_name='chars')
bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='bigrams', min_freq=4)

seq_len_proc = SeqLenProcessor('chars')

# 2. 使用processor
fs2hs_proc(tr_dataset)

char_proc(tr_dataset)
tag_proc(tr_dataset)
bigram_proc(tr_dataset)

char_vocab_proc(tr_dataset)
bigram_vocab_proc(tr_dataset)
seq_len_proc(tr_dataset)

# 2.1 处理dev_dataset
fs2hs_proc(dev_dataset)

char_proc(dev_dataset)
tag_proc(dev_dataset)
bigram_proc(dev_dataset)

char_vocab_proc(dev_dataset)
bigram_vocab_proc(dev_dataset)
seq_len_proc(dev_dataset)

dev_dataset.set_input('target')
tr_dataset.set_input('target')


print("Finish preparing data.")

# 3. 得到数据集可以用于训练了
# TODO pretrain的embedding是怎么解决的?

from torch import optim


tag_size = tag_proc.tag_size

cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100,
bigram_vocab_num=bigram_vocab_proc.get_vocab_size(),
bigram_embed_dim=30, num_bigram_per_char=8,
hidden_size=200, bidirectional=True, embed_drop_p=0.3,
num_layers=1, tag_size=tag_size)
cws_model.cuda()

num_epochs = 5
optimizer = optim.Adagrad(cws_model.parameters(), lr=0.005)

from fastNLP.core.trainer import Trainer
from fastNLP.core.sampler import BucketSampler
from fastNLP.core.metrics import BMESF1PreRecMetric

metric = BMESF1PreRecMetric(target='tags')
trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=num_epochs,
batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None,
optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True)

trainer.train()

# 4. 组装需要存下的内容
pp = Pipeline()
pp.add_processor(fs2hs_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(tag_proc)
pp.add_processor(bigram_proc)
pp.add_processor(char_vocab_proc)
pp.add_processor(bigram_vocab_proc)
pp.add_processor(seq_len_proc)
# pp.add_processor(input_target_proc)

# te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name)
te_filename = '/home/hyan/ctb3/test.conllx'
te_dataset = reader.load(te_filename)
pp(te_dataset)

from fastNLP.core.tester import Tester

tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False,
verbose=1)
tester.test()
#
# batch_size = 64
# te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False)
# pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes')
# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100,
# pre * 100,
# rec * 100))

# TODO 这里貌似需要区分test pipeline与infer pipeline

test_context_dict = {'pipeline': pp,
'model': cws_model}
# torch.save(test_context_dict, 'models/test_context_crf.pkl')


# 5. dev的pp
# 4. 组装需要存下的内容

from fastNLP.api.processor import ModelProcessor
from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor

model_proc = ModelProcessor(cws_model)
output_proc = BMES2OutputProcessor(chars_field_name='chars_lst', tag_field_name='pred')

pp = Pipeline()
pp.add_processor(fs2hs_proc)
# pp.add_processor(sp_proc)
pp.add_processor(char_proc)
pp.add_processor(bigram_proc)
char_vocab_proc.set_verbose(0)
pp.add_processor(char_vocab_proc)
bigram_vocab_proc.set_verbose(0)
pp.add_processor(bigram_vocab_proc)
pp.add_processor(seq_len_proc)

pp.add_processor(model_proc)
pp.add_processor(output_proc)


# TODO 这里貌似需要区分test pipeline与infer pipeline
import torch
import datetime
now = datetime.datetime.now()
infer_context_dict = {'pipeline': pp, 'tag_proc': tag_proc}
torch.save(infer_context_dict, 'models/cws_crf_{}_{}.pkl'.format(now.month, now.day))


# TODO 还需要考虑如何替换回原文的问题?
# 1. 不需要将特殊tag替换
# 2. 需要将特殊tag替换回去

+ 4
- 4
tutorials/README.md View File

@@ -1,12 +1,12 @@
# fastNLP 教程 # fastNLP 教程


### 上手教程 Quick Start ### 上手教程 Quick Start
- 一分钟上手:`fastnlp_1min_tutorial.ipynb` ![Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_1min_tutorial.ipynb)
- 十分钟上手:`fastnlp_10min_tutorial.ipynb` ![Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_10min_tutorial.ipynb)
- 一分钟上手:`fastnlp_1min_tutorial.ipynb` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_1min_tutorial.ipynb)
- 十分钟上手:`fastnlp_10min_tutorial.ipynb` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_10min_tutorial.ipynb)


### 进阶教程 Advanced Tutorial ### 进阶教程 Advanced Tutorial
- `fastnlp_advanced_tutorial/advance_tutorial.ipynb` ![Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb)
- `fastnlp_advanced_tutorial/advance_tutorial.ipynb` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/fastnlp_advanced_tutorial/advance_tutorial.ipynb)




### 开发者指南 Developer Guide ### 开发者指南 Developer Guide
- `tutorial_for_developer.md` ![Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/tutorial_for_developer.md)
- `tutorial_for_developer.md` [Click Here](https://github.com/fastnlp/fastNLP/tree/master/tutorials/tutorial_for_developer.md)

Loading…
Cancel
Save