From 3a2833a69e4dc3b6c5fad6b0a6e880b9a3548c58 Mon Sep 17 00:00:00 2001 From: yh Date: Wed, 19 Dec 2018 19:51:55 +0800 Subject: [PATCH] =?UTF-8?q?1.=E8=A7=A3=E5=86=B3trainer=E5=8A=A0=E8=BD=BD?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=9A=84=E9=97=AE=E9=A2=98;=202.=E5=88=A0?= =?UTF-8?q?=E9=99=A4loss=E4=B8=8Emetric=E4=B8=AD=E5=AF=B9*arg=E7=9A=84?= =?UTF-8?q?=E6=8A=A5=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/losses.py | 8 ++++---- fastNLP/core/metrics.py | 8 ++++---- fastNLP/core/trainer.py | 6 +++--- fastNLP/core/utils.py | 12 ++++++------ 4 files changed, 17 insertions(+), 17 deletions(-) diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 757ce465..057b03f2 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -63,9 +63,9 @@ class LossBase(object): f"initialization parameters, or change its signature.") # evaluate should not have varargs. - if func_spect.varargs: - raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " - f"positional argument.).") + # if func_spect.varargs: + # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " + # f"positional argument.).") def _fast_param_map(self, pred_dict, target_dict): """ @@ -148,7 +148,7 @@ class LossBase(object): all_needed=check_res.all_needed, varargs=check_res.varargs) - if check_res.missing or check_res.duplicated or check_res.varargs: + if check_res.missing or check_res.duplicated: raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 34a90d5a..07ebe3fe 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -63,9 +63,9 @@ class MetricBase(object): f"initialization parameters, or change its signature.") # evaluate should not have varargs. - if func_spect.varargs: - raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " - f"positional argument.).") + # if func_spect.varargs: + # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " + # f"positional argument.).") def get_metric(self, reset=True): raise NotImplemented @@ -165,7 +165,7 @@ class MetricBase(object): all_needed=check_res.all_needed, varargs=check_res.varargs) - if check_res.missing or check_res.duplicated or check_res.varargs: + if check_res.missing or check_res.duplicated: raise CheckError(check_res=check_res, func_signature=get_func_signature(self.evaluate)) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index aa5f978c..bf32fa6c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -363,11 +363,11 @@ class Trainer(object): def _load_model(self, model, model_name, only_param=False): # TODO: 这个是不是有问题? if self.save_path is not None: - model_name = os.path.join(self.save_path, model_name) + model_path = os.path.join(self.save_path, model_name) if only_param: - states = torch.save(model.state_dict(), model_name) + states = torch.load(model_path) else: - states = torch.save(model, model_name).state_dict() + states = torch.load(model_path).state_dict() model.load_state_dict(states) def _better_eval_result(self, metrics): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 526ade15..e93c95f4 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -242,9 +242,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re _unused_field = [] _unused_param = [] suggestions = [] - if check_res.varargs: - errs.append(f"\tvarargs: *{check_res.varargs}") - suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") + # if check_res.varargs: + # errs.append(f"\tvarargs: *{check_res.varargs}") + # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") if check_res.unused: for _unused in check_res.unused: @@ -344,9 +344,9 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): suggestions = [] _unused = [] - if check_res.varargs: - errs.append(f"\tvarargs: {check_res.varargs}") - suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") + # if check_res.varargs: + # errs.append(f"\tvarargs: {check_res.varargs}") + # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") if check_res.missing: errs.append(f"\tmissing param: {check_res.missing}") _miss_in_dataset = []