|
|
@@ -3,9 +3,9 @@ import inspect |
|
|
|
import os |
|
|
|
from collections import Counter |
|
|
|
from collections import namedtuple |
|
|
|
from collections import defaultdict |
|
|
|
import torch |
|
|
|
|
|
|
|
CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False) |
|
|
|
def save_pickle(obj, pickle_path, file_name): |
|
|
|
"""Save an object into a pickle file. |
|
|
|
|
|
|
@@ -121,14 +121,11 @@ def _check_arg_dict_list(func, args): |
|
|
|
missing = list(require_args - input_args) |
|
|
|
unused = list(input_args - all_args) |
|
|
|
|
|
|
|
check_res = {} |
|
|
|
check_res['missing'] = missing |
|
|
|
check_res['unused'] = unused |
|
|
|
check_res['duplicated'] = duplicated |
|
|
|
check_res['required'] = list(require_args) |
|
|
|
check_res['all_needed'] = list(all_args) |
|
|
|
|
|
|
|
return check_res |
|
|
|
return CheckRes(missing=missing, |
|
|
|
unused=unused, |
|
|
|
duplicated=duplicated, |
|
|
|
required=list(require_args), |
|
|
|
all_needed=list(all_args)) |
|
|
|
|
|
|
|
def get_func_signature(func): |
|
|
|
""" |
|
|
@@ -165,6 +162,19 @@ def get_func_signature(func): |
|
|
|
signature_str = func.__name__ + signature_str |
|
|
|
return signature_str |
|
|
|
|
|
|
|
def _is_function_or_method(func): |
|
|
|
""" |
|
|
|
|
|
|
|
:param func: |
|
|
|
:return: |
|
|
|
""" |
|
|
|
if not inspect.ismethod(func) and not inspect.isfunction(func): |
|
|
|
return False |
|
|
|
return True |
|
|
|
|
|
|
|
def _check_function_or_method(func): |
|
|
|
if not _is_function_or_method(func): |
|
|
|
raise TypeError(f"{type(func)} is not a method or function.") |
|
|
|
|
|
|
|
def _syn_model_data(model, *args): |
|
|
|
""" |
|
|
@@ -204,3 +214,19 @@ def _move_dict_value_to_device(device, *args): |
|
|
|
else: |
|
|
|
raise TypeError("Only support `dict` type right now.") |
|
|
|
|
|
|
|
|
|
|
|
class CheckError(Exception): |
|
|
|
""" |
|
|
|
|
|
|
|
CheckError. Used in losses.LossBase, metrics.MetricBase. |
|
|
|
""" |
|
|
|
def __init__(self, check_res): |
|
|
|
err = '' |
|
|
|
if check_res['missing']: |
|
|
|
err += f"Missing: {check_res['missing']}\n" |
|
|
|
if check_res['duplicated']: |
|
|
|
err += f"Duplicated: {check_res['duplicated']}\n" |
|
|
|
if check_res['unused']: |
|
|
|
err += f"Unused: {check_res['unused']}\n" |
|
|
|
Exception.__init__(self, err) |
|
|
|
self.check_res = check_res |