Browse Source

add _method_function

tags/v0.2.0^2
yh 6 years ago
parent
commit
2c8bd9575a
1 changed files with 35 additions and 9 deletions
  1. +35
    -9
      fastNLP/core/utils.py

+ 35
- 9
fastNLP/core/utils.py View File

@@ -3,9 +3,9 @@ import inspect
import os
from collections import Counter
from collections import namedtuple
from collections import defaultdict
import torch

CheckRes = namedtuple('CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed'], verbose=False)
def save_pickle(obj, pickle_path, file_name):
"""Save an object into a pickle file.

@@ -121,14 +121,11 @@ def _check_arg_dict_list(func, args):
missing = list(require_args - input_args)
unused = list(input_args - all_args)

check_res = {}
check_res['missing'] = missing
check_res['unused'] = unused
check_res['duplicated'] = duplicated
check_res['required'] = list(require_args)
check_res['all_needed'] = list(all_args)

return check_res
return CheckRes(missing=missing,
unused=unused,
duplicated=duplicated,
required=list(require_args),
all_needed=list(all_args))

def get_func_signature(func):
"""
@@ -165,6 +162,19 @@ def get_func_signature(func):
signature_str = func.__name__ + signature_str
return signature_str

def _is_function_or_method(func):
"""

:param func:
:return:
"""
if not inspect.ismethod(func) and not inspect.isfunction(func):
return False
return True

def _check_function_or_method(func):
if not _is_function_or_method(func):
raise TypeError(f"{type(func)} is not a method or function.")

def _syn_model_data(model, *args):
"""
@@ -204,3 +214,19 @@ def _move_dict_value_to_device(device, *args):
else:
raise TypeError("Only support `dict` type right now.")


class CheckError(Exception):
"""

CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res):
err = ''
if check_res['missing']:
err += f"Missing: {check_res['missing']}\n"
if check_res['duplicated']:
err += f"Duplicated: {check_res['duplicated']}\n"
if check_res['unused']:
err += f"Unused: {check_res['unused']}\n"
Exception.__init__(self, err)
self.check_res = check_res

Loading…
Cancel
Save