| @@ -0,0 +1,229 @@ | |||||
| from fastNLP.api.pipeline import Pipeline | |||||
| from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor | |||||
| from fastNLP.api.processor import SeqLenProcessor | |||||
| from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor | |||||
| from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor | |||||
| from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor | |||||
| from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor | |||||
| from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader | |||||
| from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF | |||||
| from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 | |||||
| ds_name = 'msr' | |||||
| tr_filename = '/home/hyan/ctb3/train.conllx' | |||||
| dev_filename = '/home/hyan/ctb3/dev.conllx' | |||||
| reader = ConllCWSReader() | |||||
| tr_dataset = reader.load(tr_filename, cut_long_sent=True) | |||||
| dev_dataset = reader.load(dev_filename) | |||||
| print("Train {}. Dev: {}".format(len(tr_dataset), len(dev_dataset))) | |||||
| # 1. 准备processor | |||||
| fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') | |||||
| char_proc = CWSCharSegProcessor('raw_sentence', 'chars_lst') | |||||
| tag_proc = CWSBMESTagProcessor('raw_sentence', 'target') | |||||
| bigram_proc = Pre2Post2BigramProcessor('chars_lst', 'bigrams_lst') | |||||
| char_vocab_proc = VocabIndexerProcessor('chars_lst', new_added_filed_name='chars') | |||||
| bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='bigrams', min_freq=4) | |||||
| seq_len_proc = SeqLenProcessor('chars') | |||||
| # 2. 使用processor | |||||
| fs2hs_proc(tr_dataset) | |||||
| char_proc(tr_dataset) | |||||
| tag_proc(tr_dataset) | |||||
| bigram_proc(tr_dataset) | |||||
| char_vocab_proc(tr_dataset) | |||||
| bigram_vocab_proc(tr_dataset) | |||||
| seq_len_proc(tr_dataset) | |||||
| # 2.1 处理dev_dataset | |||||
| fs2hs_proc(dev_dataset) | |||||
| char_proc(dev_dataset) | |||||
| tag_proc(dev_dataset) | |||||
| bigram_proc(dev_dataset) | |||||
| char_vocab_proc(dev_dataset) | |||||
| bigram_vocab_proc(dev_dataset) | |||||
| seq_len_proc(dev_dataset) | |||||
| dev_dataset.set_input('chars', 'bigrams', 'target') | |||||
| tr_dataset.set_input('chars', 'bigrams', 'target') | |||||
| dev_dataset.set_target('seq_lens') | |||||
| tr_dataset.set_target('seq_lens') | |||||
| print("Finish preparing data.") | |||||
| # 3. 得到数据集可以用于训练了 | |||||
| # TODO pretrain的embedding是怎么解决的? | |||||
| import torch | |||||
| from torch import optim | |||||
| tag_size = tag_proc.tag_size | |||||
| cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, | |||||
| bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), | |||||
| bigram_embed_dim=100, num_bigram_per_char=8, | |||||
| hidden_size=200, bidirectional=True, embed_drop_p=0.2, | |||||
| num_layers=1, tag_size=tag_size) | |||||
| cws_model.cuda() | |||||
| num_epochs = 5 | |||||
| optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) | |||||
| from fastNLP.core.trainer import Trainer | |||||
| from fastNLP.core.sampler import BucketSampler | |||||
| from fastNLP.core.metrics import BMESF1PreRecMetric | |||||
| metric = BMESF1PreRecMetric(target='tags') | |||||
| trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3, | |||||
| batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, | |||||
| optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) | |||||
| trainer.train() | |||||
| exit(0) | |||||
| # | |||||
| # print_every = 50 | |||||
| # batch_size = 32 | |||||
| # tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
| # dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
| # num_batch_per_epoch = len(tr_dataset) // batch_size | |||||
| # best_f1 = 0 | |||||
| # best_epoch = 0 | |||||
| # for num_epoch in range(num_epochs): | |||||
| # print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) | |||||
| # sys.stdout.flush() | |||||
| # avg_loss = 0 | |||||
| # with tqdm(total=num_batch_per_epoch, leave=True) as pbar: | |||||
| # pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) | |||||
| # cws_model.train() | |||||
| # for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): | |||||
| # optimizer.zero_grad() | |||||
| # | |||||
| # tags = batch_y['tags'].long() | |||||
| # pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size | |||||
| # | |||||
| # seq_lens = pred_dict['seq_lens'] | |||||
| # masks = seq_lens_to_mask(seq_lens).float() | |||||
| # tags = tags.to(seq_lens.device) | |||||
| # | |||||
| # loss = pred_dict['loss'] | |||||
| # | |||||
| # # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), | |||||
| # # tags.view(-1)) * masks.view(-1)) / torch.sum(masks) | |||||
| # # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) | |||||
| # | |||||
| # avg_loss += loss.item() | |||||
| # | |||||
| # loss.backward() | |||||
| # for group in optimizer.param_groups: | |||||
| # for param in group['params']: | |||||
| # param.grad.clamp_(-5, 5) | |||||
| # | |||||
| # optimizer.step() | |||||
| # | |||||
| # if batch_idx % print_every == 0: | |||||
| # pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) | |||||
| # avg_loss = 0 | |||||
| # pbar.update(print_every) | |||||
| # tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) | |||||
| # # 验证集 | |||||
| # pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes') | |||||
| # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, | |||||
| # pre*100, | |||||
| # rec*100)) | |||||
| # if best_f1<f1: | |||||
| # best_f1 = f1 | |||||
| # # 缓存最佳的parameter,可能之后会用于保存 | |||||
| # best_state_dict = { | |||||
| # key:value.clone() for key, value in | |||||
| # cws_model.state_dict().items() | |||||
| # } | |||||
| # best_epoch = num_epoch | |||||
| # | |||||
| # cws_model.load_state_dict(best_state_dict) | |||||
| # 4. 组装需要存下的内容 | |||||
| pp = Pipeline() | |||||
| pp.add_processor(fs2hs_proc) | |||||
| # pp.add_processor(sp_proc) | |||||
| pp.add_processor(char_proc) | |||||
| pp.add_processor(tag_proc) | |||||
| pp.add_processor(bigram_proc) | |||||
| pp.add_processor(char_vocab_proc) | |||||
| pp.add_processor(bigram_vocab_proc) | |||||
| pp.add_processor(seq_len_proc) | |||||
| # te_filename = '/hdd/fudanNLP/CWS/CWS_semiCRF/all_data/{}/middle_files/{}_test.txt'.format(ds_name, ds_name) | |||||
| te_filename = '/home/hyan/ctb3/test.conllx' | |||||
| te_dataset = reader.load(te_filename) | |||||
| pp(te_dataset) | |||||
| from fastNLP.core.tester import Tester | |||||
| tester = Tester(data=te_dataset, model=cws_model, metrics=metric, batch_size=64, use_cuda=False, | |||||
| verbose=1) | |||||
| # | |||||
| # batch_size = 64 | |||||
| # te_batcher = Batch(te_dataset, batch_size, SequentialSampler(), use_cuda=False) | |||||
| # pre, rec, f1 = calculate_pre_rec_f1(cws_model, te_batcher, type='bmes') | |||||
| # print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1 * 100, | |||||
| # pre * 100, | |||||
| # rec * 100)) | |||||
| # TODO 这里貌似需要区分test pipeline与infer pipeline | |||||
| test_context_dict = {'pipeline': pp, | |||||
| 'model': cws_model} | |||||
| torch.save(test_context_dict, 'models/test_context_crf.pkl') | |||||
| # 5. dev的pp | |||||
| # 4. 组装需要存下的内容 | |||||
| from fastNLP.api.processor import ModelProcessor | |||||
| from reproduction.chinese_word_segment.process.cws_processor import BMES2OutputProcessor | |||||
| model_proc = ModelProcessor(cws_model) | |||||
| output_proc = BMES2OutputProcessor() | |||||
| pp = Pipeline() | |||||
| pp.add_processor(fs2hs_proc) | |||||
| # pp.add_processor(sp_proc) | |||||
| pp.add_processor(char_proc) | |||||
| pp.add_processor(bigram_proc) | |||||
| pp.add_processor(char_vocab_proc) | |||||
| pp.add_processor(bigram_vocab_proc) | |||||
| pp.add_processor(seq_len_proc) | |||||
| pp.add_processor(model_proc) | |||||
| pp.add_processor(output_proc) | |||||
| # TODO 这里貌似需要区分test pipeline与infer pipeline | |||||
| infer_context_dict = {'pipeline': pp} | |||||
| # torch.save(infer_context_dict, 'models/cws_crf.pkl') | |||||
| # TODO 还需要考虑如何替换回原文的问题? | |||||
| # 1. 不需要将特殊tag替换 | |||||
| # 2. 需要将特殊tag替换回去 | |||||