Browse Source

1.解决trainer加载模型的问题; 2.删除loss与metric中对*arg的报错

tags/v0.3.0
yh 6 years ago
parent
commit
3a2833a69e
4 changed files with 17 additions and 17 deletions
  1. +4
    -4
      fastNLP/core/losses.py
  2. +4
    -4
      fastNLP/core/metrics.py
  3. +3
    -3
      fastNLP/core/trainer.py
  4. +6
    -6
      fastNLP/core/utils.py

+ 4
- 4
fastNLP/core/losses.py View File

@@ -63,9 +63,9 @@ class LossBase(object):
f"initialization parameters, or change its signature.")

# evaluate should not have varargs.
if func_spect.varargs:
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
f"positional argument.).")
# if func_spect.varargs:
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
# f"positional argument.).")

def _fast_param_map(self, pred_dict, target_dict):
"""
@@ -148,7 +148,7 @@ class LossBase(object):
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated or check_res.varargs:
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.get_loss))
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)


+ 4
- 4
fastNLP/core/metrics.py View File

@@ -63,9 +63,9 @@ class MetricBase(object):
f"initialization parameters, or change its signature.")

# evaluate should not have varargs.
if func_spect.varargs:
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use "
f"positional argument.).")
# if func_spect.varargs:
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use "
# f"positional argument.).")

def get_metric(self, reset=True):
raise NotImplemented
@@ -165,7 +165,7 @@ class MetricBase(object):
all_needed=check_res.all_needed,
varargs=check_res.varargs)

if check_res.missing or check_res.duplicated or check_res.varargs:
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)


+ 3
- 3
fastNLP/core/trainer.py View File

@@ -363,11 +363,11 @@ class Trainer(object):
def _load_model(self, model, model_name, only_param=False):
# TODO: 这个是不是有问题?
if self.save_path is not None:
model_name = os.path.join(self.save_path, model_name)
model_path = os.path.join(self.save_path, model_name)
if only_param:
states = torch.save(model.state_dict(), model_name)
states = torch.load(model_path)
else:
states = torch.save(model, model_name).state_dict()
states = torch.load(model_path).state_dict()
model.load_state_dict(states)

def _better_eval_result(self, metrics):


+ 6
- 6
fastNLP/core/utils.py View File

@@ -242,9 +242,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
_unused_field = []
_unused_param = []
suggestions = []
if check_res.varargs:
errs.append(f"\tvarargs: *{check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
# if check_res.varargs:
# errs.append(f"\tvarargs: *{check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")

if check_res.unused:
for _unused in check_res.unused:
@@ -344,9 +344,9 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
suggestions = []
_unused = []

if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
# if check_res.varargs:
# errs.append(f"\tvarargs: {check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = []


Loading…
Cancel
Save