diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py index 9f9661fd..ee1354fe 100644 --- a/fastNLP/core/tester.py +++ b/fastNLP/core/tester.py @@ -10,12 +10,11 @@ from fastNLP.core.utils import _build_args class Tester(object): """An collection of model inference and evaluation of performance, used over validation/dev set and test set. """ - def __init__(self, data, model, batch_size, use_cuda, save_path="./save/", **kwargs): + def __init__(self, data, model, batch_size=16, use_cuda=False): super(Tester, self).__init__() self.use_cuda = use_cuda self.data = data self.batch_size = batch_size - self.pickle_path = save_path if torch.cuda.is_available() and self.use_cuda: self._model = model.cuda() else: @@ -53,7 +52,6 @@ class Tester(object): eval_results = self._evaluator(**args) print("[tester] {}".format(self.print_eval_results(eval_results))) self.mode(network, is_test=False) - self.metrics = eval_results return eval_results def mode(self, model, is_test=False): diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6e439c47..e5499767 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -27,7 +27,7 @@ class Trainer(object): """ def __init__(self, train_data, model, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", - optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), + optimizer=Optimizer("Adam", lr=0.001, weight_decay=0), need_check_code=True, **kwargs): super(Trainer, self).__init__() @@ -37,9 +37,13 @@ class Trainer(object): self.n_epochs = int(n_epochs) self.batch_size = int(batch_size) self.use_cuda = bool(use_cuda) - self.save_path = str(save_path) + self.save_path = save_path self.print_every = int(print_every) self.validate_every = int(validate_every) + self._best_accuracy = 0 + + if need_check_code: + _check_code(dataset=train_data, model=model, dev_data=dev_data) model_name = model.__class__.__name__ assert hasattr(self.model, 'get_loss'), "model {} has to have a 'get_loss' function.".format(model_name) @@ -56,16 +60,11 @@ class Trainer(object): self.tester = Tester(model=self.model, data=self.dev_data, batch_size=self.batch_size, - save_path=self.save_path, use_cuda=self.use_cuda) for k, v in kwargs.items(): setattr(self, k, v) - self.tensorboard_path = os.path.join(self.save_path, 'tensorboard_logs') - if os.path.exists(self.tensorboard_path): - shutil.rmtree(self.tensorboard_path) - self._graph_summaried = False self.step = 0 self.start_time = None # start timestamp @@ -77,8 +76,6 @@ class Trainer(object): :return: """ try: - self._summary_writer = SummaryWriter(self.tensorboard_path) - if torch.cuda.is_available() and self.use_cuda: self.model = self.model.cuda() @@ -87,6 +84,9 @@ class Trainer(object): start = time.time() self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) print("training epochs started " + self.start_time) + if self.save_path is not None: + path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) + self._summary_writer = SummaryWriter(path) epoch = 1 while epoch <= self.n_epochs: @@ -143,7 +143,8 @@ class Trainer(object): res = self.tester.test() for name, num in res.items(): self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) - self.save_model(self.model, 'best_model_' + self.start_time) + if self.save_path is not None and self.best_eval_result(res): + self.save_model(self.model, 'best_model_' + self.start_time) def mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. @@ -166,9 +167,6 @@ class Trainer(object): def data_forward(self, network, x): x = _build_args(network.forward, **x) y = network(**x) - if not self._graph_summaried: - # self._summary_writer.add_graph(network, x, verbose=False) - self._graph_summaried = True return y def grad_backward(self, loss): @@ -199,28 +197,27 @@ class Trainer(object): else: torch.save(model, model_name) + def best_eval_result(self, metrics): + """Check if the current epoch yields better validation results. -def best_eval_result(self, metrics): - """Check if the current epoch yields better validation results. - - :return: bool, True means current results on dev set is the best. - """ - if isinstance(metrics, tuple): - loss, metrics = metrics + :return: bool, True means current results on dev set is the best. + """ + if isinstance(metrics, tuple): + loss, metrics = metrics - if isinstance(metrics, dict): - if len(metrics) == 1: - accuracy = list(metrics.values())[0] + if isinstance(metrics, dict): + if len(metrics) == 1: + accuracy = list(metrics.values())[0] + else: + accuracy = metrics[self.eval_sort_key] else: - accuracy = metrics[self.eval_sort_key] - else: - accuracy = metrics + accuracy = metrics - if accuracy > self._best_accuracy: - self._best_accuracy = accuracy - return True - else: - return False + if accuracy > self._best_accuracy: + self._best_accuracy = accuracy + return True + else: + return False DEFAULT_CHECK_BATCH_SIZE = 2 @@ -268,9 +265,6 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No loss.backward() if batch_count + 1 >= DEFAULT_CHECK_BATCH_SIZE: break - if check_level > IGNORE_CHECK_LEVEL: - print('Finish checking training process.', flush=True) - if dev_data is not None: if not hasattr(model, 'evaluate'): @@ -310,8 +304,6 @@ def _check_code(dataset, model, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=No func_signature = get_func_signature(model.evaluate) assert isinstance(metrics, dict), "The return value of {} should be dict.". \ format(func_signature) - if check_level > IGNORE_CHECK_LEVEL: - print("Finish checking evaluate process.", flush=True) def _check_forward_error(model_func, check_level, batch_x):