@@ -147,7 +147,7 @@ class LossBase(object): | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") | raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") | ||||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size()}") | |||||
return loss | return loss | ||||
@@ -219,8 +219,8 @@ class LossInForward(LossBase): | |||||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | ||||
if not isinstance(loss, torch.Tensor): | if not isinstance(loss, torch.Tensor): | ||||
raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}") | |||||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||||
return loss | return loss | ||||
@@ -202,12 +202,20 @@ class AccuracyMetric(MetricBase): | |||||
pred2 = list(pred_dict.values())[1] | pred2 = list(pred_dict.values())[1] | ||||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | ||||
return fast_param | return fast_param | ||||
if len(pred1.size())>len(pred2.size()): | |||||
fast_param['pred'] = pred1 | |||||
fast_param['seq_lens'] = pred2 | |||||
if len(pred1.size())<len(pred2.size()) and len(pred1.size())==1: | |||||
seq_lens = pred1 | |||||
pred = pred2 | |||||
elif len(pred1.size())>len(pred2.size()) and len(pred2.size())==1: | |||||
seq_lens = pred2 | |||||
pred = pred1 | |||||
else: | |||||
return fast_param | |||||
fast_param['pred'] = pred | |||||
fast_param['seq_lens'] = seq_lens | |||||
else: | else: | ||||
return fast_param | return fast_param | ||||
fast_param['target'] = targets[0] | fast_param['target'] = targets[0] | ||||
# TODO need to make sure they all have same batch_size | |||||
return fast_param | return fast_param | ||||
def evaluate(self, pred, target, seq_lens=None): | def evaluate(self, pred, target, seq_lens=None): | ||||
@@ -48,6 +48,8 @@ class Trainer(object): | |||||
:param str save_path: file path to save models | :param str save_path: file path to save models | ||||
:param Optimizer optimizer: an optimizer object | :param Optimizer optimizer: an optimizer object | ||||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | ||||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | |||||
it will raise error if some field are not used. | |||||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | ||||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | ||||
smaller, add a `-` character in front of the string. For example | smaller, add a `-` character in front of the string. For example | ||||
@@ -254,9 +254,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
else: | else: | ||||
_unused_param.append(_unused) | _unused_param.append(_unused) | ||||
if _unused_field: | if _unused_field: | ||||
unuseds.append([f"\tunused field: {_unused_field}"]) | |||||
unuseds.append(f"\tunused field: {_unused_field}") | |||||
if _unused_param: | if _unused_param: | ||||
unuseds.append([f"\tunused param: {_unused_param}"]) # output from predict or forward | |||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}") | errs.append(f"\tmissing param: {check_res.missing}") | ||||
@@ -278,8 +278,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | ||||
f"target has {list(target_dict.keys())}) or output it " | f"target has {list(target_dict.keys())}) or output it " | ||||
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | ||||
if _unused_field: | |||||
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||||
# if _unused_field: | |||||
# _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.duplicated: | if check_res.duplicated: | ||||
@@ -287,7 +287,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | suggestions.append(f"Delete {check_res.duplicated} in the output of " | ||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | ||||
if check_level == STRICT_CHECK_LEVEL: | |||||
if len(errs)>0: | |||||
errs.extend(unuseds) | |||||
elif check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
@@ -330,14 +332,16 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | ||||
if _miss_out_dataset: | if _miss_out_dataset: | ||||
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | _tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | ||||
if check_res.unused: | |||||
_tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \ | |||||
f"rename the field in `unused field:`." | |||||
# if check_res.unused: | |||||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||||
# f"rename the field in `unused field:`." | |||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.unused: | if check_res.unused: | ||||
_unused = [f"\tunused field: {check_res.unused}"] | _unused = [f"\tunused field: {check_res.unused}"] | ||||
if check_level == STRICT_CHECK_LEVEL: | |||||
if len(errs)>0: | |||||
errs.extend(_unused) | |||||
elif check_level == STRICT_CHECK_LEVEL: | |||||
errs.extend(_unused) | errs.extend(_unused) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
@@ -7,6 +7,7 @@ import torch.nn.functional as F | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
from fastNLP.core.losses import LossInForward | |||||
from fastNLP.core.metrics import AccuracyMetric | from fastNLP.core.metrics import AccuracyMetric | ||||
from fastNLP.core.optimizer import SGD | from fastNLP.core.optimizer import SGD | ||||
from fastNLP.core.trainer import Trainer | from fastNLP.core.trainer import Trainer | ||||
@@ -142,6 +143,84 @@ class TrainerTestGround(unittest.TestCase): | |||||
# 应该正确运行 | # 应该正确运行 | ||||
""" | """ | ||||
def test_trainer_suggestion4(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入forward需要的数据,是否可以正确提示unused | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.set_input('x1', 'x_unused', 'y', flag=True) | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_trainer_suggestion5(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y') | |||||
dataset.set_target('y') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'loss': loss} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_trainer_suggestion6(self): | |||||
# 检查报错提示能否正确提醒用户 | |||||
# 这里传入多余参数,让其duplicate | |||||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||||
dataset.rename_field('x_unused', 'x2') | |||||
dataset.set_input('x1', 'x2', 'y') | |||||
dataset.set_target('x1') | |||||
class Model(nn.Module): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.fc = nn.Linear(5, 4) | |||||
def forward(self, x1, x2, y): | |||||
x1 = self.fc(x1) | |||||
x2 = self.fc(x2) | |||||
x = x1 + x2 | |||||
loss = F.cross_entropy(x, y) | |||||
return {'pred': x} | |||||
model = Model() | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_case2(self): | def test_case2(self): | ||||
# check metrics Wrong | # check metrics Wrong | ||||