|
|
@@ -7,13 +7,14 @@ from torch.optim import SGD, Adam |
|
|
|
from fastNLP import Const |
|
|
|
from fastNLP import RandomSampler, BucketSampler |
|
|
|
from fastNLP import SpanFPreRecMetric |
|
|
|
from fastNLP import Trainer |
|
|
|
from fastNLP import Trainer, Tester |
|
|
|
from fastNLP.core.metrics import MetricBase |
|
|
|
from reproduction.seqence_labelling.ner.model.dilated_cnn import IDCNN |
|
|
|
from fastNLP.core.utils import Option |
|
|
|
from fastNLP.modules.encoder.embedding import CNNCharEmbedding, StaticEmbedding |
|
|
|
from fastNLP.core.utils import cache_results |
|
|
|
from fastNLP.core.vocabulary import VocabularyOption |
|
|
|
import fitlog |
|
|
|
import sys |
|
|
|
import torch.cuda |
|
|
|
import os |
|
|
@@ -31,7 +32,7 @@ def get_path(path): |
|
|
|
ops = Option( |
|
|
|
batch_size=128, |
|
|
|
num_epochs=100, |
|
|
|
lr=5e-4, |
|
|
|
lr=3e-4, |
|
|
|
repeats=3, |
|
|
|
num_layers=3, |
|
|
|
num_filters=400, |
|
|
@@ -39,18 +40,18 @@ ops = Option( |
|
|
|
gradient_clip=5, |
|
|
|
) |
|
|
|
|
|
|
|
@cache_results('ontonotes-min_freq0-case-cache') |
|
|
|
@cache_results('ontonotes-case-cache') |
|
|
|
def load_data(): |
|
|
|
print('loading data') |
|
|
|
# data = OntoNoteNERDataLoader(encoding_type=encoding_type).process( |
|
|
|
# data_path = get_path('workdir/datasets/ontonotes-v4') |
|
|
|
# lower=False, |
|
|
|
# word_vocab_opt=VocabularyOption(min_freq=0), |
|
|
|
# ) |
|
|
|
data = Conll2003DataLoader(task='ner', encoding_type=encoding_type).process( |
|
|
|
paths=get_path('workdir/datasets/conll03'), |
|
|
|
lower=False, word_vocab_opt=VocabularyOption(min_freq=0) |
|
|
|
data = OntoNoteNERDataLoader(encoding_type=encoding_type).process( |
|
|
|
paths = get_path('workdir/datasets/ontonotes-v4'), |
|
|
|
lower=False, |
|
|
|
word_vocab_opt=VocabularyOption(min_freq=0), |
|
|
|
) |
|
|
|
# data = Conll2003DataLoader(task='ner', encoding_type=encoding_type).process( |
|
|
|
# paths=get_path('workdir/datasets/conll03'), |
|
|
|
# lower=False, word_vocab_opt=VocabularyOption(min_freq=0) |
|
|
|
# ) |
|
|
|
|
|
|
|
# char_embed = CNNCharEmbedding(vocab=data.vocabs['cap_words'], embed_size=30, char_emb_size=30, filter_nums=[30], |
|
|
|
# kernel_sizes=[3]) |
|
|
@@ -88,11 +89,11 @@ model = IDCNN(init_embed=word_embed, |
|
|
|
kernel_size=3, |
|
|
|
use_crf=ops.use_crf, use_projection=True, |
|
|
|
block_loss=True, |
|
|
|
input_dropout=0.5, hidden_dropout=0.0, inner_dropout=0.0) |
|
|
|
input_dropout=0.5, hidden_dropout=0.2, inner_dropout=0.2) |
|
|
|
|
|
|
|
print(model) |
|
|
|
|
|
|
|
callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='norm'),] |
|
|
|
callbacks = [GradientClipCallback(clip_value=ops.gradient_clip, clip_type='value'),] |
|
|
|
metrics = [] |
|
|
|
metrics.append( |
|
|
|
SpanFPreRecMetric( |
|
|
@@ -123,8 +124,9 @@ metrics.append( |
|
|
|
LossMetric(loss=Const.LOSS) |
|
|
|
) |
|
|
|
|
|
|
|
optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=1e-4) |
|
|
|
# scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) |
|
|
|
optimizer = Adam(model.parameters(), lr=ops.lr, weight_decay=0) |
|
|
|
scheduler = LRScheduler(LambdaLR(optimizer, lr_lambda=lambda epoch: 1 / (1 + 0.05 * epoch))) |
|
|
|
callbacks.append(scheduler) |
|
|
|
# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 15))) |
|
|
|
# optimizer = SWATS(model.parameters(), verbose=True) |
|
|
|
# optimizer = Adam(model.parameters(), lr=0.005) |
|
|
@@ -138,3 +140,16 @@ trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=opti |
|
|
|
check_code_level=-1, |
|
|
|
callbacks=callbacks, num_workers=2, n_epochs=ops.num_epochs) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
torch.save(model, 'idcnn.pt') |
|
|
|
|
|
|
|
tester = Tester( |
|
|
|
data=data.datasets['test'], |
|
|
|
model=model, |
|
|
|
metrics=metrics, |
|
|
|
batch_size=ops.batch_size, |
|
|
|
num_workers=2, |
|
|
|
device=device |
|
|
|
) |
|
|
|
tester.test() |
|
|
|
|