Browse Source

change the print format for dataset and instance

tags/v0.4.10
unknown 5 years ago
parent
commit
2fbc1d7851
3 changed files with 155 additions and 104 deletions
  1. +48
    -53
      fastNLP/core/dataset.py
  2. +10
    -10
      fastNLP/core/instance.py
  3. +97
    -41
      fastNLP/core/utils.py

+ 48
- 53
fastNLP/core/dataset.py View File

@@ -300,13 +300,14 @@ from .field import FieldArray
from .field import SetInputOrTargetException from .field import SetInputOrTargetException
from .instance import Instance from .instance import Instance
from .utils import _get_func_signature from .utils import _get_func_signature
from .utils import pretty_table_printer




class DataSet(object): class DataSet(object):
""" """
fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset`
""" """
def __init__(self, data=None): def __init__(self, data=None):
""" """
@@ -326,26 +327,26 @@ class DataSet(object):
for ins in data: for ins in data:
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
self.append(ins) self.append(ins)
else: else:
raise ValueError("data only be dict or list type.") raise ValueError("data only be dict or list type.")
def __contains__(self, item): def __contains__(self, item):
return item in self.field_arrays return item in self.field_arrays
def __iter__(self): def __iter__(self):
def iter_func(): def iter_func():
for idx in range(len(self)): for idx in range(len(self)):
yield self[idx] yield self[idx]
return iter_func() return iter_func()
def _inner_iter(self): def _inner_iter(self):
class Iter_ptr: class Iter_ptr:
def __init__(self, dataset, idx): def __init__(self, dataset, idx):
self.dataset = dataset self.dataset = dataset
self.idx = idx self.idx = idx
def __getitem__(self, item): def __getitem__(self, item):
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
self.idx]) self.idx])
@@ -358,13 +359,13 @@ class DataSet(object):


def __repr__(self): def __repr__(self):
return self.dataset[self.idx].__repr__() return self.dataset[self.idx].__repr__()
def inner_iter_func(): def inner_iter_func():
for idx in range(len(self)): for idx in range(len(self)):
yield Iter_ptr(self, idx) yield Iter_ptr(self, idx)
return inner_iter_func() return inner_iter_func()
def __getitem__(self, idx): def __getitem__(self, idx):
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 """给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。


@@ -397,20 +398,20 @@ class DataSet(object):
return dataset return dataset
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
def __getattr__(self, item): def __getattr__(self, item):
# Not tested. Don't use !! # Not tested. Don't use !!
if item == "field_arrays": if item == "field_arrays":
raise AttributeError raise AttributeError
if isinstance(item, str) and item in self.field_arrays: if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item] return self.field_arrays[item]
def __setstate__(self, state): def __setstate__(self, state):
self.__dict__ = state self.__dict__ = state
def __getstate__(self): def __getstate__(self):
return self.__dict__ return self.__dict__
def __len__(self): def __len__(self):
"""Fetch the length of the dataset. """Fetch the length of the dataset.


@@ -420,16 +421,10 @@ class DataSet(object):
return 0 return 0
field = iter(self.field_arrays.values()).__next__() field = iter(self.field_arrays.values()).__next__()
return len(field) return len(field)
def __inner_repr__(self):
if len(self) < 20:
return ",\n".join([ins.__repr__() for ins in self])
else:
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__()

def __repr__(self): def __repr__(self):
return "DataSet(" + self.__inner_repr__() + ")"
return str(pretty_table_printer(self))

def append(self, instance): def append(self, instance):
""" """
将一个instance对象append到DataSet后面。 将一个instance对象append到DataSet后面。
@@ -454,7 +449,7 @@ class DataSet(object):
except AppendToTargetOrInputException as e: except AppendToTargetOrInputException as e:
logger.error(f"Cannot append to field:{name}.") logger.error(f"Cannot append to field:{name}.")
raise e raise e
def add_fieldarray(self, field_name, fieldarray): def add_fieldarray(self, field_name, fieldarray):
""" """
将fieldarray添加到DataSet中. 将fieldarray添加到DataSet中.
@@ -469,7 +464,7 @@ class DataSet(object):
raise RuntimeError(f"The field to add must have the same size as dataset. " raise RuntimeError(f"The field to add must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fieldarray)}") f"Dataset size {len(self)} != field size {len(fieldarray)}")
self.field_arrays[field_name] = fieldarray self.field_arrays[field_name] = fieldarray
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False):
""" """
新增一个field 新增一个field
@@ -481,14 +476,14 @@ class DataSet(object):
:param bool is_target: 新加入的field是否是target :param bool is_target: 新加入的field是否是target
:param bool ignore_type: 是否忽略对新加入的field的类型检查 :param bool ignore_type: 是否忽略对新加入的field的类型检查
""" """
if len(self.field_arrays) != 0: if len(self.field_arrays) != 0:
if len(self) != len(fields): if len(self) != len(fields):
raise RuntimeError(f"The field to add must have the same size as dataset. " raise RuntimeError(f"The field to add must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fields)}") f"Dataset size {len(self)} != field size {len(fields)}")
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input,
padder=padder, ignore_type=ignore_type) padder=padder, ignore_type=ignore_type)
def delete_instance(self, index): def delete_instance(self, index):
""" """
删除第index个instance 删除第index个instance
@@ -504,7 +499,7 @@ class DataSet(object):
for field in self.field_arrays.values(): for field in self.field_arrays.values():
field.pop(index) field.pop(index)
return self return self
def delete_field(self, field_name): def delete_field(self, field_name):
""" """
删除名为field_name的field 删除名为field_name的field
@@ -538,7 +533,7 @@ class DataSet(object):
if isinstance(field_name, str): if isinstance(field_name, str):
return field_name in self.field_arrays return field_name in self.field_arrays
return False return False
def get_field(self, field_name): def get_field(self, field_name):
""" """
获取field_name这个field 获取field_name这个field
@@ -549,7 +544,7 @@ class DataSet(object):
if field_name not in self.field_arrays: if field_name not in self.field_arrays:
raise KeyError("Field name {} not found in DataSet".format(field_name)) raise KeyError("Field name {} not found in DataSet".format(field_name))
return self.field_arrays[field_name] return self.field_arrays[field_name]
def get_all_fields(self): def get_all_fields(self):
""" """
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` 返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray`
@@ -557,7 +552,7 @@ class DataSet(object):
:return dict: 返回如上所述的字典 :return dict: 返回如上所述的字典
""" """
return self.field_arrays return self.field_arrays
def get_field_names(self) -> list: def get_field_names(self) -> list:
""" """
返回一个list,包含所有 field 的名字 返回一个list,包含所有 field 的名字
@@ -565,7 +560,7 @@ class DataSet(object):
:return list: 返回如上所述的列表 :return list: 返回如上所述的列表
""" """
return sorted(self.field_arrays.keys()) return sorted(self.field_arrays.keys())
def get_length(self): def get_length(self):
""" """
获取DataSet的元素数量 获取DataSet的元素数量
@@ -573,7 +568,7 @@ class DataSet(object):
:return: int: DataSet中Instance的个数。 :return: int: DataSet中Instance的个数。
""" """
return len(self) return len(self)
def rename_field(self, field_name, new_field_name): def rename_field(self, field_name, new_field_name):
""" """
将某个field重新命名. 将某个field重新命名.
@@ -587,7 +582,7 @@ class DataSet(object):
else: else:
raise KeyError("DataSet has no field named {}.".format(field_name)) raise KeyError("DataSet has no field named {}.".format(field_name))
return self return self
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
""" """
将field_names的field设置为target 将field_names的field设置为target
@@ -614,7 +609,7 @@ class DataSet(object):
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
return self return self
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
""" """
将field_names的field设置为input:: 将field_names的field设置为input::
@@ -638,7 +633,7 @@ class DataSet(object):
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
return self return self
def set_ignore_type(self, *field_names, flag=True): def set_ignore_type(self, *field_names, flag=True):
""" """
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查,
@@ -655,7 +650,7 @@ class DataSet(object):
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
return self return self
def set_padder(self, field_name, padder): def set_padder(self, field_name, padder):
""" """
为field_name设置padder:: 为field_name设置padder::
@@ -671,7 +666,7 @@ class DataSet(object):
raise KeyError("There is no field named {}.".format(field_name)) raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_padder(padder) self.field_arrays[field_name].set_padder(padder)
return self return self
def set_pad_val(self, field_name, pad_val): def set_pad_val(self, field_name, pad_val):
""" """
为某个field设置对应的pad_val. 为某个field设置对应的pad_val.
@@ -683,7 +678,7 @@ class DataSet(object):
raise KeyError("There is no field named {}.".format(field_name)) raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_pad_val(pad_val) self.field_arrays[field_name].set_pad_val(pad_val)
return self return self
def get_input_name(self): def get_input_name(self):
""" """
返回所有is_input被设置为True的field名称 返回所有is_input被设置为True的field名称
@@ -691,7 +686,7 @@ class DataSet(object):
:return list: 里面的元素为被设置为input的field名称 :return list: 里面的元素为被设置为input的field名称
""" """
return [name for name, field in self.field_arrays.items() if field.is_input] return [name for name, field in self.field_arrays.items() if field.is_input]
def get_target_name(self): def get_target_name(self):
""" """
返回所有is_target被设置为True的field名称 返回所有is_target被设置为True的field名称
@@ -699,7 +694,7 @@ class DataSet(object):
:return list: 里面的元素为被设置为target的field名称 :return list: 里面的元素为被设置为target的field名称
""" """
return [name for name, field in self.field_arrays.items() if field.is_target] return [name for name, field in self.field_arrays.items() if field.is_target]
def apply_field(self, func, field_name, new_field_name=None, **kwargs): def apply_field(self, func, field_name, new_field_name=None, **kwargs):
""" """
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。
@@ -728,16 +723,16 @@ class DataSet(object):
results.append(func(ins[field_name])) results.append(func(ins[field_name]))
except Exception as e: except Exception as e:
if idx != -1: if idx != -1:
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1))
raise e raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
if new_field_name is not None: if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs) self._add_apply_field(results, new_field_name, kwargs)
return results return results
def _add_apply_field(self, results, new_field_name, kwargs): def _add_apply_field(self, results, new_field_name, kwargs):
""" """
将results作为加入到新的field中,field名称为new_field_name 将results作为加入到新的field中,field名称为new_field_name
@@ -769,7 +764,7 @@ class DataSet(object):
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None),
is_target=extra_param.get("is_target", None), is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False)) ignore_type=extra_param.get("ignore_type", False))
def apply(self, func, new_field_name=None, **kwargs): def apply(self, func, new_field_name=None, **kwargs):
""" """
将DataSet中每个instance传入到func中,并获取它的返回值. 将DataSet中每个instance传入到func中,并获取它的返回值.
@@ -801,13 +796,13 @@ class DataSet(object):
# results = [func(ins) for ins in self._inner_iter()] # results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
if new_field_name is not None: if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs) self._add_apply_field(results, new_field_name, kwargs)
return results return results


def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN):
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN):
""" """
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。


@@ -844,7 +839,7 @@ class DataSet(object):
return dataset return dataset
else: else:
return DataSet() return DataSet()
def split(self, ratio, shuffle=True): def split(self, ratio, shuffle=True):
""" """
将DataSet按照ratio的比例拆分,返回两个DataSet 将DataSet按照ratio的比例拆分,返回两个DataSet
@@ -870,9 +865,9 @@ class DataSet(object):
for field_name in self.field_arrays: for field_name in self.field_arrays:
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name])
return train_set, dev_set return train_set, dev_set
def save(self, path): def save(self, path):
""" """
保存DataSet. 保存DataSet.
@@ -881,7 +876,7 @@ class DataSet(object):
""" """
with open(path, 'wb') as f: with open(path, 'wb') as f:
pickle.dump(self, f) pickle.dump(self, f)
@staticmethod @staticmethod
def load(path): def load(path):
r""" r"""


+ 10
- 10
fastNLP/core/instance.py View File

@@ -3,10 +3,13 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格


""" """

__all__ = [ __all__ = [
"Instance" "Instance"
] ]


from .utils import pretty_table_printer



class Instance(object): class Instance(object):
""" """
@@ -20,11 +23,11 @@ class Instance(object):
>>>ins.add_field("field_3", [3, 3, 3]) >>>ins.add_field("field_3", [3, 3, 3])
>>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) >>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))})
""" """
def __init__(self, **fields): def __init__(self, **fields):
self.fields = fields self.fields = fields
def add_field(self, field_name, field): def add_field(self, field_name, field):
""" """
向Instance中增加一个field 向Instance中增加一个field
@@ -41,18 +44,15 @@ class Instance(object):
:return: 一个迭代器 :return: 一个迭代器
""" """
return self.fields.items() return self.fields.items()
def __getitem__(self, name): def __getitem__(self, name):
if name in self.fields: if name in self.fields:
return self.fields[name] return self.fields[name]
else: else:
raise KeyError("{} not found".format(name)) raise KeyError("{} not found".format(name))
def __setitem__(self, name, field): def __setitem__(self, name, field):
return self.add_field(name, field) return self.add_field(name, field)
def __repr__(self): def __repr__(self):
s = '\''
return "{" + ",\n".join(
"\'" + field_name + "\': " + str(self.fields[field_name]) + \
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}"
return str(pretty_table_printer(self))

+ 97
- 41
fastNLP/core/utils.py View File

@@ -1,6 +1,7 @@
""" """
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
""" """

__all__ = [ __all__ = [
"cache_results", "cache_results",
"seq_len_to_mask", "seq_len_to_mask",
@@ -12,12 +13,12 @@ import inspect
import os import os
import warnings import warnings
from collections import Counter, namedtuple from collections import Counter, namedtuple

import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List from typing import List
from ._logger import logger from ._logger import logger
from prettytable import PrettyTable


_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs']) 'varargs'])
@@ -25,27 +26,27 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require


class Option(dict): class Option(dict):
"""a dict can treat keys as attributes""" """a dict can treat keys as attributes"""
def __getattr__(self, item): def __getattr__(self, item):
try: try:
return self.__getitem__(item) return self.__getitem__(item)
except KeyError: except KeyError:
raise AttributeError(item) raise AttributeError(item)
def __setattr__(self, key, value): def __setattr__(self, key, value):
if key.startswith('__') and key.endswith('__'): if key.startswith('__') and key.endswith('__'):
raise AttributeError(key) raise AttributeError(key)
self.__setitem__(key, value) self.__setitem__(key, value)
def __delattr__(self, item): def __delattr__(self, item):
try: try:
self.pop(item) self.pop(item)
except KeyError: except KeyError:
raise AttributeError(item) raise AttributeError(item)
def __getstate__(self): def __getstate__(self):
return self return self
def __setstate__(self, state): def __setstate__(self, state):
self.update(state) self.update(state)


@@ -112,13 +113,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
:param int _verbose: 是否打印cache的信息。 :param int _verbose: 是否打印cache的信息。
:return: :return:
""" """
def wrapper_(func): def wrapper_(func):
signature = inspect.signature(func) signature = inspect.signature(func)
for key, _ in signature.parameters.items(): for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose'): if key in ('_cache_fp', '_refresh', '_verbose'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
if '_cache_fp' in kwargs: if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp') cache_filepath = kwargs.pop('_cache_fp')
@@ -136,7 +137,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
else: else:
verbose = _verbose verbose = _verbose
refresh_flag = True refresh_flag = True
if cache_filepath is not None and refresh is False: if cache_filepath is not None and refresh is False:
# load data # load data
if os.path.exists(cache_filepath): if os.path.exists(cache_filepath):
@@ -145,7 +146,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
if verbose == 1: if verbose == 1:
logger.info("Read cache from {}.".format(cache_filepath)) logger.info("Read cache from {}.".format(cache_filepath))
refresh_flag = False refresh_flag = False
if refresh_flag: if refresh_flag:
results = func(*args, **kwargs) results = func(*args, **kwargs)
if cache_filepath is not None: if cache_filepath is not None:
@@ -155,11 +156,11 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
with open(cache_filepath, 'wb') as f: with open(cache_filepath, 'wb') as f:
_pickle.dump(results, f) _pickle.dump(results, f)
logger.info("Save cache to {}.".format(cache_filepath)) logger.info("Save cache to {}.".format(cache_filepath))
return results return results
return wrapper return wrapper
return wrapper_ return wrapper_




@@ -187,6 +188,7 @@ def _save_model(model, model_name, save_dir, only_param=False):
torch.save(model, model_path) torch.save(model, model_path)
model.to(_model_device) model.to(_model_device)



def _move_model_to_device(model, device): def _move_model_to_device(model, device):
""" """
将model移动到device 将model移动到device
@@ -211,7 +213,7 @@ def _move_model_to_device(model, device):
""" """
# if isinstance(model, torch.nn.parallel.DistributedDataParallel): # if isinstance(model, torch.nn.parallel.DistributedDataParallel):
# raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") # raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.")
if device is None: if device is None:
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
model.cuda() model.cuda()
@@ -220,10 +222,10 @@ def _move_model_to_device(model, device):
if not torch.cuda.is_available() and ( if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")
if isinstance(model, torch.nn.DataParallel): if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")
if isinstance(device, int): if isinstance(device, int):
assert device > -1, "device can only be non-negative integer" assert device > -1, "device can only be non-negative integer"
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format(
@@ -267,7 +269,7 @@ def _get_model_device(model):
""" """
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡
assert isinstance(model, nn.Module) assert isinstance(model, nn.Module)
parameters = list(model.parameters()) parameters = list(model.parameters())
if len(parameters) == 0: if len(parameters) == 0:
return None return None
@@ -427,10 +429,10 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
""" """
if not torch.cuda.is_available(): if not torch.cuda.is_available():
return return
if not isinstance(device, torch.device): if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") raise TypeError(f"device must be `torch.device`, got `{type(device)}`")
for arg in args: for arg in args:
if isinstance(arg, dict): if isinstance(arg, dict):
for key, value in arg.items(): for key, value in arg.items():
@@ -445,10 +447,10 @@ class _CheckError(Exception):


_CheckError. Used in losses.LossBase, metrics.MetricBase. _CheckError. Used in losses.LossBase, metrics.MetricBase.
""" """
def __init__(self, check_res: _CheckRes, func_signature: str): def __init__(self, check_res: _CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`'] errs = [f'Problems occurred when calling `{func_signature}`']
if check_res.varargs: if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing: if check_res.missing:
@@ -457,9 +459,9 @@ class _CheckError(Exception):
errs.append(f"\tduplicated param: {check_res.duplicated}") errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused: if check_res.unused:
errs.append(f"\tunused param: {check_res.unused}") errs.append(f"\tunused param: {check_res.unused}")
Exception.__init__(self, '\n'.join(errs)) Exception.__init__(self, '\n'.join(errs))
self.check_res = check_res self.check_res = check_res
self.func_signature = func_signature self.func_signature = func_signature


@@ -479,7 +481,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
# if check_res.varargs: # if check_res.varargs:
# errs.append(f"\tvarargs: *{check_res.varargs}") # errs.append(f"\tvarargs: *{check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.unused: if check_res.unused:
for _unused in check_res.unused: for _unused in check_res.unused:
if _unused in target_dict: if _unused in target_dict:
@@ -490,7 +492,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
unuseds.append(f"\tunused field: {_unused_field}") unuseds.append(f"\tunused field: {_unused_field}")
if _unused_param: if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
module_name = func_signature.split('.')[0] module_name = func_signature.split('.')[0]
if check_res.missing: if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}") errs.append(f"\tmissing param: {check_res.missing}")
@@ -511,7 +513,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
mapped_missing.append(_miss) mapped_missing.append(_miss)
else: else:
unmapped_missing.append(_miss) unmapped_missing.append(_miss)
for _miss in mapped_missing + unmapped_missing: for _miss in mapped_missing + unmapped_missing:
if _miss in dataset: if _miss in dataset:
suggestions.append(f"Set `{_miss}` as target.") suggestions.append(f"Set `{_miss}` as target.")
@@ -524,17 +526,17 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
else: else:
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' _tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp) suggestions.append(_tmp)
if check_res.duplicated: if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.") errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of " suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
if len(errs) > 0: if len(errs) > 0:
errs.extend(unuseds) errs.extend(unuseds)
elif check_level == STRICT_CHECK_LEVEL: elif check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds) errs.extend(unuseds)
if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}') errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
@@ -561,11 +563,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
def _check_forward_error(forward_func, batch_x, dataset, check_level): def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x) check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = _get_func_signature(forward_func) func_signature = _get_func_signature(forward_func)
errs = [] errs = []
suggestions = [] suggestions = []
_unused = [] _unused = []
# if check_res.varargs: # if check_res.varargs:
# errs.append(f"\tvarargs: {check_res.varargs}") # errs.append(f"\tvarargs: {check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
@@ -586,14 +588,14 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`." # f"rename the field in `unused field:`."
suggestions.append(_tmp) suggestions.append(_tmp)
if check_res.unused: if check_res.unused:
_unused = [f"\tunused field: {check_res.unused}"] _unused = [f"\tunused field: {check_res.unused}"]
if len(errs) > 0: if len(errs) > 0:
errs.extend(_unused) errs.extend(_unused)
elif check_level == STRICT_CHECK_LEVEL: elif check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused) errs.extend(_unused)
if len(errs) > 0: if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}') errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = "" sugg_str = ""
@@ -641,7 +643,7 @@ def seq_len_to_mask(seq_len, max_len=None):
max_len = int(max_len) if max_len else int(seq_len.max()) max_len = int(max_len) if max_len else int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
elif isinstance(seq_len, torch.Tensor): elif isinstance(seq_len, torch.Tensor):
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
batch_size = seq_len.size(0) batch_size = seq_len.size(0)
@@ -650,7 +652,7 @@ def seq_len_to_mask(seq_len, max_len=None):
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
else: else:
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")
return mask return mask




@@ -658,24 +660,25 @@ class _pseudo_tqdm:
""" """
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
""" """

def __init__(self, **kwargs): def __init__(self, **kwargs):
self.logger = logger self.logger = logger
def write(self, info): def write(self, info):
self.logger.info(info) self.logger.info(info)
def set_postfix_str(self, info): def set_postfix_str(self, info):
self.logger.info(info) self.logger.info(info)
def __getattr__(self, item): def __getattr__(self, item):
def pass_func(*args, **kwargs): def pass_func(*args, **kwargs):
pass pass
return pass_func return pass_func
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
del self del self


@@ -749,3 +752,56 @@ def get_seq_len(words, pad_value=0):
""" """
mask = words.ne(pad_value) mask = words.ne(pad_value)
return mask.sum(dim=-1) return mask.sum(dim=-1)


def pretty_table_printer(dataset_or_ins) -> PrettyTable:
"""
:param dataset_or_ins: 传入一个dataSet或者instance
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
+-----------+-----------+-----------------+
| field_1 | field_2 | field_3 |
+-----------+-----------+-----------------+
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
+-----------+-----------+-----------------+
:return: 以 pretty table的形式返回根据terminal大小进行自动截断
"""
x = PrettyTable()
try:
sz = os.get_terminal_size()
column = sz.columns
row = sz.lines
except OSError:
column = 144
row = 11
if type(dataset_or_ins).__name__ == "DataSet":
x.field_names = list(dataset_or_ins.field_arrays.keys())
c_size = len(x.field_names)
for ins in dataset_or_ins:
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names])
row -= 1
if row < 0:
x.add_row(["..." for _ in range(c_size)])
break
elif type(dataset_or_ins).__name__ == "Instance":
x.field_names = list(dataset_or_ins.fields.keys())
c_size = len(x.field_names)
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names])

else:
raise Exception("only accept DataSet and Instance")
return x


def sub_column(string: str, c: int, c_size: int, title: str) -> str:
"""
:param string: 要被截断的字符串
:param c: 命令行列数
:param c_size: instance或dataset field数
:param title: 列名
:return: 对一个过长的列进行截断的结果
"""
avg = max(int(c / c_size), len(title))
string = str(string)
if len(string) > avg:
string = string[:(avg - 3)] + "..."
return string

Loading…
Cancel
Save