|
-
- import sys
- sys.path.append('../../..')
-
- from fastNLP import cache_results
- from reproduction.sequence_labelling.cws.data.cws_shift_pipe import CWSShiftRelayPipe
- from reproduction.sequence_labelling.cws.model.bilstm_shift_relay import ShiftRelayCWSModel
- from fastNLP import Trainer
- from torch.optim import Adam
- from fastNLP import BucketSampler
- from fastNLP import GradientClipCallback
- from reproduction.sequence_labelling.cws.model.metric import RelayMetric
- from fastNLP.embeddings import StaticEmbedding
- from fastNLP import EvaluateCallback
-
- #########hyper
- L = 4
- hidden_size = 200
- num_layers = 1
- drop_p = 0.2
- lr = 0.008
- data_name = 'pku'
- #########hyper
- device = 0
-
- cache_fp = 'caches/{}.pkl'.format(data_name)
- @cache_results(_cache_fp=cache_fp, _refresh=True) # 将结果缓存到cache_fp中,这样下次运行就直接读取,而不需要再次运行
- def prepare_data():
- data_bundle = CWSShiftRelayPipe(dataset_name=data_name, L=L).process_from_file()
- # 预训练的character embedding和bigram embedding
- char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.5, word_dropout=0.01,
- model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt')
- bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.5, min_freq=3, word_dropout=0.01,
- model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt')
-
- return data_bundle, char_embed, bigram_embed
-
- data, char_embed, bigram_embed = prepare_data()
-
- model = ShiftRelayCWSModel(char_embed=char_embed, bigram_embed=bigram_embed,
- hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p, L=L)
-
- sampler = BucketSampler()
- optimizer = Adam(model.parameters(), lr=lr)
- clipper = GradientClipCallback(clip_value=5, clip_type='value') # 截断太大的梯度
- evaluator = EvaluateCallback(data.get_dataset('test')) # 额外测试在test集上的效果
- callbacks = [clipper, evaluator]
-
- trainer = Trainer(data.get_dataset('train'), model, optimizer=optimizer, loss=None, batch_size=128, sampler=sampler,
- update_every=1, n_epochs=10, print_every=5, dev_data=data.get_dataset('dev'), metrics=RelayMetric(),
- metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks,
- check_code_level=0, num_workers=1)
- trainer.train()
|