|
|
@@ -4,9 +4,10 @@ from datetime import datetime |
|
|
|
import warnings |
|
|
|
from collections import defaultdict |
|
|
|
import os |
|
|
|
import itertools |
|
|
|
import shutil |
|
|
|
|
|
|
|
from tensorboardX import SummaryWriter |
|
|
|
import torch |
|
|
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
|
from fastNLP.core.loss import Loss |
|
|
@@ -51,17 +52,18 @@ class Trainer(object): |
|
|
|
self.evaluator = self.model.evaluate |
|
|
|
|
|
|
|
if self.dev_data is not None: |
|
|
|
valid_args = {"batch_size": self.batch_size, "save_path": self.save_path, |
|
|
|
"use_cuda": self.use_cuda, "evaluator": self.evaluator} |
|
|
|
self.tester = Tester(**valid_args) |
|
|
|
self.tester = Tester(model=self.model, |
|
|
|
data=self.dev_data, |
|
|
|
batch_size=self.batch_size, |
|
|
|
save_path=self.save_path, |
|
|
|
use_cuda=self.use_cuda) |
|
|
|
|
|
|
|
for k, v in kwargs.items(): |
|
|
|
setattr(self, k, v) |
|
|
|
|
|
|
|
self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs') |
|
|
|
if os.path.exists(self.tensorboard_path): |
|
|
|
os.rmdir(self.tensorboard_path) |
|
|
|
self._summary_writer = SummaryWriter(self.tensorboard_path) |
|
|
|
shutil.rmtree(self.tensorboard_path) |
|
|
|
self._graph_summaried = False |
|
|
|
self.step = 0 |
|
|
|
self.start_time = None # start timestamp |
|
|
@@ -73,26 +75,32 @@ class Trainer(object): |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
if torch.cuda.is_available() and self.use_cuda: |
|
|
|
self.model = self.model.cuda() |
|
|
|
try: |
|
|
|
self._summary_writer = SummaryWriter(self.tensorboard_path) |
|
|
|
|
|
|
|
self.mode(self.model, is_test=False) |
|
|
|
if torch.cuda.is_available() and self.use_cuda: |
|
|
|
self.model = self.model.cuda() |
|
|
|
|
|
|
|
start = time.time() |
|
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) |
|
|
|
print("training epochs started " + self.start_time) |
|
|
|
self.mode(self.model, is_test=False) |
|
|
|
|
|
|
|
epoch = 1 |
|
|
|
while epoch <= self.n_epochs: |
|
|
|
start = time.time() |
|
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) |
|
|
|
print("training epochs started " + self.start_time) |
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) |
|
|
|
epoch = 1 |
|
|
|
while epoch <= self.n_epochs: |
|
|
|
|
|
|
|
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) |
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) |
|
|
|
|
|
|
|
if self.dev_data: |
|
|
|
self.do_validation() |
|
|
|
self.save_model(self.model, 'training_model_' + self.start_time) |
|
|
|
epoch += 1 |
|
|
|
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) |
|
|
|
|
|
|
|
if self.dev_data: |
|
|
|
self.do_validation() |
|
|
|
self.save_model(self.model, 'training_model_' + self.start_time) |
|
|
|
epoch += 1 |
|
|
|
finally: |
|
|
|
self._summary_writer.close() |
|
|
|
del self._summary_writer |
|
|
|
|
|
|
|
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): |
|
|
|
"""Training process in one epoch. |
|
|
@@ -127,7 +135,7 @@ class Trainer(object): |
|
|
|
self.step += 1 |
|
|
|
|
|
|
|
def do_validation(self): |
|
|
|
res = self.tester.test(self.model, self.dev_data) |
|
|
|
res = self.tester.test() |
|
|
|
for name, num in res.items(): |
|
|
|
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) |
|
|
|
self.save_model(self.model, 'best_model_' + self.start_time) |
|
|
|