@@ -1,4 +1,5 @@ | |||||
import torch | import torch | ||||
import numpy as np | |||||
class Batch(object): | class Batch(object): | ||||
@@ -45,7 +46,7 @@ class Batch(object): | |||||
if field.is_target or field.is_input: | if field.is_target or field.is_input: | ||||
batch = field.get(indices) | batch = field.get(indices) | ||||
if not self.as_numpy: | if not self.as_numpy: | ||||
batch = torch.from_numpy(batch) | |||||
batch = to_tensor(batch, field.dtype) | |||||
if field.is_target: | if field.is_target: | ||||
batch_y[field_name] = batch | batch_y[field_name] = batch | ||||
if field.is_input: | if field.is_input: | ||||
@@ -54,3 +55,10 @@ class Batch(object): | |||||
self.curidx = endidx | self.curidx = endidx | ||||
return batch_x, batch_y | return batch_x, batch_y | ||||
def to_tensor(batch, dtype): | |||||
if dtype in (np.int8, np.int16, np.int32, np.int64): | |||||
batch = torch.LongTensor(batch) | |||||
if dtype in (np.float32, np.float64): | |||||
batch = torch.FloatTensor(batch) | |||||
return batch |
@@ -39,7 +39,7 @@ class FieldArray(object): | |||||
@staticmethod | @staticmethod | ||||
def _map_to_np_type(basic_type): | def _map_to_np_type(basic_type): | ||||
type_mapping = {int: np.int64, float: np.double, str: np.str} | |||||
type_mapping = {int: np.int64, float: np.float64, str: np.str} | |||||
return type_mapping[basic_type] | return type_mapping[basic_type] | ||||
def __repr__(self): | def __repr__(self): | ||||
@@ -126,15 +126,30 @@ class NLLLoss(LossBase): | |||||
class LossInForward(LossBase): | class LossInForward(LossBase): | ||||
def __init__(self, loss_key='loss'): | def __init__(self, loss_key='loss'): | ||||
super().__init__() | super().__init__() | ||||
if not isinstance(loss_key, str): | |||||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | |||||
self.loss_key = loss_key | self.loss_key = loss_key | ||||
def get_loss(self, **kwargs): | def get_loss(self, **kwargs): | ||||
if self.loss_key not in kwargs: | if self.loss_key not in kwargs: | ||||
pass | |||||
check_res = CheckRes(missing=[self.loss_key], | |||||
unused=[], | |||||
duplicated=[], | |||||
required=[], | |||||
all_needed=[], | |||||
varargs=[]) | |||||
raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) | |||||
def __call__(self, output_dict, predict_dict): | |||||
def __call__(self, output_dict, predict_dict, force_check=False): | |||||
return self.get_loss(**output_dict) | |||||
loss = self.get_loss(**output_dict) | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||||
if not isinstance(loss, torch.Tensor): | |||||
raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}") | |||||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||||
return loss | |||||
def _prepare_losser(losser): | def _prepare_losser(losser): | ||||
@@ -10,7 +10,7 @@ from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.utils import _check_arg_dict_list | from fastNLP.core.utils import _check_arg_dict_list | ||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
from fastNLP.core.utils import _check_function_or_method | |||||
class MetricBase(object): | class MetricBase(object): | ||||
def __init__(self): | def __init__(self): | ||||
@@ -20,19 +20,32 @@ class MetricBase(object): | |||||
def evaluate(self, *args, **kwargs): | def evaluate(self, *args, **kwargs): | ||||
raise NotImplementedError | raise NotImplementedError | ||||
def _init_param_map(self, key_map, **kwargs): | |||||
self.param_map = {} | |||||
value_counter = defaultdict(0) | |||||
for key, value in key_map.items(): | |||||
if isinstance(key, str): | |||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||||
if isinstance(value, str): | |||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
def _init_param_map(self, key_map=None, **kwargs): | |||||
value_counter = defaultdict(set) | |||||
if key_map is not None: | |||||
if not isinstance(key_map, dict): | |||||
raise TypeError("key_map must be `dict`, got {}.".format(type(key_map))) | |||||
for key, value in key_map.items(): | |||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if isinstance(key, str): | |||||
raise TypeError(f"key in key_map must be `str`, not `{type(key)}`.") | |||||
if isinstance(value, str): | |||||
raise TypeError(f"value in key_map must be `str`, not `{type(value)}`.") | |||||
self.param_map[key] = value | |||||
value_counter[value].add(key) | |||||
for key, value in kwargs.items(): | for key, value in kwargs.items(): | ||||
if value is None: | |||||
self.param_map[key] = key | |||||
continue | |||||
if isinstance(value, str): | if isinstance(value, str): | ||||
raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | raise TypeError(f"in {key}={value}, value must be `str`, not `{type(value)}`.") | ||||
self.param_map[key] = value | self.param_map[key] = value | ||||
value_counter[value].add(key) | |||||
for value, key_set in value_counter.items(): | |||||
if len(key_set)>1: | |||||
raise ValueError(f"Several params:{key_set} are provided with one output {value}.") | |||||
def __call__(self, output_dict, target_dict, check=False): | def __call__(self, output_dict, target_dict, check=False): | ||||
""" | """ | ||||
@@ -45,8 +58,6 @@ class MetricBase(object): | |||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
if not self._checked: | if not self._checked: | ||||
# 0. check param_map does not have same value | |||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
func_spect = inspect.getfullargspec(self.evaluate) | func_spect = inspect.getfullargspec(self.evaluate) | ||||
func_args = func_spect.args | func_args = func_spect.args | ||||
@@ -58,26 +69,32 @@ class MetricBase(object): | |||||
if arg not in self.param_map: | if arg not in self.param_map: | ||||
self.param_map[arg] = arg #This param does not need mapping. | self.param_map[arg] = arg #This param does not need mapping. | ||||
self._evaluate_args = func_args | self._evaluate_args = func_args | ||||
self._reverse_param_map = {value: key for key, value in self.param_map.items()} | |||||
# need to wrap inputs in dict. | # need to wrap inputs in dict. | ||||
mapped_output_dict = {} | mapped_output_dict = {} | ||||
mapped_target_dict = {} | mapped_target_dict = {} | ||||
for func_arg in self._evaluate_args: | for func_arg in self._evaluate_args: | ||||
input_arg = self.param_map[func_arg] | input_arg = self.param_map[func_arg] | ||||
if input_arg in self._reverse_param_map: | |||||
mapped_arg = func_arg | |||||
else: | |||||
mapped_arg = input_arg | |||||
if input_arg in output_dict: | if input_arg in output_dict: | ||||
mapped_output_dict[func_arg] = output_dict[input_arg] | |||||
mapped_output_dict[mapped_arg] = output_dict[input_arg] | |||||
if input_arg in target_dict: | if input_arg in target_dict: | ||||
mapped_target_dict[func_arg] = target_dict[input_arg] | |||||
mapped_target_dict[mapped_arg] = target_dict[input_arg] | |||||
# check duplicated, unused, missing | # check duplicated, unused, missing | ||||
if check or not self._checked: | if check or not self._checked: | ||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | check_res = _check_arg_dict_list(self.evaluate, [mapped_output_dict, mapped_output_dict]) | ||||
self._reverse_param_map = {value:key for key, value in check_res.items()} | |||||
for key, value in check_res.items(): | for key, value in check_res.items(): | ||||
new_value = list(value) | new_value = list(value) | ||||
for idx, func_param in enumerate(value): | for idx, func_param in enumerate(value): | ||||
if func_param in self._reverse_param_map: | if func_param in self._reverse_param_map: | ||||
new_value[idx] = self._reverse_param_map[func_param] | |||||
new_value[idx] = self._reverse_param_map[func_param] + f'(assign to {func_param})' | |||||
else: | |||||
new_value[idx] = func_param | |||||
if check_res.missing or check_res.duplicated or check_res.varargs: | if check_res.missing or check_res.duplicated or check_res.varargs: | ||||
raise CheckError(check_res=check_res, | raise CheckError(check_res=check_res, | ||||
func_signature=get_func_signature(self.evaluate)) | func_signature=get_func_signature(self.evaluate)) | ||||
@@ -93,11 +110,55 @@ class MetricBase(object): | |||||
return metrics | return metrics | ||||
class Metric(MetricBase): | |||||
class FuncMetric(MetricBase): | |||||
def __init__(self, func, key_map, **kwargs): | def __init__(self, func, key_map, **kwargs): | ||||
super().__init__() | super().__init__() | ||||
_check_function_or_method(func=func) | |||||
self._init_param_map(key_map=key_map, **kwargs) | |||||
self.evaluate = func | |||||
class AccuracyMetric(MetricBase): | |||||
def __init__(self, predictions=None, targets=None, masks=None, seq_lens=None): | |||||
super().__init__() | |||||
self._init_param_map(predictions=predictions, targets=targets, | |||||
masks=masks, seq_lens=seq_lens) | |||||
def evaluate(self, predictions, targets, masks=None, seq_lens=None): | |||||
""" | |||||
:param predictions: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | |||||
torch.Size([]), torch.Size([n_classes,]), torch.Size([max_len,]), torch.Size([max_len, n_classes]) | |||||
:param targets: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
torch.Size([]), torch.Size([]), torch.Size([max_len,]), torch.Size([max_len, ]) | |||||
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([max_len,], torch.Size([max_len, ]) | |||||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([1], torch.Size([1]) | |||||
:return: dict({'acc': float}) | |||||
""" | |||||
pass | pass | ||||
def _check_evaluate_param(self, predictions, targets, masks=None, seq_lens=None): | |||||
# check the validity of self.evaluate param | |||||
prediction = predictions[0] | |||||
target = targets[0] | |||||
if len(np.shape(prediction))==len(target): | |||||
pass | |||||
if masks is not None: | |||||
mask = masks[0] | |||||
if seq_lens is not None: | |||||
seq_len = seq_lens[0] | |||||
def _prepare_metrics(metrics): | def _prepare_metrics(metrics): | ||||
""" | """ | ||||
@@ -7,11 +7,11 @@ from torch import nn | |||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _build_args | from fastNLP.core.utils import _build_args | ||||
from fastNLP.core.utils import get_func_signature | from fastNLP.core.utils import get_func_signature | ||||
from fastNLP.core.utils import _move_dict_value_to_device | from fastNLP.core.utils import _move_dict_value_to_device | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _check_loss_evaluate | from fastNLP.core.utils import _check_loss_evaluate | ||||
class Tester(object): | class Tester(object): | ||||
@@ -57,7 +57,7 @@ class Tester(object): | |||||
with torch.no_grad(): | with torch.no_grad(): | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(self._predict_func, batch_x) | prediction = self._data_forward(self._predict_func, batch_x) | ||||
assert isinstance(prediction, dict) | assert isinstance(prediction, dict) | ||||
for k, v in prediction.items(): | for k, v in prediction.items(): | ||||
@@ -77,7 +77,7 @@ class Tester(object): | |||||
except CheckError as e: | except CheckError as e: | ||||
prev_func_signature = get_func_signature(self._predict_func) | prev_func_signature = get_func_signature(self._predict_func) | ||||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | _check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | ||||
check_res=e.check_res, output=output, batch_y=truths) | |||||
check_res=e.check_res, output=output, batch_y=truths, check_level=0) | |||||
if self.verbose >= 0: | if self.verbose >= 0: | ||||
@@ -1,6 +1,5 @@ | |||||
import os | import os | ||||
import time | import time | ||||
import warnings | |||||
from datetime import datetime | from datetime import datetime | ||||
from datetime import timedelta | from datetime import timedelta | ||||
@@ -9,24 +8,19 @@ from tensorboardX import SummaryWriter | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.core.batch import Batch | from fastNLP.core.batch import Batch | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.losses import _prepare_losser | |||||
from fastNLP.core.metrics import _prepare_metrics | |||||
from fastNLP.core.optimizer import Adam | from fastNLP.core.optimizer import Adam | ||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _check_arg_dict_list | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import get_func_signature | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.losses import _prepare_losser | from fastNLP.core.losses import _prepare_losser | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.utils import CheckError | from fastNLP.core.utils import CheckError | ||||
from fastNLP.core.utils import _check_loss_evaluate | from fastNLP.core.utils import _check_loss_evaluate | ||||
from fastNLP.core.utils import _check_forward_error | from fastNLP.core.utils import _check_forward_error | ||||
from fastNLP.core.utils import _build_args | |||||
from fastNLP.core.utils import _move_dict_value_to_device | |||||
from fastNLP.core.utils import get_func_signature | |||||
class Trainer(object): | class Trainer(object): | ||||
"""Main Training Loop | """Main Training Loop | ||||
@@ -52,6 +46,9 @@ class Trainer(object): | |||||
if metrics and (dev_data is None): | if metrics and (dev_data is None): | ||||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | ||||
# check save_path | |||||
if not (save_path is None or isinstance(save_path, str)): | |||||
raise ValueError("save_path can only be None or `str`.") | |||||
# prepare evaluate | # prepare evaluate | ||||
metrics = _prepare_metrics(metrics) | metrics = _prepare_metrics(metrics) | ||||
@@ -156,7 +153,7 @@ class Trainer(object): | |||||
""" | """ | ||||
for batch_x, batch_y in data_iterator: | for batch_x, batch_y in data_iterator: | ||||
# TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | # TODO 这里可能会遇到问题,万一用户在model内部修改了prediction的device就会有问题 | ||||
_move_dict_value_to_device(self._model_device, batch_x, batch_y) | |||||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||||
prediction = self._data_forward(model, batch_x) | prediction = self._data_forward(model, batch_x) | ||||
loss = self._compute_loss(prediction, batch_y) | loss = self._compute_loss(prediction, batch_y) | ||||
self._grad_backward(loss) | self._grad_backward(loss) | ||||
@@ -232,11 +229,12 @@ class Trainer(object): | |||||
return self.losser(predict, truth) | return self.losser(predict, truth) | ||||
def _save_model(self, model, model_name, only_param=False): | def _save_model(self, model, model_name, only_param=False): | ||||
model_name = os.path.join(self.save_path, model_name) | |||||
if only_param: | |||||
torch.save(model.state_dict(), model_name) | |||||
else: | |||||
torch.save(model, model_name) | |||||
if self.save_path is not None: | |||||
model_name = os.path.join(self.save_path, model_name) | |||||
if only_param: | |||||
torch.save(model.state_dict(), model_name) | |||||
else: | |||||
torch.save(model, model_name) | |||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
"""Check if the current epoch yields better validation results. | """Check if the current epoch yields better validation results. | ||||
@@ -297,7 +295,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_count, (batch_x, batch_y) in enumerate(batch): | for batch_count, (batch_x, batch_y) in enumerate(batch): | ||||
_move_dict_value_to_device(model_devcie, batch_x, batch_y) | |||||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||||
# forward check | # forward check | ||||
if batch_count==0: | if batch_count==0: | ||||
_check_forward_error(forward_func=model.forward, check_level=check_level, | _check_forward_error(forward_func=model.forward, check_level=check_level, | ||||
@@ -335,6 +333,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||||
if dev_data is not None: | if dev_data is not None: | ||||
tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | tester = Tester(data=dataset[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | ||||
batch_size=batch_size, verbose=-1) | batch_size=batch_size, verbose=-1) | ||||
tester.test() | |||||
evaluate_results = tester.test() | |||||
# TODO 这里需要检查是否返回来的值是否是合理的 | |||||
@@ -122,13 +122,13 @@ def _check_arg_dict_list(func, args): | |||||
input_args = set(input_arg_count.keys()) | input_args = set(input_arg_count.keys()) | ||||
missing = list(require_args - input_args) | missing = list(require_args - input_args) | ||||
unused = list(input_args - all_args) | unused = list(input_args - all_args) | ||||
varargs = [] if spect.varargs else [arg for arg in spect.varargs] | |||||
return CheckRes(missing=missing, | return CheckRes(missing=missing, | ||||
unused=unused, | unused=unused, | ||||
duplicated=duplicated, | duplicated=duplicated, | ||||
required=list(require_args), | required=list(require_args), | ||||
all_needed=list(all_args), | all_needed=list(all_args), | ||||
varargs=[arg for arg in spect.varargs]) | |||||
varargs=varargs) | |||||
def get_func_signature(func): | def get_func_signature(func): | ||||
""" | """ | ||||
@@ -165,6 +165,7 @@ def get_func_signature(func): | |||||
signature_str = func.__name__ + signature_str | signature_str = func.__name__ + signature_str | ||||
return signature_str | return signature_str | ||||
def _is_function_or_method(func): | def _is_function_or_method(func): | ||||
""" | """ | ||||
@@ -179,26 +180,8 @@ def _check_function_or_method(func): | |||||
if not _is_function_or_method(func): | if not _is_function_or_method(func): | ||||
raise TypeError(f"{type(func)} is not a method or function.") | raise TypeError(f"{type(func)} is not a method or function.") | ||||
def _syn_model_data(model, *args): | |||||
""" | |||||
move data to model's device, element in *args should be dict. This is a inplace change. | |||||
:param model: | |||||
:param args: | |||||
:return: | |||||
""" | |||||
if len(model.state_dict())==0: | |||||
raise ValueError("model has no parameter.") | |||||
device = model.parameters().__next__().device | |||||
for arg in args: | |||||
if isinstance(arg, dict): | |||||
for key, value in arg.items(): | |||||
if isinstance(value, torch.Tensor): | |||||
arg[key] = value.to(device) | |||||
else: | |||||
raise TypeError("Only support `dict` type right now.") | |||||
def _move_dict_value_to_device(device, *args): | |||||
def _move_dict_value_to_device(*args, device:torch.device): | |||||
""" | """ | ||||
move data to model's device, element in *args should be dict. This is a inplace change. | move data to model's device, element in *args should be dict. This is a inplace change. | ||||
@@ -240,6 +223,7 @@ class CheckError(Exception): | |||||
self.check_res = check_res | self.check_res = check_res | ||||
self.func_signature = func_signature | self.func_signature = func_signature | ||||
IGNORE_CHECK_LEVEL = 0 | IGNORE_CHECK_LEVEL = 0 | ||||
WARNING_CHECK_LEVEL = 1 | WARNING_CHECK_LEVEL = 1 | ||||
STRICT_CHECK_LEVEL = 2 | STRICT_CHECK_LEVEL = 2 | ||||
@@ -252,8 +236,8 @@ def _check_loss_evaluate(prev_func_signature:str, func_signature:str, check_res: | |||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " | errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, " | ||||
f"please delete it.)") | f"please delete it.)") | ||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(output.keys())}" | |||||
f"(from {prev_func_signature}) and {list(batch_y.keys())}(from targets in Dataset).") | |||||
errs.append(f"\tmissing param: `{check_res.missing}`, provided with `{list(output.keys())}`" | |||||
f"(from output of `{prev_func_signature}`) and `{list(batch_y.keys())}`(from targets in Dataset).") | |||||
if check_res.duplicated: | if check_res.duplicated: | ||||
errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " | errs.append(f"\tduplicated param: {check_res.duplicated}, delete {check_res.duplicated} in the output of " | ||||
f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") | f"{check_res.duplicated} or do not set {check_res.duplicated} as targets. ") | ||||
@@ -281,7 +265,7 @@ def _check_forward_error(forward_func, batch_x, check_level): | |||||
if check_res.varargs: | if check_res.varargs: | ||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | ||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}, only provided with {list(batch_x.keys())}.") | |||||
errs.append(f"\tmissing param: {check_res.missing}, provided with {list(batch_x.keys())}.") | |||||
if check_res.unused: | if check_res.unused: | ||||
_unused = [f"\tunused param: {check_res.unused}"] | _unused = [f"\tunused param: {check_res.unused}"] | ||||
if check_level == STRICT_CHECK_LEVEL: | if check_level == STRICT_CHECK_LEVEL: | ||||