| @@ -3,9 +3,9 @@ import inspect | |||
| import os | |||
| from collections import Counter | |||
| from collections import namedtuple | |||
| from collections import defaultdict | |||
| import torch | |||
| CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) | |||
| def save_pickle(obj, pickle_path, file_name): | |||
| """Save an object into a pickle file. | |||
| @@ -121,14 +121,11 @@ def _check_arg_dict_list(func, args): | |||
| missing = list(require_args - input_args) | |||
| unused = list(input_args - all_args) | |||
| check_res = {} | |||
| check_res['missing'] = missing | |||
| check_res['unused'] = unused | |||
| check_res['duplicated'] = duplicated | |||
| check_res['required'] = list(require_args) | |||
| check_res['all_needed'] = list(all_args) | |||
| return check_res | |||
| return CheckRes(missing=missing, | |||
| unused=unused, | |||
| duplicated=duplicated, | |||
| required=list(require_args), | |||
| all_needed=list(all_args)) | |||
| def get_func_signature(func): | |||
| """ | |||
| @@ -165,6 +162,19 @@ def get_func_signature(func): | |||
| signature_str = func.__name__ + signature_str | |||
| return signature_str | |||
| def _is_function_or_method(func): | |||
| """ | |||
| :param func: | |||
| :return: | |||
| """ | |||
| if not inspect.ismethod(func) and not inspect.isfunction(func): | |||
| return False | |||
| return True | |||
| def _check_function_or_method(func): | |||
| if not _is_function_or_method(func): | |||
| raise TypeError(f"{type(func)} is not a method or function.") | |||
| def _syn_model_data(model, *args): | |||
| """ | |||
| @@ -204,3 +214,19 @@ def _move_dict_value_to_device(device, *args): | |||
| else: | |||
| raise TypeError("Only support `dict` type right now.") | |||
| class CheckError(Exception): | |||
| """ | |||
| CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
| """ | |||
| def __init__(self, check_res): | |||
| err = '' | |||
| if check_res['missing']: | |||
| err += f"Missing: {check_res['missing']}\n" | |||
| if check_res['duplicated']: | |||
| err += f"Duplicated: {check_res['duplicated']}\n" | |||
| if check_res['unused']: | |||
| err += f"Unused: {check_res['unused']}\n" | |||
| Exception.__init__(self, err) | |||
| self.check_res = check_res | |||