| @@ -147,7 +147,7 @@ class LossBase(object): | |||||
| if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
| if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
| raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") | raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") | ||||
| raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||||
| raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size()}") | |||||
| return loss | return loss | ||||
| @@ -219,8 +219,8 @@ class LossInForward(LossBase): | |||||
| if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
| if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
| raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}") | |||||
| raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||||
| raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||||
| raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
| return loss | return loss | ||||
| @@ -202,12 +202,20 @@ class AccuracyMetric(MetricBase): | |||||
| pred2 = list(pred_dict.values())[1] | pred2 = list(pred_dict.values())[1] | ||||
| if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | ||||
| return fast_param | return fast_param | ||||
| if len(pred1.size())>len(pred2.size()): | |||||
| fast_param['pred'] = pred1 | |||||
| fast_param['seq_lens'] = pred2 | |||||
| if len(pred1.size())<len(pred2.size()) and len(pred1.size())==1: | |||||
| seq_lens = pred1 | |||||
| pred = pred2 | |||||
| elif len(pred1.size())>len(pred2.size()) and len(pred2.size())==1: | |||||
| seq_lens = pred2 | |||||
| pred = pred1 | |||||
| else: | |||||
| return fast_param | |||||
| fast_param['pred'] = pred | |||||
| fast_param['seq_lens'] = seq_lens | |||||
| else: | else: | ||||
| return fast_param | return fast_param | ||||
| fast_param['target'] = targets[0] | fast_param['target'] = targets[0] | ||||
| # TODO need to make sure they all have same batch_size | |||||
| return fast_param | return fast_param | ||||
| def evaluate(self, pred, target, seq_lens=None): | def evaluate(self, pred, target, seq_lens=None): | ||||
| @@ -48,6 +48,8 @@ class Trainer(object): | |||||
| :param str save_path: file path to save models | :param str save_path: file path to save models | ||||
| :param Optimizer optimizer: an optimizer object | :param Optimizer optimizer: an optimizer object | ||||
| :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | ||||
| `ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | |||||
| it will raise error if some field are not used. | |||||
| :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | ||||
| of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | ||||
| smaller, add a `-` character in front of the string. For example | smaller, add a `-` character in front of the string. For example | ||||
| @@ -254,9 +254,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
| else: | else: | ||||
| _unused_param.append(_unused) | _unused_param.append(_unused) | ||||
| if _unused_field: | if _unused_field: | ||||
| unuseds.append([f"\tunused field: {_unused_field}"]) | |||||
| unuseds.append(f"\tunused field: {_unused_field}") | |||||
| if _unused_param: | if _unused_param: | ||||
| unuseds.append([f"\tunused param: {_unused_param}"]) # output from predict or forward | |||||
| unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
| if check_res.missing: | if check_res.missing: | ||||
| errs.append(f"\tmissing param: {check_res.missing}") | errs.append(f"\tmissing param: {check_res.missing}") | ||||
| @@ -278,8 +278,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
| _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | ||||
| f"target has {list(target_dict.keys())}) or output it " | f"target has {list(target_dict.keys())}) or output it " | ||||
| f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | ||||
| if _unused_field: | |||||
| _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||||
| # if _unused_field: | |||||
| # _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||||
| suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
| if check_res.duplicated: | if check_res.duplicated: | ||||
| @@ -287,7 +287,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
| suggestions.append(f"Delete {check_res.duplicated} in the output of " | suggestions.append(f"Delete {check_res.duplicated} in the output of " | ||||
| f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | ||||
| if check_level == STRICT_CHECK_LEVEL: | |||||
| if len(errs)>0: | |||||
| errs.extend(unuseds) | |||||
| elif check_level == STRICT_CHECK_LEVEL: | |||||
| errs.extend(unuseds) | errs.extend(unuseds) | ||||
| if len(errs) > 0: | if len(errs) > 0: | ||||
| @@ -330,14 +332,16 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
| suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | ||||
| if _miss_out_dataset: | if _miss_out_dataset: | ||||
| _tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | _tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | ||||
| if check_res.unused: | |||||
| _tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \ | |||||
| f"rename the field in `unused field:`." | |||||
| # if check_res.unused: | |||||
| # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||||
| # f"rename the field in `unused field:`." | |||||
| suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
| if check_res.unused: | if check_res.unused: | ||||
| _unused = [f"\tunused field: {check_res.unused}"] | _unused = [f"\tunused field: {check_res.unused}"] | ||||
| if check_level == STRICT_CHECK_LEVEL: | |||||
| if len(errs)>0: | |||||
| errs.extend(_unused) | |||||
| elif check_level == STRICT_CHECK_LEVEL: | |||||
| errs.extend(_unused) | errs.extend(_unused) | ||||
| if len(errs) > 0: | if len(errs) > 0: | ||||
| @@ -7,6 +7,7 @@ import torch.nn.functional as F | |||||
| from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
| from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
| from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
| from fastNLP.core.losses import LossInForward | |||||
| from fastNLP.core.metrics import AccuracyMetric | from fastNLP.core.metrics import AccuracyMetric | ||||
| from fastNLP.core.optimizer import SGD | from fastNLP.core.optimizer import SGD | ||||
| from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
| @@ -142,6 +143,84 @@ class TrainerTestGround(unittest.TestCase): | |||||
| # 应该正确运行 | # 应该正确运行 | ||||
| """ | """ | ||||
| def test_trainer_suggestion4(self): | |||||
| # 检查报错提示能否正确提醒用户 | |||||
| # 这里传入forward需要的数据,是否可以正确提示unused | |||||
| dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
| dataset.set_input('x1', 'x_unused', 'y', flag=True) | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(5, 4) | |||||
| def forward(self, x1, x2, y): | |||||
| x1 = self.fc(x1) | |||||
| x2 = self.fc(x2) | |||||
| x = x1 + x2 | |||||
| loss = F.cross_entropy(x, y) | |||||
| return {'loss': loss} | |||||
| model = Model() | |||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| use_tqdm=False, | |||||
| print_every=2 | |||||
| ) | |||||
| def test_trainer_suggestion5(self): | |||||
| # 检查报错提示能否正确提醒用户 | |||||
| # 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | |||||
| dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
| dataset.rename_field('x_unused', 'x2') | |||||
| dataset.set_input('x1', 'x2', 'y') | |||||
| dataset.set_target('y') | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(5, 4) | |||||
| def forward(self, x1, x2, y): | |||||
| x1 = self.fc(x1) | |||||
| x2 = self.fc(x2) | |||||
| x = x1 + x2 | |||||
| loss = F.cross_entropy(x, y) | |||||
| return {'loss': loss} | |||||
| model = Model() | |||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| use_tqdm=False, | |||||
| print_every=2 | |||||
| ) | |||||
| def test_trainer_suggestion6(self): | |||||
| # 检查报错提示能否正确提醒用户 | |||||
| # 这里传入多余参数,让其duplicate | |||||
| dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
| dataset.rename_field('x_unused', 'x2') | |||||
| dataset.set_input('x1', 'x2', 'y') | |||||
| dataset.set_target('x1') | |||||
| class Model(nn.Module): | |||||
| def __init__(self): | |||||
| super().__init__() | |||||
| self.fc = nn.Linear(5, 4) | |||||
| def forward(self, x1, x2, y): | |||||
| x1 = self.fc(x1) | |||||
| x2 = self.fc(x2) | |||||
| x = x1 + x2 | |||||
| loss = F.cross_entropy(x, y) | |||||
| return {'pred': x} | |||||
| model = Model() | |||||
| trainer = Trainer( | |||||
| train_data=dataset, | |||||
| model=model, | |||||
| dev_data=dataset, | |||||
| metrics=AccuracyMetric(), | |||||
| use_tqdm=False, | |||||
| print_every=2 | |||||
| ) | |||||
| def test_case2(self): | def test_case2(self): | ||||
| # check metrics Wrong | # check metrics Wrong | ||||