diff --git a/reproduction/chinese_word_segment/train_context.py b/reproduction/chinese_word_segment/train_context.py new file mode 100644 index 00000000..93e3de50 --- /dev/null +++ b/reproduction/chinese_word_segment/train_context.py @@ -0,0 +1,229 @@ + +from fastNLP.api.pipeline import Pipeline +from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor +from fastNLP.api.processor import SeqLenProcessor +from reproduction.chinese_word_segment.process.cws_processor import CWSCharSegProcessor +from reproduction.chinese_word_segment.process.cws_processor import CWSBMESTagProcessor +from reproduction.chinese_word_segment.process.cws_processor import Pre2Post2BigramProcessor +from reproduction.chinese_word_segment.process.cws_processor import VocabIndexerProcessor + + +from reproduction.chinese_word_segment.cws_io.cws_reader import ConllCWSReader +from reproduction.chinese_word_segment.models.cws_model import CWSBiLSTMCRF + +from reproduction.chinese_word_segment.utils import calculate_pre_rec_f1 + +ds_name = 'msr' + +tr_filename = '/home/hyan/ctb3/train.conllx' +dev_filename = '/home/hyan/ctb3/dev.conllx' + + +reader = ConllCWSReader() + +tr_dataset = reader.load(tr_filename, cut_long_sent=True) +dev_dataset = reader.load(dev_filename) + +print("Train {}. Dev: {}".format(len(tr_dataset), len(dev_dataset))) + +# 1. 准备processor +fs2hs_proc = FullSpaceToHalfSpaceProcessor('raw_sentence') + +char_proc = CWSCharSegProcessor('raw_sentence', 'chars_lst') +tag_proc = CWSBMESTagProcessor('raw_sentence', 'target') + +bigram_proc = Pre2Post2BigramProcessor('chars_lst', 'bigrams_lst') + +char_vocab_proc = VocabIndexerProcessor('chars_lst', new_added_filed_name='chars') +bigram_vocab_proc = VocabIndexerProcessor('bigrams_lst', new_added_filed_name='bigrams', min_freq=4) + +seq_len_proc = SeqLenProcessor('chars') + +# 2. 使用processor +fs2hs_proc(tr_dataset) + +char_proc(tr_dataset) +tag_proc(tr_dataset) +bigram_proc(tr_dataset) + +char_vocab_proc(tr_dataset) +bigram_vocab_proc(tr_dataset) +seq_len_proc(tr_dataset) + +# 2.1 处理dev_dataset +fs2hs_proc(dev_dataset) + +char_proc(dev_dataset) +tag_proc(dev_dataset) +bigram_proc(dev_dataset) + +char_vocab_proc(dev_dataset) +bigram_vocab_proc(dev_dataset) +seq_len_proc(dev_dataset) + +dev_dataset.set_input('chars', 'bigrams', 'target') +tr_dataset.set_input('chars', 'bigrams', 'target') +dev_dataset.set_target('seq_lens') +tr_dataset.set_target('seq_lens') + +print("Finish preparing data.") + + +# 3. 得到数据集可以用于训练了 +# TODO pretrain的embedding是怎么解决的? + +import torch +from torch import optim + + +tag_size = tag_proc.tag_size + +cws_model = CWSBiLSTMCRF(char_vocab_proc.get_vocab_size(), embed_dim=100, + bigram_vocab_num=bigram_vocab_proc.get_vocab_size(), + bigram_embed_dim=100, num_bigram_per_char=8, + hidden_size=200, bidirectional=True, embed_drop_p=0.2, + num_layers=1, tag_size=tag_size) +cws_model.cuda() + +num_epochs = 5 +optimizer = optim.Adagrad(cws_model.parameters(), lr=0.02) + +from fastNLP.core.trainer import Trainer +from fastNLP.core.sampler import BucketSampler +from fastNLP.core.metrics import BMESF1PreRecMetric + +metric = BMESF1PreRecMetric(target='tags') +trainer = Trainer(train_data=tr_dataset, model=cws_model, loss=None, metrics=metric, n_epochs=3, + batch_size=32, print_every=50, validate_every=-1, dev_data=dev_dataset, save_path=None, + optimizer=optimizer, check_code_level=0, metric_key='f', sampler=BucketSampler(), use_tqdm=True) + +trainer.train() +exit(0) + +# +# print_every = 50 +# batch_size = 32 +# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) +# dev_batcher = Batch(dev_dataset, batch_size, SequentialSampler(), use_cuda=False) +# num_batch_per_epoch = len(tr_dataset) // batch_size +# best_f1 = 0 +# best_epoch = 0 +# for num_epoch in range(num_epochs): +# print('X' * 10 + ' Epoch: {}/{} '.format(num_epoch + 1, num_epochs) + 'X' * 10) +# sys.stdout.flush() +# avg_loss = 0 +# with tqdm(total=num_batch_per_epoch, leave=True) as pbar: +# pbar.set_description_str('Epoch:%d' % (num_epoch + 1)) +# cws_model.train() +# for batch_idx, (batch_x, batch_y) in enumerate(tr_batcher, 1): +# optimizer.zero_grad() +# +# tags = batch_y['tags'].long() +# pred_dict = cws_model(**batch_x, tags=tags) # B x L x tag_size +# +# seq_lens = pred_dict['seq_lens'] +# masks = seq_lens_to_mask(seq_lens).float() +# tags = tags.to(seq_lens.device) +# +# loss = pred_dict['loss'] +# +# # loss = torch.sum(loss_fn(pred_dict['pred_probs'].view(-1, tag_size), +# # tags.view(-1)) * masks.view(-1)) / torch.sum(masks) +# # loss = torch.mean(F.cross_entropy(probs.view(-1, 2), tags.view(-1)) * masks.float()) +# +# avg_loss += loss.item() +# +# loss.backward() +# for group in optimizer.param_groups: +# for param in group['params']: +# param.grad.clamp_(-5, 5) +# +# optimizer.step() +# +# if batch_idx % print_every == 0: +# pbar.set_postfix_str('batch=%d, avg_loss=%.5f' % (batch_idx, avg_loss / print_every)) +# avg_loss = 0 +# pbar.update(print_every) +# tr_batcher = Batch(tr_dataset, batch_size, BucketSampler(batch_size=batch_size), use_cuda=False) +# # 验证集 +# pre, rec, f1 = calculate_pre_rec_f1(cws_model, dev_batcher, type='bmes') +# print("f1:{:.2f}, pre:{:.2f}, rec:{:.2f}".format(f1*100, +# pre*100, +# rec*100)) +# if best_f1