|
- import sys
- sys.path.append('../..')
-
- import torch
- from torch.optim import Adam
-
- from fastNLP.core.callback import Callback, GradientClipCallback
- from fastNLP.core.trainer import Trainer
-
- from reproduction.coreference_resolution.data_load.cr_loader import CRLoader
- from reproduction.coreference_resolution.model.config import Config
- from reproduction.coreference_resolution.model.model_re import Model
- from reproduction.coreference_resolution.model.softmax_loss import SoftmaxLoss
- from reproduction.coreference_resolution.model.metric import CRMetric
- from fastNLP import SequentialSampler
- from fastNLP import cache_results
-
-
- # torch.backends.cudnn.benchmark = False
- # torch.backends.cudnn.deterministic = True
-
- class LRCallback(Callback):
- def __init__(self, parameters, decay_rate=1e-3):
- super().__init__()
- self.paras = parameters
- self.decay_rate = decay_rate
-
- def on_step_end(self):
- if self.step % 100 == 0:
- for para in self.paras:
- para['lr'] = para['lr'] * (1 - self.decay_rate)
-
-
- if __name__ == "__main__":
- config = Config()
-
- print(config)
-
- @cache_results('cache.pkl')
- def cache():
- cr_train_dev_test = CRLoader()
-
- data_info = cr_train_dev_test.process({'train': config.train_path, 'dev': config.dev_path,
- 'test': config.test_path})
- return data_info
- data_info = cache()
- print("数据集划分:\ntrain:", str(len(data_info.datasets["train"])),
- "\ndev:" + str(len(data_info.datasets["dev"])) + "\ntest:" + str(len(data_info.datasets["test"])))
- # print(data_info)
- model = Model(data_info.vocabs, config)
- print(model)
-
- loss = SoftmaxLoss()
-
- metric = CRMetric()
-
- optim = Adam(model.parameters(), lr=config.lr)
-
- lr_decay_callback = LRCallback(optim.param_groups, config.lr_decay)
-
- trainer = Trainer(model=model, train_data=data_info.datasets["train"], dev_data=data_info.datasets["dev"],
- loss=loss, metrics=metric, check_code_level=-1,sampler=None,
- batch_size=1, device=torch.device("cuda:" + config.cuda), metric_key='f', n_epochs=config.epoch,
- optimizer=optim,
- save_path='/remote-home/xxliu/pycharm/fastNLP/fastNLP/reproduction/coreference_resolution/save',
- callbacks=[lr_decay_callback, GradientClipCallback(clip_value=5)])
- print()
-
- trainer.train()
|