@@ -300,13 +300,14 @@ from .field import FieldArray | |||
from .field import SetInputOrTargetException | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .utils import pretty_table_printer | |||
class DataSet(object): | |||
""" | |||
fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` | |||
""" | |||
def __init__(self, data=None): | |||
""" | |||
@@ -326,26 +327,26 @@ class DataSet(object): | |||
for ins in data: | |||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||
self.append(ins) | |||
else: | |||
raise ValueError("data only be dict or list type.") | |||
def __contains__(self, item): | |||
return item in self.field_arrays | |||
def __iter__(self): | |||
def iter_func(): | |||
for idx in range(len(self)): | |||
yield self[idx] | |||
return iter_func() | |||
def _inner_iter(self): | |||
class Iter_ptr: | |||
def __init__(self, dataset, idx): | |||
self.dataset = dataset | |||
self.idx = idx | |||
def __getitem__(self, item): | |||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | |||
self.idx]) | |||
@@ -358,13 +359,13 @@ class DataSet(object): | |||
def __repr__(self): | |||
return self.dataset[self.idx].__repr__() | |||
def inner_iter_func(): | |||
for idx in range(len(self)): | |||
yield Iter_ptr(self, idx) | |||
return inner_iter_func() | |||
def __getitem__(self, idx): | |||
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||
@@ -397,20 +398,20 @@ class DataSet(object): | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __getattr__(self, item): | |||
# Not tested. Don't use !! | |||
if item == "field_arrays": | |||
raise AttributeError | |||
if isinstance(item, str) and item in self.field_arrays: | |||
return self.field_arrays[item] | |||
def __setstate__(self, state): | |||
self.__dict__ = state | |||
def __getstate__(self): | |||
return self.__dict__ | |||
def __len__(self): | |||
"""Fetch the length of the dataset. | |||
@@ -420,16 +421,10 @@ class DataSet(object): | |||
return 0 | |||
field = iter(self.field_arrays.values()).__next__() | |||
return len(field) | |||
def __inner_repr__(self): | |||
if len(self) < 20: | |||
return ",\n".join([ins.__repr__() for ins in self]) | |||
else: | |||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||
def __repr__(self): | |||
return "DataSet(" + self.__inner_repr__() + ")" | |||
return str(pretty_table_printer(self)) | |||
def append(self, instance): | |||
""" | |||
将一个instance对象append到DataSet后面。 | |||
@@ -454,7 +449,7 @@ class DataSet(object): | |||
except AppendToTargetOrInputException as e: | |||
logger.error(f"Cannot append to field:{name}.") | |||
raise e | |||
def add_fieldarray(self, field_name, fieldarray): | |||
""" | |||
将fieldarray添加到DataSet中. | |||
@@ -469,7 +464,7 @@ class DataSet(object): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fieldarray)}") | |||
self.field_arrays[field_name] = fieldarray | |||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | |||
""" | |||
新增一个field | |||
@@ -481,14 +476,14 @@ class DataSet(object): | |||
:param bool is_target: 新加入的field是否是target | |||
:param bool ignore_type: 是否忽略对新加入的field的类型检查 | |||
""" | |||
if len(self.field_arrays) != 0: | |||
if len(self) != len(fields): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fields)}") | |||
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, | |||
padder=padder, ignore_type=ignore_type) | |||
def delete_instance(self, index): | |||
""" | |||
删除第index个instance | |||
@@ -504,7 +499,7 @@ class DataSet(object): | |||
for field in self.field_arrays.values(): | |||
field.pop(index) | |||
return self | |||
def delete_field(self, field_name): | |||
""" | |||
删除名为field_name的field | |||
@@ -538,7 +533,7 @@ class DataSet(object): | |||
if isinstance(field_name, str): | |||
return field_name in self.field_arrays | |||
return False | |||
def get_field(self, field_name): | |||
""" | |||
获取field_name这个field | |||
@@ -549,7 +544,7 @@ class DataSet(object): | |||
if field_name not in self.field_arrays: | |||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||
return self.field_arrays[field_name] | |||
def get_all_fields(self): | |||
""" | |||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||
@@ -557,7 +552,7 @@ class DataSet(object): | |||
:return dict: 返回如上所述的字典 | |||
""" | |||
return self.field_arrays | |||
def get_field_names(self) -> list: | |||
""" | |||
返回一个list,包含所有 field 的名字 | |||
@@ -565,7 +560,7 @@ class DataSet(object): | |||
:return list: 返回如上所述的列表 | |||
""" | |||
return sorted(self.field_arrays.keys()) | |||
def get_length(self): | |||
""" | |||
获取DataSet的元素数量 | |||
@@ -573,7 +568,7 @@ class DataSet(object): | |||
:return: int: DataSet中Instance的个数。 | |||
""" | |||
return len(self) | |||
def rename_field(self, field_name, new_field_name): | |||
""" | |||
将某个field重新命名. | |||
@@ -587,7 +582,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("DataSet has no field named {}.".format(field_name)) | |||
return self | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为target | |||
@@ -614,7 +609,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为input:: | |||
@@ -638,7 +633,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_ignore_type(self, *field_names, flag=True): | |||
""" | |||
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | |||
@@ -655,7 +650,7 @@ class DataSet(object): | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_padder(self, field_name, padder): | |||
""" | |||
为field_name设置padder:: | |||
@@ -671,7 +666,7 @@ class DataSet(object): | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_padder(padder) | |||
return self | |||
def set_pad_val(self, field_name, pad_val): | |||
""" | |||
为某个field设置对应的pad_val. | |||
@@ -683,7 +678,7 @@ class DataSet(object): | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_pad_val(pad_val) | |||
return self | |||
def get_input_name(self): | |||
""" | |||
返回所有is_input被设置为True的field名称 | |||
@@ -691,7 +686,7 @@ class DataSet(object): | |||
:return list: 里面的元素为被设置为input的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||
def get_target_name(self): | |||
""" | |||
返回所有is_target被设置为True的field名称 | |||
@@ -699,7 +694,7 @@ class DataSet(object): | |||
:return list: 里面的元素为被设置为target的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
""" | |||
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | |||
@@ -728,16 +723,16 @@ class DataSet(object): | |||
results.append(func(ins[field_name])) | |||
except Exception as e: | |||
if idx != -1: | |||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1)) | |||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) | |||
raise e | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def _add_apply_field(self, results, new_field_name, kwargs): | |||
""" | |||
将results作为加入到新的field中,field名称为new_field_name | |||
@@ -769,7 +764,7 @@ class DataSet(object): | |||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None), | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
""" | |||
将DataSet中每个instance传入到func中,并获取它的返回值. | |||
@@ -801,13 +796,13 @@ class DataSet(object): | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN): | |||
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | |||
""" | |||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | |||
@@ -844,7 +839,7 @@ class DataSet(object): | |||
return dataset | |||
else: | |||
return DataSet() | |||
def split(self, ratio, shuffle=True): | |||
""" | |||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||
@@ -870,9 +865,9 @@ class DataSet(object): | |||
for field_name in self.field_arrays: | |||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
return train_set, dev_set | |||
def save(self, path): | |||
""" | |||
保存DataSet. | |||
@@ -881,7 +876,7 @@ class DataSet(object): | |||
""" | |||
with open(path, 'wb') as f: | |||
pickle.dump(self, f) | |||
@staticmethod | |||
def load(path): | |||
r""" | |||
@@ -3,10 +3,13 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 | |||
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | |||
""" | |||
__all__ = [ | |||
"Instance" | |||
] | |||
from .utils import pretty_table_printer | |||
class Instance(object): | |||
""" | |||
@@ -20,11 +23,11 @@ class Instance(object): | |||
>>>ins.add_field("field_3", [3, 3, 3]) | |||
>>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) | |||
""" | |||
def __init__(self, **fields): | |||
self.fields = fields | |||
def add_field(self, field_name, field): | |||
""" | |||
向Instance中增加一个field | |||
@@ -41,18 +44,15 @@ class Instance(object): | |||
:return: 一个迭代器 | |||
""" | |||
return self.fields.items() | |||
def __getitem__(self, name): | |||
if name in self.fields: | |||
return self.fields[name] | |||
else: | |||
raise KeyError("{} not found".format(name)) | |||
def __setitem__(self, name, field): | |||
return self.add_field(name, field) | |||
def __repr__(self): | |||
s = '\'' | |||
return "{" + ",\n".join( | |||
"\'" + field_name + "\': " + str(self.fields[field_name]) + \ | |||
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" | |||
return str(pretty_table_printer(self)) |
@@ -1,6 +1,7 @@ | |||
""" | |||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||
""" | |||
__all__ = [ | |||
"cache_results", | |||
"seq_len_to_mask", | |||
@@ -12,12 +13,12 @@ import inspect | |||
import os | |||
import warnings | |||
from collections import Counter, namedtuple | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from typing import List | |||
from ._logger import logger | |||
from prettytable import PrettyTable | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
@@ -25,27 +26,27 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||
class Option(dict): | |||
"""a dict can treat keys as attributes""" | |||
def __getattr__(self, item): | |||
try: | |||
return self.__getitem__(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __setattr__(self, key, value): | |||
if key.startswith('__') and key.endswith('__'): | |||
raise AttributeError(key) | |||
self.__setitem__(key, value) | |||
def __delattr__(self, item): | |||
try: | |||
self.pop(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __getstate__(self): | |||
return self | |||
def __setstate__(self, state): | |||
self.update(state) | |||
@@ -112,13 +113,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
:param int _verbose: 是否打印cache的信息。 | |||
:return: | |||
""" | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
if key in ('_cache_fp', '_refresh', '_verbose'): | |||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
def wrapper(*args, **kwargs): | |||
if '_cache_fp' in kwargs: | |||
cache_filepath = kwargs.pop('_cache_fp') | |||
@@ -136,7 +137,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
else: | |||
verbose = _verbose | |||
refresh_flag = True | |||
if cache_filepath is not None and refresh is False: | |||
# load data | |||
if os.path.exists(cache_filepath): | |||
@@ -145,7 +146,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
if verbose == 1: | |||
logger.info("Read cache from {}.".format(cache_filepath)) | |||
refresh_flag = False | |||
if refresh_flag: | |||
results = func(*args, **kwargs) | |||
if cache_filepath is not None: | |||
@@ -155,11 +156,11 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
with open(cache_filepath, 'wb') as f: | |||
_pickle.dump(results, f) | |||
logger.info("Save cache to {}.".format(cache_filepath)) | |||
return results | |||
return wrapper | |||
return wrapper_ | |||
@@ -187,6 +188,7 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||
torch.save(model, model_path) | |||
model.to(_model_device) | |||
def _move_model_to_device(model, device): | |||
""" | |||
将model移动到device | |||
@@ -211,7 +213,7 @@ def _move_model_to_device(model, device): | |||
""" | |||
# if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
# raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | |||
if device is None: | |||
if isinstance(model, torch.nn.DataParallel): | |||
model.cuda() | |||
@@ -220,10 +222,10 @@ def _move_model_to_device(model, device): | |||
if not torch.cuda.is_available() and ( | |||
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | |||
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | |||
if isinstance(model, torch.nn.DataParallel): | |||
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | |||
if isinstance(device, int): | |||
assert device > -1, "device can only be non-negative integer" | |||
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( | |||
@@ -267,7 +269,7 @@ def _get_model_device(model): | |||
""" | |||
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 | |||
assert isinstance(model, nn.Module) | |||
parameters = list(model.parameters()) | |||
if len(parameters) == 0: | |||
return None | |||
@@ -427,10 +429,10 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||
""" | |||
if not torch.cuda.is_available(): | |||
return | |||
if not isinstance(device, torch.device): | |||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | |||
for arg in args: | |||
if isinstance(arg, dict): | |||
for key, value in arg.items(): | |||
@@ -445,10 +447,10 @@ class _CheckError(Exception): | |||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
def __init__(self, check_res: _CheckRes, func_signature: str): | |||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
if check_res.missing: | |||
@@ -457,9 +459,9 @@ class _CheckError(Exception): | |||
errs.append(f"\tduplicated param: {check_res.duplicated}") | |||
if check_res.unused: | |||
errs.append(f"\tunused param: {check_res.unused}") | |||
Exception.__init__(self, '\n'.join(errs)) | |||
self.check_res = check_res | |||
self.func_signature = func_signature | |||
@@ -479,7 +481,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
# if check_res.varargs: | |||
# errs.append(f"\tvarargs: *{check_res.varargs}") | |||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||
if check_res.unused: | |||
for _unused in check_res.unused: | |||
if _unused in target_dict: | |||
@@ -490,7 +492,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
unuseds.append(f"\tunused field: {_unused_field}") | |||
if _unused_param: | |||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||
module_name = func_signature.split('.')[0] | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}") | |||
@@ -511,7 +513,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
mapped_missing.append(_miss) | |||
else: | |||
unmapped_missing.append(_miss) | |||
for _miss in mapped_missing + unmapped_missing: | |||
if _miss in dataset: | |||
suggestions.append(f"Set `{_miss}` as target.") | |||
@@ -524,17 +526,17 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
else: | |||
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
suggestions.append(_tmp) | |||
if check_res.duplicated: | |||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | |||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | |||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||
if len(errs) > 0: | |||
errs.extend(unuseds) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(unuseds) | |||
if len(errs) > 0: | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
@@ -561,11 +563,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||
func_signature = _get_func_signature(forward_func) | |||
errs = [] | |||
suggestions = [] | |||
_unused = [] | |||
# if check_res.varargs: | |||
# errs.append(f"\tvarargs: {check_res.varargs}") | |||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||
@@ -586,14 +588,14 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||
# f"rename the field in `unused field:`." | |||
suggestions.append(_tmp) | |||
if check_res.unused: | |||
_unused = [f"\tunused field: {check_res.unused}"] | |||
if len(errs) > 0: | |||
errs.extend(_unused) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(_unused) | |||
if len(errs) > 0: | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
@@ -641,7 +643,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||
max_len = int(max_len) if max_len else int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
elif isinstance(seq_len, torch.Tensor): | |||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||
batch_size = seq_len.size(0) | |||
@@ -650,7 +652,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||
else: | |||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | |||
return mask | |||
@@ -658,24 +660,25 @@ class _pseudo_tqdm: | |||
""" | |||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||
""" | |||
def __init__(self, **kwargs): | |||
self.logger = logger | |||
def write(self, info): | |||
self.logger.info(info) | |||
def set_postfix_str(self, info): | |||
self.logger.info(info) | |||
def __getattr__(self, item): | |||
def pass_func(*args, **kwargs): | |||
pass | |||
return pass_func | |||
def __enter__(self): | |||
return self | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
del self | |||
@@ -749,3 +752,56 @@ def get_seq_len(words, pad_value=0): | |||
""" | |||
mask = words.ne(pad_value) | |||
return mask.sum(dim=-1) | |||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
""" | |||
:param dataset_or_ins: 传入一个dataSet或者instance | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||
+-----------+-----------+-----------------+ | |||
| field_1 | field_2 | field_3 | | |||
+-----------+-----------+-----------------+ | |||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||
+-----------+-----------+-----------------+ | |||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||
""" | |||
x = PrettyTable() | |||
try: | |||
sz = os.get_terminal_size() | |||
column = sz.columns | |||
row = sz.lines | |||
except OSError: | |||
column = 144 | |||
row = 11 | |||
if type(dataset_or_ins).__name__ == "DataSet": | |||
x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||
c_size = len(x.field_names) | |||
for ins in dataset_or_ins: | |||
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||
row -= 1 | |||
if row < 0: | |||
x.add_row(["..." for _ in range(c_size)]) | |||
break | |||
elif type(dataset_or_ins).__name__ == "Instance": | |||
x.field_names = list(dataset_or_ins.fields.keys()) | |||
c_size = len(x.field_names) | |||
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||
else: | |||
raise Exception("only accept DataSet and Instance") | |||
return x | |||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
""" | |||
:param string: 要被截断的字符串 | |||
:param c: 命令行列数 | |||
:param c_size: instance或dataset field数 | |||
:param title: 列名 | |||
:return: 对一个过长的列进行截断的结果 | |||
""" | |||
avg = max(int(c / c_size), len(title)) | |||
string = str(string) | |||
if len(string) > avg: | |||
string = string[:(avg - 3)] + "..." | |||
return string |