Browse Source

1. unused报错运行错误修复

2. loss中修复一个错误
3. metric中fast_param调整
tags/v0.2.0^2
yh 6 years ago
parent
commit
a1a41c2d8b
5 changed files with 108 additions and 15 deletions
  1. +3
    -3
      fastNLP/core/losses.py
  2. +11
    -3
      fastNLP/core/metrics.py
  3. +2
    -0
      fastNLP/core/trainer.py
  4. +13
    -9
      fastNLP/core/utils.py
  5. +79
    -0
      test/core/test_trainer.py

+ 3
- 3
fastNLP/core/losses.py View File

@@ -147,7 +147,7 @@ class LossBase(object):
if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}") raise RuntimeError(f"loss ERROR: loss except a torch.Tensor but get {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size()}")


return loss return loss


@@ -219,8 +219,8 @@ class LossInForward(LossBase):


if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0): if not (isinstance(loss, torch.Tensor) and len(loss.size()) == 0):
if not isinstance(loss, torch.Tensor): if not isinstance(loss, torch.Tensor):
raise TypeError(f"loss ERROR: loss except a torch.Tensor but got {type(loss)}")
raise RuntimeError(f"loss ERROR: the size of loss except torch.Size([]) but got {loss.size}")
raise TypeError(f"loss excepts to be a torch.Tensor, got {type(loss)}")
raise RuntimeError(f"The size of loss excepts to be torch.Size([]), got {loss.size()}")


return loss return loss




+ 11
- 3
fastNLP/core/metrics.py View File

@@ -202,12 +202,20 @@ class AccuracyMetric(MetricBase):
pred2 = list(pred_dict.values())[1] pred2 = list(pred_dict.values())[1]
if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)): if not (isinstance(pred1, torch.Tensor) and isinstance(pred2, torch.Tensor)):
return fast_param return fast_param
if len(pred1.size())>len(pred2.size()):
fast_param['pred'] = pred1
fast_param['seq_lens'] = pred2
if len(pred1.size())<len(pred2.size()) and len(pred1.size())==1:
seq_lens = pred1
pred = pred2
elif len(pred1.size())>len(pred2.size()) and len(pred2.size())==1:
seq_lens = pred2
pred = pred1
else:
return fast_param
fast_param['pred'] = pred
fast_param['seq_lens'] = seq_lens
else: else:
return fast_param return fast_param
fast_param['target'] = targets[0] fast_param['target'] = targets[0]
# TODO need to make sure they all have same batch_size
return fast_param return fast_param


def evaluate(self, pred, target, seq_lens=None): def evaluate(self, pred, target, seq_lens=None):


+ 2
- 0
fastNLP/core/trainer.py View File

@@ -48,6 +48,8 @@ class Trainer(object):
:param str save_path: file path to save models :param str save_path: file path to save models
:param Optimizer optimizer: an optimizer object :param Optimizer optimizer: an optimizer object
:param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict. :param int check_code_level: level of FastNLP code checker. -1: don't check, 0: ignore. 1: warning. 2: strict.
`ignore` will not check unused field; `warning` when warn if some field are not used; `strict` means
it will raise error if some field are not used.
:param str metric_key: a single indicator used to decide the best model based on metric results. It must be one :param str metric_key: a single indicator used to decide the best model based on metric results. It must be one
of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets of the keys returned by the FIRST metric in `metrics`. If the overall result gets better if the indicator gets
smaller, add a `-` character in front of the string. For example smaller, add a `-` character in front of the string. For example


+ 13
- 9
fastNLP/core/utils.py View File

@@ -254,9 +254,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
else: else:
_unused_param.append(_unused) _unused_param.append(_unused)
if _unused_field: if _unused_field:
unuseds.append([f"\tunused field: {_unused_field}"])
unuseds.append(f"\tunused field: {_unused_field}")
if _unused_param: if _unused_param:
unuseds.append([f"\tunused param: {_unused_param}"]) # output from predict or forward
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward


if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}") errs.append(f"\tmissing param: {check_res.missing}")
@@ -278,8 +278,8 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
_tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now " _tmp = (f"You might need to provide {_miss_out_dataset} in DataSet and set it as target(Right now "
f"target has {list(target_dict.keys())}) or output it " f"target has {list(target_dict.keys())}) or output it "
f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).") f"in {prev_func_signature}(Right now output has {list(pred_dict.keys())}).")
if _unused_field:
_tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
# if _unused_field:
# _tmp += f"You can use DataSet.rename_field() to rename the field in `unused field:`. "
suggestions.append(_tmp) suggestions.append(_tmp)


if check_res.duplicated: if check_res.duplicated:
@@ -287,7 +287,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
suggestions.append(f"Delete {check_res.duplicated} in the output of " suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")


if check_level == STRICT_CHECK_LEVEL:
if len(errs)>0:
errs.extend(unuseds)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds) errs.extend(unuseds)


if len(errs) > 0: if len(errs) > 0:
@@ -330,14 +332,16 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
suggestions.append(f"You might need to set {_miss_in_dataset} as input. ") suggestions.append(f"You might need to set {_miss_in_dataset} as input. ")
if _miss_out_dataset: if _miss_out_dataset:
_tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. " _tmp = f"You need to provide {_miss_out_dataset} in DataSet and set it as input. "
if check_res.unused:
_tmp += f"Or you might find it is in `unused field:`, you can use DataSet.rename_field() to " \
f"rename the field in `unused field:`."
# if check_res.unused:
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`."
suggestions.append(_tmp) suggestions.append(_tmp)


if check_res.unused: if check_res.unused:
_unused = [f"\tunused field: {check_res.unused}"] _unused = [f"\tunused field: {check_res.unused}"]
if check_level == STRICT_CHECK_LEVEL:
if len(errs)>0:
errs.extend(_unused)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused) errs.extend(_unused)


if len(errs) > 0: if len(errs) > 0:


+ 79
- 0
test/core/test_trainer.py View File

@@ -7,6 +7,7 @@ import torch.nn.functional as F
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss from fastNLP.core.losses import BCELoss
from fastNLP.core.losses import LossInForward
from fastNLP.core.metrics import AccuracyMetric from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer from fastNLP.core.trainer import Trainer
@@ -142,6 +143,84 @@ class TrainerTestGround(unittest.TestCase):
# 应该正确运行 # 应该正确运行
""" """


def test_trainer_suggestion4(self):
# 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,是否可以正确提示unused
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.set_input('x1', 'x_unused', 'y', flag=True)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)

def test_trainer_suggestion5(self):
# 检查报错提示能否正确提醒用户
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2', 'y')
dataset.set_target('y')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)

def test_trainer_suggestion6(self):
# 检查报错提示能否正确提醒用户
# 这里传入多余参数,让其duplicate
dataset = prepare_fake_dataset2('x1', 'x_unused')
dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2', 'y')
dataset.set_target('x1')
class Model(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y):
x1 = self.fc(x1)
x2 = self.fc(x2)
x = x1 + x2
loss = F.cross_entropy(x, y)
return {'pred': x}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2
)



def test_case2(self): def test_case2(self):
# check metrics Wrong # check metrics Wrong


Loading…
Cancel
Save