From e6864ea7e0f42deff6d50c9e75c639a7a0ddea1f Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 1 Dec 2018 20:27:23 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E6=9B=B4=E6=96=B0embed=5Floader:=20*=20?= =?UTF-8?q?=E6=B7=BB=E5=8A=A0fast=5Fload=5Fembedding=E6=96=B9=E6=B3=95?= =?UTF-8?q?=EF=BC=8C=E7=94=A8vocab=E7=9A=84=E8=AF=8D=E7=B4=A2=E5=BC=95pre-?= =?UTF-8?q?trained=E4=B8=AD=E7=9A=84embedding=20*=20=E5=A6=82=E6=9E=9Cvoca?= =?UTF-8?q?b=E6=9C=89=E8=AF=8D=E6=B2=A1=E5=87=BA=E7=8E=B0=E5=9C=A8pre-trai?= =?UTF-8?q?n=E4=B8=AD=EF=BC=8C=E4=BB=8E=E5=B7=B2=E6=9C=89embedding?= =?UTF-8?q?=E4=B8=AD=E6=AD=A3=E6=80=81=E9=87=87=E6=A0=B7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update embed_loader: * add fast_load_embedding method, to index pre-trained embedding with words in Vocab * If words in Vocab are not exist in pre-trained, sample them from normal distribution computed by current embeddings --- fastNLP/core/trainer.py | 159 +++++++++++++++++++++++++--------------- 1 file changed, 98 insertions(+), 61 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index da8e54f9..54ce2cd9 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,39 +1,38 @@ -import itertools import os import time import warnings -from collections import defaultdict from datetime import datetime from datetime import timedelta import torch -from torch import nn from tensorboardX import SummaryWriter +from torch import nn from fastNLP.core.batch import Batch +from fastNLP.core.dataset import DataSet +from fastNLP.core.losses import _prepare_losser +from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.optimizer import Optimizer from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester +from fastNLP.core.utils import CheckError from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import get_func_signature -from fastNLP.core.dataset import DataSet -from fastNLP.core.losses import LossBase -from fastNLP.core.metrics import MetricBase -from fastNLP.core.losses import _prepare_losser -from fastNLP.core.metrics import _prepare_metrics -from fastNLP.core.utils import CheckError class Trainer(object): """Main Training Loop """ - def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, + + def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, + validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, + metric_key=None, **kwargs): super(Trainer, self).__init__() @@ -50,6 +49,13 @@ class Trainer(object): # prepare evaluate metrics = _prepare_metrics(metrics) + + # parse metric_key + # increase_better is True. It means the exp result gets better if the indicator increases. + # It is true by default. + self.increase_better = False if metric_key[0] == "-" else True + self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + # prepare loss losser = _prepare_losser(losser) @@ -67,7 +73,7 @@ class Trainer(object): self.save_path = save_path self.print_every = int(print_every) self.validate_every = int(validate_every) - self._best_accuracy = 0 + self.best_metric_indicator = None self._model_device = model.parameters().__next__().device @@ -102,7 +108,7 @@ class Trainer(object): if torch.cuda.is_available() and self.use_cuda: self.model = self.model.cuda() - self.mode(self.model, is_test=False) + self._mode(self.model, is_test=False) start = time.time() self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) @@ -112,7 +118,9 @@ class Trainer(object): def __getattr__(self, item): def pass_func(*args, **kwargs): pass + return pass_func + self._summary_writer = psudoSW() else: path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) @@ -121,13 +129,14 @@ class Trainer(object): epoch = 1 while epoch <= self.n_epochs: - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), + as_numpy=False) self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) # validate_every override validation at end of epochs if self.dev_data and self.validate_every <= 0: - self.do_validation() + self._do_validation() epoch += 1 finally: self._summary_writer.close() @@ -144,10 +153,10 @@ class Trainer(object): for batch_x, batch_y in data_iterator: # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 _move_dict_value_to_device(self._model_device, batch_x, batch_y) - prediction = self.data_forward(model, batch_x) - loss = self.get_loss(prediction, batch_y) - self.grad_backward(loss) - self.update() + prediction = self._data_forward(model, batch_x) + loss = self._compute_loss(prediction, batch_y) + self._grad_backward(loss) + self._update() self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) for name, param in self.model.named_parameters(): if param.requires_grad: @@ -162,18 +171,18 @@ class Trainer(object): print(print_output) if self.validate_every > 0 and self.step % self.validate_every == 0: - self.do_validation() + self._do_validation() self.step += 1 - def do_validation(self): + def _do_validation(self): res = self.tester.test() for name, num in res.items(): self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) - if self.save_path is not None and self.best_eval_result(res): + if self.save_path is not None and self._better_eval_result(res): self.save_model(self.model, 'best_model_' + self.start_time) - def mode(self, model, is_test=False): + def _mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. :param model: a PyTorch model @@ -185,20 +194,20 @@ class Trainer(object): else: model.train() - def update(self): + def _update(self): """Perform weight update on a model. """ self.optimizer.step() - def data_forward(self, network, x): + def _data_forward(self, network, x): x = _build_args(network.forward, **x) y = network(**x) if not isinstance(y, dict): raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") return y - def grad_backward(self, loss): + def _grad_backward(self, loss): """Compute gradient with link rules. :param loss: a scalar where back-prop starts @@ -208,7 +217,7 @@ class Trainer(object): self.model.zero_grad() loss.backward() - def get_loss(self, predict, truth): + def _compute_loss(self, predict, truth): """Compute loss given prediction and ground truth. :param predict: prediction dict, produced by model.forward @@ -224,27 +233,52 @@ class Trainer(object): else: torch.save(model, model_name) - def best_eval_result(self, metrics): + def _better_eval_result(self, metrics): """Check if the current epoch yields better validation results. - :return: bool, True means current results on dev set is the best. + :return bool value: True means current results on dev set is the best. """ if isinstance(metrics, tuple): loss, metrics = metrics if isinstance(metrics, dict): if len(metrics) == 1: - accuracy = list(metrics.values())[0] + # only single metric, just use it + metric_dict = list(metrics.values())[0] + metrics_name = list(metrics.keys())[0] else: - accuracy = metrics[self.eval_sort_key] - else: - accuracy = metrics - - if accuracy > self._best_accuracy: - self._best_accuracy = accuracy - return True - else: - return False + metrics_name = self.metrics[0].__class__.__name__ + if metrics_name not in metrics: + raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") + metric_dict = metrics[metrics_name] + + if len(metric_dict) == 1: + indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] + elif len(metric_dict) > 1 and self.metric_key is None: + raise RuntimeError( + f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") + else: + # metric_key is set + if self.metric_key not in metric_dict: + raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") + indicator_val = metric_dict[self.metric_key] + + is_better = True + if self.best_metric_indicator is None: + # first-time validation + self.best_metric_indicator = indicator_val + else: + if self.increase_better is True: + if indicator_val > self.best_metric_indicator: + self.best_metric_indicator = indicator_val + else: + is_better = False + else: + if indicator_val < self.best_metric_indicator: + self.best_metric_indicator = indicator_val + else: + is_better = False + return is_better DEFAULT_CHECK_BATCH_SIZE = 2 @@ -254,6 +288,7 @@ IGNORE_CHECK_LEVEL = 0 WARNING_CHECK_LEVEL = 1 STRICT_CHECK_LEVEL = 2 + def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): @@ -264,7 +299,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ for batch_count, (batch_x, batch_y) in enumerate(batch): _move_dict_value_to_device(model_devcie, batch_x, batch_y) # forward check - if batch_count==0: + if batch_count == 0: _check_forward_error(model_func=model.forward, check_level=check_level, batch_x=batch_x) @@ -285,17 +320,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ if batch_count == 0: if not isinstance(loss, torch.Tensor): raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " - f"but got `{type(loss)}`.") - if len(loss.size())!=0: + f"but got `{type(loss)}`.") + if len(loss.size()) != 0: raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " f"should be torch.size([])") loss.backward() model.zero_grad() - if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: + if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: break if dev_data is not None: - tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, + tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) tester.test() @@ -305,18 +340,18 @@ def _check_forward_error(model_func, check_level, batch_x): _missing = '' _unused = '' func_signature = get_func_signature(model_func) - if len(check_res['missing'])!=0: + if len(check_res['missing']) != 0: _missing = "Function {} misses {}, only provided with {}, " \ ".\n".format(func_signature, check_res.missing, - list(batch_x.keys())) - if len(check_res['unused'])!=0: + list(batch_x.keys())) + if len(check_res['unused']) != 0: if len(check_res.unused) > 1: _unused = "{} are not used ".format(check_res.unused) else: _unused = "{} is not used ".format(check_res.unused) _unused += "in function {}.\n".format(func_signature) if _missing: - if len(_unused)>0 and STRICT_CHECK_LEVEL: + if len(_unused) > 0 and STRICT_CHECK_LEVEL: _error_str = "(1).{}\n(2).{}".format(_missing, _unused) else: _error_str = _missing @@ -329,38 +364,40 @@ def _check_forward_error(model_func, check_level, batch_x): elif check_level == WARNING_CHECK_LEVEL: warnings.warn(message=_unused) -def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): + +def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): + check_res = _check_arg_dict_list(func, [output, batch_y]) _missing = '' _unused = '' _duplicated = '' func_signature = get_func_signature(func) prev_func_signature = get_func_signature(prev_func) - if len(check_res.missing)>0: + if len(check_res.missing) > 0: _missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ "{}(from target in Dataset)." \ - .format(func_signature, check_res.missing, - list(output.keys()), prev_func_signature, - list(batch_y.keys())) - if len(check_res.unused)>0: + .format(func_signature, check_res.missing, + list(output.keys()), prev_func_signature, + list(batch_y.keys())) + if len(check_res.unused) > 0: if len(check_res.unused) > 1: _unused = "{} are not used ".format(check_res.unused) else: _unused = "{} is not used ".format(check_res.unused) _unused += "in function {}.\n".format(func_signature) - if len(check_res.duplicated)>0: + if len(check_res.duplicated) > 0: if len(check_res.duplicated) > 1: _duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ "them in {} at the same time.".format(check_res.duplicated, - func_signature, - check_res.duplicated, - prev_func_signature) - else: - _duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ - "it in {} at the same time.".format(check_res.duplicated, func_signature, check_res.duplicated, prev_func_signature) - _number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) + else: + _duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ + "it in {} at the same time.".format(check_res.duplicated, + func_signature, + check_res.duplicated, + prev_func_signature) + _number_errs = int(len(_missing) != 0) + int(len(_duplicated) != 0) + int(len(_unused) != 0) if _number_errs > 0: _error_strs = [] if _number_errs > 1: From e5e7f29d7205a269fd1a922bfd9067f2ead5de81 Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 1 Dec 2018 20:27:23 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E6=9B=B4=E6=96=B0Trainer:=20*=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0Trainer=E5=8F=82=E6=95=B0metric=5Fkey=EF=BC=8C?= =?UTF-8?q?=E6=8C=87=E6=98=8E=E7=94=A8=E6=9D=A5=E5=81=9A=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E9=80=89=E6=8B=A9=E7=9A=84=E6=8C=87=E6=A0=87=E7=9A=84=E5=90=8D?= =?UTF-8?q?=E5=AD=97=20*=20=E5=9C=A8Trainer=E6=B7=BB=E5=8A=A0=E5=A4=84?= =?UTF-8?q?=E7=90=86tester=E8=BF=94=E5=9B=9E=E7=9A=84=E8=AF=84=E4=BB=B7?= =?UTF-8?q?=E6=8C=87=E6=A0=87=E7=9A=84=E9=80=BB=E8=BE=91=EF=BC=8C=E9=80=89?= =?UTF-8?q?=E6=8B=A9=E5=BD=93=E5=89=8D=E6=9C=80=E5=A5=BD=E7=9A=84=E6=A8=A1?= =?UTF-8?q?=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 168 ++++++++++++++++++++++++---------------- 1 file changed, 102 insertions(+), 66 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index da8e54f9..d4bedb6f 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -1,39 +1,38 @@ -import itertools import os import time import warnings -from collections import defaultdict from datetime import datetime from datetime import timedelta import torch -from torch import nn from tensorboardX import SummaryWriter +from torch import nn from fastNLP.core.batch import Batch +from fastNLP.core.dataset import DataSet +from fastNLP.core.losses import _prepare_losser +from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.optimizer import Optimizer from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester +from fastNLP.core.utils import CheckError from fastNLP.core.utils import _build_args from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import get_func_signature -from fastNLP.core.dataset import DataSet -from fastNLP.core.losses import LossBase -from fastNLP.core.metrics import MetricBase -from fastNLP.core.losses import _prepare_losser -from fastNLP.core.metrics import _prepare_metrics -from fastNLP.core.utils import CheckError class Trainer(object): """Main Training Loop """ - def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, + + def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, + validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, + metric_key=None, **kwargs): super(Trainer, self).__init__() @@ -50,6 +49,13 @@ class Trainer(object): # prepare evaluate metrics = _prepare_metrics(metrics) + + # parse metric_key + # increase_better is True. It means the exp result gets better if the indicator increases. + # It is true by default. + self.increase_better = False if metric_key[0] == "-" else True + self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key + # prepare loss losser = _prepare_losser(losser) @@ -67,12 +73,10 @@ class Trainer(object): self.save_path = save_path self.print_every = int(print_every) self.validate_every = int(validate_every) - self._best_accuracy = 0 + self.best_metric_indicator = None self._model_device = model.parameters().__next__().device - # TODO self._best_accuracy不能表现出当前的metric多种的情况 - if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer else: @@ -102,7 +106,7 @@ class Trainer(object): if torch.cuda.is_available() and self.use_cuda: self.model = self.model.cuda() - self.mode(self.model, is_test=False) + self._mode(self.model, is_test=False) start = time.time() self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) @@ -112,7 +116,9 @@ class Trainer(object): def __getattr__(self, item): def pass_func(*args, **kwargs): pass + return pass_func + self._summary_writer = psudoSW() else: path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) @@ -121,19 +127,20 @@ class Trainer(object): epoch = 1 while epoch <= self.n_epochs: - data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) + data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), + as_numpy=False) - self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) + self._train_epoch(data_iterator, self.model, epoch, start) # validate_every override validation at end of epochs if self.dev_data and self.validate_every <= 0: - self.do_validation() + self._do_validation() epoch += 1 finally: self._summary_writer.close() del self._summary_writer - def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): + def _train_epoch(self, data_iterator, model, epoch, start): """Training process in one epoch. kwargs should contain: @@ -144,10 +151,10 @@ class Trainer(object): for batch_x, batch_y in data_iterator: # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 _move_dict_value_to_device(self._model_device, batch_x, batch_y) - prediction = self.data_forward(model, batch_x) - loss = self.get_loss(prediction, batch_y) - self.grad_backward(loss) - self.update() + prediction = self._data_forward(model, batch_x) + loss = self._compute_loss(prediction, batch_y) + self._grad_backward(loss) + self._update() self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) for name, param in self.model.named_parameters(): if param.requires_grad: @@ -162,18 +169,19 @@ class Trainer(object): print(print_output) if self.validate_every > 0 and self.step % self.validate_every == 0: - self.do_validation() + self._do_validation() self.step += 1 - def do_validation(self): + def _do_validation(self): res = self.tester.test() for name, num in res.items(): self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) - if self.save_path is not None and self.best_eval_result(res): - self.save_model(self.model, 'best_model_' + self.start_time) + if self.save_path is not None and self._better_eval_result(res): + self.save_model(self.model, + "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) - def mode(self, model, is_test=False): + def _mode(self, model, is_test=False): """Train mode or Test mode. This is for PyTorch currently. :param model: a PyTorch model @@ -185,20 +193,20 @@ class Trainer(object): else: model.train() - def update(self): + def _update(self): """Perform weight update on a model. """ self.optimizer.step() - def data_forward(self, network, x): + def _data_forward(self, network, x): x = _build_args(network.forward, **x) y = network(**x) if not isinstance(y, dict): raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") return y - def grad_backward(self, loss): + def _grad_backward(self, loss): """Compute gradient with link rules. :param loss: a scalar where back-prop starts @@ -208,7 +216,7 @@ class Trainer(object): self.model.zero_grad() loss.backward() - def get_loss(self, predict, truth): + def _compute_loss(self, predict, truth): """Compute loss given prediction and ground truth. :param predict: prediction dict, produced by model.forward @@ -224,27 +232,52 @@ class Trainer(object): else: torch.save(model, model_name) - def best_eval_result(self, metrics): + def _better_eval_result(self, metrics): """Check if the current epoch yields better validation results. - :return: bool, True means current results on dev set is the best. + :return bool value: True means current results on dev set is the best. """ if isinstance(metrics, tuple): loss, metrics = metrics if isinstance(metrics, dict): if len(metrics) == 1: - accuracy = list(metrics.values())[0] + # only single metric, just use it + metric_dict = list(metrics.values())[0] + metrics_name = list(metrics.keys())[0] else: - accuracy = metrics[self.eval_sort_key] - else: - accuracy = metrics - - if accuracy > self._best_accuracy: - self._best_accuracy = accuracy - return True - else: - return False + metrics_name = self.metrics[0].__class__.__name__ + if metrics_name not in metrics: + raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") + metric_dict = metrics[metrics_name] + + if len(metric_dict) == 1: + indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] + elif len(metric_dict) > 1 and self.metric_key is None: + raise RuntimeError( + f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") + else: + # metric_key is set + if self.metric_key not in metric_dict: + raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") + indicator_val = metric_dict[self.metric_key] + + is_better = True + if self.best_metric_indicator is None: + # first-time validation + self.best_metric_indicator = indicator_val + else: + if self.increase_better is True: + if indicator_val > self.best_metric_indicator: + self.best_metric_indicator = indicator_val + else: + is_better = False + else: + if indicator_val < self.best_metric_indicator: + self.best_metric_indicator = indicator_val + else: + is_better = False + return is_better DEFAULT_CHECK_BATCH_SIZE = 2 @@ -254,6 +287,7 @@ IGNORE_CHECK_LEVEL = 0 WARNING_CHECK_LEVEL = 1 STRICT_CHECK_LEVEL = 2 + def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, check_level=WARNING_CHECK_LEVEL): @@ -264,7 +298,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ for batch_count, (batch_x, batch_y) in enumerate(batch): _move_dict_value_to_device(model_devcie, batch_x, batch_y) # forward check - if batch_count==0: + if batch_count == 0: _check_forward_error(model_func=model.forward, check_level=check_level, batch_x=batch_x) @@ -285,17 +319,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ if batch_count == 0: if not isinstance(loss, torch.Tensor): raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " - f"but got `{type(loss)}`.") - if len(loss.size())!=0: + f"but got `{type(loss)}`.") + if len(loss.size()) != 0: raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " f"should be torch.size([])") loss.backward() model.zero_grad() - if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: + if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: break if dev_data is not None: - tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, + tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, batch_size=batch_size, verbose=-1) tester.test() @@ -305,18 +339,18 @@ def _check_forward_error(model_func, check_level, batch_x): _missing = '' _unused = '' func_signature = get_func_signature(model_func) - if len(check_res['missing'])!=0: + if len(check_res['missing']) != 0: _missing = "Function {} misses {}, only provided with {}, " \ ".\n".format(func_signature, check_res.missing, - list(batch_x.keys())) - if len(check_res['unused'])!=0: + list(batch_x.keys())) + if len(check_res['unused']) != 0: if len(check_res.unused) > 1: _unused = "{} are not used ".format(check_res.unused) else: _unused = "{} is not used ".format(check_res.unused) _unused += "in function {}.\n".format(func_signature) if _missing: - if len(_unused)>0 and STRICT_CHECK_LEVEL: + if len(_unused) > 0 and STRICT_CHECK_LEVEL: _error_str = "(1).{}\n(2).{}".format(_missing, _unused) else: _error_str = _missing @@ -329,38 +363,40 @@ def _check_forward_error(model_func, check_level, batch_x): elif check_level == WARNING_CHECK_LEVEL: warnings.warn(message=_unused) -def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): + +def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): + check_res = _check_arg_dict_list(func, [output, batch_y]) _missing = '' _unused = '' _duplicated = '' func_signature = get_func_signature(func) prev_func_signature = get_func_signature(prev_func) - if len(check_res.missing)>0: + if len(check_res.missing) > 0: _missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ "{}(from target in Dataset)." \ - .format(func_signature, check_res.missing, - list(output.keys()), prev_func_signature, - list(batch_y.keys())) - if len(check_res.unused)>0: + .format(func_signature, check_res.missing, + list(output.keys()), prev_func_signature, + list(batch_y.keys())) + if len(check_res.unused) > 0: if len(check_res.unused) > 1: _unused = "{} are not used ".format(check_res.unused) else: _unused = "{} is not used ".format(check_res.unused) _unused += "in function {}.\n".format(func_signature) - if len(check_res.duplicated)>0: + if len(check_res.duplicated) > 0: if len(check_res.duplicated) > 1: _duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ "them in {} at the same time.".format(check_res.duplicated, - func_signature, - check_res.duplicated, - prev_func_signature) - else: - _duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ - "it in {} at the same time.".format(check_res.duplicated, func_signature, check_res.duplicated, prev_func_signature) - _number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) + else: + _duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ + "it in {} at the same time.".format(check_res.duplicated, + func_signature, + check_res.duplicated, + prev_func_signature) + _number_errs = int(len(_missing) != 0) + int(len(_duplicated) != 0) + int(len(_unused) != 0) if _number_errs > 0: _error_strs = [] if _number_errs > 1: From 8a7077fed259b0f7ce216bdf82f2999f2a90f17e Mon Sep 17 00:00:00 2001 From: FengZiYjun Date: Sat, 1 Dec 2018 22:21:57 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=9B=B4=E6=96=B0Optimizer:=20optimizer.SG?= =?UTF-8?q?D(lr=3Dxxx);=E5=A6=82=E6=9E=9C=E6=B2=A1=E6=9C=89=E4=BC=A0?= =?UTF-8?q?=E5=85=A5parameters=EF=BC=8C=E5=88=99=E5=9C=A8trainer=E4=B8=AD?= =?UTF-8?q?=E5=B8=AE=E4=BB=96=E5=8A=A0=E5=85=A5parameter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/optimizer.py | 69 ++++++++++--------------------------- fastNLP/core/trainer.py | 8 ++--- test/core/test_optimizer.py | 21 +++++++++++ test/core/test_trainer.py | 1 + 4 files changed, 44 insertions(+), 55 deletions(-) create mode 100644 test/core/test_optimizer.py diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py index ff2ee40e..72737b81 100644 --- a/fastNLP/core/optimizer.py +++ b/fastNLP/core/optimizer.py @@ -2,61 +2,28 @@ import torch class Optimizer(object): - """Wrapper of optimizer from framework + def __init__(self, model_params, **kwargs): + if model_params is not None and not isinstance(model_params, torch.Tensor): + raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) + self.model_params = model_params + self.settings = kwargs - 1. Adam: lr (float), weight_decay (float) - 2. AdaGrad - 3. RMSProp - 4. SGD: lr (float), momentum (float) - """ +class SGD(Optimizer): + def __init__(self, model_params=None, lr=0.001, momentum=0.9): + super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) - def __init__(self, optimizer_name, **kwargs): - """ - :param optimizer_name: str, the name of the optimizer - :param kwargs: the arguments - - """ - self.optim_name = optimizer_name - self.kwargs = kwargs - - @property - def name(self): - """The name of the optimizer. - - :return: str - """ - return self.optim_name + def construct_from_pytorch(self, model_params): + if self.model_params is None: + self.model_params = model_params + return torch.optim.SGD(self.model_params, **self.settings) - @property - def params(self): - """The arguments used to create the optimizer. - :return: dict of (str, *) - """ - return self.kwargs +class Adam(Optimizer): + def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): + super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) def construct_from_pytorch(self, model_params): - """Construct a optimizer from framework over given model parameters.""" - - if self.optim_name in ["SGD", "sgd"]: - if "lr" in self.kwargs: - if "momentum" not in self.kwargs: - self.kwargs["momentum"] = 0 - optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) - else: - raise ValueError("requires learning rate for SGD optimizer") - - elif self.optim_name in ["adam", "Adam"]: - if "lr" in self.kwargs: - if "weight_decay" not in self.kwargs: - self.kwargs["weight_decay"] = 0 - optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], - weight_decay=self.kwargs["weight_decay"]) - else: - raise ValueError("requires learning rate for Adam optimizer") - - else: - raise NotImplementedError - - return optimizer + if self.model_params is None: + self.model_params = model_params + return torch.optim.Adam(self.model_params, **self.settings) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index d4bedb6f..fb9ba25b 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -12,7 +12,7 @@ from fastNLP.core.batch import Batch from fastNLP.core.dataset import DataSet from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics -from fastNLP.core.optimizer import Optimizer +from fastNLP.core.optimizer import Adam from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester @@ -31,7 +31,7 @@ class Trainer(object): def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, dev_data=None, use_cuda=False, save_path="./save", - optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, + optimizer=Adam(lr=0.01, weight_decay=0), need_check_code=True, metric_key=None, **kwargs): super(Trainer, self).__init__() @@ -178,7 +178,7 @@ class Trainer(object): for name, num in res.items(): self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) if self.save_path is not None and self._better_eval_result(res): - self.save_model(self.model, + self._save_model(self.model, "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) def _mode(self, model, is_test=False): @@ -225,7 +225,7 @@ class Trainer(object): """ return self.losser(predict, truth) - def save_model(self, model, model_name, only_param=False): + def _save_model(self, model, model_name, only_param=False): model_name = os.path.join(self.save_path, model_name) if only_param: torch.save(model.state_dict(), model_name) diff --git a/test/core/test_optimizer.py b/test/core/test_optimizer.py new file mode 100644 index 00000000..26e47d43 --- /dev/null +++ b/test/core/test_optimizer.py @@ -0,0 +1,21 @@ +import unittest + +import torch + +from fastNLP.core.optimizer import SGD + + +class TestOptim(unittest.TestCase): + def test_case(self): + optim = SGD(torch.LongTensor(10)) + print(optim.__dict__) + + optim_2 = SGD(lr=0.001) + print(optim_2.__dict__) + + optim_2 = SGD(lr=0.002, momentum=0.989) + print(optim_2.__dict__) + + def test_case_2(self): + with self.assertRaises(RuntimeError): + _ = SGD(0.001) diff --git a/test/core/test_trainer.py b/test/core/test_trainer.py index 7c0a1a9d..08df6a49 100644 --- a/test/core/test_trainer.py +++ b/test/core/test_trainer.py @@ -4,3 +4,4 @@ import unittest class TestTrainer(unittest.TestCase): def test_case_1(self): pass + From 6d36190be4a221234372e58fd9e45bd03d6a0416 Mon Sep 17 00:00:00 2001 From: xuyige Date: Sat, 1 Dec 2018 22:44:24 +0800 Subject: [PATCH 4/4] update LossBase class --- fastNLP/core/losses.py | 100 ++++++++++++++++++++++++++++++----------- test/core/test_loss.py | 74 +++++++++++++++++++++++++++--- 2 files changed, 143 insertions(+), 31 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index aa1ffb89..66664859 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -1,23 +1,29 @@ import torch +import torch.nn.functional as F +from fastNLP.core.utils import CheckError +from fastNLP.core.utils import CheckRes from fastNLP.core.utils import _get_arg_list from fastNLP.core.utils import _map_args from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import _build_args +from fastNLP.core.utils import _check_function_or_method class LossBase(object): def __init__(self): # key: name in target function; value: name in output function self.param_map = {} + self._checked = False def get_loss(self, *args, **kwargs): raise NotImplementedError - def __call__(self, output_dict, target_dict): + def __call__(self, output_dict, target_dict, force_check=False): """ :param output_dict: A dict from forward function of the network. :param target_dict: A dict from DataSet.batch_y. + :param force_check: Boolean. Force to check the mapping functions when it is running. :return: """ args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) @@ -27,50 +33,94 @@ class LossBase(object): ) param_map = self.param_map - for keys in args: - if keys not in param_map: - param_map.update({keys: keys}) - for keys in defaults: - if keys not in param_map: - param_map.update({keys: keys}) + if args is None: + raise RuntimeError( + f"There is not any param in function{get_func_signature(self.get_loss)}" + ) + self._checked = self._checked and not force_check + if not self._checked: + for keys in args: + if keys not in param_map: + param_map.update({keys: keys}) + if defaults is not None: + for keys in defaults: + if keys not in param_map: + param_map.update({keys: keys}) + self.param_map = param_map # param map: key= name in get_loss function, value= name in param dict - reversed_param_map = {val: key for key, val in param_map} + reversed_param_map = {val: key for key, val in param_map.items()} # reversed param map: key= name in param dict, value= name in get_loss function + duplicated = [] + missing = [] + if not self._checked: + for keys, val in output_dict.items(): + if keys in target_dict.keys(): + duplicated.append(keys) + param_val_dict = {} for keys, val in output_dict.items(): - if keys not in target_dict.keys(): - param_val_dict.update({keys: val}) - else: - raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) + param_val_dict.update({keys: val}) for keys, val in target_dict.items(): - if keys not in output_dict.keys(): - param_val_dict.update({keys: val}) - else: - raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) + param_val_dict.update({keys: val}) - for keys in args: - if param_map[keys] not in param_val_dict.keys(): - raise RuntimeError(f"missing param {keys} in function {get_func_signature(self.get_loss)}") + if not self._checked: + for keys in args: + if param_map[keys] not in param_val_dict.keys(): + missing.append(keys) + + if len(duplicated) > 0 or len(missing) > 0: + raise CheckError( + CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[]), + func_signature=get_func_signature(self.get_loss) + ) + + self._checked = True param_map_val = _map_args(reversed_param_map, **param_val_dict) - param_value = _build_args(**param_map_val) + param_value = _build_args(self.get_loss, **param_map_val) loss = self.get_loss(**param_value) if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not isinstance(loss, torch.Tensor): - raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss))) - raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size()))) + raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") + raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") return loss class NewLoss(LossBase): def __init__(self, func, key_map=None, **kwargs): - super(NewLoss).__init__() - if not callable(func): - raise RuntimeError("") + super(NewLoss, self).__init__() + _check_function_or_method(func) + if key_map is not None: + if not isinstance(key_map, dict): + raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") + self.param_map = key_map + if len(kwargs) > 0: + for key, val in kwargs.items(): + self.param_map.update({key: val}) + + self.get_loss = func + + +class L1Loss(LossBase): + def __init__(self): + super(L1Loss, self).__init__() + self.get_loss = F.l1_loss + + +class BCELoss(LossBase): + def __init__(self): + super(BCELoss, self).__init__() + self.get_loss = F.binary_cross_entropy + + +class NLLLoss(LossBase): + def __init__(self): + super(NLLLoss, self).__init__() + self.get_loss = F.nll_loss class LossInForward(LossBase): diff --git a/test/core/test_loss.py b/test/core/test_loss.py index fdde4f0e..fddc56e9 100644 --- a/test/core/test_loss.py +++ b/test/core/test_loss.py @@ -2,6 +2,7 @@ import math import unittest import torch as tc +import torch.nn.functional as F import fastNLP.core.losses as loss @@ -13,7 +14,11 @@ class TestLoss(unittest.TestCase): print (".----------------------------------") - loss_func = loss.Loss("nll") + # loss_func = loss.Loss("nll") + print(callable(tc.nn.NLLLoss)) + loss_func = loss.NewLoss(F.nll_loss) + + nll_loss = loss.NLLLoss() #pdb.set_trace() @@ -35,16 +40,18 @@ class TestLoss(unittest.TestCase): y = tc.log(y) - los = loss_func(y , gy) + los = loss_func({'input': y}, {'target': gy}) + losses = nll_loss({'input': y}, {'target': gy}) r = -math.log(.3) - math.log(.3) - math.log(.1) r /= 3 print ("loss = %f" % (los)) print ("r = %f" % (r)) + print ("nll_loss = %f" % (losses)) self.assertEqual(int(los * 1000), int(r * 1000)) - def test_case_2(self): + def _test_case_2(self): #验证squash()的正确性 print ("----------------------------------") @@ -74,7 +81,8 @@ class TestLoss(unittest.TestCase): #pdb.set_trace() y = tc.log(y) - los = loss_func(y , gy) + #los = loss_func({'input': y}, {'target': gy}) + los = loss_func(y, gy) print ("loss = %f" % (los)) r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) @@ -89,7 +97,8 @@ class TestLoss(unittest.TestCase): log = math.log - loss_func = loss.Loss("nll") + #loss_func = loss.Loss("nll") + loss_func = loss.NLLLoss() #pdb.set_trace() @@ -117,7 +126,7 @@ class TestLoss(unittest.TestCase): yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data - los = loss_func(yy , gyy) + los = loss_func({'input': yy}, {'target': gyy}) print ("loss = %f" % (los)) @@ -303,5 +312,58 @@ class TestLoss(unittest.TestCase): print ("r = %f" % (r)) self.assertEqual(int(los * 1000), int(r * 1000)) + def test_case_8(self): + def func(a, b): + import torch.nn.functional as F + return F.cross_entropy(a, b) + + def func2(a, truth): + return func(a, truth) + + def func3(predict, truth): + return func(predict, truth) + + def func4(a, b, c=2): + return (a + b) * c + + def func6(a, b, **kwargs): + c = kwargs['c'] + return (a + b) * c + + import torch + from fastNLP.core.losses import LossBase, NewLoss + + get_loss = NewLoss(func, {'a': 'predict', 'b': 'truth'}) + predict = torch.randn(5, 3) + truth = torch.LongTensor([1, 0, 1, 2, 1]) + loss1 = get_loss({'predict': predict}, {'truth': truth}) + get_loss_2 = NewLoss(func2, {'a': 'predict'}) + loss2 = get_loss_2({'predict': predict}, {'truth': truth}) + get_loss_3 = NewLoss(func3) + loss3 = get_loss_3({'predict': predict}, {'truth': truth}) + print(loss1, loss2, loss3) + assert loss1 == loss2 and loss1 == loss3 + + get_loss_4 = NewLoss(func4) + loss4 = get_loss_4({'a': 1, 'b': 3}, {}) + print(loss4) + assert loss4 == (1 + 3) * 2 + + get_loss_5 = NewLoss(func4) + loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) + print(loss5) + assert loss5 == (1 + 3) * 4 + + get_loss_6 = NewLoss(func6) + loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) + print(loss6) + assert loss6 == (1 + 3) * 4 + + get_loss_7 = NewLoss(func6, c='cc') + loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) + print(loss7) + assert loss7 == (1 + 3) * 4 + + if __name__ == "__main__": unittest.main()