2. trainer中update_every改为print_every, 因为update_every可能引起optimizer update的误解 3. fieldarray content支持使用np.ndarray初始化tags/v0.2.0^2
@@ -17,6 +17,12 @@ class FieldArray(object): | |||||
:param bool is_input: If True, this FieldArray is used to the model input. | :param bool is_input: If True, this FieldArray is used to the model input. | ||||
""" | """ | ||||
self.name = name | self.name = name | ||||
if isinstance(content, list): | |||||
content = content | |||||
elif isinstance(content, np.ndarray): | |||||
content = content.tolist() | |||||
else: | |||||
raise TypeError("content in FieldArray can only be list or numpy.ndarray, got {}.".format(type(content))) | |||||
self.content = content | self.content = content | ||||
self.padding_val = padding_val | self.padding_val = padding_val | ||||
self.is_target = is_target | self.is_target = is_target | ||||
@@ -61,8 +61,7 @@ class MetricBase(object): | |||||
if func_param not in func_args: | if func_param not in func_args: | ||||
raise NameError( | raise NameError( | ||||
f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | f"Parameter `{func_param}` is not in {get_func_signature(self.evaluate)}. Please check the " | ||||
f"initialization parameters, or change the signature of" | |||||
f" {get_func_signature(self.evaluate)}.") | |||||
f"initialization parameters, or change its signature.") | |||||
# evaluate should not have varargs. | # evaluate should not have varargs. | ||||
if func_spect.varargs: | if func_spect.varargs: | ||||
@@ -79,13 +78,14 @@ class MetricBase(object): | |||||
such as pred_dict has one element, target_dict has one element | such as pred_dict has one element, target_dict has one element | ||||
:param pred_dict: | :param pred_dict: | ||||
:param target_dict: | :param target_dict: | ||||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||||
:return: dict, if dict is not {}, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | """ | ||||
fast_param = {} | |||||
if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | if len(self.param_map) == 2 and len(pred_dict) == 1 and len(target_dict) == 1: | ||||
return pred_dict.values[0] and target_dict.values[0] | return pred_dict.values[0] and target_dict.values[0] | ||||
return None | |||||
return fast_param | |||||
def __call__(self, pred_dict, target_dict, check=False): | |||||
def __call__(self, pred_dict, target_dict): | |||||
""" | """ | ||||
This method will call self.evaluate method. | This method will call self.evaluate method. | ||||
@@ -96,20 +96,19 @@ class MetricBase(object): | |||||
(4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | (4) whether params in output_dict, target_dict are not used by evaluate.(Might cause warning) | ||||
Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | Besides, before passing params into self.evaluate, this function will filter out params from output_dict and | ||||
target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering | ||||
will be conducted) | |||||
will be conducted.) | |||||
This function also support _fast_param_map. | |||||
:param pred_dict: usually the output of forward or prediction function | :param pred_dict: usually the output of forward or prediction function | ||||
:param target_dict: usually features set as target.. | :param target_dict: usually features set as target.. | ||||
:param check: boolean, if check is True, it will force check `varargs, missing, unused, duplicated`. | |||||
:return: | :return: | ||||
""" | """ | ||||
if not callable(self.evaluate): | if not callable(self.evaluate): | ||||
raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | raise TypeError(f"{self.__class__.__name__}.evaluate has to be callable, not {type(self.evaluate)}.") | ||||
if not check: | |||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param is not None: | |||||
self.evaluate(*fast_param) | |||||
return | |||||
fast_param = self._fast_param_map(pred_dict=pred_dict, target_dict=target_dict) | |||||
if fast_param: | |||||
self.evaluate(**fast_param) | |||||
return | |||||
if not self._checked: | if not self._checked: | ||||
# 1. check consistence between signature and param_map | # 1. check consistence between signature and param_map | ||||
@@ -147,7 +146,7 @@ class MetricBase(object): | |||||
duplicated.append(input_arg) | duplicated.append(input_arg) | ||||
# missing | # missing | ||||
if check or not self._checked: | |||||
if not self._checked: | |||||
check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | check_res = _check_arg_dict_list(self.evaluate, [mapped_pred_dict, mapped_target_dict]) | ||||
# only check missing. | # only check missing. | ||||
missing = check_res.missing | missing = check_res.missing | ||||
@@ -175,40 +174,49 @@ class MetricBase(object): | |||||
class AccuracyMetric(MetricBase): | class AccuracyMetric(MetricBase): | ||||
def __init__(self, pred=None, target=None, masks=None, seq_lens=None): | |||||
def __init__(self, pred=None, target=None, seq_lens=None): | |||||
super().__init__() | super().__init__() | ||||
self._init_param_map(pred=pred, target=target, | |||||
masks=masks, seq_lens=seq_lens) | |||||
self._init_param_map(pred=pred, target=target, seq_lens=seq_lens) | |||||
self.total = 0 | self.total = 0 | ||||
self.acc_count = 0 | self.acc_count = 0 | ||||
def _fast_call_evaluate(self, pred_dict, target_dict): | |||||
def _fast_param_map(self, pred_dict, target_dict): | |||||
""" | """ | ||||
Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | Only used as inner function. When the pred_dict, target is unequivocal. Don't need users to pass key_map. | ||||
such as pred_dict has one element, target_dict has one element | such as pred_dict has one element, target_dict has one element | ||||
:param pred_dict: | :param pred_dict: | ||||
:param target_dict: | :param target_dict: | ||||
:return: boolean, whether to go on codes in self.__call__(). When False, don't go on. | |||||
:return: dict, if dict is not None, pass it to self.evaluate. Otherwise do mapping. | |||||
""" | """ | ||||
if len(pred_dict) == 1 and len(target_dict) == 1: | |||||
pred = list(pred_dict.values())[0] | |||||
target = list(target_dict.values())[0] | |||||
self.evaluate(pred=pred, target=target) | |||||
return True | |||||
return False | |||||
def evaluate(self, pred, target, masks=None, seq_lens=None): | |||||
fast_param = {} | |||||
targets = list(target_dict.values()) | |||||
if len(targets)==1 and isinstance(targets[0], torch.Tensor): | |||||
if len(pred_dict)==1: | |||||
pred = list(pred_dict.values())[0] | |||||
fast_param['pred'] = pred | |||||
elif len(pred_dict)==2: | |||||
pred1 = list(pred_dict.values())[0] | |||||
pred2 = list(pred_dict.values())[1] | |||||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||||
return fast_param | |||||
if len(pred1.size())>len(pred2.size()): | |||||
fast_param['pred'] = pred1 | |||||
fast_param['seq_lens'] = pred2 | |||||
else: | |||||
return fast_param | |||||
fast_param['target'] = targets[0] | |||||
return fast_param | |||||
def evaluate(self, pred, target, seq_lens=None): | |||||
""" | """ | ||||
:param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | :param pred: List of (torch.Tensor, or numpy.ndarray). Element's shape can be: | ||||
torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | torch.Size([B,]), torch.Size([B, n_classes]), torch.Size([B, max_len]), torch.Size([B, max_len, n_classes]) | ||||
:param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | :param target: List of (torch.Tensor, or numpy.ndarray). Element's can be: | ||||
torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | torch.Size([B,]), torch.Size([B,]), torch.Size([B, max_len]), torch.Size([B, max_len]) | ||||
:param masks: List of (torch.Tensor, or numpy.ndarray). Element's can be: | |||||
None, None, torch.Size([B, max_len], torch.Size([B, max_len]) | |||||
:param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | :param seq_lens: List of (torch.Tensor, or numpy.ndarray). Element's can be: | ||||
None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | None, None, torch.Size([B], torch.Size([B]). ignored if masks are provided. | ||||
:return: dict({'acc': float}) | :return: dict({'acc': float}) | ||||
@@ -221,15 +229,14 @@ class AccuracyMetric(MetricBase): | |||||
raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`target` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(target)}.") | f"got {type(target)}.") | ||||
if masks is not None and not isinstance(masks, torch.Tensor): | |||||
raise TypeError(f"`masks` in {get_func_signature(self.evaluate)} must be torch.Tensor," | |||||
f"got {type(masks)}.") | |||||
elif seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | |||||
if seq_lens is not None and not isinstance(seq_lens, torch.Tensor): | |||||
raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | raise TypeError(f"`seq_lens` in {get_func_signature(self.evaluate)} must be torch.Tensor," | ||||
f"got {type(seq_lens)}.") | f"got {type(seq_lens)}.") | ||||
if masks is None and seq_lens is not None: | |||||
if seq_lens is not None: | |||||
masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | masks = seq_lens_to_masks(seq_lens=seq_lens, float=True) | ||||
else: | |||||
masks = None | |||||
if pred.size() == target.size(): | if pred.size() == target.size(): | ||||
pass | pass | ||||
@@ -29,7 +29,7 @@ class Trainer(object): | |||||
"""Main Training Loop | """Main Training Loop | ||||
""" | """ | ||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, update_every=50, | |||||
def __init__(self, train_data, model, losser=None, metrics=None, n_epochs=3, batch_size=32, print_every=50, | |||||
validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | validate_every=-1, dev_data=None, use_cuda=False, save_path=None, | ||||
optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | optimizer=Adam(lr=0.01, weight_decay=0), check_code_level=0, | ||||
metric_key=None, sampler=RandomSampler(), use_tqdm=True): | metric_key=None, sampler=RandomSampler(), use_tqdm=True): | ||||
@@ -41,7 +41,7 @@ class Trainer(object): | |||||
:param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics | :param MetricBase or List[MetricBase] metrics: a metric object or a list of metrics | ||||
:param int n_epochs: the number of training epochs | :param int n_epochs: the number of training epochs | ||||
:param int batch_size: batch size for training and validation | :param int batch_size: batch size for training and validation | ||||
:param int update_every: step interval to print next training information. Default: -1(no print). | |||||
:param int print_every: step interval to print next training information. Default: -1(no print). | |||||
:param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | :param int validate_every: step interval to do next validation. Default: -1(validate every epoch). | ||||
:param DataSet dev_data: the validation data | :param DataSet dev_data: the validation data | ||||
:param use_cuda: | :param use_cuda: | ||||
@@ -106,7 +106,7 @@ class Trainer(object): | |||||
self.batch_size = int(batch_size) | self.batch_size = int(batch_size) | ||||
self.use_cuda = bool(use_cuda) | self.use_cuda = bool(use_cuda) | ||||
self.save_path = save_path | self.save_path = save_path | ||||
self.print_every = int(update_every) | |||||
self.print_every = int(print_every) | |||||
self.validate_every = int(validate_every) | self.validate_every = int(validate_every) | ||||
self.best_metric_indicator = None | self.best_metric_indicator = None | ||||
self.sampler = sampler | self.sampler = sampler | ||||
@@ -214,7 +214,7 @@ class CheckError(Exception): | |||||
""" | """ | ||||
def __init__(self, check_res: CheckRes, func_signature: str): | def __init__(self, check_res: CheckRes, func_signature: str): | ||||
errs = [f'The following problems occurred when calling `{func_signature}`'] | |||||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||||
if check_res.varargs: | if check_res.varargs: | ||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | ||||
@@ -276,8 +276,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
f"target is {list(target_dict.keys())}).") | f"target is {list(target_dict.keys())}).") | ||||
if _miss_out_dataset: | if _miss_out_dataset: | ||||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | ||||
f"target is {list(target_dict.keys())}) or output it " | |||||
f"in {prev_func_signature}(Right now it outputs {list(pred_dict.keys())}).") | |||||
f"target has {list(target_dict.keys())}) or output it " | |||||
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | |||||
if _unused_field: | if _unused_field: | ||||
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | ||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
@@ -291,7 +291,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
@@ -341,7 +341,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
errs.extend(_unused) | errs.extend(_unused) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'The following problems occurred when calling {func_signature}') | |||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||||
sugg_str = "" | sugg_str = "" | ||||
if len(suggestions) > 1: | if len(suggestions) > 1: | ||||
for idx, sugg in enumerate(suggestions): | for idx, sugg in enumerate(suggestions): | ||||
@@ -356,7 +356,7 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
warnings.warn(message=_unused_warn) | warnings.warn(message=_unused_warn) | ||||
def seq_lens_to_masks(seq_lens, float=True): | |||||
def seq_lens_to_masks(seq_lens, float=False): | |||||
""" | """ | ||||
Convert seq_lens to masks. | Convert seq_lens to masks. | ||||
@@ -6,131 +6,126 @@ import torch | |||||
import numpy as np | import numpy as np | ||||
class TestAccuracyMetric(unittest.TestCase): | class TestAccuracyMetric(unittest.TestCase): | ||||
# def test_AccuracyMetric1(self): | |||||
# # (1) only input, targets passed | |||||
# pred_dict = {"pred": torch.zeros(4, 3)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric = AccuracyMetric() | |||||
# | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# print(metric.get_metric()) | |||||
# | |||||
# def test_AccuracyMetric2(self): | |||||
# # (2) with corrupted size | |||||
# try: | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric = AccuracyMetric() | |||||
# | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# print(metric.get_metric()) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
# | |||||
# def test_AccuracyMetric3(self): | |||||
# # (3) with check=False , the second batch is corrupted size | |||||
# try: | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# | |||||
# print(metric.get_metric()) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
# | |||||
# def test_AccuracyMetric4(self): | |||||
# # (4) with check=True , the second batch is corrupted size | |||||
# try: | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# | |||||
# print(metric.get_metric()) | |||||
# | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuracyMetric1(self): | |||||
# (1) only input, targets passed | |||||
pred_dict = {"pred": torch.zeros(4, 3)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
# | # | ||||
# def test_AccuaryMetric5(self): | |||||
# # (5) check reset | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
def test_AccuracyMetric2(self): | |||||
# (2) with corrupted size | |||||
try: | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric = AccuracyMetric() | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
# | # | ||||
# pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'target': torch.zeros(4, 3)+1} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||||
def test_AccuracyMetric3(self): | |||||
# (3) the second batch is corrupted size | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
print(metric.get_metric()) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
# | # | ||||
# def test_AccuaryMetric6(self): | |||||
# # (6) check numpy array is not acceptable | |||||
# try: | |||||
# metric = AccuracyMetric() | |||||
# pred_dict = {"pred": np.zeros((4, 3, 2))} | |||||
# target_dict = {'target': np.zeros((4, 3))} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
# def test_AccuaryMetric7(self): | |||||
# # (7) check map, match | |||||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||||
# pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'targets': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
def test_AccuaryMetric4(self): | |||||
# (5) check reset | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)+1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc':0}) | |||||
def test_AccuaryMetric5(self): | |||||
# (5) check reset | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(reset=False), {'acc': 1}) | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||||
target_dict = {'target': torch.zeros(4, 3)+1} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc':0.5}) | |||||
# | # | ||||
# def test_AccuaryMetric8(self): | |||||
# # (8) check map, does not match | |||||
# try: | |||||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||||
# pred_dict = {"prediction": torch.zeros(4, 3, 2)} | |||||
# target_dict = {'targets': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict, check=True) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
# def test_AccuaryMetric9(self): | |||||
# # (9) check map, include unused | |||||
# try: | |||||
# metric = AccuracyMetric(pred='predictions', target='targets') | |||||
# pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||||
# target_dict = {'targets': torch.zeros(4, 3)} | |||||
# metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
# self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
# except Exception as e: | |||||
# print(e) | |||||
# return | |||||
# self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric6(self): | |||||
# (6) check numpy array is not acceptable | |||||
try: | |||||
metric = AccuracyMetric() | |||||
pred_dict = {"pred": np.zeros((4, 3, 2))} | |||||
target_dict = {'target': np.zeros((4, 3))} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric7(self): | |||||
# (7) check map, match | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
def test_AccuaryMetric8(self): | |||||
# (8) check map, does not match. use stop_fast_param to stop fast param map | |||||
try: | |||||
metric = AccuracyMetric(pred='predictions', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), "stop_fast_param":1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric9(self): | |||||
# (9) check map, include unused | |||||
try: | |||||
metric = AccuracyMetric(pred='prediction', target='targets') | |||||
pred_dict = {"prediction": torch.zeros(4, 3, 2), 'unused':1} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | |||||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||||
except Exception as e: | |||||
print(e) | |||||
return | |||||
self.assertTrue(True, False), "No exception catches." | |||||
def test_AccuaryMetric10(self): | def test_AccuaryMetric10(self): | ||||
# (10) check _fast_metric | # (10) check _fast_metric | ||||
try: | try: | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
pred_dict = {"predictions": torch.zeros(4, 3, 2)} | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "masks": torch.zeros(4, 3)} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||
@@ -1,6 +1,8 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
from torch import nn | |||||
import torch.nn.functional as F | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
@@ -11,19 +13,29 @@ from fastNLP.core.trainer import Trainer | |||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
class TrainerTestGround(unittest.TestCase): | |||||
def test_case(self): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
def prepare_fake_dataset(): | |||||
mean = np.array([-3, -3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
mean = np.array([3, 3]) | |||||
cov = np.array([[1, 0], [0, 1]]) | |||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||||
return data_set | |||||
def prepare_fake_dataset2(*args, size=100): | |||||
ys = np.random.randint(4, size=100) | |||||
data = {'y': ys} | |||||
for arg in args: | |||||
data[arg] = np.random.randn(size, 5) | |||||
return DataSet(data=data) | |||||
class TrainerTestGround(unittest.TestCase): | |||||
def test_case(self): | |||||
data_set = prepare_fake_dataset() | |||||
data_set.set_input("x", flag=True) | data_set.set_input("x", flag=True) | ||||
data_set.set_target("y", flag=True) | data_set.set_target("y", flag=True) | ||||
@@ -36,10 +48,101 @@ class TrainerTestGround(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
n_epochs=10, | n_epochs=10, | ||||
batch_size=32, | batch_size=32, | ||||
update_every=1, | |||||
validate_every=10, | |||||
print_every=50, | |||||
validate_every=-1, | |||||
dev_data=dev_set, | dev_data=dev_set, | ||||
optimizer=SGD(lr=0.1), | optimizer=SGD(lr=0.1), | ||||
check_code_level=2, | check_code_level=2, | ||||
use_tqdm=True) | use_tqdm=True) | ||||
trainer.train() | |||||
trainer.train() | |||||
def test_trainer_suggestion1(self): | |||||
# 检查报错提示能否正确提醒用户。 | |||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | |||||
dataset = prepare_fake_dataset2('x') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model | |||||
) | |||||
""" | |||||
# 应该获取到的报错提示 | |||||
NameError: | |||||
The following problems occurred when calling Model.forward(self, x1, x2, y) | |||||
missing param: ['y', 'x1', 'x2'] | |||||
Suggestion: (1). You might need to set ['y'] as input. | |||||
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | |||||
""" | |||||
def test_trainer_suggestion2(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,看是否可以运行 | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
def test_trainer_suggestion3(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,但是forward没有返回loss这个key | |||||
dataset = prepare_fake_dataset2('x1', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'wrong_loss_key': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
def test_case2(self): | |||||
# check metrics Wrong | |||||
data_set = prepare_fake_dataset2('x1', 'x2') |