diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index 757ce465..057b03f2 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -63,9 +63,9 @@ class LossBase(object): f"initialization parameters, or change its signature.") # evaluate should not have varargs. - if func_spect.varargs: - raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " - f"positional argument.).") + # if func_spect.varargs: + # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " + # f"positional argument.).") def _fast_param_map(self, pred_dict, target_dict): """ @@ -148,7 +148,7 @@ class LossBase(object): all_needed=check_res.all_needed, varargs=check_res.varargs) - if check_res.missing or check_res.duplicated or check_res.varargs: + if check_res.missing or check_res.duplicated: raise CheckError(check_res=check_res, func_signature=get_func_signature(self.get_loss)) refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 34a90d5a..07ebe3fe 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -63,9 +63,9 @@ class MetricBase(object): f"initialization parameters, or change its signature.") # evaluate should not have varargs. - if func_spect.varargs: - raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " - f"positional argument.).") + # if func_spect.varargs: + # raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " + # f"positional argument.).") def get_metric(self, reset=True): raise NotImplemented @@ -165,7 +165,7 @@ class MetricBase(object): all_needed=check_res.all_needed, varargs=check_res.varargs) - if check_res.missing or check_res.duplicated or check_res.varargs: + if check_res.missing or check_res.duplicated: raise CheckError(check_res=check_res, func_signature=get_func_signature(self.evaluate)) refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index aa5f978c..bf32fa6c 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -363,11 +363,11 @@ class Trainer(object): def _load_model(self, model, model_name, only_param=False): # TODO: 这个是不是有问题? if self.save_path is not None: - model_name = os.path.join(self.save_path, model_name) + model_path = os.path.join(self.save_path, model_name) if only_param: - states = torch.save(model.state_dict(), model_name) + states = torch.load(model_path) else: - states = torch.save(model, model_name).state_dict() + states = torch.load(model_path).state_dict() model.load_state_dict(states) def _better_eval_result(self, metrics): diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 526ade15..e93c95f4 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -242,9 +242,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re _unused_field = [] _unused_param = [] suggestions = [] - if check_res.varargs: - errs.append(f"\tvarargs: *{check_res.varargs}") - suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") + # if check_res.varargs: + # errs.append(f"\tvarargs: *{check_res.varargs}") + # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") if check_res.unused: for _unused in check_res.unused: @@ -344,9 +344,9 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): suggestions = [] _unused = [] - if check_res.varargs: - errs.append(f"\tvarargs: {check_res.varargs}") - suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") + # if check_res.varargs: + # errs.append(f"\tvarargs: {check_res.varargs}") + # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") if check_res.missing: errs.append(f"\tmissing param: {check_res.missing}") _miss_in_dataset = []