|
- import sys
-
- sys.path.append('../../..')
-
- from fastNLP.embeddings import CNNCharEmbedding, StaticEmbedding, StackEmbedding
-
- from reproduction.seqence_labelling.ner.model.lstm_cnn_crf import CNNBiLSTMCRF
- from fastNLP import Trainer
- from fastNLP import SpanFPreRecMetric
- from fastNLP import Const
- from torch.optim import SGD
- from torch.optim.lr_scheduler import LambdaLR
- from fastNLP import GradientClipCallback
- from fastNLP import BucketSampler
- from fastNLP.core.callback import EvaluateCallback, LRScheduler
- from fastNLP import cache_results
- from fastNLP.io.pipe.conll import OntoNotesNERPipe
-
- #######hyper
- normalize = False
- lr = 0.01
- dropout = 0.5
- batch_size = 32
- data_name = 'ontonote'
- #######hyper
-
-
- encoding_type = 'bioes'
-
- @cache_results('caches/ontonotes.pkl', _refresh=True)
- def cache():
- data = OntoNotesNERPipe(encoding_type=encoding_type).process_from_file('../../../../others/data/v4/english')
- char_embed = CNNCharEmbedding(vocab=data.vocabs['words'], embed_size=30, char_emb_size=30, filter_nums=[30],
- kernel_sizes=[3], dropout=dropout)
- word_embed = StaticEmbedding(vocab=data.vocabs[Const.INPUT],
- model_dir_or_name='en-glove-6b-100d',
- requires_grad=True,
- normalize=normalize,
- word_dropout=0.01,
- dropout=dropout,
- lower=True,
- min_freq=1)
- return data, char_embed, word_embed
- data, char_embed, word_embed = cache()
-
- print(data)
-
- embed = StackEmbedding([word_embed, char_embed])
- model = CNNBiLSTMCRF(embed, hidden_size=1200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET],
- encoding_type=encoding_type, dropout=dropout)
-
- callbacks = [
- GradientClipCallback(clip_value=5, clip_type='value'),
- EvaluateCallback(data.datasets['test'])
- ]
-
- optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
- scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch)))
- callbacks.append(scheduler)
-
-
- trainer = Trainer(train_data=data.get_dataset('train'), model=model, optimizer=optimizer, sampler=BucketSampler(num_buckets=100),
- device=0, dev_data=data.get_dataset('dev'), batch_size=batch_size,
- metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
- callbacks=callbacks, num_workers=1, n_epochs=100, dev_batch_size=256)
- trainer.train()
|