@@ -147,7 +147,7 @@ class LossBase(object): | |||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||
if not isinstance(loss, torch.Tensor): | |||
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") | |||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size()}") | |||
return loss | |||
@@ -219,8 +219,8 @@ class LossInForward(LossBase): | |||
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): | |||
if not isinstance(loss, torch.Tensor): | |||
raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}") | |||
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}") | |||
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}") | |||
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}") | |||
return loss | |||
@@ -202,12 +202,20 @@ class AccuracyMetric(MetricBase): | |||
pred2 = list(pred_dict.values())[1] | |||
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): | |||
return fast_param | |||
if len(pred1.size())>len(pred2.size()): | |||
fast_param['pred'] = pred1 | |||
fast_param['seq_lens'] = pred2 | |||
if len(pred1.size())<len(pred2.size()) and len(pred1.size())==1: | |||
seq_lens = pred1 | |||
pred = pred2 | |||
elif len(pred1.size())>len(pred2.size()) and len(pred2.size())==1: | |||
seq_lens = pred2 | |||
pred = pred1 | |||
else: | |||
return fast_param | |||
fast_param['pred'] = pred | |||
fast_param['seq_lens'] = seq_lens | |||
else: | |||
return fast_param | |||
fast_param['target'] = targets[0] | |||
# TODO need to make sure they all have same batch_size | |||
return fast_param | |||
def evaluate(self, pred, target, seq_lens=None): | |||
@@ -48,6 +48,8 @@ class Trainer(object): | |||
:param str save_path: file path to save models | |||
:param Optimizer optimizer: an optimizer object | |||
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. | |||
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means | |||
it will raise error if some field are not used. | |||
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one | |||
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets | |||
smaller, add a `-` character in front of the string. For example | |||
@@ -254,9 +254,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
else: | |||
_unused_param.append(_unused) | |||
if _unused_field: | |||
unuseds.append([f"\tunused field: {_unused_field}"]) | |||
unuseds.append(f"\tunused field: {_unused_field}") | |||
if _unused_param: | |||
unuseds.append([f"\tunused param: {_unused_param}"]) # output from predict or forward | |||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}") | |||
@@ -278,8 +278,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " | |||
f"target has {list(target_dict.keys())}) or output it " | |||
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") | |||
if _unused_field: | |||
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||
# if _unused_field: | |||
# _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. " | |||
suggestions.append(_tmp) | |||
if check_res.duplicated: | |||
@@ -287,7 +287,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | |||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||
if check_level == STRICT_CHECK_LEVEL: | |||
if len(errs)>0: | |||
errs.extend(unuseds) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(unuseds) | |||
if len(errs) > 0: | |||
@@ -330,14 +332,16 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") | |||
if _miss_out_dataset: | |||
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " | |||
if check_res.unused: | |||
_tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \ | |||
f"rename the field in `unused field:`." | |||
# if check_res.unused: | |||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||
# f"rename the field in `unused field:`." | |||
suggestions.append(_tmp) | |||
if check_res.unused: | |||
_unused = [f"\tunused field: {check_res.unused}"] | |||
if check_level == STRICT_CHECK_LEVEL: | |||
if len(errs)>0: | |||
errs.extend(_unused) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(_unused) | |||
if len(errs) > 0: | |||
@@ -7,6 +7,7 @@ import torch.nn.functional as F | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
from fastNLP.core.losses import LossInForward | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.optimizer import SGD | |||
from fastNLP.core.trainer import Trainer | |||
@@ -142,6 +143,84 @@ class TrainerTestGround(unittest.TestCase): | |||
# 应该正确运行 | |||
""" | |||
def test_trainer_suggestion4(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入forward需要的数据,是否可以正确提示unused | |||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||
dataset.set_input('x1', 'x_unused', 'y', flag=True) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
def test_trainer_suggestion5(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | |||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||
dataset.rename_field('x_unused', 'x2') | |||
dataset.set_input('x1', 'x2', 'y') | |||
dataset.set_target('y') | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
def test_trainer_suggestion6(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入多余参数,让其duplicate | |||
dataset = prepare_fake_dataset2('x1', 'x_unused') | |||
dataset.rename_field('x_unused', 'x2') | |||
dataset.set_input('x1', 'x2', 'y') | |||
dataset.set_target('x1') | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'pred': x} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
dev_data=dataset, | |||
metrics=AccuracyMetric(), | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
def test_case2(self): | |||
# check metrics Wrong | |||