@@ -1,23 +1,29 @@ | |||||
import torch | import torch | ||||
import torch.nn.functional as F | |||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import CheckRes | |||||
from fastNLP.core.utils import _get_arg_list | from fastNLP.core.utils import _get_arg_list | ||||
from fastNLP.core.utils import _map_args | from fastNLP.core.utils import _map_args | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_function_or_method | |||||
class LossBase(object): | class LossBase(object): | ||||
def __init__(self): | def __init__(self): | ||||
# key: name in target function; value: name in output function | # key: name in target function; value: name in output function | ||||
self.param_map = {} | self.param_map = {} | ||||
self._checked = False | |||||
def get_loss(self, *args, **kwargs): | def get_loss(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def __call__(self, output_dict, target_dict): | |||||
def __call__(self, output_dict, target_dict, force_check=False): | |||||
""" | """ | ||||
:param output_dict: A dict from forward function of the network. | :param output_dict: A dict from forward function of the network. | ||||
:param target_dict: A dict from DataSet.batch_y. | :param target_dict: A dict from DataSet.batch_y. | ||||
:param force_check: Boolean. Force to check the mapping functions when it is running. | |||||
:return: | :return: | ||||
""" | """ | ||||
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) | args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) | ||||
@@ -27,50 +33,94 @@ class LossBase(object): | |||||
) | ) | ||||
param_map = self.param_map | param_map = self.param_map | ||||
for keys in args: | |||||
if keys not in param_map: | |||||
param_map.update({keys: keys}) | |||||
for keys in defaults: | |||||
if keys not in param_map: | |||||
param_map.update({keys: keys}) | |||||
if args is None: | |||||
raise RuntimeError( | |||||
f"There is not any param in function{get_func_signature(self.get_loss)}" | |||||
) | |||||
self._checked = self._checked and not force_check | |||||
if not self._checked: | |||||
for keys in args: | |||||
if keys not in param_map: | |||||
param_map.update({keys: keys}) | |||||
if defaults is not None: | |||||
for keys in defaults: | |||||
if keys not in param_map: | |||||
param_map.update({keys: keys}) | |||||
self.param_map = param_map | |||||
# param map: key= name in get_loss function, value= name in param dict | # param map: key= name in get_loss function, value= name in param dict | ||||
reversed_param_map = {val: key for key, val in param_map} | |||||
reversed_param_map = {val: key for key, val in param_map.items()} | |||||
# reversed param map: key= name in param dict, value= name in get_loss function | # reversed param map: key= name in param dict, value= name in get_loss function | ||||
duplicated = [] | |||||
missing = [] | |||||
if not self._checked: | |||||
for keys, val in output_dict.items(): | |||||
if keys in target_dict.keys(): | |||||
duplicated.append(keys) | |||||
param_val_dict = {} | param_val_dict = {} | ||||
for keys, val in output_dict.items(): | for keys, val in output_dict.items(): | ||||
if keys not in target_dict.keys(): | |||||
param_val_dict.update({keys: val}) | |||||
else: | |||||
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) | |||||
param_val_dict.update({keys: val}) | |||||
for keys, val in target_dict.items(): | for keys, val in target_dict.items(): | ||||
if keys not in output_dict.keys(): | |||||
param_val_dict.update({keys: val}) | |||||
else: | |||||
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys)) | |||||
param_val_dict.update({keys: val}) | |||||
for keys in args: | |||||
if param_map[keys] not in param_val_dict.keys(): | |||||
raise RuntimeError(f"missing param {keys} in function {get_func_signature(self.get_loss)}") | |||||
if not self._checked: | |||||
for keys in args: | |||||
if param_map[keys] not in param_val_dict.keys(): | |||||
missing.append(keys) | |||||
if len(duplicated) > 0 or len(missing) > 0: | |||||
raise CheckError( | |||||
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[]), | |||||
func_signature=get_func_signature(self.get_loss) | |||||
) | |||||
self._checked = True | |||||
param_map_val = _map_args(reversed_param_map, **param_val_dict) | param_map_val = _map_args(reversed_param_map, **param_val_dict) | ||||
param_value = _build_args(**param_map_val) | |||||
param_value = _build_args(self.get_loss, **param_map_val) | |||||
loss = self.get_loss(**param_value) | loss = self.get_loss(**param_value) | ||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss))) | |||||
raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size()))) | |||||
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") | |||||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||||
return loss | return loss | ||||
class NewLoss(LossBase): | class NewLoss(LossBase): | ||||
def __init__(self, func, key_map=None, **kwargs): | def __init__(self, func, key_map=None, **kwargs): | ||||
super(NewLoss).__init__() | |||||
if not callable(func): | |||||
raise RuntimeError("") | |||||
super(NewLoss, self).__init__() | |||||
_check_function_or_method(func) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}") | |||||
self.param_map = key_map | |||||
if len(kwargs) > 0: | |||||
for key, val in kwargs.items(): | |||||
self.param_map.update({key: val}) | |||||
self.get_loss = func | |||||
class L1Loss(LossBase): | |||||
def __init__(self): | |||||
super(L1Loss, self).__init__() | |||||
self.get_loss = F.l1_loss | |||||
class BCELoss(LossBase): | |||||
def __init__(self): | |||||
super(BCELoss, self).__init__() | |||||
self.get_loss = F.binary_cross_entropy | |||||
class NLLLoss(LossBase): | |||||
def __init__(self): | |||||
super(NLLLoss, self).__init__() | |||||
self.get_loss = F.nll_loss | |||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
@@ -2,61 +2,28 @@ import torch | |||||
class Optimizer(object): | class Optimizer(object): | ||||
"""Wrapper of optimizer from framework | |||||
def __init__(self, model_params, **kwargs): | |||||
if model_params is not None and not isinstance(model_params, torch.Tensor): | |||||
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params))) | |||||
self.model_params = model_params | |||||
self.settings = kwargs | |||||
1. Adam: lr (float), weight_decay (float) | |||||
2. AdaGrad | |||||
3. RMSProp | |||||
4. SGD: lr (float), momentum (float) | |||||
""" | |||||
class SGD(Optimizer): | |||||
def __init__(self, model_params=None, lr=0.001, momentum=0.9): | |||||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||||
def __init__(self, optimizer_name, **kwargs): | |||||
""" | |||||
:param optimizer_name: str, the name of the optimizer | |||||
:param kwargs: the arguments | |||||
""" | |||||
self.optim_name = optimizer_name | |||||
self.kwargs = kwargs | |||||
@property | |||||
def name(self): | |||||
"""The name of the optimizer. | |||||
:return: str | |||||
""" | |||||
return self.optim_name | |||||
def construct_from_pytorch(self, model_params): | |||||
if self.model_params is None: | |||||
self.model_params = model_params | |||||
return torch.optim.SGD(self.model_params, **self.settings) | |||||
@property | |||||
def params(self): | |||||
"""The arguments used to create the optimizer. | |||||
:return: dict of (str, *) | |||||
""" | |||||
return self.kwargs | |||||
class Adam(Optimizer): | |||||
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8): | |||||
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay) | |||||
def construct_from_pytorch(self, model_params): | def construct_from_pytorch(self, model_params): | ||||
"""Construct a optimizer from framework over given model parameters.""" | |||||
if self.optim_name in ["SGD", "sgd"]: | |||||
if "lr" in self.kwargs: | |||||
if "momentum" not in self.kwargs: | |||||
self.kwargs["momentum"] = 0 | |||||
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"]) | |||||
else: | |||||
raise ValueError("requires learning rate for SGD optimizer") | |||||
elif self.optim_name in ["adam", "Adam"]: | |||||
if "lr" in self.kwargs: | |||||
if "weight_decay" not in self.kwargs: | |||||
self.kwargs["weight_decay"] = 0 | |||||
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"], | |||||
weight_decay=self.kwargs["weight_decay"]) | |||||
else: | |||||
raise ValueError("requires learning rate for Adam optimizer") | |||||
else: | |||||
raise NotImplementedError | |||||
return optimizer | |||||
if self.model_params is None: | |||||
self.model_params = model_params | |||||
return torch.optim.Adam(self.model_params, **self.settings) |
@@ -1,20 +1,22 @@ | |||||
import itertools | |||||
import os | import os | ||||
import time | import time | ||||
import warnings | import warnings | ||||
from collections import defaultdict | |||||
from datetime import datetime | from datetime import datetime | ||||
from datetime import timedelta | from datetime import timedelta | ||||
import torch | import torch | ||||
from torch import nn | |||||
from tensorboardX import SummaryWriter | from tensorboardX import SummaryWriter | ||||
from torch import nn | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.optimizer import Optimizer | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.losses import _prepare_losser | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.optimizer import Adam | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
@@ -30,9 +32,12 @@ class Trainer(object): | |||||
"""Main Training Loop | """Main Training Loop | ||||
""" | """ | ||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, | |||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, | |||||
validate_every=-1, | |||||
dev_data=None, use_cuda=False, save_path="./save", | dev_data=None, use_cuda=False, save_path="./save", | ||||
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0, | |||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | |||||
metric_key=None, | |||||
**kwargs): | **kwargs): | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
@@ -49,6 +54,13 @@ class Trainer(object): | |||||
# prepare evaluate | # prepare evaluate | ||||
metrics = _prepare_metrics(metrics) | metrics = _prepare_metrics(metrics) | ||||
# parse metric_key | |||||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||||
# It is true by default. | |||||
self.increase_better = False if metric_key[0] == "-" else True | |||||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||||
# prepare loss | # prepare loss | ||||
losser = _prepare_losser(losser) | losser = _prepare_losser(losser) | ||||
@@ -67,12 +79,10 @@ class Trainer(object): | |||||
self.save_path = save_path | self.save_path = save_path | ||||
self.print_every = int(print_every) | self.print_every = int(print_every) | ||||
self.validate_every = int(validate_every) | self.validate_every = int(validate_every) | ||||
self._best_accuracy = 0 | |||||
self.best_metric_indicator = None | |||||
self._model_device = model.parameters().__next__().device | self._model_device = model.parameters().__next__().device | ||||
# TODO self._best_accuracy不能表现出当前的metric多种的情况 | |||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
else: | else: | ||||
@@ -102,7 +112,7 @@ class Trainer(object): | |||||
if torch.cuda.is_available() and self.use_cuda: | if torch.cuda.is_available() and self.use_cuda: | ||||
self.model = self.model.cuda() | self.model = self.model.cuda() | ||||
self.mode(self.model, is_test=False) | |||||
self._mode(self.model, is_test=False) | |||||
start = time.time() | start = time.time() | ||||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | ||||
@@ -112,7 +122,9 @@ class Trainer(object): | |||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
def pass_func(*args, **kwargs): | def pass_func(*args, **kwargs): | ||||
pass | pass | ||||
return pass_func | return pass_func | ||||
self._summary_writer = psudoSW() | self._summary_writer = psudoSW() | ||||
else: | else: | ||||
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) | ||||
@@ -121,19 +133,20 @@ class Trainer(object): | |||||
epoch = 1 | epoch = 1 | ||||
while epoch <= self.n_epochs: | while epoch <= self.n_epochs: | ||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) | |||||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), | |||||
as_numpy=False) | |||||
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) | |||||
self._train_epoch(data_iterator, self.model, epoch, start) | |||||
# validate_every override validation at end of epochs | # validate_every override validation at end of epochs | ||||
if self.dev_data and self.validate_every <= 0: | if self.dev_data and self.validate_every <= 0: | ||||
self.do_validation() | |||||
self._do_validation() | |||||
epoch += 1 | epoch += 1 | ||||
finally: | finally: | ||||
self._summary_writer.close() | self._summary_writer.close() | ||||
del self._summary_writer | del self._summary_writer | ||||
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): | |||||
def _train_epoch(self, data_iterator, model, epoch, start): | |||||
"""Training process in one epoch. | """Training process in one epoch. | ||||
kwargs should contain: | kwargs should contain: | ||||
@@ -144,10 +157,10 @@ class Trainer(object): | |||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | ||||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | _move_dict_value_to_device(self._model_device, batch_x, batch_y) | ||||
prediction = self.data_forward(model, batch_x) | |||||
loss = self.get_loss(prediction, batch_y) | |||||
self.grad_backward(loss) | |||||
self.update() | |||||
prediction = self._data_forward(model, batch_x) | |||||
loss = self._compute_loss(prediction, batch_y) | |||||
self._grad_backward(loss) | |||||
self._update() | |||||
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) | ||||
for name, param in self.model.named_parameters(): | for name, param in self.model.named_parameters(): | ||||
if param.requires_grad: | if param.requires_grad: | ||||
@@ -162,18 +175,19 @@ class Trainer(object): | |||||
print(print_output) | print(print_output) | ||||
if self.validate_every > 0 and self.step % self.validate_every == 0: | if self.validate_every > 0 and self.step % self.validate_every == 0: | ||||
self.do_validation() | |||||
self._do_validation() | |||||
self.step += 1 | self.step += 1 | ||||
def do_validation(self): | |||||
def _do_validation(self): | |||||
res = self.tester.test() | res = self.tester.test() | ||||
for name, num in res.items(): | for name, num in res.items(): | ||||
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) | self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) | ||||
if self.save_path is not None and self.best_eval_result(res): | |||||
self.save_model(self.model, 'best_model_' + self.start_time) | |||||
if self.save_path is not None and self._better_eval_result(res): | |||||
self._save_model(self.model, | |||||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||||
def mode(self, model, is_test=False): | |||||
def _mode(self, model, is_test=False): | |||||
"""Train mode or Test mode. This is for PyTorch currently. | """Train mode or Test mode. This is for PyTorch currently. | ||||
:param model: a PyTorch model | :param model: a PyTorch model | ||||
@@ -185,20 +199,20 @@ class Trainer(object): | |||||
else: | else: | ||||
model.train() | model.train() | ||||
def update(self): | |||||
def _update(self): | |||||
"""Perform weight update on a model. | """Perform weight update on a model. | ||||
""" | """ | ||||
self.optimizer.step() | self.optimizer.step() | ||||
def data_forward(self, network, x): | |||||
def _data_forward(self, network, x): | |||||
x = _build_args(network.forward, **x) | x = _build_args(network.forward, **x) | ||||
y = network(**x) | y = network(**x) | ||||
if not isinstance(y, dict): | if not isinstance(y, dict): | ||||
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") | ||||
return y | return y | ||||
def grad_backward(self, loss): | |||||
def _grad_backward(self, loss): | |||||
"""Compute gradient with link rules. | """Compute gradient with link rules. | ||||
:param loss: a scalar where back-prop starts | :param loss: a scalar where back-prop starts | ||||
@@ -208,7 +222,7 @@ class Trainer(object): | |||||
self.model.zero_grad() | self.model.zero_grad() | ||||
loss.backward() | loss.backward() | ||||
def get_loss(self, predict, truth): | |||||
def _compute_loss(self, predict, truth): | |||||
"""Compute loss given prediction and ground truth. | """Compute loss given prediction and ground truth. | ||||
:param predict: prediction dict, produced by model.forward | :param predict: prediction dict, produced by model.forward | ||||
@@ -217,34 +231,59 @@ class Trainer(object): | |||||
""" | """ | ||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def save_model(self, model, model_name, only_param=False): | |||||
def _save_model(self, model, model_name, only_param=False): | |||||
model_name = os.path.join(self.save_path, model_name) | model_name = os.path.join(self.save_path, model_name) | ||||
if only_param: | if only_param: | ||||
torch.save(model.state_dict(), model_name) | torch.save(model.state_dict(), model_name) | ||||
else: | else: | ||||
torch.save(model, model_name) | torch.save(model, model_name) | ||||
def best_eval_result(self, metrics): | |||||
def _better_eval_result(self, metrics): | |||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
:return: bool, True means current results on dev set is the best. | |||||
:return bool value: True means current results on dev set is the best. | |||||
""" | """ | ||||
if isinstance(metrics, tuple): | if isinstance(metrics, tuple): | ||||
loss, metrics = metrics | loss, metrics = metrics | ||||
if isinstance(metrics, dict): | if isinstance(metrics, dict): | ||||
if len(metrics) == 1: | if len(metrics) == 1: | ||||
accuracy = list(metrics.values())[0] | |||||
# only single metric, just use it | |||||
metric_dict = list(metrics.values())[0] | |||||
metrics_name = list(metrics.keys())[0] | |||||
else: | else: | ||||
accuracy = metrics[self.eval_sort_key] | |||||
else: | |||||
accuracy = metrics | |||||
if accuracy > self._best_accuracy: | |||||
self._best_accuracy = accuracy | |||||
return True | |||||
else: | |||||
return False | |||||
metrics_name = self.metrics[0].__class__.__name__ | |||||
if metrics_name not in metrics: | |||||
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") | |||||
metric_dict = metrics[metrics_name] | |||||
if len(metric_dict) == 1: | |||||
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] | |||||
elif len(metric_dict) > 1 and self.metric_key is None: | |||||
raise RuntimeError( | |||||
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") | |||||
else: | |||||
# metric_key is set | |||||
if self.metric_key not in metric_dict: | |||||
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") | |||||
indicator_val = metric_dict[self.metric_key] | |||||
is_better = True | |||||
if self.best_metric_indicator is None: | |||||
# first-time validation | |||||
self.best_metric_indicator = indicator_val | |||||
else: | |||||
if self.increase_better is True: | |||||
if indicator_val > self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | |||||
else: | |||||
is_better = False | |||||
else: | |||||
if indicator_val < self.best_metric_indicator: | |||||
self.best_metric_indicator = indicator_val | |||||
else: | |||||
is_better = False | |||||
return is_better | |||||
DEFAULT_CHECK_BATCH_SIZE = 2 | DEFAULT_CHECK_BATCH_SIZE = 2 | ||||
@@ -285,18 +324,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
f"should be torch.size([])") | f"should be torch.size([])") | ||||
loss.backward() | loss.backward() | ||||
except CheckError as e: | except CheckError as e: | ||||
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature, | |||||
pre_func_signature = get_func_signature(model.forward) | |||||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | |||||
check_res=e.check_res, output=output, batch_y=batch_y, | check_res=e.check_res, output=output, batch_y=batch_y, | ||||
check_level=check_level) | check_level=check_level) | ||||
model.zero_grad() | model.zero_grad() | ||||
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: | |||||
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: | |||||
break | break | ||||
if dev_data is not None: | if dev_data is not None: | ||||
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||||
batch_size=batch_size, verbose=-1) | batch_size=batch_size, verbose=-1) | ||||
tester.test() | tester.test() | ||||
@@ -2,6 +2,7 @@ import math | |||||
import unittest | import unittest | ||||
import torch as tc | import torch as tc | ||||
import torch.nn.functional as F | |||||
import fastNLP.core.losses as loss | import fastNLP.core.losses as loss | ||||
@@ -13,7 +14,11 @@ class TestLoss(unittest.TestCase): | |||||
print (".----------------------------------") | print (".----------------------------------") | ||||
loss_func = loss.Loss("nll") | |||||
# loss_func = loss.Loss("nll") | |||||
print(callable(tc.nn.NLLLoss)) | |||||
loss_func = loss.NewLoss(F.nll_loss) | |||||
nll_loss = loss.NLLLoss() | |||||
#pdb.set_trace() | #pdb.set_trace() | ||||
@@ -35,16 +40,18 @@ class TestLoss(unittest.TestCase): | |||||
y = tc.log(y) | y = tc.log(y) | ||||
los = loss_func(y , gy) | |||||
los = loss_func({'input': y}, {'target': gy}) | |||||
losses = nll_loss({'input': y}, {'target': gy}) | |||||
r = -math.log(.3) - math.log(.3) - math.log(.1) | r = -math.log(.3) - math.log(.3) - math.log(.1) | ||||
r /= 3 | r /= 3 | ||||
print ("loss = %f" % (los)) | print ("loss = %f" % (los)) | ||||
print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
print ("nll_loss = %f" % (losses)) | |||||
self.assertEqual(int(los * 1000), int(r * 1000)) | self.assertEqual(int(los * 1000), int(r * 1000)) | ||||
def test_case_2(self): | |||||
def _test_case_2(self): | |||||
#验证squash()的正确性 | #验证squash()的正确性 | ||||
print ("----------------------------------") | print ("----------------------------------") | ||||
@@ -74,7 +81,8 @@ class TestLoss(unittest.TestCase): | |||||
#pdb.set_trace() | #pdb.set_trace() | ||||
y = tc.log(y) | y = tc.log(y) | ||||
los = loss_func(y , gy) | |||||
#los = loss_func({'input': y}, {'target': gy}) | |||||
los = loss_func(y, gy) | |||||
print ("loss = %f" % (los)) | print ("loss = %f" % (los)) | ||||
r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) | r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) | ||||
@@ -89,7 +97,8 @@ class TestLoss(unittest.TestCase): | |||||
log = math.log | log = math.log | ||||
loss_func = loss.Loss("nll") | |||||
#loss_func = loss.Loss("nll") | |||||
loss_func = loss.NLLLoss() | |||||
#pdb.set_trace() | #pdb.set_trace() | ||||
@@ -117,7 +126,7 @@ class TestLoss(unittest.TestCase): | |||||
yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data | yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data | ||||
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data | gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data | ||||
los = loss_func(yy , gyy) | |||||
los = loss_func({'input': yy}, {'target': gyy}) | |||||
print ("loss = %f" % (los)) | print ("loss = %f" % (los)) | ||||
@@ -303,5 +312,58 @@ class TestLoss(unittest.TestCase): | |||||
print ("r = %f" % (r)) | print ("r = %f" % (r)) | ||||
self.assertEqual(int(los * 1000), int(r * 1000)) | self.assertEqual(int(los * 1000), int(r * 1000)) | ||||
def test_case_8(self): | |||||
def func(a, b): | |||||
import torch.nn.functional as F | |||||
return F.cross_entropy(a, b) | |||||
def func2(a, truth): | |||||
return func(a, truth) | |||||
def func3(predict, truth): | |||||
return func(predict, truth) | |||||
def func4(a, b, c=2): | |||||
return (a + b) * c | |||||
def func6(a, b, **kwargs): | |||||
c = kwargs['c'] | |||||
return (a + b) * c | |||||
import torch | |||||
from fastNLP.core.losses import LossBase, NewLoss | |||||
get_loss = NewLoss(func, {'a': 'predict', 'b': 'truth'}) | |||||
predict = torch.randn(5, 3) | |||||
truth = torch.LongTensor([1, 0, 1, 2, 1]) | |||||
loss1 = get_loss({'predict': predict}, {'truth': truth}) | |||||
get_loss_2 = NewLoss(func2, {'a': 'predict'}) | |||||
loss2 = get_loss_2({'predict': predict}, {'truth': truth}) | |||||
get_loss_3 = NewLoss(func3) | |||||
loss3 = get_loss_3({'predict': predict}, {'truth': truth}) | |||||
print(loss1, loss2, loss3) | |||||
assert loss1 == loss2 and loss1 == loss3 | |||||
get_loss_4 = NewLoss(func4) | |||||
loss4 = get_loss_4({'a': 1, 'b': 3}, {}) | |||||
print(loss4) | |||||
assert loss4 == (1 + 3) * 2 | |||||
get_loss_5 = NewLoss(func4) | |||||
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4}) | |||||
print(loss5) | |||||
assert loss5 == (1 + 3) * 4 | |||||
get_loss_6 = NewLoss(func6) | |||||
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4}) | |||||
print(loss6) | |||||
assert loss6 == (1 + 3) * 4 | |||||
get_loss_7 = NewLoss(func6, c='cc') | |||||
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4}) | |||||
print(loss7) | |||||
assert loss7 == (1 + 3) * 4 | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
unittest.main() | unittest.main() |
@@ -0,0 +1,21 @@ | |||||
import unittest | |||||
import torch | |||||
from fastNLP.core.optimizer import SGD | |||||
class TestOptim(unittest.TestCase): | |||||
def test_case(self): | |||||
optim = SGD(torch.LongTensor(10)) | |||||
print(optim.__dict__) | |||||
optim_2 = SGD(lr=0.001) | |||||
print(optim_2.__dict__) | |||||
optim_2 = SGD(lr=0.002, momentum=0.989) | |||||
print(optim_2.__dict__) | |||||
def test_case_2(self): | |||||
with self.assertRaises(RuntimeError): | |||||
_ = SGD(0.001) |
@@ -4,3 +4,4 @@ import unittest | |||||
class TestTrainer(unittest.TestCase): | class TestTrainer(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
pass | pass | ||||