Browse Source

conflict in trainer solved

tags/v0.2.0^2
yh 6 years ago
parent
commit
1b961f136c
6 changed files with 266 additions and 127 deletions
  1. +75
    -25
      fastNLP/core/losses.py
  2. +18
    -51
      fastNLP/core/optimizer.py
  3. +83
    -45
      fastNLP/core/trainer.py
  4. +68
    -6
      test/core/test_loss.py
  5. +21
    -0
      test/core/test_optimizer.py
  6. +1
    -0
      test/core/test_trainer.py

+ 75
- 25
fastNLP/core/losses.py View File

@@ -1,23 +1,29 @@
import torch import torch
import torch.nn.functional as F


from fastNLP.core.utils import CheckError
from fastNLP.core.utils import CheckRes
from fastNLP.core.utils import _get_arg_list from fastNLP.core.utils import _get_arg_list
from fastNLP.core.utils import _map_args from fastNLP.core.utils import _map_args
from fastNLP.core.utils import get_func_signature from fastNLP.core.utils import get_func_signature
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_function_or_method




class LossBase(object): class LossBase(object):
def __init__(self): def __init__(self):
# key: name in target function; value: name in output function # key: name in target function; value: name in output function
self.param_map = {} self.param_map = {}
self._checked = False


def get_loss(self, *args, **kwargs): def get_loss(self, *args, **kwargs):
raise NotImplementedError raise NotImplementedError


def __call__(self, output_dict, target_dict):
def __call__(self, output_dict, target_dict, force_check=False):
""" """
:param output_dict: A dict from forward function of the network. :param output_dict: A dict from forward function of the network.
:param target_dict: A dict from DataSet.batch_y. :param target_dict: A dict from DataSet.batch_y.
:param force_check: Boolean. Force to check the mapping functions when it is running.
:return: :return:
""" """
args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss) args, defaults, defaults_val, varargs, kwargs = _get_arg_list(self.get_loss)
@@ -27,50 +33,94 @@ class LossBase(object):
) )


param_map = self.param_map param_map = self.param_map
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
if args is None:
raise RuntimeError(
f"There is not any param in function{get_func_signature(self.get_loss)}"
)
self._checked = self._checked and not force_check
if not self._checked:
for keys in args:
if keys not in param_map:
param_map.update({keys: keys})
if defaults is not None:
for keys in defaults:
if keys not in param_map:
param_map.update({keys: keys})
self.param_map = param_map
# param map: key= name in get_loss function, value= name in param dict # param map: key= name in get_loss function, value= name in param dict
reversed_param_map = {val: key for key, val in param_map}
reversed_param_map = {val: key for key, val in param_map.items()}
# reversed param map: key= name in param dict, value= name in get_loss function # reversed param map: key= name in param dict, value= name in get_loss function


duplicated = []
missing = []
if not self._checked:
for keys, val in output_dict.items():
if keys in target_dict.keys():
duplicated.append(keys)

param_val_dict = {} param_val_dict = {}
for keys, val in output_dict.items(): for keys, val in output_dict.items():
if keys not in target_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))
param_val_dict.update({keys: val})
for keys, val in target_dict.items(): for keys, val in target_dict.items():
if keys not in output_dict.keys():
param_val_dict.update({keys: val})
else:
raise RuntimeError("conflict Error in output dict and target dict with name {}".format(keys))
param_val_dict.update({keys: val})


for keys in args:
if param_map[keys] not in param_val_dict.keys():
raise RuntimeError(f"missing param {keys} in function {get_func_signature(self.get_loss)}")
if not self._checked:
for keys in args:
if param_map[keys] not in param_val_dict.keys():
missing.append(keys)

if len(duplicated) > 0 or len(missing) > 0:
raise CheckError(
CheckRes(missing=missing, unused=[], duplicated=duplicated, required=[], all_needed=[]),
func_signature=get_func_signature(self.get_loss)
)

self._checked = True


param_map_val = _map_args(reversed_param_map, **param_val_dict) param_map_val = _map_args(reversed_param_map, **param_val_dict)
param_value = _build_args(**param_map_val)
param_value = _build_args(self.get_loss, **param_map_val)


loss = self.get_loss(**param_value) loss = self.get_loss(**param_value)


if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise RuntimeError("loss ERROR: loss except a torch.Tensor but get {}".format(type(loss)))
raise RuntimeError("loss ERROR: len(loss.size()) except 0 but got {}".format(len(loss.size())))
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}")


return loss return loss




class NewLoss(LossBase): class NewLoss(LossBase):
def __init__(self, func, key_map=None, **kwargs): def __init__(self, func, key_map=None, **kwargs):
super(NewLoss).__init__()
if not callable(func):
raise RuntimeError("")
super(NewLoss, self).__init__()
_check_function_or_method(func)
if key_map is not None:
if not isinstance(key_map, dict):
raise RuntimeError(f"Loss error: key_map except a {type({})} but got a {type(key_map)}")
self.param_map = key_map
if len(kwargs) > 0:
for key, val in kwargs.items():
self.param_map.update({key: val})

self.get_loss = func


class L1Loss(LossBase):
def __init__(self):
super(L1Loss, self).__init__()
self.get_loss = F.l1_loss


class BCELoss(LossBase):
def __init__(self):
super(BCELoss, self).__init__()
self.get_loss = F.binary_cross_entropy


class NLLLoss(LossBase):
def __init__(self):
super(NLLLoss, self).__init__()
self.get_loss = F.nll_loss




class LossInForward(LossBase): class LossInForward(LossBase):


+ 18
- 51
fastNLP/core/optimizer.py View File

@@ -2,61 +2,28 @@ import torch




class Optimizer(object): class Optimizer(object):
"""Wrapper of optimizer from framework
def __init__(self, model_params, **kwargs):
if model_params is not None and not isinstance(model_params, torch.Tensor):
raise RuntimeError("model parameters should be torch.Tensor, rather than {}".format(type(model_params)))
self.model_params = model_params
self.settings = kwargs


1. Adam: lr (float), weight_decay (float)
2. AdaGrad
3. RMSProp
4. SGD: lr (float), momentum (float)


"""
class SGD(Optimizer):
def __init__(self, model_params=None, lr=0.001, momentum=0.9):
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum)


def __init__(self, optimizer_name, **kwargs):
"""
:param optimizer_name: str, the name of the optimizer
:param kwargs: the arguments

"""
self.optim_name = optimizer_name
self.kwargs = kwargs

@property
def name(self):
"""The name of the optimizer.

:return: str
"""
return self.optim_name
def construct_from_pytorch(self, model_params):
if self.model_params is None:
self.model_params = model_params
return torch.optim.SGD(self.model_params, **self.settings)


@property
def params(self):
"""The arguments used to create the optimizer.


:return: dict of (str, *)
"""
return self.kwargs
class Adam(Optimizer):
def __init__(self, model_params=None, lr=0.001, weight_decay=0.8):
super(Adam, self).__init__(model_params, lr=lr, weight_decay=weight_decay)


def construct_from_pytorch(self, model_params): def construct_from_pytorch(self, model_params):
"""Construct a optimizer from framework over given model parameters."""

if self.optim_name in ["SGD", "sgd"]:
if "lr" in self.kwargs:
if "momentum" not in self.kwargs:
self.kwargs["momentum"] = 0
optimizer = torch.optim.SGD(model_params, lr=self.kwargs["lr"], momentum=self.kwargs["momentum"])
else:
raise ValueError("requires learning rate for SGD optimizer")

elif self.optim_name in ["adam", "Adam"]:
if "lr" in self.kwargs:
if "weight_decay" not in self.kwargs:
self.kwargs["weight_decay"] = 0
optimizer = torch.optim.Adam(model_params, lr=self.kwargs["lr"],
weight_decay=self.kwargs["weight_decay"])
else:
raise ValueError("requires learning rate for Adam optimizer")

else:
raise NotImplementedError

return optimizer
if self.model_params is None:
self.model_params = model_params
return torch.optim.Adam(self.model_params, **self.settings)

+ 83
- 45
fastNLP/core/trainer.py View File

@@ -1,20 +1,22 @@
import itertools
import os import os
import time import time
import warnings import warnings
from collections import defaultdict
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta


import torch import torch
from torch import nn
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from torch import nn


from fastNLP.core.batch import Batch from fastNLP.core.batch import Batch
from fastNLP.core.optimizer import Optimizer
from fastNLP.core.dataset import DataSet
from fastNLP.core.losses import _prepare_losser
from fastNLP.core.metrics import _prepare_metrics
from fastNLP.core.optimizer import Adam
from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import RandomSampler
from fastNLP.core.sampler import SequentialSampler from fastNLP.core.sampler import SequentialSampler
from fastNLP.core.tester import Tester from fastNLP.core.tester import Tester
from fastNLP.core.utils import CheckError
from fastNLP.core.utils import _build_args from fastNLP.core.utils import _build_args
from fastNLP.core.utils import _check_arg_dict_list from fastNLP.core.utils import _check_arg_dict_list
from fastNLP.core.utils import _move_dict_value_to_device from fastNLP.core.utils import _move_dict_value_to_device
@@ -30,9 +32,12 @@ class Trainer(object):
"""Main Training Loop """Main Training Loop


""" """
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1,

def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1,
validate_every=-1,
dev_data=None, use_cuda=False, save_path="./save", dev_data=None, use_cuda=False, save_path="./save",
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), check_code_level=0,
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0,
metric_key=None,
**kwargs): **kwargs):
super(Trainer, self).__init__() super(Trainer, self).__init__()


@@ -49,6 +54,13 @@ class Trainer(object):


# prepare evaluate # prepare evaluate
metrics = _prepare_metrics(metrics) metrics = _prepare_metrics(metrics)

# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
self.increase_better = False if metric_key[0] == "-" else True
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key

# prepare loss # prepare loss
losser = _prepare_losser(losser) losser = _prepare_losser(losser)


@@ -67,12 +79,10 @@ class Trainer(object):
self.save_path = save_path self.save_path = save_path
self.print_every = int(print_every) self.print_every = int(print_every)
self.validate_every = int(validate_every) self.validate_every = int(validate_every)
self._best_accuracy = 0
self.best_metric_indicator = None


self._model_device = model.parameters().__next__().device self._model_device = model.parameters().__next__().device


# TODO self._best_accuracy不能表现出当前的metric多种的情况

if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
else: else:
@@ -102,7 +112,7 @@ class Trainer(object):
if torch.cuda.is_available() and self.use_cuda: if torch.cuda.is_available() and self.use_cuda:
self.model = self.model.cuda() self.model = self.model.cuda()


self.mode(self.model, is_test=False)
self._mode(self.model, is_test=False)


start = time.time() start = time.time()
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
@@ -112,7 +122,9 @@ class Trainer(object):
def __getattr__(self, item): def __getattr__(self, item):
def pass_func(*args, **kwargs): def pass_func(*args, **kwargs):
pass pass

return pass_func return pass_func

self._summary_writer = psudoSW() self._summary_writer = psudoSW()
else: else:
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time))
@@ -121,19 +133,20 @@ class Trainer(object):
epoch = 1 epoch = 1
while epoch <= self.n_epochs: while epoch <= self.n_epochs:


data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False)
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(),
as_numpy=False)


self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start)
self._train_epoch(data_iterator, self.model, epoch, start)


# validate_every override validation at end of epochs # validate_every override validation at end of epochs
if self.dev_data and self.validate_every <= 0: if self.dev_data and self.validate_every <= 0:
self.do_validation()
self._do_validation()
epoch += 1 epoch += 1
finally: finally:
self._summary_writer.close() self._summary_writer.close()
del self._summary_writer del self._summary_writer


def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs):
def _train_epoch(self, data_iterator, model, epoch, start):
"""Training process in one epoch. """Training process in one epoch.


kwargs should contain: kwargs should contain:
@@ -144,10 +157,10 @@ class Trainer(object):
for batch_x, batch_y in data_iterator: for batch_x, batch_y in data_iterator:
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题
_move_dict_value_to_device(self._model_device, batch_x, batch_y) _move_dict_value_to_device(self._model_device, batch_x, batch_y)
prediction = self.data_forward(model, batch_x)
loss = self.get_loss(prediction, batch_y)
self.grad_backward(loss)
self.update()
prediction = self._data_forward(model, batch_x)
loss = self._compute_loss(prediction, batch_y)
self._grad_backward(loss)
self._update()
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step)
for name, param in self.model.named_parameters(): for name, param in self.model.named_parameters():
if param.requires_grad: if param.requires_grad:
@@ -162,18 +175,19 @@ class Trainer(object):
print(print_output) print(print_output)


if self.validate_every > 0 and self.step % self.validate_every == 0: if self.validate_every > 0 and self.step % self.validate_every == 0:
self.do_validation()
self._do_validation()


self.step += 1 self.step += 1


def do_validation(self):
def _do_validation(self):
res = self.tester.test() res = self.tester.test()
for name, num in res.items(): for name, num in res.items():
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step)
if self.save_path is not None and self.best_eval_result(res):
self.save_model(self.model, 'best_model_' + self.start_time)
if self.save_path is not None and self._better_eval_result(res):
self._save_model(self.model,
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]))


def mode(self, model, is_test=False):
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently. """Train mode or Test mode. This is for PyTorch currently.


:param model: a PyTorch model :param model: a PyTorch model
@@ -185,20 +199,20 @@ class Trainer(object):
else: else:
model.train() model.train()


def update(self):
def _update(self):
"""Perform weight update on a model. """Perform weight update on a model.


""" """
self.optimizer.step() self.optimizer.step()


def data_forward(self, network, x):
def _data_forward(self, network, x):
x = _build_args(network.forward, **x) x = _build_args(network.forward, **x)
y = network(**x) y = network(**x)
if not isinstance(y, dict): if not isinstance(y, dict):
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.")
return y return y


def grad_backward(self, loss):
def _grad_backward(self, loss):
"""Compute gradient with link rules. """Compute gradient with link rules.


:param loss: a scalar where back-prop starts :param loss: a scalar where back-prop starts
@@ -208,7 +222,7 @@ class Trainer(object):
self.model.zero_grad() self.model.zero_grad()
loss.backward() loss.backward()


def get_loss(self, predict, truth):
def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth. """Compute loss given prediction and ground truth.


:param predict: prediction dict, produced by model.forward :param predict: prediction dict, produced by model.forward
@@ -217,34 +231,59 @@ class Trainer(object):
""" """
return self.losser(predict, truth) return self.losser(predict, truth)


def save_model(self, model, model_name, only_param=False):
def _save_model(self, model, model_name, only_param=False):
model_name = os.path.join(self.save_path, model_name) model_name = os.path.join(self.save_path, model_name)
if only_param: if only_param:
torch.save(model.state_dict(), model_name) torch.save(model.state_dict(), model_name)
else: else:
torch.save(model, model_name) torch.save(model, model_name)


def best_eval_result(self, metrics):
def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results. """Check if the current epoch yields better validation results.


:return: bool, True means current results on dev set is the best.
:return bool value: True means current results on dev set is the best.
""" """
if isinstance(metrics, tuple): if isinstance(metrics, tuple):
loss, metrics = metrics loss, metrics = metrics


if isinstance(metrics, dict): if isinstance(metrics, dict):
if len(metrics) == 1: if len(metrics) == 1:
accuracy = list(metrics.values())[0]
# only single metric, just use it
metric_dict = list(metrics.values())[0]
metrics_name = list(metrics.keys())[0]
else: else:
accuracy = metrics[self.eval_sort_key]
else:
accuracy = metrics

if accuracy > self._best_accuracy:
self._best_accuracy = accuracy
return True
else:
return False
metrics_name = self.metrics[0].__class__.__name__
if metrics_name not in metrics:
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}")
metric_dict = metrics[metrics_name]

if len(metric_dict) == 1:
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0]
elif len(metric_dict) > 1 and self.metric_key is None:
raise RuntimeError(
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?")
else:
# metric_key is set
if self.metric_key not in metric_dict:
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}")
indicator_val = metric_dict[self.metric_key]

is_better = True
if self.best_metric_indicator is None:
# first-time validation
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
if indicator_val > self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
if indicator_val < self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better




DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_BATCH_SIZE = 2
@@ -285,18 +324,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_
f"should be torch.size([])") f"should be torch.size([])")
loss.backward() loss.backward()
except CheckError as e: except CheckError as e:
_check_loss_evaluate(prev_func=model.forward, func=e.func_signature,
pre_func_signature = get_func_signature(model.forward)
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature,
check_res=e.check_res, output=output, batch_y=batch_y, check_res=e.check_res, output=output, batch_y=batch_y,
check_level=check_level) check_level=check_level)
model.zero_grad() model.zero_grad()
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH:
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH:
break break


if dev_data is not None: if dev_data is not None:
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics,
batch_size=batch_size, verbose=-1) batch_size=batch_size, verbose=-1)
tester.test() tester.test()







+ 68
- 6
test/core/test_loss.py View File

@@ -2,6 +2,7 @@ import math
import unittest import unittest


import torch as tc import torch as tc
import torch.nn.functional as F


import fastNLP.core.losses as loss import fastNLP.core.losses as loss


@@ -13,7 +14,11 @@ class TestLoss(unittest.TestCase):


print (".----------------------------------") print (".----------------------------------")


loss_func = loss.Loss("nll")
# loss_func = loss.Loss("nll")
print(callable(tc.nn.NLLLoss))
loss_func = loss.NewLoss(F.nll_loss)

nll_loss = loss.NLLLoss()


#pdb.set_trace() #pdb.set_trace()


@@ -35,16 +40,18 @@ class TestLoss(unittest.TestCase):




y = tc.log(y) y = tc.log(y)
los = loss_func(y , gy)
los = loss_func({'input': y}, {'target': gy})
losses = nll_loss({'input': y}, {'target': gy})


r = -math.log(.3) - math.log(.3) - math.log(.1) r = -math.log(.3) - math.log(.3) - math.log(.1)
r /= 3 r /= 3
print ("loss = %f" % (los)) print ("loss = %f" % (los))
print ("r = %f" % (r)) print ("r = %f" % (r))
print ("nll_loss = %f" % (losses))


self.assertEqual(int(los * 1000), int(r * 1000)) self.assertEqual(int(los * 1000), int(r * 1000))


def test_case_2(self):
def _test_case_2(self):
#验证squash()的正确性 #验证squash()的正确性
print ("----------------------------------") print ("----------------------------------")


@@ -74,7 +81,8 @@ class TestLoss(unittest.TestCase):
#pdb.set_trace() #pdb.set_trace()


y = tc.log(y) y = tc.log(y)
los = loss_func(y , gy)
#los = loss_func({'input': y}, {'target': gy})
los = loss_func(y, gy)
print ("loss = %f" % (los)) print ("loss = %f" % (los))


r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1) r = -log(.3) - log(.3) - log(.1) - log(.3) - log(.7) - log(.1)
@@ -89,7 +97,8 @@ class TestLoss(unittest.TestCase):


log = math.log log = math.log


loss_func = loss.Loss("nll")
#loss_func = loss.Loss("nll")
loss_func = loss.NLLLoss()


#pdb.set_trace() #pdb.set_trace()


@@ -117,7 +126,7 @@ class TestLoss(unittest.TestCase):


yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data yy = tc.nn.utils.rnn.pack_padded_sequence(y , lens , batch_first = True).data
gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data gyy = tc.nn.utils.rnn.pack_padded_sequence(gy , lens , batch_first = True).data
los = loss_func(yy , gyy)
los = loss_func({'input': yy}, {'target': gyy})
print ("loss = %f" % (los)) print ("loss = %f" % (los))




@@ -303,5 +312,58 @@ class TestLoss(unittest.TestCase):
print ("r = %f" % (r)) print ("r = %f" % (r))
self.assertEqual(int(los * 1000), int(r * 1000)) self.assertEqual(int(los * 1000), int(r * 1000))


def test_case_8(self):
def func(a, b):
import torch.nn.functional as F
return F.cross_entropy(a, b)

def func2(a, truth):
return func(a, truth)

def func3(predict, truth):
return func(predict, truth)

def func4(a, b, c=2):
return (a + b) * c

def func6(a, b, **kwargs):
c = kwargs['c']
return (a + b) * c

import torch
from fastNLP.core.losses import LossBase, NewLoss

get_loss = NewLoss(func, {'a': 'predict', 'b': 'truth'})
predict = torch.randn(5, 3)
truth = torch.LongTensor([1, 0, 1, 2, 1])
loss1 = get_loss({'predict': predict}, {'truth': truth})
get_loss_2 = NewLoss(func2, {'a': 'predict'})
loss2 = get_loss_2({'predict': predict}, {'truth': truth})
get_loss_3 = NewLoss(func3)
loss3 = get_loss_3({'predict': predict}, {'truth': truth})
print(loss1, loss2, loss3)
assert loss1 == loss2 and loss1 == loss3

get_loss_4 = NewLoss(func4)
loss4 = get_loss_4({'a': 1, 'b': 3}, {})
print(loss4)
assert loss4 == (1 + 3) * 2

get_loss_5 = NewLoss(func4)
loss5 = get_loss_5({'a': 1, 'b': 3}, {'c': 4})
print(loss5)
assert loss5 == (1 + 3) * 4

get_loss_6 = NewLoss(func6)
loss6 = get_loss_6({'a': 1, 'b': 3}, {'c': 4})
print(loss6)
assert loss6 == (1 + 3) * 4

get_loss_7 = NewLoss(func6, c='cc')
loss7 = get_loss_7({'a': 1, 'b': 3}, {'cc': 4})
print(loss7)
assert loss7 == (1 + 3) * 4


if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

+ 21
- 0
test/core/test_optimizer.py View File

@@ -0,0 +1,21 @@
import unittest

import torch

from fastNLP.core.optimizer import SGD


class TestOptim(unittest.TestCase):
def test_case(self):
optim = SGD(torch.LongTensor(10))
print(optim.__dict__)

optim_2 = SGD(lr=0.001)
print(optim_2.__dict__)

optim_2 = SGD(lr=0.002, momentum=0.989)
print(optim_2.__dict__)

def test_case_2(self):
with self.assertRaises(RuntimeError):
_ = SGD(0.001)

+ 1
- 0
test/core/test_trainer.py View File

@@ -4,3 +4,4 @@ import unittest
class TestTrainer(unittest.TestCase): class TestTrainer(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
pass pass


Loading…
Cancel
Save