Browse Source

change the print format for dataset and instance

tags/v0.4.10
unknown 5 years ago
parent
commit
2fbc1d7851
3 changed files with 155 additions and 104 deletions
  1. +48
    -53
      fastNLP/core/dataset.py
  2. +10
    -10
      fastNLP/core/instance.py
  3. +97
    -41
      fastNLP/core/utils.py

+ 48
- 53
fastNLP/core/dataset.py View File

@@ -300,13 +300,14 @@ from .field import FieldArray
from .field import SetInputOrTargetException
from .instance import Instance
from .utils import _get_func_signature
from .utils import pretty_table_printer


class DataSet(object):
"""
fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset`
"""
def __init__(self, data=None):
"""
@@ -326,26 +327,26 @@ class DataSet(object):
for ins in data:
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins))
self.append(ins)
else:
raise ValueError("data only be dict or list type.")
def __contains__(self, item):
return item in self.field_arrays
def __iter__(self):
def iter_func():
for idx in range(len(self)):
yield self[idx]
return iter_func()
def _inner_iter(self):
class Iter_ptr:
def __init__(self, dataset, idx):
self.dataset = dataset
self.idx = idx
def __getitem__(self, item):
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[
self.idx])
@@ -358,13 +359,13 @@ class DataSet(object):

def __repr__(self):
return self.dataset[self.idx].__repr__()
def inner_iter_func():
for idx in range(len(self)):
yield Iter_ptr(self, idx)
return inner_iter_func()
def __getitem__(self, idx):
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。

@@ -397,20 +398,20 @@ class DataSet(object):
return dataset
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
def __getattr__(self, item):
# Not tested. Don't use !!
if item == "field_arrays":
raise AttributeError
if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item]
def __setstate__(self, state):
self.__dict__ = state
def __getstate__(self):
return self.__dict__
def __len__(self):
"""Fetch the length of the dataset.

@@ -420,16 +421,10 @@ class DataSet(object):
return 0
field = iter(self.field_arrays.values()).__next__()
return len(field)
def __inner_repr__(self):
if len(self) < 20:
return ",\n".join([ins.__repr__() for ins in self])
else:
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__()

def __repr__(self):
return "DataSet(" + self.__inner_repr__() + ")"
return str(pretty_table_printer(self))

def append(self, instance):
"""
将一个instance对象append到DataSet后面。
@@ -454,7 +449,7 @@ class DataSet(object):
except AppendToTargetOrInputException as e:
logger.error(f"Cannot append to field:{name}.")
raise e
def add_fieldarray(self, field_name, fieldarray):
"""
将fieldarray添加到DataSet中.
@@ -469,7 +464,7 @@ class DataSet(object):
raise RuntimeError(f"The field to add must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fieldarray)}")
self.field_arrays[field_name] = fieldarray
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False):
"""
新增一个field
@@ -481,14 +476,14 @@ class DataSet(object):
:param bool is_target: 新加入的field是否是target
:param bool ignore_type: 是否忽略对新加入的field的类型检查
"""
if len(self.field_arrays) != 0:
if len(self) != len(fields):
raise RuntimeError(f"The field to add must have the same size as dataset. "
f"Dataset size {len(self)} != field size {len(fields)}")
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input,
padder=padder, ignore_type=ignore_type)
def delete_instance(self, index):
"""
删除第index个instance
@@ -504,7 +499,7 @@ class DataSet(object):
for field in self.field_arrays.values():
field.pop(index)
return self
def delete_field(self, field_name):
"""
删除名为field_name的field
@@ -538,7 +533,7 @@ class DataSet(object):
if isinstance(field_name, str):
return field_name in self.field_arrays
return False
def get_field(self, field_name):
"""
获取field_name这个field
@@ -549,7 +544,7 @@ class DataSet(object):
if field_name not in self.field_arrays:
raise KeyError("Field name {} not found in DataSet".format(field_name))
return self.field_arrays[field_name]
def get_all_fields(self):
"""
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray`
@@ -557,7 +552,7 @@ class DataSet(object):
:return dict: 返回如上所述的字典
"""
return self.field_arrays
def get_field_names(self) -> list:
"""
返回一个list,包含所有 field 的名字
@@ -565,7 +560,7 @@ class DataSet(object):
:return list: 返回如上所述的列表
"""
return sorted(self.field_arrays.keys())
def get_length(self):
"""
获取DataSet的元素数量
@@ -573,7 +568,7 @@ class DataSet(object):
:return: int: DataSet中Instance的个数。
"""
return len(self)
def rename_field(self, field_name, new_field_name):
"""
将某个field重新命名.
@@ -587,7 +582,7 @@ class DataSet(object):
else:
raise KeyError("DataSet has no field named {}.".format(field_name))
return self
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
"""
将field_names的field设置为target
@@ -614,7 +609,7 @@ class DataSet(object):
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
"""
将field_names的field设置为input::
@@ -638,7 +633,7 @@ class DataSet(object):
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def set_ignore_type(self, *field_names, flag=True):
"""
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查,
@@ -655,7 +650,7 @@ class DataSet(object):
else:
raise KeyError("{} is not a valid field name.".format(name))
return self
def set_padder(self, field_name, padder):
"""
为field_name设置padder::
@@ -671,7 +666,7 @@ class DataSet(object):
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_padder(padder)
return self
def set_pad_val(self, field_name, pad_val):
"""
为某个field设置对应的pad_val.
@@ -683,7 +678,7 @@ class DataSet(object):
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_pad_val(pad_val)
return self
def get_input_name(self):
"""
返回所有is_input被设置为True的field名称
@@ -691,7 +686,7 @@ class DataSet(object):
:return list: 里面的元素为被设置为input的field名称
"""
return [name for name, field in self.field_arrays.items() if field.is_input]
def get_target_name(self):
"""
返回所有is_target被设置为True的field名称
@@ -699,7 +694,7 @@ class DataSet(object):
:return list: 里面的元素为被设置为target的field名称
"""
return [name for name, field in self.field_arrays.items() if field.is_target]
def apply_field(self, func, field_name, new_field_name=None, **kwargs):
"""
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。
@@ -728,16 +723,16 @@ class DataSet(object):
results.append(func(ins[field_name]))
except Exception as e:
if idx != -1:
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)
return results
def _add_apply_field(self, results, new_field_name, kwargs):
"""
将results作为加入到新的field中,field名称为new_field_name
@@ -769,7 +764,7 @@ class DataSet(object):
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None),
is_target=extra_param.get("is_target", None),
ignore_type=extra_param.get("ignore_type", False))
def apply(self, func, new_field_name=None, **kwargs):
"""
将DataSet中每个instance传入到func中,并获取它的返回值.
@@ -801,13 +796,13 @@ class DataSet(object):
# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
if new_field_name is not None:
self._add_apply_field(results, new_field_name, kwargs)
return results

def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN):
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN):
"""
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。

@@ -844,7 +839,7 @@ class DataSet(object):
return dataset
else:
return DataSet()
def split(self, ratio, shuffle=True):
"""
将DataSet按照ratio的比例拆分,返回两个DataSet
@@ -870,9 +865,9 @@ class DataSet(object):
for field_name in self.field_arrays:
train_set.field_arrays[field_name].to(self.field_arrays[field_name])
dev_set.field_arrays[field_name].to(self.field_arrays[field_name])
return train_set, dev_set
def save(self, path):
"""
保存DataSet.
@@ -881,7 +876,7 @@ class DataSet(object):
"""
with open(path, 'wb') as f:
pickle.dump(self, f)
@staticmethod
def load(path):
r"""


+ 10
- 10
fastNLP/core/instance.py View File

@@ -3,10 +3,13 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格

"""

__all__ = [
"Instance"
]

from .utils import pretty_table_printer


class Instance(object):
"""
@@ -20,11 +23,11 @@ class Instance(object):
>>>ins.add_field("field_3", [3, 3, 3])
>>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))})
"""
def __init__(self, **fields):
self.fields = fields
def add_field(self, field_name, field):
"""
向Instance中增加一个field
@@ -41,18 +44,15 @@ class Instance(object):
:return: 一个迭代器
"""
return self.fields.items()
def __getitem__(self, name):
if name in self.fields:
return self.fields[name]
else:
raise KeyError("{} not found".format(name))
def __setitem__(self, name, field):
return self.add_field(name, field)
def __repr__(self):
s = '\''
return "{" + ",\n".join(
"\'" + field_name + "\': " + str(self.fields[field_name]) + \
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}"
return str(pretty_table_printer(self))

+ 97
- 41
fastNLP/core/utils.py View File

@@ -1,6 +1,7 @@
"""
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
"""

__all__ = [
"cache_results",
"seq_len_to_mask",
@@ -12,12 +13,12 @@ import inspect
import os
import warnings
from collections import Counter, namedtuple

import numpy as np
import torch
import torch.nn as nn
from typing import List
from ._logger import logger
from prettytable import PrettyTable

_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])
@@ -25,27 +26,27 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require

class Option(dict):
"""a dict can treat keys as attributes"""
def __getattr__(self, item):
try:
return self.__getitem__(item)
except KeyError:
raise AttributeError(item)
def __setattr__(self, key, value):
if key.startswith('__') and key.endswith('__'):
raise AttributeError(key)
self.__setitem__(key, value)
def __delattr__(self, item):
try:
self.pop(item)
except KeyError:
raise AttributeError(item)
def __getstate__(self):
return self
def __setstate__(self, state):
self.update(state)

@@ -112,13 +113,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
:param int _verbose: 是否打印cache的信息。
:return:
"""
def wrapper_(func):
signature = inspect.signature(func)
for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))
def wrapper(*args, **kwargs):
if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp')
@@ -136,7 +137,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
else:
verbose = _verbose
refresh_flag = True
if cache_filepath is not None and refresh is False:
# load data
if os.path.exists(cache_filepath):
@@ -145,7 +146,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
if verbose == 1:
logger.info("Read cache from {}.".format(cache_filepath))
refresh_flag = False
if refresh_flag:
results = func(*args, **kwargs)
if cache_filepath is not None:
@@ -155,11 +156,11 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
with open(cache_filepath, 'wb') as f:
_pickle.dump(results, f)
logger.info("Save cache to {}.".format(cache_filepath))
return results
return wrapper
return wrapper_


@@ -187,6 +188,7 @@ def _save_model(model, model_name, save_dir, only_param=False):
torch.save(model, model_path)
model.to(_model_device)


def _move_model_to_device(model, device):
"""
将model移动到device
@@ -211,7 +213,7 @@ def _move_model_to_device(model, device):
"""
# if isinstance(model, torch.nn.parallel.DistributedDataParallel):
# raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.")
if device is None:
if isinstance(model, torch.nn.DataParallel):
model.cuda()
@@ -220,10 +222,10 @@ def _move_model_to_device(model, device):
if not torch.cuda.is_available() and (
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')):
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.")
if isinstance(model, torch.nn.DataParallel):
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.")
if isinstance(device, int):
assert device > -1, "device can only be non-negative integer"
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format(
@@ -267,7 +269,7 @@ def _get_model_device(model):
"""
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡
assert isinstance(model, nn.Module)
parameters = list(model.parameters())
if len(parameters) == 0:
return None
@@ -427,10 +429,10 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False):
"""
if not torch.cuda.is_available():
return
if not isinstance(device, torch.device):
raise TypeError(f"device must be `torch.device`, got `{type(device)}`")
for arg in args:
if isinstance(arg, dict):
for key, value in arg.items():
@@ -445,10 +447,10 @@ class _CheckError(Exception):

_CheckError. Used in losses.LossBase, metrics.MetricBase.
"""
def __init__(self, check_res: _CheckRes, func_signature: str):
errs = [f'Problems occurred when calling `{func_signature}`']
if check_res.varargs:
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)")
if check_res.missing:
@@ -457,9 +459,9 @@ class _CheckError(Exception):
errs.append(f"\tduplicated param: {check_res.duplicated}")
if check_res.unused:
errs.append(f"\tunused param: {check_res.unused}")
Exception.__init__(self, '\n'.join(errs))
self.check_res = check_res
self.func_signature = func_signature

@@ -479,7 +481,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
# if check_res.varargs:
# errs.append(f"\tvarargs: *{check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
if check_res.unused:
for _unused in check_res.unused:
if _unused in target_dict:
@@ -490,7 +492,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
unuseds.append(f"\tunused field: {_unused_field}")
if _unused_param:
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward
module_name = func_signature.split('.')[0]
if check_res.missing:
errs.append(f"\tmissing param: {check_res.missing}")
@@ -511,7 +513,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
mapped_missing.append(_miss)
else:
unmapped_missing.append(_miss)
for _miss in mapped_missing + unmapped_missing:
if _miss in dataset:
suggestions.append(f"Set `{_miss}` as target.")
@@ -524,17 +526,17 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
else:
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.'
suggestions.append(_tmp)
if check_res.duplicated:
errs.append(f"\tduplicated param: {check_res.duplicated}.")
suggestions.append(f"Delete {check_res.duplicated} in the output of "
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ")
if len(errs) > 0:
errs.extend(unuseds)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(unuseds)
if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = ""
@@ -561,11 +563,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re
def _check_forward_error(forward_func, batch_x, dataset, check_level):
check_res = _check_arg_dict_list(forward_func, batch_x)
func_signature = _get_func_signature(forward_func)
errs = []
suggestions = []
_unused = []
# if check_res.varargs:
# errs.append(f"\tvarargs: {check_res.varargs}")
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.")
@@ -586,14 +588,14 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level):
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \
# f"rename the field in `unused field:`."
suggestions.append(_tmp)
if check_res.unused:
_unused = [f"\tunused field: {check_res.unused}"]
if len(errs) > 0:
errs.extend(_unused)
elif check_level == STRICT_CHECK_LEVEL:
errs.extend(_unused)
if len(errs) > 0:
errs.insert(0, f'Problems occurred when calling {func_signature}')
sugg_str = ""
@@ -641,7 +643,7 @@ def seq_len_to_mask(seq_len, max_len=None):
max_len = int(max_len) if max_len else int(seq_len.max())
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1))
mask = broad_cast_seq_len < seq_len.reshape(-1, 1)
elif isinstance(seq_len, torch.Tensor):
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}."
batch_size = seq_len.size(0)
@@ -650,7 +652,7 @@ def seq_len_to_mask(seq_len, max_len=None):
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1))
else:
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.")
return mask


@@ -658,24 +660,25 @@ class _pseudo_tqdm:
"""
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
"""

def __init__(self, **kwargs):
self.logger = logger
def write(self, info):
self.logger.info(info)
def set_postfix_str(self, info):
self.logger.info(info)
def __getattr__(self, item):
def pass_func(*args, **kwargs):
pass
return pass_func
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
del self

@@ -749,3 +752,56 @@ def get_seq_len(words, pad_value=0):
"""
mask = words.ne(pad_value)
return mask.sum(dim=-1)


def pretty_table_printer(dataset_or_ins) -> PrettyTable:
"""
:param dataset_or_ins: 传入一个dataSet或者instance
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"])
+-----------+-----------+-----------------+
| field_1 | field_2 | field_3 |
+-----------+-----------+-----------------+
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] |
+-----------+-----------+-----------------+
:return: 以 pretty table的形式返回根据terminal大小进行自动截断
"""
x = PrettyTable()
try:
sz = os.get_terminal_size()
column = sz.columns
row = sz.lines
except OSError:
column = 144
row = 11
if type(dataset_or_ins).__name__ == "DataSet":
x.field_names = list(dataset_or_ins.field_arrays.keys())
c_size = len(x.field_names)
for ins in dataset_or_ins:
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names])
row -= 1
if row < 0:
x.add_row(["..." for _ in range(c_size)])
break
elif type(dataset_or_ins).__name__ == "Instance":
x.field_names = list(dataset_or_ins.fields.keys())
c_size = len(x.field_names)
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names])

else:
raise Exception("only accept DataSet and Instance")
return x


def sub_column(string: str, c: int, c_size: int, title: str) -> str:
"""
:param string: 要被截断的字符串
:param c: 命令行列数
:param c_size: instance或dataset field数
:param title: 列名
:return: 对一个过长的列进行截断的结果
"""
avg = max(int(c / c_size), len(title))
string = str(string)
if len(string) > avg:
string = string[:(avg - 3)] + "..."
return string

Loading…
Cancel
Save