From 2fbc1d78518d6f75080da8bdab6ddaecd5d3cd87 Mon Sep 17 00:00:00 2001 From: unknown <793736331@qq.com> Date: Sat, 7 Sep 2019 15:22:43 +0800 Subject: [PATCH] change the print format for dataset and instance --- fastNLP/core/dataset.py | 101 ++++++++++++++-------------- fastNLP/core/instance.py | 20 +++--- fastNLP/core/utils.py | 138 +++++++++++++++++++++++++++------------ 3 files changed, 155 insertions(+), 104 deletions(-) diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index ebdc780f..36852b93 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -300,13 +300,14 @@ from .field import FieldArray from .field import SetInputOrTargetException from .instance import Instance from .utils import _get_func_signature +from .utils import pretty_table_printer class DataSet(object): """ fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` """ - + def __init__(self, data=None): """ @@ -326,26 +327,26 @@ class DataSet(object): for ins in data: assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) self.append(ins) - + else: raise ValueError("data only be dict or list type.") - + def __contains__(self, item): return item in self.field_arrays - + def __iter__(self): def iter_func(): for idx in range(len(self)): yield self[idx] - + return iter_func() - + def _inner_iter(self): class Iter_ptr: def __init__(self, dataset, idx): self.dataset = dataset self.idx = idx - + def __getitem__(self, item): assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ self.idx]) @@ -358,13 +359,13 @@ class DataSet(object): def __repr__(self): return self.dataset[self.idx].__repr__() - + def inner_iter_func(): for idx in range(len(self)): yield Iter_ptr(self, idx) - + return inner_iter_func() - + def __getitem__(self, idx): """给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 @@ -397,20 +398,20 @@ class DataSet(object): return dataset else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) - + def __getattr__(self, item): # Not tested. Don't use !! if item == "field_arrays": raise AttributeError if isinstance(item, str) and item in self.field_arrays: return self.field_arrays[item] - + def __setstate__(self, state): self.__dict__ = state - + def __getstate__(self): return self.__dict__ - + def __len__(self): """Fetch the length of the dataset. @@ -420,16 +421,10 @@ class DataSet(object): return 0 field = iter(self.field_arrays.values()).__next__() return len(field) - - def __inner_repr__(self): - if len(self) < 20: - return ",\n".join([ins.__repr__() for ins in self]) - else: - return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() - + def __repr__(self): - return "DataSet(" + self.__inner_repr__() + ")" - + return str(pretty_table_printer(self)) + def append(self, instance): """ 将一个instance对象append到DataSet后面。 @@ -454,7 +449,7 @@ class DataSet(object): except AppendToTargetOrInputException as e: logger.error(f"Cannot append to field:{name}.") raise e - + def add_fieldarray(self, field_name, fieldarray): """ 将fieldarray添加到DataSet中. @@ -469,7 +464,7 @@ class DataSet(object): raise RuntimeError(f"The field to add must have the same size as dataset. " f"Dataset size {len(self)} != field size {len(fieldarray)}") self.field_arrays[field_name] = fieldarray - + def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): """ 新增一个field @@ -481,14 +476,14 @@ class DataSet(object): :param bool is_target: 新加入的field是否是target :param bool ignore_type: 是否忽略对新加入的field的类型检查 """ - + if len(self.field_arrays) != 0: if len(self) != len(fields): raise RuntimeError(f"The field to add must have the same size as dataset. " f"Dataset size {len(self)} != field size {len(fields)}") self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, padder=padder, ignore_type=ignore_type) - + def delete_instance(self, index): """ 删除第index个instance @@ -504,7 +499,7 @@ class DataSet(object): for field in self.field_arrays.values(): field.pop(index) return self - + def delete_field(self, field_name): """ 删除名为field_name的field @@ -538,7 +533,7 @@ class DataSet(object): if isinstance(field_name, str): return field_name in self.field_arrays return False - + def get_field(self, field_name): """ 获取field_name这个field @@ -549,7 +544,7 @@ class DataSet(object): if field_name not in self.field_arrays: raise KeyError("Field name {} not found in DataSet".format(field_name)) return self.field_arrays[field_name] - + def get_all_fields(self): """ 返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` @@ -557,7 +552,7 @@ class DataSet(object): :return dict: 返回如上所述的字典 """ return self.field_arrays - + def get_field_names(self) -> list: """ 返回一个list,包含所有 field 的名字 @@ -565,7 +560,7 @@ class DataSet(object): :return list: 返回如上所述的列表 """ return sorted(self.field_arrays.keys()) - + def get_length(self): """ 获取DataSet的元素数量 @@ -573,7 +568,7 @@ class DataSet(object): :return: int: DataSet中Instance的个数。 """ return len(self) - + def rename_field(self, field_name, new_field_name): """ 将某个field重新命名. @@ -587,7 +582,7 @@ class DataSet(object): else: raise KeyError("DataSet has no field named {}.".format(field_name)) return self - + def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): """ 将field_names的field设置为target @@ -614,7 +609,7 @@ class DataSet(object): else: raise KeyError("{} is not a valid field name.".format(name)) return self - + def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): """ 将field_names的field设置为input:: @@ -638,7 +633,7 @@ class DataSet(object): else: raise KeyError("{} is not a valid field name.".format(name)) return self - + def set_ignore_type(self, *field_names, flag=True): """ 将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, @@ -655,7 +650,7 @@ class DataSet(object): else: raise KeyError("{} is not a valid field name.".format(name)) return self - + def set_padder(self, field_name, padder): """ 为field_name设置padder:: @@ -671,7 +666,7 @@ class DataSet(object): raise KeyError("There is no field named {}.".format(field_name)) self.field_arrays[field_name].set_padder(padder) return self - + def set_pad_val(self, field_name, pad_val): """ 为某个field设置对应的pad_val. @@ -683,7 +678,7 @@ class DataSet(object): raise KeyError("There is no field named {}.".format(field_name)) self.field_arrays[field_name].set_pad_val(pad_val) return self - + def get_input_name(self): """ 返回所有is_input被设置为True的field名称 @@ -691,7 +686,7 @@ class DataSet(object): :return list: 里面的元素为被设置为input的field名称 """ return [name for name, field in self.field_arrays.items() if field.is_input] - + def get_target_name(self): """ 返回所有is_target被设置为True的field名称 @@ -699,7 +694,7 @@ class DataSet(object): :return list: 里面的元素为被设置为target的field名称 """ return [name for name, field in self.field_arrays.items() if field.is_target] - + def apply_field(self, func, field_name, new_field_name=None, **kwargs): """ 将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 @@ -728,16 +723,16 @@ class DataSet(object): results.append(func(ins[field_name])) except Exception as e: if idx != -1: - logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) + logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) raise e if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(_get_func_signature(func=func))) - + if new_field_name is not None: self._add_apply_field(results, new_field_name, kwargs) - + return results - + def _add_apply_field(self, results, new_field_name, kwargs): """ 将results作为加入到新的field中,field名称为new_field_name @@ -769,7 +764,7 @@ class DataSet(object): self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), is_target=extra_param.get("is_target", None), ignore_type=extra_param.get("ignore_type", False)) - + def apply(self, func, new_field_name=None, **kwargs): """ 将DataSet中每个instance传入到func中,并获取它的返回值. @@ -801,13 +796,13 @@ class DataSet(object): # results = [func(ins) for ins in self._inner_iter()] if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(_get_func_signature(func=func))) - + if new_field_name is not None: self._add_apply_field(results, new_field_name, kwargs) - + return results - def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): + def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): """ 将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 @@ -844,7 +839,7 @@ class DataSet(object): return dataset else: return DataSet() - + def split(self, ratio, shuffle=True): """ 将DataSet按照ratio的比例拆分,返回两个DataSet @@ -870,9 +865,9 @@ class DataSet(object): for field_name in self.field_arrays: train_set.field_arrays[field_name].to(self.field_arrays[field_name]) dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) - + return train_set, dev_set - + def save(self, path): """ 保存DataSet. @@ -881,7 +876,7 @@ class DataSet(object): """ with open(path, 'wb') as f: pickle.dump(self, f) - + @staticmethod def load(path): r""" diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index 9460b5e4..3cf7ab45 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -3,10 +3,13 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 """ + __all__ = [ "Instance" ] +from .utils import pretty_table_printer + class Instance(object): """ @@ -20,11 +23,11 @@ class Instance(object): >>>ins.add_field("field_3", [3, 3, 3]) >>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) """ - + def __init__(self, **fields): - + self.fields = fields - + def add_field(self, field_name, field): """ 向Instance中增加一个field @@ -41,18 +44,15 @@ class Instance(object): :return: 一个迭代器 """ return self.fields.items() - + def __getitem__(self, name): if name in self.fields: return self.fields[name] else: raise KeyError("{} not found".format(name)) - + def __setitem__(self, name, field): return self.add_field(name, field) - + def __repr__(self): - s = '\'' - return "{" + ",\n".join( - "\'" + field_name + "\': " + str(self.fields[field_name]) + \ - f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" + return str(pretty_table_printer(self)) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index 814e0bd5..dd2afab7 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -1,6 +1,7 @@ """ utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 """ + __all__ = [ "cache_results", "seq_len_to_mask", @@ -12,12 +13,12 @@ import inspect import os import warnings from collections import Counter, namedtuple - import numpy as np import torch import torch.nn as nn from typing import List from ._logger import logger +from prettytable import PrettyTable _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', 'varargs']) @@ -25,27 +26,27 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require class Option(dict): """a dict can treat keys as attributes""" - + def __getattr__(self, item): try: return self.__getitem__(item) except KeyError: raise AttributeError(item) - + def __setattr__(self, key, value): if key.startswith('__') and key.endswith('__'): raise AttributeError(key) self.__setitem__(key, value) - + def __delattr__(self, item): try: self.pop(item) except KeyError: raise AttributeError(item) - + def __getstate__(self): return self - + def __setstate__(self, state): self.update(state) @@ -112,13 +113,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): :param int _verbose: 是否打印cache的信息。 :return: """ - + def wrapper_(func): signature = inspect.signature(func) for key, _ in signature.parameters.items(): if key in ('_cache_fp', '_refresh', '_verbose'): raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) - + def wrapper(*args, **kwargs): if '_cache_fp' in kwargs: cache_filepath = kwargs.pop('_cache_fp') @@ -136,7 +137,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): else: verbose = _verbose refresh_flag = True - + if cache_filepath is not None and refresh is False: # load data if os.path.exists(cache_filepath): @@ -145,7 +146,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): if verbose == 1: logger.info("Read cache from {}.".format(cache_filepath)) refresh_flag = False - + if refresh_flag: results = func(*args, **kwargs) if cache_filepath is not None: @@ -155,11 +156,11 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): with open(cache_filepath, 'wb') as f: _pickle.dump(results, f) logger.info("Save cache to {}.".format(cache_filepath)) - + return results - + return wrapper - + return wrapper_ @@ -187,6 +188,7 @@ def _save_model(model, model_name, save_dir, only_param=False): torch.save(model, model_path) model.to(_model_device) + def _move_model_to_device(model, device): """ 将model移动到device @@ -211,7 +213,7 @@ def _move_model_to_device(model, device): """ # if isinstance(model, torch.nn.parallel.DistributedDataParallel): # raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") - + if device is None: if isinstance(model, torch.nn.DataParallel): model.cuda() @@ -220,10 +222,10 @@ def _move_model_to_device(model, device): if not torch.cuda.is_available() and ( device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") - + if isinstance(model, torch.nn.DataParallel): raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") - + if isinstance(device, int): assert device > -1, "device can only be non-negative integer" assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( @@ -267,7 +269,7 @@ def _get_model_device(model): """ # TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 assert isinstance(model, nn.Module) - + parameters = list(model.parameters()) if len(parameters) == 0: return None @@ -427,10 +429,10 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): """ if not torch.cuda.is_available(): return - + if not isinstance(device, torch.device): raise TypeError(f"device must be `torch.device`, got `{type(device)}`") - + for arg in args: if isinstance(arg, dict): for key, value in arg.items(): @@ -445,10 +447,10 @@ class _CheckError(Exception): _CheckError. Used in losses.LossBase, metrics.MetricBase. """ - + def __init__(self, check_res: _CheckRes, func_signature: str): errs = [f'Problems occurred when calling `{func_signature}`'] - + if check_res.varargs: errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") if check_res.missing: @@ -457,9 +459,9 @@ class _CheckError(Exception): errs.append(f"\tduplicated param: {check_res.duplicated}") if check_res.unused: errs.append(f"\tunused param: {check_res.unused}") - + Exception.__init__(self, '\n'.join(errs)) - + self.check_res = check_res self.func_signature = func_signature @@ -479,7 +481,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re # if check_res.varargs: # errs.append(f"\tvarargs: *{check_res.varargs}") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") - + if check_res.unused: for _unused in check_res.unused: if _unused in target_dict: @@ -490,7 +492,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re unuseds.append(f"\tunused field: {_unused_field}") if _unused_param: unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward - + module_name = func_signature.split('.')[0] if check_res.missing: errs.append(f"\tmissing param: {check_res.missing}") @@ -511,7 +513,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re mapped_missing.append(_miss) else: unmapped_missing.append(_miss) - + for _miss in mapped_missing + unmapped_missing: if _miss in dataset: suggestions.append(f"Set `{_miss}` as target.") @@ -524,17 +526,17 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re else: _tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' suggestions.append(_tmp) - + if check_res.duplicated: errs.append(f"\tduplicated param: {check_res.duplicated}.") suggestions.append(f"Delete {check_res.duplicated} in the output of " f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") - + if len(errs) > 0: errs.extend(unuseds) elif check_level == STRICT_CHECK_LEVEL: errs.extend(unuseds) - + if len(errs) > 0: errs.insert(0, f'Problems occurred when calling {func_signature}') sugg_str = "" @@ -561,11 +563,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re def _check_forward_error(forward_func, batch_x, dataset, check_level): check_res = _check_arg_dict_list(forward_func, batch_x) func_signature = _get_func_signature(forward_func) - + errs = [] suggestions = [] _unused = [] - + # if check_res.varargs: # errs.append(f"\tvarargs: {check_res.varargs}") # suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") @@ -586,14 +588,14 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): # _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ # f"rename the field in `unused field:`." suggestions.append(_tmp) - + if check_res.unused: _unused = [f"\tunused field: {check_res.unused}"] if len(errs) > 0: errs.extend(_unused) elif check_level == STRICT_CHECK_LEVEL: errs.extend(_unused) - + if len(errs) > 0: errs.insert(0, f'Problems occurred when calling {func_signature}') sugg_str = "" @@ -641,7 +643,7 @@ def seq_len_to_mask(seq_len, max_len=None): max_len = int(max_len) if max_len else int(seq_len.max()) broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) mask = broad_cast_seq_len < seq_len.reshape(-1, 1) - + elif isinstance(seq_len, torch.Tensor): assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." batch_size = seq_len.size(0) @@ -650,7 +652,7 @@ def seq_len_to_mask(seq_len, max_len=None): mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) else: raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") - + return mask @@ -658,24 +660,25 @@ class _pseudo_tqdm: """ 当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 """ + def __init__(self, **kwargs): self.logger = logger - + def write(self, info): self.logger.info(info) - + def set_postfix_str(self, info): self.logger.info(info) - + def __getattr__(self, item): def pass_func(*args, **kwargs): pass - + return pass_func - + def __enter__(self): return self - + def __exit__(self, exc_type, exc_val, exc_tb): del self @@ -749,3 +752,56 @@ def get_seq_len(words, pad_value=0): """ mask = words.ne(pad_value) return mask.sum(dim=-1) + + +def pretty_table_printer(dataset_or_ins) -> PrettyTable: + """ + :param dataset_or_ins: 传入一个dataSet或者instance + ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) + +-----------+-----------+-----------------+ + | field_1 | field_2 | field_3 | + +-----------+-----------+-----------------+ + | [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | + +-----------+-----------+-----------------+ + :return: 以 pretty table的形式返回根据terminal大小进行自动截断 + """ + x = PrettyTable() + try: + sz = os.get_terminal_size() + column = sz.columns + row = sz.lines + except OSError: + column = 144 + row = 11 + if type(dataset_or_ins).__name__ == "DataSet": + x.field_names = list(dataset_or_ins.field_arrays.keys()) + c_size = len(x.field_names) + for ins in dataset_or_ins: + x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) + row -= 1 + if row < 0: + x.add_row(["..." for _ in range(c_size)]) + break + elif type(dataset_or_ins).__name__ == "Instance": + x.field_names = list(dataset_or_ins.fields.keys()) + c_size = len(x.field_names) + x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) + + else: + raise Exception("only accept DataSet and Instance") + return x + + +def sub_column(string: str, c: int, c_size: int, title: str) -> str: + """ + :param string: 要被截断的字符串 + :param c: 命令行列数 + :param c_size: instance或dataset field数 + :param title: 列名 + :return: 对一个过长的列进行截断的结果 + """ + avg = max(int(c / c_size), len(title)) + string = str(string) + if len(string) > avg: + string = string[:(avg - 3)] + "..." + return string