@@ -300,13 +300,14 @@ from .field import FieldArray | |||||
from .field import SetInputOrTargetException | from .field import SetInputOrTargetException | ||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
from .utils import pretty_table_printer | |||||
class DataSet(object): | class DataSet(object): | ||||
""" | """ | ||||
fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` | fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` | ||||
""" | """ | ||||
def __init__(self, data=None): | def __init__(self, data=None): | ||||
""" | """ | ||||
@@ -326,26 +327,26 @@ class DataSet(object): | |||||
for ins in data: | for ins in data: | ||||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | ||||
self.append(ins) | self.append(ins) | ||||
else: | else: | ||||
raise ValueError("data only be dict or list type.") | raise ValueError("data only be dict or list type.") | ||||
def __contains__(self, item): | def __contains__(self, item): | ||||
return item in self.field_arrays | return item in self.field_arrays | ||||
def __iter__(self): | def __iter__(self): | ||||
def iter_func(): | def iter_func(): | ||||
for idx in range(len(self)): | for idx in range(len(self)): | ||||
yield self[idx] | yield self[idx] | ||||
return iter_func() | return iter_func() | ||||
def _inner_iter(self): | def _inner_iter(self): | ||||
class Iter_ptr: | class Iter_ptr: | ||||
def __init__(self, dataset, idx): | def __init__(self, dataset, idx): | ||||
self.dataset = dataset | self.dataset = dataset | ||||
self.idx = idx | self.idx = idx | ||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | ||||
self.idx]) | self.idx]) | ||||
@@ -358,13 +359,13 @@ class DataSet(object): | |||||
def __repr__(self): | def __repr__(self): | ||||
return self.dataset[self.idx].__repr__() | return self.dataset[self.idx].__repr__() | ||||
def inner_iter_func(): | def inner_iter_func(): | ||||
for idx in range(len(self)): | for idx in range(len(self)): | ||||
yield Iter_ptr(self, idx) | yield Iter_ptr(self, idx) | ||||
return inner_iter_func() | return inner_iter_func() | ||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | """给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | ||||
@@ -397,20 +398,20 @@ class DataSet(object): | |||||
return dataset | return dataset | ||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
# Not tested. Don't use !! | # Not tested. Don't use !! | ||||
if item == "field_arrays": | if item == "field_arrays": | ||||
raise AttributeError | raise AttributeError | ||||
if isinstance(item, str) and item in self.field_arrays: | if isinstance(item, str) and item in self.field_arrays: | ||||
return self.field_arrays[item] | return self.field_arrays[item] | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
self.__dict__ = state | self.__dict__ = state | ||||
def __getstate__(self): | def __getstate__(self): | ||||
return self.__dict__ | return self.__dict__ | ||||
def __len__(self): | def __len__(self): | ||||
"""Fetch the length of the dataset. | """Fetch the length of the dataset. | ||||
@@ -420,16 +421,10 @@ class DataSet(object): | |||||
return 0 | return 0 | ||||
field = iter(self.field_arrays.values()).__next__() | field = iter(self.field_arrays.values()).__next__() | ||||
return len(field) | return len(field) | ||||
def __inner_repr__(self): | |||||
if len(self) < 20: | |||||
return ",\n".join([ins.__repr__() for ins in self]) | |||||
else: | |||||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||||
def __repr__(self): | def __repr__(self): | ||||
return "DataSet(" + self.__inner_repr__() + ")" | |||||
return str(pretty_table_printer(self)) | |||||
def append(self, instance): | def append(self, instance): | ||||
""" | """ | ||||
将一个instance对象append到DataSet后面。 | 将一个instance对象append到DataSet后面。 | ||||
@@ -454,7 +449,7 @@ class DataSet(object): | |||||
except AppendToTargetOrInputException as e: | except AppendToTargetOrInputException as e: | ||||
logger.error(f"Cannot append to field:{name}.") | logger.error(f"Cannot append to field:{name}.") | ||||
raise e | raise e | ||||
def add_fieldarray(self, field_name, fieldarray): | def add_fieldarray(self, field_name, fieldarray): | ||||
""" | """ | ||||
将fieldarray添加到DataSet中. | 将fieldarray添加到DataSet中. | ||||
@@ -469,7 +464,7 @@ class DataSet(object): | |||||
raise RuntimeError(f"The field to add must have the same size as dataset. " | raise RuntimeError(f"The field to add must have the same size as dataset. " | ||||
f"Dataset size {len(self)} != field size {len(fieldarray)}") | f"Dataset size {len(self)} != field size {len(fieldarray)}") | ||||
self.field_arrays[field_name] = fieldarray | self.field_arrays[field_name] = fieldarray | ||||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | ||||
""" | """ | ||||
新增一个field | 新增一个field | ||||
@@ -481,14 +476,14 @@ class DataSet(object): | |||||
:param bool is_target: 新加入的field是否是target | :param bool is_target: 新加入的field是否是target | ||||
:param bool ignore_type: 是否忽略对新加入的field的类型检查 | :param bool ignore_type: 是否忽略对新加入的field的类型检查 | ||||
""" | """ | ||||
if len(self.field_arrays) != 0: | if len(self.field_arrays) != 0: | ||||
if len(self) != len(fields): | if len(self) != len(fields): | ||||
raise RuntimeError(f"The field to add must have the same size as dataset. " | raise RuntimeError(f"The field to add must have the same size as dataset. " | ||||
f"Dataset size {len(self)} != field size {len(fields)}") | f"Dataset size {len(self)} != field size {len(fields)}") | ||||
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, | self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, | ||||
padder=padder, ignore_type=ignore_type) | padder=padder, ignore_type=ignore_type) | ||||
def delete_instance(self, index): | def delete_instance(self, index): | ||||
""" | """ | ||||
删除第index个instance | 删除第index个instance | ||||
@@ -504,7 +499,7 @@ class DataSet(object): | |||||
for field in self.field_arrays.values(): | for field in self.field_arrays.values(): | ||||
field.pop(index) | field.pop(index) | ||||
return self | return self | ||||
def delete_field(self, field_name): | def delete_field(self, field_name): | ||||
""" | """ | ||||
删除名为field_name的field | 删除名为field_name的field | ||||
@@ -538,7 +533,7 @@ class DataSet(object): | |||||
if isinstance(field_name, str): | if isinstance(field_name, str): | ||||
return field_name in self.field_arrays | return field_name in self.field_arrays | ||||
return False | return False | ||||
def get_field(self, field_name): | def get_field(self, field_name): | ||||
""" | """ | ||||
获取field_name这个field | 获取field_name这个field | ||||
@@ -549,7 +544,7 @@ class DataSet(object): | |||||
if field_name not in self.field_arrays: | if field_name not in self.field_arrays: | ||||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | raise KeyError("Field name {} not found in DataSet".format(field_name)) | ||||
return self.field_arrays[field_name] | return self.field_arrays[field_name] | ||||
def get_all_fields(self): | def get_all_fields(self): | ||||
""" | """ | ||||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | 返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | ||||
@@ -557,7 +552,7 @@ class DataSet(object): | |||||
:return dict: 返回如上所述的字典 | :return dict: 返回如上所述的字典 | ||||
""" | """ | ||||
return self.field_arrays | return self.field_arrays | ||||
def get_field_names(self) -> list: | def get_field_names(self) -> list: | ||||
""" | """ | ||||
返回一个list,包含所有 field 的名字 | 返回一个list,包含所有 field 的名字 | ||||
@@ -565,7 +560,7 @@ class DataSet(object): | |||||
:return list: 返回如上所述的列表 | :return list: 返回如上所述的列表 | ||||
""" | """ | ||||
return sorted(self.field_arrays.keys()) | return sorted(self.field_arrays.keys()) | ||||
def get_length(self): | def get_length(self): | ||||
""" | """ | ||||
获取DataSet的元素数量 | 获取DataSet的元素数量 | ||||
@@ -573,7 +568,7 @@ class DataSet(object): | |||||
:return: int: DataSet中Instance的个数。 | :return: int: DataSet中Instance的个数。 | ||||
""" | """ | ||||
return len(self) | return len(self) | ||||
def rename_field(self, field_name, new_field_name): | def rename_field(self, field_name, new_field_name): | ||||
""" | """ | ||||
将某个field重新命名. | 将某个field重新命名. | ||||
@@ -587,7 +582,7 @@ class DataSet(object): | |||||
else: | else: | ||||
raise KeyError("DataSet has no field named {}.".format(field_name)) | raise KeyError("DataSet has no field named {}.".format(field_name)) | ||||
return self | return self | ||||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | ||||
""" | """ | ||||
将field_names的field设置为target | 将field_names的field设置为target | ||||
@@ -614,7 +609,7 @@ class DataSet(object): | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
return self | return self | ||||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | ||||
""" | """ | ||||
将field_names的field设置为input:: | 将field_names的field设置为input:: | ||||
@@ -638,7 +633,7 @@ class DataSet(object): | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
return self | return self | ||||
def set_ignore_type(self, *field_names, flag=True): | def set_ignore_type(self, *field_names, flag=True): | ||||
""" | """ | ||||
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | ||||
@@ -655,7 +650,7 @@ class DataSet(object): | |||||
else: | else: | ||||
raise KeyError("{} is not a valid field name.".format(name)) | raise KeyError("{} is not a valid field name.".format(name)) | ||||
return self | return self | ||||
def set_padder(self, field_name, padder): | def set_padder(self, field_name, padder): | ||||
""" | """ | ||||
为field_name设置padder:: | 为field_name设置padder:: | ||||
@@ -671,7 +666,7 @@ class DataSet(object): | |||||
raise KeyError("There is no field named {}.".format(field_name)) | raise KeyError("There is no field named {}.".format(field_name)) | ||||
self.field_arrays[field_name].set_padder(padder) | self.field_arrays[field_name].set_padder(padder) | ||||
return self | return self | ||||
def set_pad_val(self, field_name, pad_val): | def set_pad_val(self, field_name, pad_val): | ||||
""" | """ | ||||
为某个field设置对应的pad_val. | 为某个field设置对应的pad_val. | ||||
@@ -683,7 +678,7 @@ class DataSet(object): | |||||
raise KeyError("There is no field named {}.".format(field_name)) | raise KeyError("There is no field named {}.".format(field_name)) | ||||
self.field_arrays[field_name].set_pad_val(pad_val) | self.field_arrays[field_name].set_pad_val(pad_val) | ||||
return self | return self | ||||
def get_input_name(self): | def get_input_name(self): | ||||
""" | """ | ||||
返回所有is_input被设置为True的field名称 | 返回所有is_input被设置为True的field名称 | ||||
@@ -691,7 +686,7 @@ class DataSet(object): | |||||
:return list: 里面的元素为被设置为input的field名称 | :return list: 里面的元素为被设置为input的field名称 | ||||
""" | """ | ||||
return [name for name, field in self.field_arrays.items() if field.is_input] | return [name for name, field in self.field_arrays.items() if field.is_input] | ||||
def get_target_name(self): | def get_target_name(self): | ||||
""" | """ | ||||
返回所有is_target被设置为True的field名称 | 返回所有is_target被设置为True的field名称 | ||||
@@ -699,7 +694,7 @@ class DataSet(object): | |||||
:return list: 里面的元素为被设置为target的field名称 | :return list: 里面的元素为被设置为target的field名称 | ||||
""" | """ | ||||
return [name for name, field in self.field_arrays.items() if field.is_target] | return [name for name, field in self.field_arrays.items() if field.is_target] | ||||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | def apply_field(self, func, field_name, new_field_name=None, **kwargs): | ||||
""" | """ | ||||
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | 将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | ||||
@@ -728,16 +723,16 @@ class DataSet(object): | |||||
results.append(func(ins[field_name])) | results.append(func(ins[field_name])) | ||||
except Exception as e: | except Exception as e: | ||||
if idx != -1: | if idx != -1: | ||||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) | |||||
raise e | raise e | ||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | ||||
if new_field_name is not None: | if new_field_name is not None: | ||||
self._add_apply_field(results, new_field_name, kwargs) | self._add_apply_field(results, new_field_name, kwargs) | ||||
return results | return results | ||||
def _add_apply_field(self, results, new_field_name, kwargs): | def _add_apply_field(self, results, new_field_name, kwargs): | ||||
""" | """ | ||||
将results作为加入到新的field中,field名称为new_field_name | 将results作为加入到新的field中,field名称为new_field_name | ||||
@@ -769,7 +764,7 @@ class DataSet(object): | |||||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | ||||
is_target=extra_param.get("is_target", None), | is_target=extra_param.get("is_target", None), | ||||
ignore_type=extra_param.get("ignore_type", False)) | ignore_type=extra_param.get("ignore_type", False)) | ||||
def apply(self, func, new_field_name=None, **kwargs): | def apply(self, func, new_field_name=None, **kwargs): | ||||
""" | """ | ||||
将DataSet中每个instance传入到func中,并获取它的返回值. | 将DataSet中每个instance传入到func中,并获取它的返回值. | ||||
@@ -801,13 +796,13 @@ class DataSet(object): | |||||
# results = [func(ins) for ins in self._inner_iter()] | # results = [func(ins) for ins in self._inner_iter()] | ||||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | ||||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | ||||
if new_field_name is not None: | if new_field_name is not None: | ||||
self._add_apply_field(results, new_field_name, kwargs) | self._add_apply_field(results, new_field_name, kwargs) | ||||
return results | return results | ||||
def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): | |||||
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | |||||
""" | """ | ||||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | ||||
@@ -844,7 +839,7 @@ class DataSet(object): | |||||
return dataset | return dataset | ||||
else: | else: | ||||
return DataSet() | return DataSet() | ||||
def split(self, ratio, shuffle=True): | def split(self, ratio, shuffle=True): | ||||
""" | """ | ||||
将DataSet按照ratio的比例拆分,返回两个DataSet | 将DataSet按照ratio的比例拆分,返回两个DataSet | ||||
@@ -870,9 +865,9 @@ class DataSet(object): | |||||
for field_name in self.field_arrays: | for field_name in self.field_arrays: | ||||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | ||||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | ||||
return train_set, dev_set | return train_set, dev_set | ||||
def save(self, path): | def save(self, path): | ||||
""" | """ | ||||
保存DataSet. | 保存DataSet. | ||||
@@ -881,7 +876,7 @@ class DataSet(object): | |||||
""" | """ | ||||
with open(path, 'wb') as f: | with open(path, 'wb') as f: | ||||
pickle.dump(self, f) | pickle.dump(self, f) | ||||
@staticmethod | @staticmethod | ||||
def load(path): | def load(path): | ||||
r""" | r""" | ||||
@@ -3,10 +3,13 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 | |||||
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"Instance" | "Instance" | ||||
] | ] | ||||
from .utils import pretty_table_printer | |||||
class Instance(object): | class Instance(object): | ||||
""" | """ | ||||
@@ -20,11 +23,11 @@ class Instance(object): | |||||
>>>ins.add_field("field_3", [3, 3, 3]) | >>>ins.add_field("field_3", [3, 3, 3]) | ||||
>>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) | >>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) | ||||
""" | """ | ||||
def __init__(self, **fields): | def __init__(self, **fields): | ||||
self.fields = fields | self.fields = fields | ||||
def add_field(self, field_name, field): | def add_field(self, field_name, field): | ||||
""" | """ | ||||
向Instance中增加一个field | 向Instance中增加一个field | ||||
@@ -41,18 +44,15 @@ class Instance(object): | |||||
:return: 一个迭代器 | :return: 一个迭代器 | ||||
""" | """ | ||||
return self.fields.items() | return self.fields.items() | ||||
def __getitem__(self, name): | def __getitem__(self, name): | ||||
if name in self.fields: | if name in self.fields: | ||||
return self.fields[name] | return self.fields[name] | ||||
else: | else: | ||||
raise KeyError("{} not found".format(name)) | raise KeyError("{} not found".format(name)) | ||||
def __setitem__(self, name, field): | def __setitem__(self, name, field): | ||||
return self.add_field(name, field) | return self.add_field(name, field) | ||||
def __repr__(self): | def __repr__(self): | ||||
s = '\'' | |||||
return "{" + ",\n".join( | |||||
"\'" + field_name + "\': " + str(self.fields[field_name]) + \ | |||||
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" | |||||
return str(pretty_table_printer(self)) |
@@ -1,6 +1,7 @@ | |||||
""" | """ | ||||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | ||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
"cache_results", | "cache_results", | ||||
"seq_len_to_mask", | "seq_len_to_mask", | ||||
@@ -12,12 +13,12 @@ import inspect | |||||
import os | import os | ||||
import warnings | import warnings | ||||
from collections import Counter, namedtuple | from collections import Counter, namedtuple | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from typing import List | from typing import List | ||||
from ._logger import logger | from ._logger import logger | ||||
from prettytable import PrettyTable | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -25,27 +26,27 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||||
class Option(dict): | class Option(dict): | ||||
"""a dict can treat keys as attributes""" | """a dict can treat keys as attributes""" | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
try: | try: | ||||
return self.__getitem__(item) | return self.__getitem__(item) | ||||
except KeyError: | except KeyError: | ||||
raise AttributeError(item) | raise AttributeError(item) | ||||
def __setattr__(self, key, value): | def __setattr__(self, key, value): | ||||
if key.startswith('__') and key.endswith('__'): | if key.startswith('__') and key.endswith('__'): | ||||
raise AttributeError(key) | raise AttributeError(key) | ||||
self.__setitem__(key, value) | self.__setitem__(key, value) | ||||
def __delattr__(self, item): | def __delattr__(self, item): | ||||
try: | try: | ||||
self.pop(item) | self.pop(item) | ||||
except KeyError: | except KeyError: | ||||
raise AttributeError(item) | raise AttributeError(item) | ||||
def __getstate__(self): | def __getstate__(self): | ||||
return self | return self | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
self.update(state) | self.update(state) | ||||
@@ -112,13 +113,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
:param int _verbose: 是否打印cache的信息。 | :param int _verbose: 是否打印cache的信息。 | ||||
:return: | :return: | ||||
""" | """ | ||||
def wrapper_(func): | def wrapper_(func): | ||||
signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
for key, _ in signature.parameters.items(): | for key, _ in signature.parameters.items(): | ||||
if key in ('_cache_fp', '_refresh', '_verbose'): | if key in ('_cache_fp', '_refresh', '_verbose'): | ||||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | ||||
def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
if '_cache_fp' in kwargs: | if '_cache_fp' in kwargs: | ||||
cache_filepath = kwargs.pop('_cache_fp') | cache_filepath = kwargs.pop('_cache_fp') | ||||
@@ -136,7 +137,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
else: | else: | ||||
verbose = _verbose | verbose = _verbose | ||||
refresh_flag = True | refresh_flag = True | ||||
if cache_filepath is not None and refresh is False: | if cache_filepath is not None and refresh is False: | ||||
# load data | # load data | ||||
if os.path.exists(cache_filepath): | if os.path.exists(cache_filepath): | ||||
@@ -145,7 +146,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
if verbose == 1: | if verbose == 1: | ||||
logger.info("Read cache from {}.".format(cache_filepath)) | logger.info("Read cache from {}.".format(cache_filepath)) | ||||
refresh_flag = False | refresh_flag = False | ||||
if refresh_flag: | if refresh_flag: | ||||
results = func(*args, **kwargs) | results = func(*args, **kwargs) | ||||
if cache_filepath is not None: | if cache_filepath is not None: | ||||
@@ -155,11 +156,11 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||||
with open(cache_filepath, 'wb') as f: | with open(cache_filepath, 'wb') as f: | ||||
_pickle.dump(results, f) | _pickle.dump(results, f) | ||||
logger.info("Save cache to {}.".format(cache_filepath)) | logger.info("Save cache to {}.".format(cache_filepath)) | ||||
return results | return results | ||||
return wrapper | return wrapper | ||||
return wrapper_ | return wrapper_ | ||||
@@ -187,6 +188,7 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||||
torch.save(model, model_path) | torch.save(model, model_path) | ||||
model.to(_model_device) | model.to(_model_device) | ||||
def _move_model_to_device(model, device): | def _move_model_to_device(model, device): | ||||
""" | """ | ||||
将model移动到device | 将model移动到device | ||||
@@ -211,7 +213,7 @@ def _move_model_to_device(model, device): | |||||
""" | """ | ||||
# if isinstance(model, torch.nn.parallel.DistributedDataParallel): | # if isinstance(model, torch.nn.parallel.DistributedDataParallel): | ||||
# raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | # raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | ||||
if device is None: | if device is None: | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
model.cuda() | model.cuda() | ||||
@@ -220,10 +222,10 @@ def _move_model_to_device(model, device): | |||||
if not torch.cuda.is_available() and ( | if not torch.cuda.is_available() and ( | ||||
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | ||||
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | ||||
if isinstance(device, int): | if isinstance(device, int): | ||||
assert device > -1, "device can only be non-negative integer" | assert device > -1, "device can only be non-negative integer" | ||||
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( | assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( | ||||
@@ -267,7 +269,7 @@ def _get_model_device(model): | |||||
""" | """ | ||||
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 | # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 | ||||
assert isinstance(model, nn.Module) | assert isinstance(model, nn.Module) | ||||
parameters = list(model.parameters()) | parameters = list(model.parameters()) | ||||
if len(parameters) == 0: | if len(parameters) == 0: | ||||
return None | return None | ||||
@@ -427,10 +429,10 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||||
""" | """ | ||||
if not torch.cuda.is_available(): | if not torch.cuda.is_available(): | ||||
return | return | ||||
if not isinstance(device, torch.device): | if not isinstance(device, torch.device): | ||||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | ||||
for arg in args: | for arg in args: | ||||
if isinstance(arg, dict): | if isinstance(arg, dict): | ||||
for key, value in arg.items(): | for key, value in arg.items(): | ||||
@@ -445,10 +447,10 @@ class _CheckError(Exception): | |||||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | _CheckError. Used in losses.LossBase, metrics.MetricBase. | ||||
""" | """ | ||||
def __init__(self, check_res: _CheckRes, func_signature: str): | def __init__(self, check_res: _CheckRes, func_signature: str): | ||||
errs = [f'Problems occurred when calling `{func_signature}`'] | errs = [f'Problems occurred when calling `{func_signature}`'] | ||||
if check_res.varargs: | if check_res.varargs: | ||||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | ||||
if check_res.missing: | if check_res.missing: | ||||
@@ -457,9 +459,9 @@ class _CheckError(Exception): | |||||
errs.append(f"\tduplicated param: {check_res.duplicated}") | errs.append(f"\tduplicated param: {check_res.duplicated}") | ||||
if check_res.unused: | if check_res.unused: | ||||
errs.append(f"\tunused param: {check_res.unused}") | errs.append(f"\tunused param: {check_res.unused}") | ||||
Exception.__init__(self, '\n'.join(errs)) | Exception.__init__(self, '\n'.join(errs)) | ||||
self.check_res = check_res | self.check_res = check_res | ||||
self.func_signature = func_signature | self.func_signature = func_signature | ||||
@@ -479,7 +481,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
# if check_res.varargs: | # if check_res.varargs: | ||||
# errs.append(f"\tvarargs: *{check_res.varargs}") | # errs.append(f"\tvarargs: *{check_res.varargs}") | ||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | ||||
if check_res.unused: | if check_res.unused: | ||||
for _unused in check_res.unused: | for _unused in check_res.unused: | ||||
if _unused in target_dict: | if _unused in target_dict: | ||||
@@ -490,7 +492,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
unuseds.append(f"\tunused field: {_unused_field}") | unuseds.append(f"\tunused field: {_unused_field}") | ||||
if _unused_param: | if _unused_param: | ||||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | ||||
module_name = func_signature.split('.')[0] | module_name = func_signature.split('.')[0] | ||||
if check_res.missing: | if check_res.missing: | ||||
errs.append(f"\tmissing param: {check_res.missing}") | errs.append(f"\tmissing param: {check_res.missing}") | ||||
@@ -511,7 +513,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
mapped_missing.append(_miss) | mapped_missing.append(_miss) | ||||
else: | else: | ||||
unmapped_missing.append(_miss) | unmapped_missing.append(_miss) | ||||
for _miss in mapped_missing + unmapped_missing: | for _miss in mapped_missing + unmapped_missing: | ||||
if _miss in dataset: | if _miss in dataset: | ||||
suggestions.append(f"Set `{_miss}` as target.") | suggestions.append(f"Set `{_miss}` as target.") | ||||
@@ -524,17 +526,17 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
else: | else: | ||||
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | _tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | ||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.duplicated: | if check_res.duplicated: | ||||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | errs.append(f"\tduplicated param: {check_res.duplicated}.") | ||||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | suggestions.append(f"Delete {check_res.duplicated} in the output of " | ||||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
elif check_level == STRICT_CHECK_LEVEL: | elif check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(unuseds) | errs.extend(unuseds) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | errs.insert(0, f'Problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
@@ -561,11 +563,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | def _check_forward_error(forward_func, batch_x, dataset, check_level): | ||||
check_res = _check_arg_dict_list(forward_func, batch_x) | check_res = _check_arg_dict_list(forward_func, batch_x) | ||||
func_signature = _get_func_signature(forward_func) | func_signature = _get_func_signature(forward_func) | ||||
errs = [] | errs = [] | ||||
suggestions = [] | suggestions = [] | ||||
_unused = [] | _unused = [] | ||||
# if check_res.varargs: | # if check_res.varargs: | ||||
# errs.append(f"\tvarargs: {check_res.varargs}") | # errs.append(f"\tvarargs: {check_res.varargs}") | ||||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | ||||
@@ -586,14 +588,14 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | ||||
# f"rename the field in `unused field:`." | # f"rename the field in `unused field:`." | ||||
suggestions.append(_tmp) | suggestions.append(_tmp) | ||||
if check_res.unused: | if check_res.unused: | ||||
_unused = [f"\tunused field: {check_res.unused}"] | _unused = [f"\tunused field: {check_res.unused}"] | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.extend(_unused) | errs.extend(_unused) | ||||
elif check_level == STRICT_CHECK_LEVEL: | elif check_level == STRICT_CHECK_LEVEL: | ||||
errs.extend(_unused) | errs.extend(_unused) | ||||
if len(errs) > 0: | if len(errs) > 0: | ||||
errs.insert(0, f'Problems occurred when calling {func_signature}') | errs.insert(0, f'Problems occurred when calling {func_signature}') | ||||
sugg_str = "" | sugg_str = "" | ||||
@@ -641,7 +643,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||||
max_len = int(max_len) if max_len else int(seq_len.max()) | max_len = int(max_len) if max_len else int(seq_len.max()) | ||||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | ||||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | ||||
elif isinstance(seq_len, torch.Tensor): | elif isinstance(seq_len, torch.Tensor): | ||||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | ||||
batch_size = seq_len.size(0) | batch_size = seq_len.size(0) | ||||
@@ -650,7 +652,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | ||||
else: | else: | ||||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | ||||
return mask | return mask | ||||
@@ -658,24 +660,25 @@ class _pseudo_tqdm: | |||||
""" | """ | ||||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | ||||
""" | """ | ||||
def __init__(self, **kwargs): | def __init__(self, **kwargs): | ||||
self.logger = logger | self.logger = logger | ||||
def write(self, info): | def write(self, info): | ||||
self.logger.info(info) | self.logger.info(info) | ||||
def set_postfix_str(self, info): | def set_postfix_str(self, info): | ||||
self.logger.info(info) | self.logger.info(info) | ||||
def __getattr__(self, item): | def __getattr__(self, item): | ||||
def pass_func(*args, **kwargs): | def pass_func(*args, **kwargs): | ||||
pass | pass | ||||
return pass_func | return pass_func | ||||
def __enter__(self): | def __enter__(self): | ||||
return self | return self | ||||
def __exit__(self, exc_type, exc_val, exc_tb): | def __exit__(self, exc_type, exc_val, exc_tb): | ||||
del self | del self | ||||
@@ -749,3 +752,56 @@ def get_seq_len(words, pad_value=0): | |||||
""" | """ | ||||
mask = words.ne(pad_value) | mask = words.ne(pad_value) | ||||
return mask.sum(dim=-1) | return mask.sum(dim=-1) | ||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||||
""" | |||||
:param dataset_or_ins: 传入一个dataSet或者instance | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||||
+-----------+-----------+-----------------+ | |||||
| field_1 | field_2 | field_3 | | |||||
+-----------+-----------+-----------------+ | |||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||||
+-----------+-----------+-----------------+ | |||||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||||
""" | |||||
x = PrettyTable() | |||||
try: | |||||
sz = os.get_terminal_size() | |||||
column = sz.columns | |||||
row = sz.lines | |||||
except OSError: | |||||
column = 144 | |||||
row = 11 | |||||
if type(dataset_or_ins).__name__ == "DataSet": | |||||
x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||||
c_size = len(x.field_names) | |||||
for ins in dataset_or_ins: | |||||
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||||
row -= 1 | |||||
if row < 0: | |||||
x.add_row(["..." for _ in range(c_size)]) | |||||
break | |||||
elif type(dataset_or_ins).__name__ == "Instance": | |||||
x.field_names = list(dataset_or_ins.fields.keys()) | |||||
c_size = len(x.field_names) | |||||
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||||
else: | |||||
raise Exception("only accept DataSet and Instance") | |||||
return x | |||||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||||
""" | |||||
:param string: 要被截断的字符串 | |||||
:param c: 命令行列数 | |||||
:param c_size: instance或dataset field数 | |||||
:param title: 列名 | |||||
:return: 对一个过长的列进行截断的结果 | |||||
""" | |||||
avg = max(int(c / c_size), len(title)) | |||||
string = str(string) | |||||
if len(string) > avg: | |||||
string = string[:(avg - 3)] + "..." | |||||
return string |