|
@@ -1,39 +1,38 @@ |
|
|
import itertools |
|
|
|
|
|
import os |
|
|
import os |
|
|
import time |
|
|
import time |
|
|
import warnings |
|
|
import warnings |
|
|
from collections import defaultdict |
|
|
|
|
|
from datetime import datetime |
|
|
from datetime import datetime |
|
|
from datetime import timedelta |
|
|
from datetime import timedelta |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
from tensorboardX import SummaryWriter |
|
|
from tensorboardX import SummaryWriter |
|
|
|
|
|
from torch import nn |
|
|
|
|
|
|
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.batch import Batch |
|
|
from fastNLP.core.optimizer import Optimizer |
|
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
from fastNLP.core.losses import _prepare_losser |
|
|
|
|
|
from fastNLP.core.metrics import _prepare_metrics |
|
|
|
|
|
from fastNLP.core.optimizer import Adam |
|
|
from fastNLP.core.sampler import RandomSampler |
|
|
from fastNLP.core.sampler import RandomSampler |
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
from fastNLP.core.sampler import SequentialSampler |
|
|
from fastNLP.core.tester import Tester |
|
|
from fastNLP.core.tester import Tester |
|
|
|
|
|
from fastNLP.core.utils import CheckError |
|
|
from fastNLP.core.utils import _build_args |
|
|
from fastNLP.core.utils import _build_args |
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
from fastNLP.core.utils import _check_arg_dict_list |
|
|
from fastNLP.core.utils import _move_dict_value_to_device |
|
|
from fastNLP.core.utils import _move_dict_value_to_device |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
from fastNLP.core.utils import get_func_signature |
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
|
|
|
|
|
|
from fastNLP.core.losses import LossBase |
|
|
|
|
|
from fastNLP.core.metrics import MetricBase |
|
|
|
|
|
from fastNLP.core.losses import _prepare_losser |
|
|
|
|
|
from fastNLP.core.metrics import _prepare_metrics |
|
|
|
|
|
from fastNLP.core.utils import CheckError |
|
|
|
|
|
|
|
|
|
|
|
class Trainer(object): |
|
|
class Trainer(object): |
|
|
"""Main Training Loop |
|
|
"""Main Training Loop |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, validate_every=-1, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=-1, |
|
|
|
|
|
validate_every=-1, |
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
dev_data=None, use_cuda=False, save_path="./save", |
|
|
optimizer=Optimizer("Adam", lr=0.01, weight_decay=0), need_check_code=True, |
|
|
|
|
|
|
|
|
optimizer=Adam(lr=0.01, weight_decay=0), need_check_code=True, |
|
|
|
|
|
metric_key=None, |
|
|
**kwargs): |
|
|
**kwargs): |
|
|
super(Trainer, self).__init__() |
|
|
super(Trainer, self).__init__() |
|
|
|
|
|
|
|
@@ -50,6 +49,13 @@ class Trainer(object): |
|
|
|
|
|
|
|
|
# prepare evaluate |
|
|
# prepare evaluate |
|
|
metrics = _prepare_metrics(metrics) |
|
|
metrics = _prepare_metrics(metrics) |
|
|
|
|
|
|
|
|
|
|
|
# parse metric_key |
|
|
|
|
|
# increase_better is True. It means the exp result gets better if the indicator increases. |
|
|
|
|
|
# It is true by default. |
|
|
|
|
|
self.increase_better = False if metric_key[0] == "-" else True |
|
|
|
|
|
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key |
|
|
|
|
|
|
|
|
# prepare loss |
|
|
# prepare loss |
|
|
losser = _prepare_losser(losser) |
|
|
losser = _prepare_losser(losser) |
|
|
|
|
|
|
|
@@ -67,12 +73,10 @@ class Trainer(object): |
|
|
self.save_path = save_path |
|
|
self.save_path = save_path |
|
|
self.print_every = int(print_every) |
|
|
self.print_every = int(print_every) |
|
|
self.validate_every = int(validate_every) |
|
|
self.validate_every = int(validate_every) |
|
|
self._best_accuracy = 0 |
|
|
|
|
|
|
|
|
self.best_metric_indicator = None |
|
|
|
|
|
|
|
|
self._model_device = model.parameters().__next__().device |
|
|
self._model_device = model.parameters().__next__().device |
|
|
|
|
|
|
|
|
# TODO self._best_accuracy不能表现出当前的metric多种的情况 |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
if isinstance(optimizer, torch.optim.Optimizer): |
|
|
self.optimizer = optimizer |
|
|
self.optimizer = optimizer |
|
|
else: |
|
|
else: |
|
@@ -102,7 +106,7 @@ class Trainer(object): |
|
|
if torch.cuda.is_available() and self.use_cuda: |
|
|
if torch.cuda.is_available() and self.use_cuda: |
|
|
self.model = self.model.cuda() |
|
|
self.model = self.model.cuda() |
|
|
|
|
|
|
|
|
self.mode(self.model, is_test=False) |
|
|
|
|
|
|
|
|
self._mode(self.model, is_test=False) |
|
|
|
|
|
|
|
|
start = time.time() |
|
|
start = time.time() |
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) |
|
|
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) |
|
@@ -112,7 +116,9 @@ class Trainer(object): |
|
|
def __getattr__(self, item): |
|
|
def __getattr__(self, item): |
|
|
def pass_func(*args, **kwargs): |
|
|
def pass_func(*args, **kwargs): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
return pass_func |
|
|
return pass_func |
|
|
|
|
|
|
|
|
self._summary_writer = psudoSW() |
|
|
self._summary_writer = psudoSW() |
|
|
else: |
|
|
else: |
|
|
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) |
|
|
path = os.path.join(self.save_path, 'tensorboard_logs_{}'.format(self.start_time)) |
|
@@ -121,19 +127,20 @@ class Trainer(object): |
|
|
epoch = 1 |
|
|
epoch = 1 |
|
|
while epoch <= self.n_epochs: |
|
|
while epoch <= self.n_epochs: |
|
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), as_numpy=False) |
|
|
|
|
|
|
|
|
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=RandomSampler(), |
|
|
|
|
|
as_numpy=False) |
|
|
|
|
|
|
|
|
self._train_epoch(data_iterator, self.model, epoch, self.dev_data, start) |
|
|
|
|
|
|
|
|
self._train_epoch(data_iterator, self.model, epoch, start) |
|
|
|
|
|
|
|
|
# validate_every override validation at end of epochs |
|
|
# validate_every override validation at end of epochs |
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
if self.dev_data and self.validate_every <= 0: |
|
|
self.do_validation() |
|
|
|
|
|
|
|
|
self._do_validation() |
|
|
epoch += 1 |
|
|
epoch += 1 |
|
|
finally: |
|
|
finally: |
|
|
self._summary_writer.close() |
|
|
self._summary_writer.close() |
|
|
del self._summary_writer |
|
|
del self._summary_writer |
|
|
|
|
|
|
|
|
def _train_epoch(self, data_iterator, model, epoch, dev_data, start, **kwargs): |
|
|
|
|
|
|
|
|
def _train_epoch(self, data_iterator, model, epoch, start): |
|
|
"""Training process in one epoch. |
|
|
"""Training process in one epoch. |
|
|
|
|
|
|
|
|
kwargs should contain: |
|
|
kwargs should contain: |
|
@@ -144,10 +151,10 @@ class Trainer(object): |
|
|
for batch_x, batch_y in data_iterator: |
|
|
for batch_x, batch_y in data_iterator: |
|
|
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 |
|
|
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 |
|
|
_move_dict_value_to_device(self._model_device, batch_x, batch_y) |
|
|
_move_dict_value_to_device(self._model_device, batch_x, batch_y) |
|
|
prediction = self.data_forward(model, batch_x) |
|
|
|
|
|
loss = self.get_loss(prediction, batch_y) |
|
|
|
|
|
self.grad_backward(loss) |
|
|
|
|
|
self.update() |
|
|
|
|
|
|
|
|
prediction = self._data_forward(model, batch_x) |
|
|
|
|
|
loss = self._compute_loss(prediction, batch_y) |
|
|
|
|
|
self._grad_backward(loss) |
|
|
|
|
|
self._update() |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
self._summary_writer.add_scalar("loss", loss.item(), global_step=self.step) |
|
|
for name, param in self.model.named_parameters(): |
|
|
for name, param in self.model.named_parameters(): |
|
|
if param.requires_grad: |
|
|
if param.requires_grad: |
|
@@ -162,18 +169,19 @@ class Trainer(object): |
|
|
print(print_output) |
|
|
print(print_output) |
|
|
|
|
|
|
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0: |
|
|
if self.validate_every > 0 and self.step % self.validate_every == 0: |
|
|
self.do_validation() |
|
|
|
|
|
|
|
|
self._do_validation() |
|
|
|
|
|
|
|
|
self.step += 1 |
|
|
self.step += 1 |
|
|
|
|
|
|
|
|
def do_validation(self): |
|
|
|
|
|
|
|
|
def _do_validation(self): |
|
|
res = self.tester.test() |
|
|
res = self.tester.test() |
|
|
for name, num in res.items(): |
|
|
for name, num in res.items(): |
|
|
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) |
|
|
self._summary_writer.add_scalar("valid_{}".format(name), num, global_step=self.step) |
|
|
if self.save_path is not None and self.best_eval_result(res): |
|
|
|
|
|
self.save_model(self.model, 'best_model_' + self.start_time) |
|
|
|
|
|
|
|
|
if self.save_path is not None and self._better_eval_result(res): |
|
|
|
|
|
self._save_model(self.model, |
|
|
|
|
|
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) |
|
|
|
|
|
|
|
|
def mode(self, model, is_test=False): |
|
|
|
|
|
|
|
|
def _mode(self, model, is_test=False): |
|
|
"""Train mode or Test mode. This is for PyTorch currently. |
|
|
"""Train mode or Test mode. This is for PyTorch currently. |
|
|
|
|
|
|
|
|
:param model: a PyTorch model |
|
|
:param model: a PyTorch model |
|
@@ -185,20 +193,20 @@ class Trainer(object): |
|
|
else: |
|
|
else: |
|
|
model.train() |
|
|
model.train() |
|
|
|
|
|
|
|
|
def update(self): |
|
|
|
|
|
|
|
|
def _update(self): |
|
|
"""Perform weight update on a model. |
|
|
"""Perform weight update on a model. |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
self.optimizer.step() |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
def data_forward(self, network, x): |
|
|
|
|
|
|
|
|
def _data_forward(self, network, x): |
|
|
x = _build_args(network.forward, **x) |
|
|
x = _build_args(network.forward, **x) |
|
|
y = network(**x) |
|
|
y = network(**x) |
|
|
if not isinstance(y, dict): |
|
|
if not isinstance(y, dict): |
|
|
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") |
|
|
raise TypeError(f"The return value of {get_func_signature(network.forward)} should be dict, got {type(y)}.") |
|
|
return y |
|
|
return y |
|
|
|
|
|
|
|
|
def grad_backward(self, loss): |
|
|
|
|
|
|
|
|
def _grad_backward(self, loss): |
|
|
"""Compute gradient with link rules. |
|
|
"""Compute gradient with link rules. |
|
|
|
|
|
|
|
|
:param loss: a scalar where back-prop starts |
|
|
:param loss: a scalar where back-prop starts |
|
@@ -208,7 +216,7 @@ class Trainer(object): |
|
|
self.model.zero_grad() |
|
|
self.model.zero_grad() |
|
|
loss.backward() |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
def get_loss(self, predict, truth): |
|
|
|
|
|
|
|
|
def _compute_loss(self, predict, truth): |
|
|
"""Compute loss given prediction and ground truth. |
|
|
"""Compute loss given prediction and ground truth. |
|
|
|
|
|
|
|
|
:param predict: prediction dict, produced by model.forward |
|
|
:param predict: prediction dict, produced by model.forward |
|
@@ -217,34 +225,59 @@ class Trainer(object): |
|
|
""" |
|
|
""" |
|
|
return self.losser(predict, truth) |
|
|
return self.losser(predict, truth) |
|
|
|
|
|
|
|
|
def save_model(self, model, model_name, only_param=False): |
|
|
|
|
|
|
|
|
def _save_model(self, model, model_name, only_param=False): |
|
|
model_name = os.path.join(self.save_path, model_name) |
|
|
model_name = os.path.join(self.save_path, model_name) |
|
|
if only_param: |
|
|
if only_param: |
|
|
torch.save(model.state_dict(), model_name) |
|
|
torch.save(model.state_dict(), model_name) |
|
|
else: |
|
|
else: |
|
|
torch.save(model, model_name) |
|
|
torch.save(model, model_name) |
|
|
|
|
|
|
|
|
def best_eval_result(self, metrics): |
|
|
|
|
|
|
|
|
def _better_eval_result(self, metrics): |
|
|
"""Check if the current epoch yields better validation results. |
|
|
"""Check if the current epoch yields better validation results. |
|
|
|
|
|
|
|
|
:return: bool, True means current results on dev set is the best. |
|
|
|
|
|
|
|
|
:return bool value: True means current results on dev set is the best. |
|
|
""" |
|
|
""" |
|
|
if isinstance(metrics, tuple): |
|
|
if isinstance(metrics, tuple): |
|
|
loss, metrics = metrics |
|
|
loss, metrics = metrics |
|
|
|
|
|
|
|
|
if isinstance(metrics, dict): |
|
|
if isinstance(metrics, dict): |
|
|
if len(metrics) == 1: |
|
|
if len(metrics) == 1: |
|
|
accuracy = list(metrics.values())[0] |
|
|
|
|
|
|
|
|
# only single metric, just use it |
|
|
|
|
|
metric_dict = list(metrics.values())[0] |
|
|
|
|
|
metrics_name = list(metrics.keys())[0] |
|
|
else: |
|
|
else: |
|
|
accuracy = metrics[self.eval_sort_key] |
|
|
|
|
|
else: |
|
|
|
|
|
accuracy = metrics |
|
|
|
|
|
|
|
|
|
|
|
if accuracy > self._best_accuracy: |
|
|
|
|
|
self._best_accuracy = accuracy |
|
|
|
|
|
return True |
|
|
|
|
|
else: |
|
|
|
|
|
return False |
|
|
|
|
|
|
|
|
metrics_name = self.metrics[0].__class__.__name__ |
|
|
|
|
|
if metrics_name not in metrics: |
|
|
|
|
|
raise RuntimeError(f"{metrics_name} is chosen to do validation, but got {metrics}") |
|
|
|
|
|
metric_dict = metrics[metrics_name] |
|
|
|
|
|
|
|
|
|
|
|
if len(metric_dict) == 1: |
|
|
|
|
|
indicator_val, indicator = list(metric_dict.values())[0], list(metric_dict.keys())[0] |
|
|
|
|
|
elif len(metric_dict) > 1 and self.metric_key is None: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
f"Got multiple metric keys: {metric_dict}, but metric_key is not set. Which one to use?") |
|
|
|
|
|
else: |
|
|
|
|
|
# metric_key is set |
|
|
|
|
|
if self.metric_key not in metric_dict: |
|
|
|
|
|
raise RuntimeError(f"matric key {self.metric_key} not found in {metric_dict}") |
|
|
|
|
|
indicator_val = metric_dict[self.metric_key] |
|
|
|
|
|
|
|
|
|
|
|
is_better = True |
|
|
|
|
|
if self.best_metric_indicator is None: |
|
|
|
|
|
# first-time validation |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
else: |
|
|
|
|
|
if self.increase_better is True: |
|
|
|
|
|
if indicator_val > self.best_metric_indicator: |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
else: |
|
|
|
|
|
is_better = False |
|
|
|
|
|
else: |
|
|
|
|
|
if indicator_val < self.best_metric_indicator: |
|
|
|
|
|
self.best_metric_indicator = indicator_val |
|
|
|
|
|
else: |
|
|
|
|
|
is_better = False |
|
|
|
|
|
return is_better |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
|
DEFAULT_CHECK_BATCH_SIZE = 2 |
|
@@ -254,6 +287,7 @@ IGNORE_CHECK_LEVEL = 0 |
|
|
WARNING_CHECK_LEVEL = 1 |
|
|
WARNING_CHECK_LEVEL = 1 |
|
|
STRICT_CHECK_LEVEL = 2 |
|
|
STRICT_CHECK_LEVEL = 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, |
|
|
dev_data=None, |
|
|
dev_data=None, |
|
|
check_level=WARNING_CHECK_LEVEL): |
|
|
check_level=WARNING_CHECK_LEVEL): |
|
@@ -264,7 +298,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ |
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
|
for batch_count, (batch_x, batch_y) in enumerate(batch): |
|
|
_move_dict_value_to_device(model_devcie, batch_x, batch_y) |
|
|
_move_dict_value_to_device(model_devcie, batch_x, batch_y) |
|
|
# forward check |
|
|
# forward check |
|
|
if batch_count==0: |
|
|
|
|
|
|
|
|
if batch_count == 0: |
|
|
_check_forward_error(model_func=model.forward, check_level=check_level, |
|
|
_check_forward_error(model_func=model.forward, check_level=check_level, |
|
|
batch_x=batch_x) |
|
|
batch_x=batch_x) |
|
|
|
|
|
|
|
@@ -285,17 +319,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ |
|
|
if batch_count == 0: |
|
|
if batch_count == 0: |
|
|
if not isinstance(loss, torch.Tensor): |
|
|
if not isinstance(loss, torch.Tensor): |
|
|
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " |
|
|
raise TypeError(f"The return value of {get_func_signature(losser.__call__)} should be `torch.Tensor`, " |
|
|
f"but got `{type(loss)}`.") |
|
|
|
|
|
if len(loss.size())!=0: |
|
|
|
|
|
|
|
|
f"but got `{type(loss)}`.") |
|
|
|
|
|
if len(loss.size()) != 0: |
|
|
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " |
|
|
raise ValueError(f"The size of return value of {get_func_signature(losser.__call__)} is {loss.size()}, " |
|
|
f"should be torch.size([])") |
|
|
f"should be torch.size([])") |
|
|
loss.backward() |
|
|
loss.backward() |
|
|
model.zero_grad() |
|
|
model.zero_grad() |
|
|
if batch_count+1>=DEFAULT_CHECK_NUM_BATCH: |
|
|
|
|
|
|
|
|
if batch_count + 1 >= DEFAULT_CHECK_NUM_BATCH: |
|
|
break |
|
|
break |
|
|
|
|
|
|
|
|
if dev_data is not None: |
|
|
if dev_data is not None: |
|
|
tester = Tester(data=dataset[:batch_size*DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, |
|
|
|
|
|
|
|
|
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, |
|
|
batch_size=batch_size, verbose=-1) |
|
|
batch_size=batch_size, verbose=-1) |
|
|
tester.test() |
|
|
tester.test() |
|
|
|
|
|
|
|
@@ -305,18 +339,18 @@ def _check_forward_error(model_func, check_level, batch_x): |
|
|
_missing = '' |
|
|
_missing = '' |
|
|
_unused = '' |
|
|
_unused = '' |
|
|
func_signature = get_func_signature(model_func) |
|
|
func_signature = get_func_signature(model_func) |
|
|
if len(check_res['missing'])!=0: |
|
|
|
|
|
|
|
|
if len(check_res['missing']) != 0: |
|
|
_missing = "Function {} misses {}, only provided with {}, " \ |
|
|
_missing = "Function {} misses {}, only provided with {}, " \ |
|
|
".\n".format(func_signature, check_res.missing, |
|
|
".\n".format(func_signature, check_res.missing, |
|
|
list(batch_x.keys())) |
|
|
|
|
|
if len(check_res['unused'])!=0: |
|
|
|
|
|
|
|
|
list(batch_x.keys())) |
|
|
|
|
|
if len(check_res['unused']) != 0: |
|
|
if len(check_res.unused) > 1: |
|
|
if len(check_res.unused) > 1: |
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
else: |
|
|
else: |
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
if _missing: |
|
|
if _missing: |
|
|
if len(_unused)>0 and STRICT_CHECK_LEVEL: |
|
|
|
|
|
|
|
|
if len(_unused) > 0 and STRICT_CHECK_LEVEL: |
|
|
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) |
|
|
_error_str = "(1).{}\n(2).{}".format(_missing, _unused) |
|
|
else: |
|
|
else: |
|
|
_error_str = _missing |
|
|
_error_str = _missing |
|
@@ -329,38 +363,40 @@ def _check_forward_error(model_func, check_level, batch_x): |
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
elif check_level == WARNING_CHECK_LEVEL: |
|
|
warnings.warn(message=_unused) |
|
|
warnings.warn(message=_unused) |
|
|
|
|
|
|
|
|
def _check_loss_evaluate(prev_func, func, check_res, output, batch_y, check_level): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _check_loss_evaluate(prev_func, func, check_level, output, batch_y): |
|
|
|
|
|
check_res = _check_arg_dict_list(func, [output, batch_y]) |
|
|
_missing = '' |
|
|
_missing = '' |
|
|
_unused = '' |
|
|
_unused = '' |
|
|
_duplicated = '' |
|
|
_duplicated = '' |
|
|
func_signature = get_func_signature(func) |
|
|
func_signature = get_func_signature(func) |
|
|
prev_func_signature = get_func_signature(prev_func) |
|
|
prev_func_signature = get_func_signature(prev_func) |
|
|
if len(check_res.missing)>0: |
|
|
|
|
|
|
|
|
if len(check_res.missing) > 0: |
|
|
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ |
|
|
_missing = "function {} misses argument {}, \n\t only provided with {}(from {}) and " \ |
|
|
"{}(from target in Dataset)." \ |
|
|
"{}(from target in Dataset)." \ |
|
|
.format(func_signature, check_res.missing, |
|
|
|
|
|
list(output.keys()), prev_func_signature, |
|
|
|
|
|
list(batch_y.keys())) |
|
|
|
|
|
if len(check_res.unused)>0: |
|
|
|
|
|
|
|
|
.format(func_signature, check_res.missing, |
|
|
|
|
|
list(output.keys()), prev_func_signature, |
|
|
|
|
|
list(batch_y.keys())) |
|
|
|
|
|
if len(check_res.unused) > 0: |
|
|
if len(check_res.unused) > 1: |
|
|
if len(check_res.unused) > 1: |
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
_unused = "{} are not used ".format(check_res.unused) |
|
|
else: |
|
|
else: |
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
_unused = "{} is not used ".format(check_res.unused) |
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
_unused += "in function {}.\n".format(func_signature) |
|
|
if len(check_res.duplicated)>0: |
|
|
|
|
|
|
|
|
if len(check_res.duplicated) > 0: |
|
|
if len(check_res.duplicated) > 1: |
|
|
if len(check_res.duplicated) > 1: |
|
|
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ |
|
|
_duplicated = "duplicated keys {} are detected when calling function {}. \n\tDon't set {} as target and output " \ |
|
|
"them in {} at the same time.".format(check_res.duplicated, |
|
|
"them in {} at the same time.".format(check_res.duplicated, |
|
|
func_signature, |
|
|
|
|
|
check_res.duplicated, |
|
|
|
|
|
prev_func_signature) |
|
|
|
|
|
else: |
|
|
|
|
|
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ |
|
|
|
|
|
"it in {} at the same time.".format(check_res.duplicated, |
|
|
|
|
|
func_signature, |
|
|
func_signature, |
|
|
check_res.duplicated, |
|
|
check_res.duplicated, |
|
|
prev_func_signature) |
|
|
prev_func_signature) |
|
|
_number_errs = int(len(_missing)!=0) + int(len(_duplicated)!=0) + int(len(_unused)!=0) |
|
|
|
|
|
|
|
|
else: |
|
|
|
|
|
_duplicated = "duplicated key {} is detected when calling function {}. \n\tDon't set {} as target and output " \ |
|
|
|
|
|
"it in {} at the same time.".format(check_res.duplicated, |
|
|
|
|
|
func_signature, |
|
|
|
|
|
check_res.duplicated, |
|
|
|
|
|
prev_func_signature) |
|
|
|
|
|
_number_errs = int(len(_missing) != 0) + int(len(_duplicated) != 0) + int(len(_unused) != 0) |
|
|
if _number_errs > 0: |
|
|
if _number_errs > 0: |
|
|
_error_strs = [] |
|
|
_error_strs = [] |
|
|
if _number_errs > 1: |
|
|
if _number_errs > 1: |
|
|