|
-
-
- """
- 使用Bert进行英文命名实体识别
-
- """
-
- import sys
-
- sys.path.append('../../../')
-
- from reproduction.sequence_labelling.ner.model.bert_crf import BertCRF
- from fastNLP.embeddings import BertEmbedding
- from fastNLP import Trainer, Const
- from fastNLP import BucketSampler, SpanFPreRecMetric, GradientClipCallback
- from fastNLP.core.callback import WarmupCallback
- from fastNLP.core.optimizer import AdamW
- from fastNLP.io import Conll2003NERPipe
-
- from fastNLP import cache_results, EvaluateCallback
-
- encoding_type = 'bioes'
-
- @cache_results('caches/conll2003.pkl', _refresh=False)
- def load_data():
- # 替换路径
- paths = 'data/conll2003'
- data = Conll2003NERPipe(encoding_type=encoding_type).process_from_file(paths)
- return data
- data = load_data()
- print(data)
-
- embed = BertEmbedding(data.get_vocab(Const.INPUT), model_dir_or_name='en-base-cased',
- pool_method='max', requires_grad=True, layers='11', include_cls_sep=False, dropout=0.5,
- word_dropout=0.01)
-
- callbacks = [
- GradientClipCallback(clip_type='norm', clip_value=1),
- WarmupCallback(warmup=0.1, schedule='linear'),
- EvaluateCallback(data.get_dataset('test'))
- ]
-
- model = BertCRF(embed, tag_vocab=data.get_vocab('target'), encoding_type=encoding_type)
- optimizer = AdamW(model.parameters(), lr=2e-5)
-
- trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
- device=0, dev_data=data.datasets['dev'], batch_size=6,
- metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
- loss=None, callbacks=callbacks, num_workers=2, n_epochs=5,
- check_code_level=0, update_every=3, test_use_tqdm=False)
- trainer.train()
|