Browse Source

Merge pull request #113 from FengZiYjun/trainer

[core] improve trainer, loss and metrics
tags/v0.3.0
Xipeng Qiu GitHub 6 years ago
parent
commit
7371593fd9
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 17 additions and 17 deletions
  1. +4
    -4
      fastNLP/core/losses.py
  2. +4
    -4
      fastNLP/core/metrics.py
  3. +3
    -3
      fastNLP/core/trainer.py
  4. +6
    -6
      fastNLP/core/utils.py

+ 4
- 4
fastNLP/core/losses.py View File

@@ -63,9 +63,9 @@ class LossBase(object):
f"initialization parameters, or change its signature.") f"initialization parameters, or change its signature.")


# evaluate should not have varargs. # evaluate should not have varargs.
if func_spect.varargs:
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
f"positional argument.).")
# if func_spect.varargs:
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use "
# f"positional argument.).")


def _fast_param_map(self, pred_dict, target_dict): def _fast_param_map(self, pred_dict, target_dict):
""" """
@@ -148,7 +148,7 @@ class LossBase(object):
all_needed=check_res.all_needed, all_needed=check_res.all_needed,
varargs=check_res.varargs) varargs=check_res.varargs)


if check_res.missing or check_res.duplicated or check_res.varargs:
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res, raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.get_loss)) func_signature=get_func_signature(self.get_loss))
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict)


+ 4
- 4
fastNLP/core/metrics.py View File

@@ -63,9 +63,9 @@ class MetricBase(object):
f"initialization parameters, or change its signature.") f"initialization parameters, or change its signature.")


# evaluate should not have varargs. # evaluate should not have varargs.
if func_spect.varargs:
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use "
f"positional argument.).")
# if func_spect.varargs:
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use "
# f"positional argument.).")


def get_metric(self, reset=True): def get_metric(self, reset=True):
raise NotImplemented raise NotImplemented
@@ -165,7 +165,7 @@ class MetricBase(object):
all_needed=check_res.all_needed, all_needed=check_res.all_needed,
varargs=check_res.varargs) varargs=check_res.varargs)


if check_res.missing or check_res.duplicated or check_res.varargs:
if check_res.missing or check_res.duplicated:
raise CheckError(check_res=check_res, raise CheckError(check_res=check_res,
func_signature=get_func_signature(self.evaluate)) func_signature=get_func_signature(self.evaluate))
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict)


+ 3
- 3
fastNLP/core/trainer.py View File

@@ -363,11 +363,11 @@ class Trainer(object):
def _load_model(self, model, model_name, only_param=False): def _load_model(self, model, model_name, only_param=False):
# TODO: 这个是不是有问题? # TODO: 这个是不是有问题?
if self.save_path is not None: if self.save_path is not None:
model_name = os.path.join(self.save_path, model_name)
model_path = os.path.join(self.save_path, model_name)
if only_param: if only_param:
states = torch.save(model.state_dict(), model_name)
states = torch.load(model_path)
else: else:
states = torch.save(model, model_name).state_dict()
states = torch.load(model_path).state_dict()
model.load_state_dict(states) model.load_state_dict(states)


def _better_eval_result(self, metrics): def _better_eval_result(self, metrics):


+ 6
- 6
fastNLP/core/utils.py View File

@@ -242,9 +242,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
_unused_field = [] _unused_field = []
_unused_param = [] _unused_param = []
suggestions = [] suggestions = []
if check_res.varargs:
errs.append(f"\tvarargs: *{check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
# if check_res.varargs:
# errs.append(f"\tvarargs: *{check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")


if check_res.unused: if check_res.unused:
for _unused in check_res.unused: for _unused in check_res.unused:
@@ -344,9 +344,9 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
suggestions = [] suggestions = []
_unused = [] _unused = []


if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}")
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
# if check_res.varargs:
# errs.append(f"\tvarargs: {check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}") errs.append(f"\tmissing param: {check_res.missing}")
_miss_in_dataset = [] _miss_in_dataset = []


Loading…
Cancel
Save