Browse Source

add _method_function

tags/v0.2.0^2
yh 6 years ago
parent
commit
2c8bd9575a
1 changed files with 35 additions and 9 deletions
  1. +35
    -9
      fastNLP/core/utils.py

+ 35
- 9
fastNLP/core/utils.py View File

@@ -3,9 +3,9 @@ import inspect
import os import os
from collections import Counter from collections import Counter
from collections import namedtuple from collections import namedtuple
from collections import defaultdict
import torch import torch


CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)
def save_pickle(obj, pickle_path, file_name): def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file. """Save an object into a pickle file.


@@ -121,14 +121,11 @@ def _check_arg_dict_list(func, args):
missing = list(require_args - input_args) missing = list(require_args - input_args)
unused = list(input_args - all_args) unused = list(input_args - all_args)


check_res = {}
check_res['missing'] = missing
check_res['unused'] = unused
check_res['duplicated'] = duplicated
check_res['required'] = list(require_args)
check_res['all_needed'] = list(all_args)

return check_res
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))


def get_func_signature(func): def get_func_signature(func):
""" """
@@ -165,6 +162,19 @@ def get_func_signature(func):
signature_str = func.__name__ + signature_str signature_str = func.__name__ + signature_str
return signature_str return signature_str


def _is_function_or_method(func):
"""

:param func:
:return:
"""
if not inspect.ismethod(func) and not inspect.isfunction(func):
return False
return True

def _check_function_or_method(func):
if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.")


def _syn_model_data(model, *args): def _syn_model_data(model, *args):
""" """
@@ -204,3 +214,19 @@ def _move_dict_value_to_device(device, *args):
else: else:
raise TypeError("Only support `dict` type right now.") raise TypeError("Only support `dict` type right now.")



class CheckError(Exception):
"""

CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res):
err = ''
if check_res['missing']:
err += f"Missing: {check_res['missing']}\n"
if check_res['duplicated']:
err += f"Duplicated: {check_res['duplicated']}\n"
if check_res['unused']:
err += f"Unused: {check_res['unused']}\n"
Exception.__init__(self, err)
self.check_res = check_res

Loading…
Cancel
Save