[core] improve trainer, loss and metricstags/v0.3.0
@@ -63,9 +63,9 @@ class LossBase(object): | |||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
# evaluate should not have varargs. | # evaluate should not have varargs. | ||||
if func_spect.varargs: | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | |||||
f"positional argument.).") | |||||
# if func_spect.varargs: | |||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.get_loss)}(Do not use " | |||||
# f"positional argument.).") | |||||
def _fast_param_map(self, pred_dict, target_dict): | def _fast_param_map(self, pred_dict, target_dict): | ||||
""" | """ | ||||
@@ -148,7 +148,7 @@ class LossBase(object): | |||||
all_needed=check_res.all_needed, | all_needed=check_res.all_needed, | ||||
varargs=check_res.varargs) | varargs=check_res.varargs) | ||||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||||
if check_res.missing or check_res.duplicated: | |||||
raise CheckError(check_res=check_res, | raise CheckError(check_res=check_res, | ||||
func_signature=get_func_signature(self.get_loss)) | func_signature=get_func_signature(self.get_loss)) | ||||
refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.get_loss, **mapped_pred_dict, **mapped_target_dict) | ||||
@@ -63,9 +63,9 @@ class MetricBase(object): | |||||
f"initialization parameters, or change its signature.") | f"initialization parameters, or change its signature.") | ||||
# evaluate should not have varargs. | # evaluate should not have varargs. | ||||
if func_spect.varargs: | |||||
raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||||
f"positional argument.).") | |||||
# if func_spect.varargs: | |||||
# raise NameError(f"Delete `*{func_spect.varargs}` in {get_func_signature(self.evaluate)}(Do not use " | |||||
# f"positional argument.).") | |||||
def get_metric(self, reset=True): | def get_metric(self, reset=True): | ||||
raise NotImplemented | raise NotImplemented | ||||
@@ -165,7 +165,7 @@ class MetricBase(object): | |||||
all_needed=check_res.all_needed, | all_needed=check_res.all_needed, | ||||
varargs=check_res.varargs) | varargs=check_res.varargs) | ||||
if check_res.missing or check_res.duplicated or check_res.varargs: | |||||
if check_res.missing or check_res.duplicated: | |||||
raise CheckError(check_res=check_res, | raise CheckError(check_res=check_res, | ||||
func_signature=get_func_signature(self.evaluate)) | func_signature=get_func_signature(self.evaluate)) | ||||
refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | refined_args = _build_args(self.evaluate, **mapped_pred_dict, **mapped_target_dict) | ||||
@@ -363,11 +363,11 @@ class Trainer(object): | |||||
def _load_model(self, model, model_name, only_param=False): | def _load_model(self, model, model_name, only_param=False): | ||||
# TODO: 这个是不是有问题? | # TODO: 这个是不是有问题? | ||||
if self.save_path is not None: | if self.save_path is not None: | ||||
model_name = os.path.join(self.save_path, model_name) | |||||
model_path = os.path.join(self.save_path, model_name) | |||||
if only_param: | if only_param: | ||||
states = torch.save(model.state_dict(), model_name) | |||||
states = torch.load(model_path) | |||||
else: | else: | ||||
states = torch.save(model, model_name).state_dict() | |||||
states = torch.load(model_path).state_dict() | |||||
model.load_state_dict(states) | model.load_state_dict(states) | ||||
def _better_eval_result(self, metrics): | def _better_eval_result(self, metrics): | ||||
@@ -242,9 +242,9 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
_unused_field = [] | _unused_field = [] | ||||
_unused_param = [] | _unused_param = [] | ||||
suggestions = [] | suggestions = [] | ||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: *{check_res.varargs}") | |||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
# if check_res.varargs: | |||||
# errs.append(f"\tvarargs: *{check_res.varargs}") | |||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
if check_res.unused: | if check_res.unused: | ||||
for _unused in check_res.unused: | for _unused in check_res.unused: | ||||
@@ -344,9 +344,9 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
suggestions = [] | suggestions = [] | ||||
_unused = [] | _unused = [] | ||||
if check_res.varargs: | |||||
errs.append(f"\tvarargs: {check_res.varargs}") | |||||
suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
# if check_res.varargs: | |||||
# errs.append(f"\tvarargs: {check_res.varargs}") | |||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}") | errs.append(f"\tmissing param: {check_res.missing}") | ||||
_miss_in_dataset = [] | _miss_in_dataset = [] | ||||