@@ -0,0 +1,5 @@ | |||
__all__ = [ | |||
'AutoCollator', | |||
'Collator' | |||
] | |||
from .collator import AutoCollator, Collator |
@@ -0,0 +1,379 @@ | |||
__all__ = [ | |||
'AutoCollator', | |||
'Collator', | |||
] | |||
from abc import ABCMeta, abstractmethod | |||
from typing import Any, Dict, List, Callable, Union | |||
from numbers import Number | |||
import warnings | |||
import numpy as np | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
class ApplyResultException(Exception): | |||
def __init__(self, msg, index=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
class SetInputOrTargetException(Exception): | |||
def __init__(self, msg, index=None, field_name=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
self.field_name = field_name # 标示当前 field 的名称 | |||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||
r""" | |||
识别cell的类别与dimension的数量 | |||
numpy scalar type:https://docs.scipy.org/doc/numpy-1.13.0/reference/arrays.scalars.html | |||
:param cell: | |||
:param dim: | |||
:return: | |||
""" | |||
if isinstance(cell, (str, Number, np.bool_)): | |||
if hasattr(cell, 'dtype'): | |||
return cell.dtype.type, dim | |||
return type(cell), dim | |||
elif isinstance(cell, list): | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i, j in res]) | |||
dims = set([j for i, j in res]) | |||
if len(types) > 1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types) == 0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims) > 1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
elif isinstance(cell, torch.Tensor): | |||
return cell.dtype, cell.dim() + dim # 如果是 torch.mean 的结果是0 | |||
elif isinstance(cell, paddle.Tensor): | |||
return cell.dtype, cell.dim() + dim | |||
elif isinstance(cell, np.ndarray): | |||
if cell.dtype != np.dtype('O'): # 如果不是 object 的话说明是 well-formatted 的了 | |||
return cell.dtype.type, cell.ndim + dim # dtype.type 返回的会是 np.int32, np.float 等 | |||
# 否则需要继续往下 iterate | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i, j in res]) | |||
dims = set([j for i, j in res]) | |||
if len(types) > 1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types) == 0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims) > 1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
else: # 包含 tuple, set, dict 以及其它的类型 | |||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||
def _get_ds_type_dim(ds: dict): | |||
# 获取数据集第一行的 field 内部函数的类型和维度 | |||
field_dtype, field_dim = {}, {} | |||
for field_name, field_content in ds.items(): | |||
type_0, dim_0 = _get_ele_type_and_dim(field_content) | |||
field_dtype[field_name], field_dim[field_name] = type_0, dim_0 | |||
return field_dtype, field_dim | |||
class Collator(metaclass=ABCMeta): | |||
r""" | |||
辅助DataLoader管理collate_fn的类 | |||
""" | |||
def __init__(self): | |||
super(Collator, self).__init__() | |||
self.collate_fn = [] | |||
@abstractmethod | |||
def __call__(self, ins_lst: List) -> Any: | |||
raise NotImplementedError | |||
@abstractmethod | |||
def set_pad_val(self, *field_names: str, value=0): | |||
raise NotImplementedError | |||
class _MultiCollator: | |||
""" | |||
管理所有collator的容器, | |||
遵循覆盖原则,后加入的collate_fn会覆盖之前处理的数据。 | |||
""" | |||
def __init__(self, collate_fns: Union[Callable, List[Callable], None]): | |||
if collate_fns is None: | |||
collate_fns = [] | |||
if isinstance(collate_fns, Callable): | |||
collate_fns = [collate_fns] | |||
self._collators: list = collate_fns | |||
def __call__(self, ins_lst) -> Dict: | |||
out, list_out = {}, [] | |||
for idx, _collate_fn in enumerate(self._collators): | |||
res = _collate_fn(ins_lst) | |||
if isinstance(res, Dict): | |||
out.update(res) | |||
else: | |||
list_out.append(res) | |||
# else: | |||
# raise ValueError(f"the return type of collate_fn {idx} is {type(res)}, but require is dict") | |||
if len(out) > 0 and len(list_out) > 0: | |||
raise ValueError("the return of collate_fns is not the same, must be dict or list") | |||
if len(list_out) == 1: | |||
list_out = list_out[-1] | |||
# print(list_out) | |||
return out if len(out) > 0 else list_out | |||
def get_collators(self): | |||
return self._collators | |||
def add_collator(self, collator: Callable): | |||
self._collators.append(collator) | |||
def set_as_numpy(self, as_numpy: bool): | |||
""" | |||
存在AutoCollator时,as_numpy控制其返回值的类型 | |||
:param as_numpy: | |||
:return: | |||
""" | |||
for collator in self._collators: | |||
if isinstance(collator, AutoCollator): | |||
collator.set_as_numpy(as_numpy) | |||
return self | |||
def set_pad_val(self, *field_names, val=0): | |||
""" | |||
存在AutoCollator时,设置field_name的padding值 | |||
:param field_names: 数据集的field名 | |||
:param val: padding的值 | |||
:return: | |||
""" | |||
flag = True | |||
for collator in self._collators: | |||
if isinstance(collator, AutoCollator): | |||
collator.set_pad_val(*field_names, val=val) | |||
flag = False | |||
if flag: | |||
warnings.warn("AutoCollator is remove, set_padding is unavailable!!") | |||
return self | |||
def set_input(self, *field_names): | |||
""" | |||
设置AutoCollator需要的field_names,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
flag = True | |||
for collator in self._collators: | |||
if isinstance(collator, AutoCollator): | |||
collator.set_input(*field_names) | |||
flag = False | |||
if flag: | |||
warnings.warn("AutoCollator is remove, set_input is unavailable!!") | |||
return self | |||
class AutoCollator(Collator): | |||
def __init__(self, as_numpy: bool): | |||
super(AutoCollator, self).__init__() | |||
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 | |||
self.need_inputs = [] # 需要的 field name | |||
self.field_dtypes = None # 每列数据单元的 dtype 类型 | |||
self.field_dims = None # 每列数据单元维度 | |||
self.as_numpy = as_numpy | |||
def __call__(self, ins_lst: List[Dict]) -> dict: | |||
if len(self.need_inputs) == 0: | |||
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) | |||
# 第一种情况,设置了 set_input 的值 | |||
# 第二种情况, 根据数据的类型的判断是否 padding | |||
if self.field_dtypes is None and self.field_dims is None: | |||
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0]) | |||
pack_ins_lst, pad_ins_lst = {field_name: [] | |||
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} | |||
# 将 list 列表内数据按列名打包 | |||
for per_ins in ins_lst: | |||
for field_name, _field_content in per_ins.items(): | |||
if field_name in self.need_inputs: | |||
pack_ins_lst[field_name].append(_field_content) | |||
pad_field_kv = {field_name: 0 for field_name in self.need_inputs} | |||
pad_field_kv.update(self.pad_field_value) | |||
self.pad_field_value = pad_field_kv | |||
if len(self.pad_field_value.keys()) > 0: | |||
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 | |||
drop_field_names = [] | |||
for k, v in self.pad_field_value.items(): | |||
if v is None: | |||
drop_field_names.append(k) | |||
# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) | |||
for field_name in drop_field_names: | |||
field_array = pack_ins_lst.pop(field_name) | |||
pad_ins_lst[field_name] = np.array(field_array) | |||
for field_name, field_array in pack_ins_lst.items(): | |||
content = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
self.field_dims[field_name], | |||
self.pad_field_value[field_name], | |||
as_numpy=self.as_numpy) | |||
pad_ins_lst[field_name] = content | |||
# else: | |||
# # 取出每列的数据,根据类型判断是否能 pad | |||
# for field_name, field_array in pack_ins_lst.items(): | |||
# pad_field_array = pad_content(field_array, field_name, self.field_dtypes[field_name], | |||
# self.field_dims[field_name], | |||
# pad_val=0, as_numpy=self.as_numpy) | |||
# pad_ins_lst[field_name] = pad_field_array | |||
return pad_ins_lst | |||
def set_pad_val(self, *field_names, val=0): | |||
for field_name in field_names: | |||
self.pad_field_value[field_name] = val | |||
def set_as_numpy(self, as_numpy: bool): | |||
self.as_numpy = as_numpy | |||
def set_input(self, *field_names): | |||
for field_name in field_names: | |||
self.need_inputs.append(field_name) | |||
def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): | |||
if field_type: | |||
# 不处理, 返回 np.array 类型 | |||
if field_dim > 3: | |||
return np.array(content) | |||
# 元素类型为数值类型 np.int64, np.float64, int, float 等 | |||
if isinstance(field_type, type) and \ | |||
(issubclass(field_type, np.number) or issubclass(field_type, Number)): | |||
if field_dim == 0: | |||
array = np.array(content, dtype=field_type) | |||
elif field_dim == 1: | |||
max_len = max(map(len, content)) | |||
array = np.full((len(content), max_len), pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
array[i, :len(content_i)] = content_i | |||
elif field_dim == 2: | |||
max_len = max(map(len, content)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in content]) | |||
array = np.full((len(content), max_len, max_word_len), pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
for j, content_ii in enumerate(content_i): | |||
array[i, j, :len(content_ii)] = content_ii | |||
else: | |||
shape = np.shape(content) | |||
if len(shape) == 4: # 说明各 dimension 是相同的大小 | |||
array = np.array(content, dtype=field_type) | |||
else: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
if as_numpy is False: | |||
array = torch.tensor(array) | |||
return array | |||
# 元素类型为数值类型 torch.float 等 | |||
elif str(field_type).startswith('torch'): | |||
if field_dim == 0: | |||
tensor = torch.tensor(content).to(field_type) | |||
elif field_dim == 1: | |||
max_len = max(map(len, content)) | |||
tensor = torch.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||
elif field_dim == 2: | |||
max_len = max(map(len, content)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in content]) | |||
tensor = torch.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
for j, content_ii in enumerate(content_i): | |||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
else: | |||
shapes = set([np.shape(content_i) for content_i in content]) | |||
if len(shapes) > 1: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
shape = shapes.pop() | |||
if len(shape) == 3: | |||
tensor = torch.full([len(content)] + list(shape), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i] = content_i.clone().detach().to(field_type) | |||
else: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return tensor | |||
# TODO 增加jittor/paddle? | |||
elif str(field_type).startswith('paddle'): | |||
if field_dim == 0: | |||
tensor = paddle.Tensor(content).to(field_type) | |||
elif field_dim == 1: | |||
max_len = max(map(len, content)) | |||
tensor = paddle.full((len(content), max_len), fill_value=pad_val, dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||
elif field_dim == 2: | |||
max_len = max(map(len, content)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in content]) | |||
tensor = paddle.full((len(content), max_len, max_word_len), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
for j, content_ii in enumerate(content_i): | |||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
else: | |||
shapes = set([np.shape(content_i) for content_i in content]) | |||
if len(shapes) > 1: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
shape = shapes.pop() | |||
if len(shape) == 3: | |||
tensor = paddle.full([len(content)] + list(shape), fill_value=pad_val, | |||
dtype=field_type) | |||
for i, content_i in enumerate(content): | |||
tensor[i] = content_i.clone().detach().to(field_type) | |||
else: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return tensor | |||
else: | |||
return np.array(content) # 不进行任何操作 | |||
else: | |||
return np.array(content) |
@@ -0,0 +1,14 @@ | |||
__all__ = [ | |||
'MixDataLoader', | |||
'TorchDataLoader', | |||
'PaddleDataLoader', | |||
'JittorDataLoader', | |||
'prepare_jittor_dataloader', | |||
'prepare_paddle_dataloader', | |||
'prepare_torch_dataloader' | |||
] | |||
from .mix_dataloader import MixDataLoader | |||
from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader | |||
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader | |||
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader |
@@ -0,0 +1,7 @@ | |||
__all__ = [ | |||
'FDataLoader' | |||
] | |||
class FDataLoader: | |||
pass |
@@ -0,0 +1,7 @@ | |||
__all__ = [ | |||
"JittorDataLoader", | |||
'prepare_jittor_dataloader' | |||
] | |||
from .fdl import JittorDataLoader, prepare_jittor_dataloader |
@@ -0,0 +1,138 @@ | |||
__all__ = [ | |||
'JittorDataLoader', | |||
'prepare_jittor_dataloader' | |||
] | |||
from typing import Callable, Optional, List | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
from jittor.dataset.utils import collate_batch | |||
from jittor.dataset import Dataset | |||
else: | |||
from fastNLP.core.dataset import DataSet as Dataset | |||
from fastNLP.core.utils.jittor_utils import jittor_collate_wraps | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
class _JittorDataset(Dataset): | |||
""" | |||
对用户传的dataset进行封装,以便JittorDataLoader能够支持使用自定义的dataset使用jittor的dataset | |||
""" | |||
def __init__(self, dataset) -> None: | |||
super(_JittorDataset, self).__init__() | |||
self.dataset = dataset | |||
def __getitem__(self, item): | |||
return (item, self.dataset[item]) | |||
def __len__(self) -> int: | |||
return len(self.dataset) | |||
# def __getattr__(self, item): | |||
# # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用 | |||
# try: | |||
# self.dataset.__getattribute__(item) | |||
# except Exception as e: | |||
# raise e | |||
class JittorDataLoader: | |||
""" | |||
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | |||
""" | |||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | |||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | |||
collate_fn: Callable = None) -> None: | |||
""" | |||
:param dataset: 实现__getitem__和__len__的dataset | |||
:param batch_size: 批次大小 | |||
:param shuffle: 是否打乱数据集 | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||
:param buffer_size: | |||
:param stop_grad: | |||
:param keep_numpy_array: | |||
:param endless: | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||
""" | |||
# TODO 支持fastnlp dataset | |||
# TODO 验证支持replacesampler (以后完成) | |||
# 是否为 jittor 类型的 dataset | |||
if isinstance(dataset, FDataSet): | |||
collator = dataset.get_collator().set_as_numpy(as_numpy=True) | |||
else: | |||
collator = None | |||
self.dataset = _JittorDataset(dataset) | |||
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | |||
keep_numpy_array=keep_numpy_array, endless=endless) | |||
if isinstance(self.dataset.dataset, Dataset): | |||
self.dataset.dataset.set_attrs(batch_size=1) | |||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | |||
self.collate_fn = collate_fn | |||
if self.collate_fn is None: | |||
self.collate_fn = collate_batch | |||
self.auto_collator = collator | |||
self.cur_batch_indices = None | |||
def __iter__(self): | |||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | |||
if self.cur_batch_indices is None: | |||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(jittor_collate_wraps(self.collate_fn, | |||
self.auto_collator))) | |||
for indices, data in self.dataset.__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def __len__(self): | |||
if self.dataset.drop_last: | |||
return len(self.dataset) // self.dataset.batch_size | |||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
if self.auto_collator is None: | |||
self.auto_collator = AutoCollator(as_numpy=True) | |||
self.auto_collator.set_pad_val(*field_names, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
if self.auto_collator is None: | |||
self.auto_collator = AutoCollator(as_numpy=True) | |||
self.auto_collator.set_input(*field_names) | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
获取当前数据的idx | |||
:return: | |||
""" | |||
return self.cur_batch_indices | |||
def prepare_jittor_dataloader(): | |||
... |
@@ -0,0 +1,194 @@ | |||
__all__ = [ | |||
'MixDataLoader' | |||
] | |||
from typing import Optional, Callable, List, Union, Tuple, Dict, Sequence | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet, Instance | |||
from fastNLP.core.samplers import PollingSampler, MixSequentialSampler, DopedSampler | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
class _MixDataset: | |||
""" | |||
将所有数据集当成一个混合大数据集来对待,实现的__getitem__能区别每个数据idx | |||
""" | |||
def __init__(self, datasets: list = None) -> None: | |||
""" | |||
:param datasets: 数据集的列表 | |||
""" | |||
self.datasets = datasets | |||
# 记录每个数据集的长度索引, 以便根据idx定位数据集的位置 | |||
self.lens = [] | |||
index = 0 | |||
for item in self.datasets: | |||
index += len(item) | |||
self.lens.append(index) | |||
def __getitem__(self, idx: Union[int, List[int]]) -> Union[Tuple[Instance, int], Tuple[DataSet, int]]: | |||
""" | |||
:param idx: | |||
:return: | |||
""" | |||
if isinstance(idx, int): | |||
if idx >= self.lens[-1]: | |||
raise ValueError(f"idx: {idx} out of range") | |||
# 找到其属于哪个数据集,返回下标 | |||
ds_index = np.searchsorted(self.lens, idx, side='right') | |||
if ds_index > 0: | |||
idx -= self.lens[ds_index - 1] | |||
return self.datasets[ds_index][idx], ds_index | |||
elif isinstance(idx, list): | |||
# 一般一个list列表只能是属于一种数据的,否则会报错 | |||
dataset = DataSet() | |||
ds_index = 0 | |||
for i in idx: | |||
assert isinstance(i, int), "Only int index allowed." | |||
instance, ds_index = self[i] | |||
dataset.append(instance) | |||
return dataset, ds_index | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __len__(self) -> int: | |||
return self.lens[-1] | |||
class _MixCollateFn: | |||
""" | |||
存在多个auto_collate和多个collate_fn时候,对一个批次数据集应用哪个auto_collate和collate_fn的问题 | |||
""" | |||
def __init__(self, collate_fns: Optional[Union[List[Callable], Callable]] = None, | |||
auto_collators: Optional[List[Callable]] = None) -> None: | |||
if isinstance(collate_fns, Sequence): | |||
self.collate_fns = lambda idx, lst: collate_fns[idx](lst) | |||
elif callable(collate_fns): | |||
self.collate_fns = lambda idx, lst: collate_fns(lst) | |||
else: | |||
self.collate_fns = lambda idx, lst: lst | |||
self.collate_fns = collate_fns | |||
self.auto_collators = auto_collators | |||
def __call__(self, ins_list: List) -> Dict: | |||
""" | |||
调用一次该方法,我们将ins_list视为同一个数据集采样出来的,故ds_index只能为一种 | |||
:param ins_list: | |||
:return: | |||
""" | |||
_ins_list, _ds_index = [], 0 | |||
for ins, _ds_index in ins_list: | |||
_ins_list.append(ins) | |||
# auto_collate先处理 | |||
if self.auto_collators is not None: | |||
_ins_list = self.auto_collators[_ds_index](_ins_list) | |||
_ins_list = self.collate_fns(_ds_index, _ins_list) | |||
return _ins_list | |||
class MixDataLoader(DataLoader): | |||
""" | |||
针对一下三种情况提供的MixDataLoader: | |||
1. 给定datasets集合或者列表,顺序采样datasets,处理采样完首个dataset后取出第二个dataset,重复上面过程直至datasets取完。 | |||
2. 给定datasets集合或者列表,随机采样这个datasets的任意一个数据集组合成一个混合的batch返回给用户,直至datasets所有数据集采样完。 | |||
3. 给定datasets集合或者列表,轮流采样datasets:即是循环遍历datasets,每取出一个dataset采样一个batch的数据,然后取出下一个dataset | |||
采样一个batch数据,重复上述过程直至某个dataset采样结束或者所有dataset采样结束。 | |||
""" | |||
def __init__(self, datasets: Union[List, Dict] = None, mode: Union[str, "Sampler"] = 'sequential', | |||
collate_fn: Union[List[Callable], Callable, Dict[str, Callable]] = None, | |||
sampler: Union[List["Sampler"], Dict[str, "Sampler"]] = None, | |||
num_workers: int = 0, batch_size: int = 16, drop_last=False, | |||
ds_ratio: Union[str, List[float], None, Dict[str, float]] = None, | |||
pin_memory: bool = True) -> None: | |||
""" | |||
:param datasets: dataset的列表 | |||
:param mode: mode包括四种类型,前三种分别为"sequential", "mix", "polling"分别代表上述三种情况, | |||
当mode为Sampler时为用户定制,此时sampler,ds_ratio,batch_size,drop_last失效,此时Sampler应该是一个可迭代 | |||
对象,每次迭代返回的是List[int] | |||
:param collate_fn: 对取得到的数据进行打包的callable函数, | |||
当其为callable类型时候,所有数据集采样的数据都会经过这个函数; | |||
当其为List[Callable]类型时,datasets也应该为List;会根据每个数据集__getitem__返回的idx判断当前数据对应的Callable函数, | |||
其对应关系与datasets位置匹配; | |||
当其为Dict[str, Callable]类型时, datasets也是Dict类型且一一对应。 | |||
:param sampler: sampler是datasets每个数据集内部采样的实例化sampler对象 | |||
sampler为None时候,datasets包含的每个dataset都会初始化一个sequentialSampler用于采样; | |||
sampler为List[Sampler],则datasets也为List,且一一对应 | |||
sampler为Dict[str, Sampler], datasets也是Dict类型且一一对应。 | |||
:param num_workers: 进程的数量,当num_workers=0时不开启多进程 | |||
:param batch_size: 批次大小, datasets的所有数据集batch_size一致 | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param ds_ratio: 当ds_ratio为None,原有数据集不进行扩充 | |||
当ds_ratio为'truncate_to_least'时,以datasets的最短数据集为基准,将其他数据集截断到一样长度 | |||
当ds_ratio为'pad_to_most'时,以datasets的最长数据集为基准,将最短数据集重采样到最长数据集长度一致为止 | |||
当ds_ratio为List[float]时,datasets也为List,ds_ratio的每一个参数都是datasets每个数据集应该采样的倍数, | |||
其大于0,可以超过1,将数据集重采样翻倍即可 | |||
当ds_ratio为Dict[str, float]时,datasets也为Dict,参数相互对应。 | |||
""" | |||
# 如果dataset为Dict,则其他参数如collate_fn必须为Dict或者Callable, | |||
if not isinstance(datasets, Dict) and (isinstance(collate_fn, Callable) or isinstance(collate_fn, Dict)) and \ | |||
isinstance(sampler, Dict): | |||
raise ValueError(f"") | |||
if isinstance(collate_fn, list): | |||
if len(collate_fn) != len(datasets): | |||
raise ValueError("the length of collate_fn != datasets!!") | |||
if isinstance(sampler, list): | |||
if len(sampler) != len(datasets): | |||
raise ValueError("the length of sampler != datasets!!") | |||
# Dict类型转化为List,以便于_MixCollateFn处理 | |||
if isinstance(collate_fn, Dict): | |||
collate_fn = [fn for _, fn in collate_fn.items()] | |||
# 由于datasets可能是FastNLP类型的dataset或者是交杂的, 故需要检测 | |||
if isinstance(datasets, Dict): | |||
dataset = [ds for _, ds in datasets.items()] | |||
else: | |||
dataset = datasets | |||
auto_collators = [] | |||
for per_ds in dataset: | |||
if isinstance(per_ds, DataSet): | |||
auto_collators.append(per_ds.get_collator()) | |||
else: | |||
# 如果没有对应的collator就设置一个不做任何操作的collator | |||
auto_collators.append(lambda x: x) | |||
# List类型的collate_fn只有两种情况,需要对其进行包裹 | |||
collate_fn = _MixCollateFn(collate_fn, auto_collators) | |||
if mode == 'sequential': | |||
batch_sampler = MixSequentialSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
drop_last=drop_last, ds_ratio=ds_ratio) | |||
elif mode == 'polling': | |||
batch_sampler = PollingSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
drop_last=drop_last, ds_ratio=ds_ratio) | |||
elif mode == 'mix': | |||
batch_sampler = DopedSampler(datasets, batch_size=batch_size, sampler=sampler, | |||
drop_last=drop_last, ds_ratio=ds_ratio) | |||
elif isinstance(mode, Sampler): | |||
batch_sampler = mode | |||
else: | |||
raise ValueError(f"{mode} must be sequential, polling, mix or batch_sampler") | |||
super(MixDataLoader, self).__init__( | |||
_MixDataset(datasets=dataset), batch_size=1, shuffle=False, sampler=None, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, | |||
pin_memory=pin_memory, drop_last=False, timeout=0, | |||
worker_init_fn=None, multiprocessing_context=None, generator=None, | |||
prefetch_factor=2, persistent_workers=False | |||
) | |||
def __iter__(self): | |||
return super().__iter__() |
@@ -0,0 +1,6 @@ | |||
__all__ = [ | |||
'prepare_paddle_dataloader', | |||
'PaddleDataLoader' | |||
] | |||
from .fdl import PaddleDataLoader, prepare_paddle_dataloader |
@@ -0,0 +1,192 @@ | |||
__all__ = [ | |||
'PaddleDataLoader', | |||
'prepare_paddle_dataloader' | |||
] | |||
from typing import Callable, List, Optional, Union, Dict, Sequence | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
from paddle.io import DataLoader, Dataset | |||
from paddle.fluid.dataloader.collate import default_collate_fn | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | |||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.core.dataset import DataSet as FDataSet | |||
class _PaddleDataset(Dataset): | |||
""" | |||
对用户传的dataset进行封装,以便Fdataloader能够支持使用自定义的dataset使用paddle的dataloader | |||
""" | |||
def __init__(self, dataset) -> None: | |||
super(_PaddleDataset, self).__init__() | |||
self.dataset = dataset | |||
def __getitem__(self, item): | |||
return (item, self.dataset[item]) | |||
def __len__(self) -> int: | |||
return len(self.dataset) | |||
def __getattr__(self, item): | |||
try: | |||
self.dataset.__getattribute__(item) | |||
except Exception as e: | |||
raise e | |||
class PaddleDataLoader(DataLoader): | |||
def __init__(self, dataset, feed_list=None, places=None, | |||
return_list: bool = True, batch_sampler=None, | |||
batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Callable = None, | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | |||
if not isinstance(dataset, _PaddleDataset): | |||
dataset = _PaddleDataset(dataset) | |||
super(PaddleDataLoader, self).__init__(dataset=dataset, feed_list=feed_list, places=places, | |||
return_list=return_list, batch_sampler=batch_sampler, | |||
batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | |||
collate_fn=None, num_workers=num_workers, | |||
use_buffer_reader=use_buffer_reader, use_shared_memory=use_shared_memory, | |||
timeout=timeout, worker_init_fn=worker_init_fn, | |||
persistent_workers=persistent_workers) | |||
if isinstance(dataset.dataset, FDataSet): | |||
self._collate_fn = dataset.dataset.get_collator() | |||
self._collate_fn.set_as_numpy(as_numpy=True) | |||
if collate_fn is not None: | |||
self._collate_fn.add_collator(collate_fn) | |||
else: | |||
self._collate_fn = _MultiCollator(collate_fn) | |||
# _collate_fn = _MultiCollator(AutoCollator(as_numpy=True)) | |||
# if collate_fn is not None: | |||
# _collate_fn.add_collator(collate_fn) | |||
# self._collate_fn = _collate_fn | |||
self.cur_batch_indices = None | |||
def __iter__(self): | |||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
if len(self._collate_fn.get_collators()) == 0: | |||
self._collate_fn.add_collator(default_collate_fn) | |||
# self._collate_fn = default_collate_fn | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
return self.dataset.__getattr__(item) | |||
except AttributeError as e: | |||
raise e | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
for field_name in field_names: | |||
self._collate_fn.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
self._collate_fn.set_input(*field_names) | |||
def set_collator(self, collator: Callable) -> None: | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collator: 用户自定义的Callable函数 | |||
:return: | |||
""" | |||
self._collate_fn = _MultiCollator(collator) | |||
def add_collator(self, collator) -> None: | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collator: | |||
:return: | |||
""" | |||
self._collate_fn.add_collator(collator) | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
获取当前数据的idx | |||
:return: | |||
""" | |||
return self.cur_batch_indices | |||
def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||
return_list: bool = True, batch_sampler=None, | |||
train_batch_size: int = 1, shuffle: bool = False, | |||
drop_last: bool = False, collate_fn: Callable = None, | |||
num_workers: int = 0, use_buffer_reader: bool = True, | |||
use_shared_memory: bool = True, timeout: int = 0, | |||
worker_init_fn: Callable = None, persistent_workers=False, | |||
non_train_batch_size: int = 16, | |||
input_fields: Union[List[str], str] = None)\ | |||
-> Union[Sequence[PaddleDataLoader], Dict[str, PaddleDataLoader], PaddleDataLoader]: | |||
if isinstance(input_fields, str): | |||
input_fields = [input_fields] | |||
if isinstance(ds_or_db, Dataset): | |||
... | |||
elif isinstance(ds_or_db, Sequence): | |||
ds_seq = [] | |||
for ds in ds_or_db: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
dl.set_input(*input_fields) | |||
ds_seq.append(dl) | |||
return ds_seq | |||
elif isinstance(ds_or_db, Dict): | |||
ds_dict = {} | |||
for name, ds in ds_or_db.items(): | |||
if 'train' in name: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
else: | |||
dl = PaddleDataLoader(ds, feed_list=feed_list, places=places, return_list=return_list, | |||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, shuffle=shuffle, | |||
drop_last=drop_last, collate_fn=collate_fn, num_workers=num_workers, | |||
use_shared_memory=use_shared_memory, use_buffer_reader=use_buffer_reader, | |||
timeout=timeout, worker_init_fn=worker_init_fn, persistent_workers=persistent_workers) | |||
dl.set_input(*input_fields) | |||
ds_dict[name] = dl | |||
return ds_dict | |||
else: | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") |
@@ -0,0 +1,6 @@ | |||
__all__ = [ | |||
"TorchDataLoader", | |||
"prepare_torch_dataloader" | |||
] | |||
from .fdl import TorchDataLoader, prepare_torch_dataloader |
@@ -0,0 +1,300 @@ | |||
__all__ = [ | |||
'TorchDataLoader', | |||
'prepare_torch_dataloader' | |||
] | |||
from typing import Optional, Callable, Sequence, List, Union, Tuple, Dict, Mapping | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.utils.utils import indice_collate_wrapper | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader, Sampler | |||
from torch.utils.data._utils.collate import default_collate | |||
else: | |||
from ..fdataloader import FDataLoader as DataLoader | |||
class _FDataSet: | |||
""" | |||
对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset | |||
中调用dataset的方法 | |||
""" | |||
def __init__(self, dataset) -> None: | |||
self.dataset = dataset | |||
def __getitem__(self, item: Union[int, list]) -> Tuple: | |||
return (item, self.dataset[item]) | |||
def __getattr__(self, item): | |||
try: | |||
return self.dataset.__getattribute__(item) | |||
except AttributeError as e: | |||
raise e | |||
def __len__(self) -> int: | |||
return len(self.dataset) | |||
class TorchDataLoader(DataLoader): | |||
""" | |||
提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 | |||
提供的方法调节设置collate_fn的若干参数。 | |||
""" | |||
def __init__(self, dataset, batch_size: int = 1, | |||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||
""" | |||
:param dataset: 实现了__getitem__和__len__的数据容器 | |||
:param batch_size: 批次大小,当batch_sampler为None生效 | |||
:param shuffle: 是否打乱数据集 | |||
:param sampler: sampler实例化对象 | |||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param pin_memory: | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param timeout: | |||
:param worker_init_fn: | |||
:param multiprocessing_context: | |||
:param generator: | |||
:param prefetch_factor: | |||
:param persistent_workers: | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | |||
""" | |||
if not isinstance(dataset, _FDataSet): | |||
dataset = _FDataSet(dataset) | |||
super().__init__(dataset=dataset, batch_size=batch_size, shuffle=shuffle, sampler=sampler, | |||
batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=None, | |||
pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, | |||
persistent_workers=persistent_workers) | |||
if isinstance(dataset.dataset, DataSet): # 使用了 fastnlp dataset | |||
self._collate_fn = dataset.dataset.get_collator() | |||
self._collate_fn.set_as_numpy(as_numpy) | |||
if collate_fn is not None and collate_fn is not default_collate: | |||
# 防止ddp重新初始化时候将torch dataloader的默认collate加进来 | |||
self._collate_fn.add_collator(collate_fn) | |||
else: | |||
self._collate_fn = _MultiCollator(collate_fn) | |||
self.cur_indices_batch = None | |||
self.as_numpy = as_numpy | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
return self.dataset.__getattr__(item) | |||
except AttributeError as e: | |||
raise e | |||
def __iter__(self): | |||
# 如果没有auto_collator 也没有自定义collate_fn, 那么此时采用dataloader自带的collate_fn, 将数据打包即可。 | |||
if len(self._collate_fn.get_collators()) == 0: | |||
self._collate_fn.add_collator(self.collate_fn) | |||
self.collate_fn = indice_collate_wrapper(self._collate_fn) | |||
for indices, data in super().__iter__(): | |||
self.cur_batch_indices = indices | |||
yield data | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当autocollate存在时该方法有效, 若没有则会添加auto_collator函数 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: | |||
:param val: padding值,默认为0 | |||
:return: | |||
""" | |||
flag = False | |||
for collator in self._collate_fn.get_collators(): | |||
if isinstance(collator, AutoCollator): | |||
flag = True | |||
break | |||
if flag is False: | |||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||
for field_name in field_names: | |||
self._collate_fn.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
flag = False | |||
for collator in self._collate_fn.get_collators(): | |||
if isinstance(collator, AutoCollator): | |||
flag = True | |||
break | |||
if flag is False: | |||
self._collate_fn.add_collator(AutoCollator(self.as_numpy)) | |||
self._collate_fn.set_input(*field_names) | |||
def set_collator(self, collator: Callable) -> None: | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collator: 用户自定义的Callable函数 | |||
:return: | |||
""" | |||
self._collate_fn = _MultiCollator(collator) | |||
def add_collator(self, collator) -> None: | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collator: | |||
:return: | |||
""" | |||
self._collate_fn.add_collator(collator) | |||
def get_batch_indices(self) -> List[int]: | |||
""" | |||
获取当前数据的idx | |||
:return: | |||
""" | |||
return self.cur_batch_indices | |||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | |||
batch_size: int = 1, | |||
shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, | |||
batch_sampler: Optional["Sampler[Sequence[int]]"] = None, | |||
num_workers: int = 0, collate_fn: Optional[Callable] = None, | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, | |||
non_train_batch_size: int = 16, as_numpy: bool = False, | |||
input_fields: Union[List, str] = None)\ | |||
-> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: | |||
""" | |||
传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 | |||
:param input_fields: | |||
:param ds_or_db: dataset或者data_bundle | |||
:param batch_size: 批次大小,当batch_sampler为None生效 | |||
:param shuffle: 是否打乱数据集 | |||
:param sampler: sampler实例化对象 | |||
:param batch_sampler: batch_sampler实例化对象,其能迭代返回一个list的index数据 | |||
:param num_workers: 进程的数量,当num_worker=0时不开启多进程 | |||
:param collate_fn: 对取得到的数据进行打包的callable函数 | |||
:param pin_memory: | |||
:param drop_last: 是否去掉最后一个不符合batch_size的数据 | |||
:param timeout: | |||
:param worker_init_fn: | |||
:param multiprocessing_context: | |||
:param generator: | |||
:param prefetch_factor: | |||
:param persistent_workers: | |||
:param non_train_sampler: 非 'train' 数据使用的 Sampler, 以及Sequence的第二个以上的ds使用的 Sampler | |||
:param non_train_batch_size: | |||
:param as_numpy: 返回数据是否设置为numpy类型,否则根据情况设置为 torch.tensor 类型。 | |||
""" | |||
# TODO dict, sequence情况下需要提供 | |||
if isinstance(input_fields, str): | |||
input_fields = [input_fields] | |||
if isinstance(ds_or_db, DataSet): | |||
dl = TorchDataLoader(dataset=ds_or_db, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
dl.set_input(*input_fields) | |||
return dl | |||
elif isinstance(ds_or_db, DataBundle): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.iter_datasets(): | |||
if 'train' in name: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
dl_bundle[name].set_input(*input_fields) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Sequence): | |||
dl_bundle = [] | |||
for idx, ds in enumerate(ds_or_db): | |||
if idx == 0: | |||
dl_bundle.append( | |||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
else: | |||
dl_bundle.append( | |||
TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
) | |||
for dl in dl_bundle: | |||
dl.set_input(*input_fields) | |||
return dl_bundle | |||
elif isinstance(ds_or_db, Mapping): | |||
dl_bundle = {} | |||
for name, ds in ds_or_db.items(): | |||
if 'train' in name: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=batch_size, | |||
shuffle=shuffle, sampler=sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
else: | |||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||
shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, | |||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | |||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | |||
multiprocessing_context=multiprocessing_context, generator=generator, | |||
prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, | |||
as_numpy=as_numpy) | |||
dl_bundle[name].set_input(*input_fields) | |||
return dl_bundle | |||
else: | |||
raise ValueError(f"ds_or_db: {ds_or_db} must be fastnlp dataset or data_bundle or sequence or mapping!") |
@@ -0,0 +1,10 @@ | |||
__all__ = [ | |||
'DataSet', | |||
'FieldArray', | |||
'Instance', | |||
'ApplyResultException' | |||
] | |||
from .dataset import DataSet, ApplyResultException | |||
from .field import FieldArray | |||
from .instance import Instance |
@@ -0,0 +1,818 @@ | |||
r""" | |||
""" | |||
__all__ = [ | |||
"DataSet", | |||
"ApplyResultException" | |||
] | |||
import _pickle as pickle | |||
from copy import deepcopy | |||
from typing import Optional, List, Callable, Union, Dict, Any | |||
from functools import partial | |||
import warnings | |||
import numpy as np | |||
from threading import Thread | |||
try: | |||
import multiprocess as mp | |||
from multiprocess import RLock | |||
except: | |||
pass | |||
from .field import FieldArray | |||
from .instance import Instance | |||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.utils.rich_progress import f_rich_progress | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
class ApplyResultException(Exception): | |||
def __init__(self, msg, index=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, | |||
pipe=None, desc: str = None) -> list: | |||
""" | |||
对数据集进行处理封装函数,以便多进程使用 | |||
:param ds: 数据集 | |||
:param _apply_field: 需要处理数据集的field_name | |||
:param func: 用户自定义的func | |||
:param pipe: 管道 | |||
:param desc: 进度条的描述字符 | |||
:param show_progress_bar: 是否展示子进程进度条 | |||
:return: | |||
""" | |||
if show_progress_bar: | |||
desc = desc if desc else f"Main" | |||
pg_main = f_rich_progress.add_task(description=desc, total=len(ds), visible=show_progress_bar) | |||
results = [] | |||
idx = -1 | |||
try: | |||
# for idx, ins in tqdm(enumerate(ds), total=len(ds), position=0, desc=desc, disable=not show_progress_bar): | |||
for idx, ins in enumerate(ds): | |||
if _apply_field is not None: | |||
results.append(func(ins[_apply_field])) | |||
else: | |||
results.append(func(ins)) | |||
if pipe is not None: | |||
pipe.send([idx + 1]) | |||
if show_progress_bar: | |||
f_rich_progress.update(pg_main, advance=1) | |||
except BaseException as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
finally: | |||
if show_progress_bar: | |||
f_rich_progress.destroy_task(pg_main) | |||
return results | |||
def _progress_bar(parent, total_len: int, desc: str = None, show_progress_bar: bool = True) -> None: | |||
""" | |||
多进程下显示主进程的进度条 | |||
:param parent: 进程管道 | |||
:param total_len: 数据集总长度 | |||
:param desc: 进度条描述符 | |||
:param show_progress_bar: 是否展示进度条 | |||
:return: | |||
""" | |||
desc = desc if desc else "Main" | |||
main_pro = f_rich_progress.add_task(description=desc, total=total_len, visible=show_progress_bar) | |||
# pb_main = tqdm(total=total_len, desc=desc, position=0) | |||
nums = 0 | |||
while True: | |||
msg = parent.recv()[0] | |||
if msg is not None: | |||
f_rich_progress.update(main_pro, advance=1) | |||
nums += 1 | |||
if nums == total_len: | |||
break | |||
# pb_main.close() | |||
class DataSet: | |||
r""" | |||
fastNLP的数据容器,详细的使用方法见文档 :mod:`fastNLP.core.dataset` | |||
""" | |||
def __init__(self, data: Union[List[Instance], Dict[str, List[Any]], None] = None): | |||
r""" | |||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
""" | |||
self.field_arrays = {} | |||
self.collate_fns: _MultiCollator = _MultiCollator(AutoCollator(as_numpy=False)) | |||
if data is not None: | |||
if isinstance(data, Dict): | |||
length_set = set() | |||
for key, value in data.items(): | |||
length_set.add(len(value)) | |||
assert len(length_set) == 1, "Arrays must all be same length." | |||
for key, value in data.items(): | |||
self.add_field(field_name=key, fields=value) | |||
elif isinstance(data, List): | |||
for ins in data: | |||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||
self.append(ins) | |||
else: | |||
raise ValueError("data only be dict or list type.") | |||
def __contains__(self, item): | |||
return item in self.field_arrays | |||
def __iter__(self): | |||
for idx in range(len(self)): | |||
yield self[idx] | |||
def _inner_iter(self): | |||
class Iter_ptr: | |||
def __init__(self, dataset, idx): | |||
self.dataset = dataset | |||
self.idx = idx | |||
def __getitem__(self, item): | |||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | |||
self.idx]) | |||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | |||
return self.dataset.field_arrays[item][self.idx] | |||
def __setitem__(self, key, value): | |||
raise TypeError("You cannot modify value directly.") | |||
def items(self): | |||
ins = self.dataset[self.idx] | |||
return ins.items() | |||
def __repr__(self): | |||
return self.dataset[self.idx].__repr__() | |||
def inner_iter_func(): | |||
for idx in range(len(self)): | |||
yield Iter_ptr(self, idx) | |||
return inner_iter_func() | |||
def __getitem__(self, idx: Union[int, slice, str, list]): | |||
r"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||
:param idx: can be int or slice. | |||
:return: If `idx` is int, return an Instance object. | |||
If `idx` is slice, return a DataSet object. | |||
""" | |||
if isinstance(idx, int): | |||
return Instance(**{name: self.field_arrays[name][idx] for name in self.field_arrays}) | |||
elif isinstance(idx, slice): | |||
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): | |||
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") | |||
data_set = DataSet() | |||
for field_name, field in self.field_arrays.items(): | |||
data_set.add_field(field_name=field_name, fields=field.content[idx]) | |||
return data_set | |||
elif isinstance(idx, str): | |||
if idx not in self: | |||
raise KeyError("No such field called {} in DataSet.".format(idx)) | |||
return self.field_arrays[idx] | |||
elif isinstance(idx, list): | |||
dataset = DataSet() | |||
for i in idx: | |||
assert isinstance(i, int), "Only int index allowed." | |||
instance = self[i] | |||
dataset.append(instance) | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __getattribute__(self, item): | |||
return object.__getattribute__(self, item) | |||
def __getattr__(self, item): | |||
# Not tested. Don't use !! | |||
if item == "field_arrays": | |||
raise AttributeError | |||
if isinstance(item, str) and item in self.field_arrays: | |||
return self.field_arrays[item] | |||
def __setstate__(self, state): | |||
self.__dict__ = state | |||
def __getstate__(self): | |||
return self.__dict__ | |||
def __len__(self): | |||
r"""Fetch the length of the dataset. | |||
:return length: | |||
""" | |||
if len(self.field_arrays) == 0: | |||
return 0 | |||
field = iter(self.field_arrays.values()).__next__() | |||
return len(field) | |||
def __repr__(self): | |||
return str(pretty_table_printer(self)) | |||
def append(self, instance: Instance) -> None: | |||
r""" | |||
将一个instance对象append到DataSet后面。 | |||
:param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。 | |||
""" | |||
if len(self.field_arrays) == 0: | |||
# DataSet has no field yet | |||
for name, field in instance.items(): | |||
# field = field.tolist() if isinstance(field, np.ndarray) else field | |||
self.field_arrays[name] = FieldArray(name, [field]) # 第一个样本,必须用list包装起来 | |||
else: | |||
if len(self.field_arrays) != len(instance.fields): | |||
raise ValueError( | |||
"DataSet object has {} fields, but attempt to append an Instance object with {} fields." | |||
.format(len(self.field_arrays), len(instance.fields))) | |||
for name, field in instance.items(): | |||
assert name in self.field_arrays | |||
try: | |||
self.field_arrays[name].append(field) | |||
except Exception as e: | |||
print(f"Cannot append to field:{name}.") | |||
raise e | |||
def add_fieldarray(self, field_name: str, fieldarray: FieldArray) -> None: | |||
r""" | |||
将fieldarray添加到DataSet中. | |||
:param str field_name: 新加入的field的名称 | |||
:param ~fastNLP.core.FieldArray fieldarray: 需要加入DataSet的field的内容 | |||
:return: | |||
""" | |||
if not isinstance(fieldarray, FieldArray): | |||
raise TypeError("Only fastNLP.FieldArray supported.") | |||
if len(self) != len(fieldarray): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fieldarray)}") | |||
fieldarray.name = field_name | |||
self.field_arrays[field_name] = fieldarray | |||
def add_field(self, field_name: str, fields: list) -> None: | |||
r""" | |||
新增一个field, 需要注意的是fields的长度跟dataset长度一致 | |||
:param str field_name: 新增的field的名称 | |||
:param list fields: 需要新增的field的内容 | |||
""" | |||
if len(self.field_arrays) != 0: | |||
if len(self) != len(fields): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fields)}") | |||
self.field_arrays[field_name] = FieldArray(field_name, fields) | |||
def delete_instance(self, index: int): | |||
r""" | |||
删除第index个instance | |||
:param int index: 需要删除的instance的index,序号从0开始。 | |||
""" | |||
assert isinstance(index, int), "Only integer supported." | |||
if len(self) <= index: | |||
raise IndexError("{} is too large for as DataSet with {} instances.".format(index, len(self))) | |||
if len(self) == 1: | |||
self.field_arrays.clear() | |||
else: | |||
for field in self.field_arrays.values(): | |||
field.pop(index) | |||
return self | |||
def delete_field(self, field_name: str): | |||
r""" | |||
删除名为field_name的field | |||
:param str field_name: 需要删除的field的名称. | |||
""" | |||
if self.has_field(field_name): | |||
self.field_arrays.pop(field_name) | |||
else: | |||
raise KeyError(f"Field:{field_name} not found in DataSet.") | |||
return self | |||
def copy_field(self, field_name: str, new_field_name: str): | |||
r""" | |||
深度copy名为field_name的field到new_field_name | |||
:param str field_name: 需要copy的field。 | |||
:param str new_field_name: copy生成的field名称 | |||
:return: self | |||
""" | |||
if not self.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} not found in DataSet.") | |||
fieldarray = deepcopy(self.get_field(field_name)) | |||
fieldarray.name = new_field_name | |||
self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray) | |||
return self | |||
def has_field(self, field_name: str) -> bool: | |||
r""" | |||
判断DataSet中是否有名为field_name这个field | |||
:param str field_name: field的名称 | |||
:return bool: 表示是否有名为field_name这个field | |||
""" | |||
if isinstance(field_name, str): | |||
return field_name in self.field_arrays | |||
return False | |||
def get_field(self, field_name: str) -> FieldArray: | |||
r""" | |||
获取field_name这个field | |||
:param str field_name: field的名称 | |||
:return: :class:`~fastNLP.FieldArray` | |||
""" | |||
if field_name not in self.field_arrays: | |||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||
return self.field_arrays[field_name] | |||
def get_all_fields(self) -> dict: | |||
r""" | |||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||
:return dict: 返回如上所述的字典 | |||
""" | |||
return self.field_arrays | |||
def get_field_names(self) -> list: | |||
r""" | |||
返回一个list,包含所有 field 的名字 | |||
:return list: 返回如上所述的列表 | |||
""" | |||
return sorted(self.field_arrays.keys()) | |||
def get_length(self) -> int: | |||
r""" | |||
获取DataSet的元素数量 | |||
:return: int: DataSet中Instance的个数。 | |||
""" | |||
return len(self) | |||
def rename_field(self, field_name: str, new_field_name: str): | |||
r""" | |||
将某个field重新命名. | |||
:param str field_name: 原来的field名称。 | |||
:param str new_field_name: 修改为new_name。 | |||
""" | |||
if field_name in self.field_arrays: | |||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | |||
self.field_arrays[new_field_name].name = new_field_name | |||
else: | |||
raise KeyError("DataSet has no field named {}.".format(field_name)) | |||
return self | |||
def apply_field(self, func: Union[Callable], field_name: str = None, | |||
new_field_name: str = None, num_proc: int = 0, | |||
progress_desc: str = None, show_progress_bar: bool = True): | |||
r""" | |||
将 DataSet 中的每个 instance 中的名为 `field_name` 的 field 传给 func,并获取它的返回值。 | |||
:param num_proc: 进程的数量 | |||
:param field_name: 传入 func 的是哪个 field。 | |||
:param func: input是 instance 中名为 `field_name` 的 field 的内容。 | |||
:param new_field_name: 将 func 返回的内容放入到 `new_field_name` 这个 field 中,如果名称与已有的 field 相同,则覆 | |||
盖之前的 field。如果为 None 则不创建新的 field。 | |||
:param progress_desc: progress_desc 的值,默认为 Main | |||
:param show_progress_bar: 是否展示进度条,默认展示进度条 | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
if not self.has_field(field_name=field_name): | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
try: | |||
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, | |||
progress_desc=progress_desc, _apply_field=field_name) | |||
except BaseException as e: | |||
raise e | |||
if new_field_name is not None: | |||
self.add_field(field_name=new_field_name, fields=results) | |||
return results | |||
def apply_field_more(self, func: Callable = None, field_name: str = None, | |||
modify_fields: bool = True, num_proc: int = 0, | |||
progress_desc: str = None, show_progress_bar: bool = True): | |||
r""" | |||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | |||
func 可以返回一个或多个 field 上的结果。 | |||
.. note:: | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`~fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param num_proc: 进程的数量 | |||
:param field_name: 传入func的是哪个field。 | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param show_progress_bar: 是否显示进度条,默认展示 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
if not self.has_field(field_name=field_name): | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
idx = -1 | |||
results = {} | |||
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | |||
show_progress_bar=show_progress_bar, _apply_field=field_name) | |||
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | |||
if not isinstance(apply_out[0], dict): | |||
raise Exception("The result of func is not a dict") | |||
for key, value in apply_out[0].items(): | |||
results[key] = [value] | |||
# 尝试合并所有dict数据, idx+1 的原因是第一条数据不可能出现错误,默认第一条数据为准 | |||
try: | |||
for idx, per_out in enumerate(apply_out[1:]): | |||
if len(set(results.keys()) - set(per_out.keys())): | |||
raise ApplyResultException("apply results have different fields", idx + 1) | |||
for key, value in per_out.items(): | |||
results[key].append(value) | |||
except Exception as e: | |||
if idx != -1: | |||
if isinstance(e, ApplyResultException): | |||
print(e.msg) | |||
print("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
raise e | |||
if modify_fields is True: | |||
for field, result in results.items(): | |||
self.add_field(field_name=field, fields=result) | |||
return results | |||
def _apply_process(self, num_proc: int = 0, func: Callable = None, | |||
show_progress_bar: bool = True, _apply_field: str = None, | |||
progress_desc: str = 'Main') -> list: | |||
""" | |||
:param num_proc: 进程的数量 | |||
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | |||
:param _apply_field: 需要传进去func的数据集的field_name | |||
:param show_progress_bar: 是否展示progress进度条,默认为展示 | |||
:param progress_desc: 进度条的描述字符,默认为'Main | |||
""" | |||
if num_proc == 0: | |||
results = _apply_single(ds=self, _apply_field=_apply_field, func=func, | |||
desc=progress_desc, show_progress_bar=show_progress_bar) | |||
else: | |||
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 | |||
results = [] | |||
if num_proc > len(self): | |||
num_proc = len(self) | |||
print( | |||
f"num_proc must be <= {len(self)}. Reducing num_proc to {num_proc} for dataset of size {len(self)}." | |||
) | |||
# 划分数据集 | |||
shard_len = len(self) // num_proc | |||
num_left_sample = len(self) % num_proc | |||
start = 0 | |||
shard_data = [] | |||
for _i in range(num_proc): | |||
end = shard_len + int(_i<num_left_sample) + start | |||
shard_data.append(self[start:end]) | |||
start = end | |||
# 配置管道,线程以实现 main progress 能够实时更新。 | |||
parent, child = mp.Pipe() | |||
main_thread = Thread(target=_progress_bar, args=(parent, len(self), progress_desc, | |||
show_progress_bar)) | |||
partial_single_map = partial(_apply_single, _apply_field=_apply_field, func=func, | |||
pipe=child, show_progress_bar=False) | |||
# 开启进程池,线程 | |||
main_thread.start() | |||
pool = mp.Pool(processes=num_proc) | |||
pool_outs = [pool.apply_async(partial_single_map, kwds={'ds': ds, "proc_id": proc_id}) | |||
for proc_id, ds in enumerate(shard_data)] | |||
pool.close() | |||
pool.join() | |||
main_thread.join() | |||
for async_result in pool_outs: | |||
data = async_result.get() | |||
results.extend(data) | |||
return results | |||
def apply_more(self, func: Callable = None, modify_fields: bool = True, | |||
num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | |||
.. note:: | |||
``apply_more`` 与 ``apply`` 的区别: | |||
1. ``apply_more`` 可以返回多个 field 的结果, ``apply`` 只可以返回一个field 的结果; | |||
2. ``apply_more`` 的返回值是一个字典,每个 key-value 对中的 key 表示 field 的名字,value 表示计算结果; | |||
3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 | |||
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param num_proc: 进程的数量 | |||
:param show_progress_bar: 是否使用tqd显示预处理进度 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
# 返回 dict , 检查是否一直相同 | |||
assert callable(func), "The func you provide is not callable." | |||
assert len(self) != 0, "Null DataSet cannot use apply()." | |||
assert num_proc >= 0, "num_proc must >= 0" | |||
idx = -1 | |||
results = {} | |||
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | |||
show_progress_bar=show_progress_bar) | |||
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | |||
if not isinstance(apply_out[0], dict): | |||
raise Exception("The result of func is not a dict") | |||
for key, value in apply_out[0].items(): | |||
results[key] = [value] | |||
# 尝试合并所有dict数据, idx+1 的原因是第一条数据不可能出现错误,已经将第一条数据取出来 | |||
try: | |||
for idx, per_out in enumerate(apply_out[1:]): | |||
if len(set(results.keys()) - set(per_out.keys())): | |||
raise ApplyResultException("apply results have different fields", idx + 1) | |||
for key, value in per_out.items(): | |||
results[key].append(value) | |||
except Exception as e: | |||
if idx != -1: | |||
if isinstance(e, ApplyResultException): | |||
print(e.msg) | |||
print("Exception happens at the `{}`th instance.".format(idx + 1)) | |||
raise e | |||
if modify_fields is True: | |||
for field, result in results.items(): | |||
self.add_field(field_name=field, fields=result) | |||
return results | |||
def apply(self, func: Callable = None, new_field_name: str = None, | |||
num_proc: int = 0, show_progress_bar: bool = True, progress_desc: str = ''): | |||
""" | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param num_proc: 进程的数量。 | |||
:param show_progress_bar: 是否显示进度条。 | |||
:param progress_desc: progress bar 显示的值,默认为空。 | |||
""" | |||
assert callable(func), "The func you provide is not callable." | |||
assert len(self) != 0, "Null DataSet cannot use apply()." | |||
assert num_proc >= 0, "num_proc must be an integer >= 0." | |||
try: | |||
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, | |||
progress_desc=progress_desc) | |||
except BaseException as e: | |||
raise e | |||
if new_field_name is not None: | |||
self.add_field(field_name=new_field_name, fields=results) | |||
return results | |||
def add_seq_len(self, field_name: str, new_field_name='seq_len'): | |||
r""" | |||
将使用len()直接对field_name中每个元素作用,将其结果作为sequence length, 并放入seq_len这个field。 | |||
:param field_name: str. | |||
:param new_field_name: str. 新的field_name | |||
:return: | |||
""" | |||
if self.has_field(field_name=field_name): | |||
self.apply_field(len, field_name, new_field_name=new_field_name) | |||
else: | |||
raise KeyError(f"Field:{field_name} not found.") | |||
return self | |||
def drop(self, func: Callable, inplace=True): | |||
r""" | |||
func接受一个Instance,返回bool值。返回值为True时,该Instance会被移除或者不会包含在返回的DataSet中。 | |||
:param callable func: 接受一个Instance作为参数,返回bool值。为True时删除该instance | |||
:param bool inplace: 是否在当前DataSet中直接删除instance;如果为False,将返回一个新的DataSet。 | |||
:return: DataSet | |||
""" | |||
if inplace: | |||
results = [ins for ins in self if not func(ins)] | |||
for name, old_field in self.field_arrays.items(): | |||
self.field_arrays[name].content = [ins[name] for ins in results] | |||
return self | |||
else: | |||
results = [ins for ins in self if not func(ins)] | |||
if len(results) != 0: | |||
dataset = DataSet(results) | |||
return dataset | |||
else: | |||
return DataSet() | |||
def split(self, ratio: float, shuffle=True): | |||
r""" | |||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||
:param float ratio: 0<ratio<1, 返回的第一个DataSet拥有 `ratio` 这么多数据,第二个DataSet拥有`(1-ratio)`这么多数据 | |||
:param bool shuffle: 在split前是否shuffle一下。为False,返回的第一个dataset就是当前dataset中前`ratio`比例的数据, | |||
:return: [ :class:`~fastNLP.读取后的DataSet` , :class:`~fastNLP.读取后的DataSet` ] | |||
""" | |||
assert len(self) > 1, f'DataSet with {len(self)} instance cannot be split.' | |||
assert isinstance(ratio, float) | |||
assert 0 < ratio < 1 | |||
all_indices = [_ for _ in range(len(self))] | |||
if shuffle: | |||
np.random.shuffle(all_indices) | |||
split = int(ratio * len(self)) | |||
if split == 0: | |||
error_msg = f'Dev DataSet has {split} instance after split.' | |||
print(error_msg) | |||
raise IndexError(error_msg) | |||
dev_indices = all_indices[:split] | |||
train_indices = all_indices[split:] | |||
dev_set = DataSet() | |||
train_set = DataSet() | |||
for idx in dev_indices: | |||
dev_set.append(self[idx]) | |||
for idx in train_indices: | |||
train_set.append(self[idx]) | |||
return dev_set, train_set | |||
def save(self, path: str) -> None: | |||
r""" | |||
保存DataSet. | |||
:param str path: 将DataSet存在哪个路径 | |||
""" | |||
with open(path, 'wb') as f: | |||
pickle.dump(self, f) | |||
@staticmethod | |||
def load(path: str): | |||
r""" | |||
从保存的DataSet pickle文件的路径中读取DataSet | |||
:param str path: 从哪里读取DataSet | |||
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 | |||
""" | |||
with open(path, 'rb') as f: | |||
d = pickle.load(f) | |||
assert isinstance(d, DataSet), "The object is not DataSet, but {}.".format(type(d)) | |||
return d | |||
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | |||
""" | |||
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target | |||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 | |||
当前dataset含有field,则会报错。 | |||
:param DataSet, dataset: 需要和当前dataset concat的dataset | |||
:param bool, inplace: 是否直接将dataset组合到当前dataset中 | |||
:param dict, field_mapping: 当传入的dataset中的field名称和当前dataset不一致时,需要通过field_mapping把输入的dataset中的 | |||
field名称映射到当前field. field_mapping为dict类型,key为dataset中的field名称,value是需要映射成的名称 | |||
:return: DataSet | |||
""" | |||
assert isinstance(dataset, DataSet), "Can only concat two datasets." | |||
fns_in_this_dataset = set(self.get_field_names()) | |||
fns_in_other_dataset = dataset.get_field_names() | |||
reverse_field_mapping = {} | |||
if field_mapping is not None: | |||
fns_in_other_dataset = [field_mapping.get(fn, fn) for fn in fns_in_other_dataset] | |||
reverse_field_mapping = {v: k for k, v in field_mapping.items()} | |||
fns_in_other_dataset = set(fns_in_other_dataset) | |||
fn_not_seen = list(fns_in_this_dataset - fns_in_other_dataset) | |||
if fn_not_seen: | |||
raise RuntimeError(f"The following fields are not provided in the dataset:{fn_not_seen}") | |||
if inplace: | |||
ds = self | |||
else: | |||
ds = deepcopy(self) | |||
for fn in fns_in_this_dataset: | |||
ds.get_field(fn).content.extend(deepcopy(dataset.get_field(reverse_field_mapping.get(fn, fn)).content)) | |||
return ds | |||
@classmethod | |||
def from_pandas(cls, df): | |||
""" | |||
从pandas.DataFrame中读取数据转为Dataset | |||
:param df: | |||
:return: | |||
""" | |||
df_dict = df.to_dict(orient='list') | |||
return cls(df_dict) | |||
def to_pandas(self): | |||
""" | |||
将dataset转为pandas.DataFrame类型的数据 | |||
:return: | |||
""" | |||
import pandas as pd | |||
dict_ = {key: value.content for key, value in self.field_arrays.items()} | |||
return pd.DataFrame.from_dict(dict_) | |||
# TODO 应该有返回值的吧 | |||
def to_csv(self, path: str) -> None: | |||
""" | |||
将dataset保存为csv文件 | |||
:param path: | |||
:return: | |||
""" | |||
df = self.to_pandas() | |||
df.to_csv(path, encoding="utf-8") | |||
def add_collate_fn(self, collate_fn: Callable) -> None: | |||
""" | |||
添加collate_fn函数,调用该函数后会将其添加到已有的collate_fn后面 | |||
:param collate_fn: Callable的函数 | |||
:return: | |||
""" | |||
self.collate_fns.add_collator(collate_fn) | |||
def set_collate_fn(self, collate_fn: Callable) -> None: | |||
""" | |||
设置collate_fn函数,调用该函数后覆盖当前所有的collate_fn,包括Auto_Collate | |||
:param collate_fn: | |||
:return: | |||
""" | |||
self.collate_fns = _MultiCollator(collate_fn) | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: dataset存在的field_name | |||
:param val: 默认为0 | |||
:return: | |||
""" | |||
for field_name in field_names: | |||
self.collate_fns.set_pad_val(field_name, val=val) | |||
def set_input(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
self.collate_fns.set_input(*field_names) | |||
def get_collator(self) -> _MultiCollator: | |||
""" | |||
获取dataset绑定的collate_fn,其中包括auto_collate | |||
:return: | |||
""" | |||
return self.collate_fns | |||
@deprecated() | |||
def set_target(self, *field_names) -> None: | |||
""" | |||
被设置为inputs的field_names,会输入到AutoCollator中,未被设置默认过滤掉 | |||
:param field_names: | |||
:return: | |||
""" | |||
self.collate_fns.set_input(*field_names) | |||
class IterableDataset: | |||
pass | |||
@@ -0,0 +1,229 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
'FieldArray' | |||
] | |||
from collections import Counter | |||
from typing import Any, Union, List, Callable | |||
import numpy as np | |||
class FieldArray: | |||
def __init__(self, name: str, content): | |||
if len(content) == 0: | |||
raise RuntimeError("Empty fieldarray is not allowed.") | |||
_content = content | |||
try: | |||
_content = list(_content) | |||
except BaseException as e: | |||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||
raise e | |||
self.name = name | |||
self.content = _content | |||
def append(self, val: Any) -> None: | |||
r""" | |||
:param val: 把该val append到fieldarray。 | |||
:return: | |||
""" | |||
self.content.append(val) | |||
def pop(self, index: int) -> None: | |||
r""" | |||
删除该field中index处的元素 | |||
:param int index: 从0开始的数据下标。 | |||
:return: | |||
""" | |||
self.content.pop(index) | |||
def __getitem__(self, indices: Union[int, List[int]]): | |||
return self.get(indices) | |||
def __setitem__(self, idx: int, val: Any): | |||
assert isinstance(idx, int) | |||
if idx == -1: | |||
idx = len(self) - 1 | |||
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}" | |||
self.content[idx] = val | |||
def get(self, indices: Union[int, List[int]]): | |||
r""" | |||
根据给定的indices返回内容。 | |||
:param int,List[int] indices: 获取indices对应的内容。 | |||
:return: 根据给定的indices返回的内容,可能是单个值或ndarray | |||
""" | |||
if isinstance(indices, int): | |||
if indices == -1: | |||
indices = len(self) - 1 | |||
assert 0 <= indices < len(self) | |||
return self.content[indices] | |||
try: | |||
contents = [self.content[i] for i in indices] | |||
except BaseException as e: | |||
raise e | |||
return np.array(contents) | |||
def __len__(self): | |||
r""" | |||
Returns the size of FieldArray. | |||
:return int length: | |||
""" | |||
return len(self.content) | |||
def split(self, sep: str = None, inplace: bool = True): | |||
r""" | |||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||
:param sep: 分割符,如果为None则直接调用str.split()。 | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[List[str]] or self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
new_contents.append(cell.split(sep)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def int(self, inplace: bool = True): | |||
r""" | |||
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([int(value) for value in cell]) | |||
else: | |||
new_contents.append(int(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def float(self, inplace=True): | |||
r""" | |||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([float(value) for value in cell]) | |||
else: | |||
new_contents.append(float(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def bool(self, inplace=True): | |||
r""" | |||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([bool(value) for value in cell]) | |||
else: | |||
new_contents.append(bool(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def lower(self, inplace=True): | |||
r""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([value.lower() for value in cell]) | |||
else: | |||
new_contents.append(cell.lower()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def upper(self, inplace=True): | |||
r""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
:return: List[int], List[List[int]], self | |||
""" | |||
new_contents = [] | |||
for index, cell in enumerate(self.content): | |||
try: | |||
if isinstance(cell, list): | |||
new_contents.append([value.upper() for value in cell]) | |||
else: | |||
new_contents.append(cell.upper()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def value_count(self): | |||
r""" | |||
返回该field下不同value的数量。多用于统计label数量 | |||
:return: Counter, key是label,value是出现次数 | |||
""" | |||
count = Counter() | |||
def cum(cells): | |||
if isinstance(cells, Callable) and not isinstance(cells, str): | |||
for cell_ in cells: | |||
cum(cell_) | |||
else: | |||
count[cells] += 1 | |||
for cell in self.content: | |||
cum(cell) | |||
return count | |||
def _after_process(self, new_contents: list, inplace: bool): | |||
r""" | |||
当调用处理函数之后,决定是否要替换field。 | |||
:param new_contents: | |||
:param inplace: | |||
:return: self或者生成的content | |||
""" | |||
if inplace: | |||
self.content = new_contents | |||
return self | |||
else: | |||
return new_contents |
@@ -0,0 +1,71 @@ | |||
r""" | |||
instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可以认为是一个Instance类型的对象。 | |||
便于理解的例子可以参考文档 :mod:`fastNLP.core.dataset` 中的表格 | |||
""" | |||
__all__ = [ | |||
"Instance" | |||
] | |||
from fastNLP.core.utils.utils import pretty_table_printer | |||
class Instance: | |||
r""" | |||
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | |||
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | |||
""" | |||
def __init__(self, **fields): | |||
self.fields = fields | |||
def add_field(self, field_name: str, field: any): | |||
r""" | |||
向Instance中增加一个field | |||
:param str field_name: 新增field的名称 | |||
:param Any field: 新增field的内容 | |||
""" | |||
self.fields[field_name] = field | |||
def items(self): | |||
r""" | |||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | |||
:return: 一个迭代器 | |||
""" | |||
return self.fields.items() | |||
def keys(self): | |||
r""" | |||
返回一个迭代器,内容是field_name | |||
:return: 一个迭代器 | |||
""" | |||
return self.fields.keys() | |||
def values(self): | |||
r""" | |||
返回一个迭代器,内容是field_value | |||
:return: 一个迭代器 | |||
""" | |||
return self.fields.values() | |||
def __contains__(self, item): | |||
return item in self.fields | |||
def __getitem__(self, name): | |||
if name in self.fields: | |||
return self.fields[name] | |||
else: | |||
raise KeyError("{} not found".format(name)) | |||
def __setitem__(self, name, field): | |||
return self.add_field(name, field) | |||
def __repr__(self): | |||
return str(pretty_table_printer(self)) |
@@ -0,0 +1,9 @@ | |||
__all__ = [ | |||
"JittorDriver", | |||
"JittorSingleDriver", | |||
"JittorMPIDriver", | |||
] | |||
from .jittor_driver import JittorDriver | |||
from .single_device import JittorSingleDriver | |||
from .mpi import JittorMPIDriver |
@@ -0,0 +1,31 @@ | |||
from typing import Union, List | |||
from fastNLP.core.drivers.jittor_driver.jittor_driver import JittorDriver | |||
from fastNLP.core.drivers.jittor_driver.single_device import JittorSingleDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
def initialize_jittor_driver(driver: str, device: Union[str, int, List[int]], model: jittor.Module, **kwargs) -> JittorDriver: | |||
r""" | |||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | |||
在这个函数中,我们会根据用户设置的device来确定JittorDriver的mode。 | |||
:param driver: 该参数的值应为以下之一:["jittor"]; | |||
:param device: jittor运行的设备 | |||
:param model: 训练或者评测的具体的模型; | |||
:param kwargs: | |||
:return: 返回一个元组,元组的第一个值是具体的基于 jittor 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||
先后 driver 的次序的正确问题); | |||
""" | |||
if driver not in {"jittor"}: | |||
raise ValueError("Parameter `driver` can only be one of these values: ['jittor'].") | |||
# TODO 实现更详细的判断 | |||
if driver == "jittor": | |||
return JittorSingleDriver(model, device, **kwargs) | |||
else: | |||
raise NotImplementedError |
@@ -0,0 +1,155 @@ | |||
import os | |||
import warnings | |||
from typing import Optional, Callable, Dict | |||
from .utils import _build_fp16_env | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.dataloaders import JittorDataLoader | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import apply_to_collection | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from jittor import Module | |||
from jittor.optim import Optimizer | |||
_reduces = { | |||
'max': jt.max, | |||
'min': jt.min, | |||
'mean': jt.mean, | |||
'sum': jt.sum | |||
} | |||
class JittorDriver(Driver): | |||
r""" | |||
Jittor 框架的 Driver | |||
""" | |||
def __init__(self, model, fp16: bool = False, **kwargs): | |||
if not isinstance(model, Module): | |||
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `JittorDriver`, it should be exactly " | |||
f"`jittor.Module` type.") | |||
super(JittorDriver, self).__init__(model) | |||
self.model = model | |||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||
self.grad_scaler = _grad_scaler() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
# 在fastnlp中实现了JittorDataLoader | |||
# TODO: 是否允许传入Dataset? | |||
if is_train: | |||
if not isinstance(dataloader, JittorDataLoader): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'JittorDataLoader' type, not {type(dataloader)}.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
else: | |||
for each_dataloader in dataloader.values(): | |||
if not isinstance(each_dataloader, JittorDataLoader): | |||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'JittorDataLoader' " | |||
f"type, not {type(each_dataloader)}.") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, Optimizer): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'jittor.optim.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
def check_evaluator_mode(self, mode: str): | |||
model = self.unwrap_model() | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'mode=validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'mode=test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): | |||
""" | |||
保存模型 | |||
""" | |||
if model_save_fn is not None: | |||
outputs = model_save_fn(filepath) | |||
if outputs is not None: | |||
jt.save(outputs, filepath) | |||
else: | |||
if only_state_dict: | |||
states = self.model.state_dict() | |||
else: | |||
warnings.warn("Saving the whole model is not supported now in Jittor. Save state dict instead.") | |||
jt.save(states, filepath) | |||
def load_model(self, filepath: str): | |||
""" | |||
加载模型的加载函数; | |||
:param file_path: 保存文件的文件位置(需要包括文件名); | |||
:return: 加载后的state_dict | |||
""" | |||
if not os.path.exists(filepath): | |||
raise FileNotFoundError("Checkpoint at {} not found.".format(filepath)) | |||
return jt.load(filepath) | |||
def save(self): | |||
... | |||
def load(self): | |||
... | |||
def get_evaluate_context(self): | |||
return jt.no_grad | |||
def get_model_device(self): | |||
return self.model_device | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce=None): | |||
if tensor is None: | |||
return None | |||
def _translate(_data): | |||
# 如果只含有一个元素,则返回元素本身,而非list | |||
if _data.numel() == 1: | |||
return _data.item() | |||
if reduce is None: | |||
return _data.tolist() | |||
return _reduces[reduce](_data).item() | |||
return apply_to_collection( | |||
data=tensor, | |||
dtype=jt.Var, | |||
function=_translate | |||
) | |||
def set_model_mode(self, mode: str): | |||
assert mode in {"train", "eval"} | |||
getattr(self.model, mode)() | |||
@property | |||
def data_device(self): | |||
return self.model_device | |||
def move_data_to_device(self, batch: 'jt.Var'): | |||
""" | |||
jittor暂时没有提供数据迁移的函数,因此这个函数只是简单地返回batch | |||
""" | |||
return batch | |||
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | |||
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) |
@@ -0,0 +1,100 @@ | |||
import os | |||
from typing import Optional, Union | |||
from .jittor_driver import JittorDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import ReproducibleIterator | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
__all__ = [ | |||
"JittorMPIDriver", | |||
] | |||
class JittorMPIDriver(JittorDriver): | |||
def __init__( | |||
self, | |||
model, | |||
parallel_device: None, | |||
is_pull_by_jittor_run: bool = False, | |||
fp16: bool = False, | |||
**kwargs | |||
): | |||
super(JittorMPIDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
self.is_pull_by_jittor_run = is_pull_by_jittor_run | |||
self.parallel_device = parallel_device | |||
self.outside_mpi = False | |||
def setup(self): | |||
pass | |||
def configure_mpi(self): | |||
pass | |||
@property | |||
def world_size(self) -> int: | |||
return self._world_size | |||
@world_size.setter | |||
def world_size(self, size: int): | |||
self._world_size = size | |||
@property | |||
def global_rank(self) -> int: | |||
return self._global_rank | |||
@global_rank.setter | |||
def global_rank(self, rank: int) -> None: | |||
self._global_rank = rank | |||
@property | |||
def local_rank(self) -> int: | |||
return int(os.environ.get("LOCAL_RANK", 0)) | |||
@property | |||
def data_device(self): | |||
if self.outside_mpi: | |||
return self._data_device | |||
return self.model_device | |||
def train_step(self, batch): | |||
return self._train_step(batch) | |||
def validate_step(self, batch): | |||
return self._validate_step(batch) | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||
pass | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
def get_no_sync_context(self): | |||
return self.model.no_sync | |||
def unwrap_model(self): | |||
pass | |||
def get_local_rank(self) -> int: | |||
return self.local_rank | |||
def barrier(self): | |||
pass | |||
def is_distributed(self): | |||
return True |
@@ -0,0 +1,127 @@ | |||
from typing import Dict, Union | |||
from .jittor_driver import JittorDriver | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
__all__ = [ | |||
"JittorSingleDriver", | |||
] | |||
class JittorSingleDriver(JittorDriver): | |||
r""" | |||
用于 cpu 和 单卡 gpu 运算 | |||
TODO: jittor 的 fp16 | |||
""" | |||
def __init__(self, model, device=None, fp16: bool = False, **kwargs): | |||
super(JittorSingleDriver, self).__init__(model, fp16) | |||
self.model_device = device | |||
self.local_rank = 0 | |||
self.global_rank = 0 | |||
self.world_size = 1 | |||
if hasattr(self.model, "train_step"): | |||
self._train_step = self.model.train_step | |||
self._train_signature_fn = None | |||
else: | |||
self._train_step = self.model | |||
model = self.unwrap_model() | |||
self._train_signature_fn = model.execute | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
self._validate_signature_fn = self.model.test_step | |||
else: | |||
self._validate_step = self.model | |||
model = self.unwrap_model() | |||
self._validate_signature_fn = model.execute | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
self._test_signature_fn = self.model.validate_step | |||
else: | |||
self._test_step = self.model | |||
model = self.unwrap_model() | |||
self._test_signature_fn = model.execute | |||
def train_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
def step(self): | |||
""" | |||
jittor optimizers 的step函数可以传入参数loss | |||
此时会同时进行 zero_grad 和 backward | |||
为了统一,这里暂不使用这样的方式 | |||
""" | |||
for optimizer in self.optimizers: | |||
optimizer.step() | |||
def backward(self, loss): | |||
for optimizer in self.optimizers: | |||
optimizer.backward(loss) | |||
def zero_grad(self, set_to_none=False): | |||
for optimizer in self.optimizers: | |||
optimizer.zero_grad() | |||
def validate_step(self, batch): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
def unwrap_model(self): | |||
return self.model | |||
def is_distributed(self): | |||
return False | |||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||
# reproducible 的相关功能暂时没有实现 | |||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||
raise NotImplementedError | |||
dataloader.batch_sampler = dist_sample | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
raise NotImplementedError | |||
dataloader.batch_sampler.sampler = dist_sampler | |||
if reproducible: | |||
raise NotImplementedError | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
) | |||
dataloader.batch_sampler = batch_sampler | |||
return dataloader | |||
else: | |||
return dataloader |
@@ -0,0 +1,55 @@ | |||
from contextlib import ExitStack | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
class DummyGradScaler: | |||
""" | |||
用于仿造的GradScaler对象,防止重复写大量的if判断 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def get_scale(self): | |||
return 1.0 | |||
def is_enabled(self): | |||
return False | |||
def scale(self, outputs): | |||
return outputs | |||
def step(self, optimizer, *args, **kwargs): | |||
optimizer.step(*args, **kwargs) | |||
def update(self, new_scale=None): | |||
pass | |||
def unscale_(self, optimizer): | |||
pass | |||
def load_state_dict(self, state_dict): | |||
pass | |||
def state_dict(self): | |||
return {} | |||
def _build_fp16_env(dummy=False): | |||
if dummy: | |||
auto_cast = ExitStack | |||
GradScaler = DummyGradScaler | |||
else: | |||
raise NotImplementedError("JittorDriver does not support fp16 now.") | |||
# if not jt.flags.use_cuda: | |||
# raise RuntimeError("No cuda") | |||
# if paddle.device.cuda.get_device_capability(0)[0] < 7: | |||
# log.warning( | |||
# "NOTE: your device does NOT support faster training with fp16, " | |||
# "please switch to FP32 which is likely to be faster" | |||
# ) | |||
# from paddle.amp import auto_cast, GradScaler | |||
return auto_cast, GradScaler |
@@ -0,0 +1,11 @@ | |||
__all__ = [ | |||
"PaddleDriver", | |||
"PaddleSingleDriver", | |||
"PaddleFleetDriver", | |||
"paddle_seed_everything", | |||
] | |||
from .paddle_driver import PaddleDriver | |||
from .single_device import PaddleSingleDriver | |||
from .fleet import PaddleFleetDriver | |||
from .utils import paddle_seed_everything |
@@ -0,0 +1,426 @@ | |||
import os | |||
from functools import partial | |||
from typing import List, Union, Optional, Dict | |||
from .paddle_driver import PaddleDriver | |||
from .fleet_launcher import FleetLauncher | |||
from .utils import ( | |||
_FleetWrappingModel, | |||
ForwardState, | |||
_MODE_PARAMETER, | |||
get_host_name_ip, | |||
get_device_from_visible, | |||
reset_seed, | |||
) | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.core.utils import ( | |||
auto_param_call, | |||
check_user_specific_params, | |||
paddle_move_data_to_device, | |||
is_in_paddle_dist, | |||
) | |||
from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler | |||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle import DataParallel | |||
import paddle.distributed.fleet as fleet | |||
import paddle.distributed as dist | |||
from paddle.io import BatchSampler | |||
from paddle.optimizer import Optimizer | |||
from paddle.fluid.reader import _DatasetKind | |||
from paddle.fluid.dygraph import parallel_helper | |||
__all__ = [ | |||
"PaddleFleetDriver", | |||
] | |||
# if os.path.exists(self.gloo_rendezvous_dir): | |||
# shutil.rmtree(self.gloo_rendezvous_dir) | |||
class PaddleFleetDriver(PaddleDriver): | |||
def __init__( | |||
self, | |||
model, | |||
parallel_device: Optional[Union[List[int], int]], | |||
is_pull_by_paddle_run: bool = False, | |||
fp16: bool = False, | |||
**kwargs | |||
): | |||
""" | |||
采用fleet接口进行并行paddle训练的driver | |||
PaddleFleetDriver 目前考虑支持的三种启动方式: | |||
1. 用户自己不进行 fleet 的任何操作,直接使用我们的 Trainer,并且只运行一个 main 脚本,这时是由我们自己使用 open_subprocesses 拉起 | |||
多个进程,然后由 Driver 自己进行初始化 | |||
2. 其它情况同 1,但是用户自己使用 python -m paddle.distributed.launch 拉起; | |||
3. 用户自己在外面初始化 Fleet,并且通过 python -m paddle.distributed.launch 拉起; | |||
注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动; | |||
如果用户自己在外面初始化了 fleet,那么 | |||
parallel_device 为 None; | |||
data_device 为 表示单卡的一个参数; | |||
dist.is_initialized 为 true; | |||
""" | |||
super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
# 如果不是通过 launch 启动,要求用户必须传入 parallel_device | |||
if not is_pull_by_paddle_run and parallel_device is None: | |||
raise ValueError("Parameter `parallel_device` can not be None when using `PaddleFleetDriver`. This error is caused " | |||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||
# 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 | |||
self.is_pull_by_paddle_run = is_pull_by_paddle_run | |||
self.parallel_device = parallel_device | |||
# 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu | |||
if is_pull_by_paddle_run: | |||
self._model_device = parallel_device | |||
else: | |||
self._model_device = parallel_device[self.local_rank] | |||
# 如果用户自己在外面初始化了并行模型; | |||
self.outside_fleet = False | |||
# 检测 paddle 分布式的环境变量 | |||
if parallel_helper._is_parallel_ctx_initialized(): | |||
# 如果用户自己在外面初始化了 DDP,那么我们要求用户传入的模型一定是已经由 DistributedDataParallel 包裹后的模型; | |||
if not isinstance(model, DataParallel): | |||
raise RuntimeError( | |||
"It is not allowed to input a normal model instead of `paddle.DataParallel` when" | |||
"you initialize the paddle distribued process out of our control.") | |||
self.outside_fleet = True | |||
# 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中 | |||
# 我们就直接将 model_device 置为 None; | |||
self._model_device = None | |||
def _running_fn_(batch, step_fn, signature_fn): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
model = model._layers | |||
if hasattr(model, "train_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("_data_device", None) | |||
if self._data_device is not None: | |||
if isinstance(self._data_device, int): | |||
if self._data_device < 0: | |||
raise ValueError("Parameter `_data_device` can not be smaller than 0.") | |||
_could_use_device_num = paddle.device.cuda.device_count() | |||
if self._data_device >= _could_use_device_num: | |||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
self._data_device = f"gpu:{self._data_device}" | |||
elif not isinstance(self._data_device, str): | |||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||
if self.outside_fleet and paddle.device.get_device() != self._data_device: | |||
logger.warning("`Parameter data_device` is not equal to paddle.deivce.get_device(), " | |||
"please keep them equal to avoid some potential bugs.") | |||
if not self.outside_fleet and parallel_device is None: | |||
raise ValueError("Parameter `parallel_device` can not be None when using `PaddleFleetDriver`. This error is caused " | |||
"when your value of parameter `device` is `None` in your `Trainer` instance.") | |||
# 可能需要放在参数里 | |||
self.strategy = kwargs.get("strategy", fleet.DistributedStrategy()) | |||
self.is_collective = kwargs.get("is_collective", True) | |||
if not self.is_collective: | |||
raise NotImplementedError("FastNLP dose not support `parameters server` for distributed training now.") | |||
self.role_maker = kwargs.get("role_maker", None) | |||
self._master_port = None | |||
self.world_size = None | |||
self.global_rank = 0 | |||
self._configured = False # 防止重复调用 configure_ddp() 函数使用 | |||
self._has_setup = False # 防止重复调用 setup() 函数 | |||
self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) | |||
check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) | |||
# TODO 对这些参数的检查 | |||
if self.local_rank == 0 and not is_in_paddle_dist(): | |||
# 由于使用driver时模型一定会被初始化,因此在一开始程序一定会占用一部分显存来存放模型,然而这部分显存没有 | |||
# 发挥任何作用。 | |||
logger.warning(f"The program will use some extra space on {paddle.device.get_device()} to place your model since the model " | |||
"has already been initialized.") | |||
self.output_from_new_proc = kwargs.get("output_from_new_proc", "only_error") | |||
assert isinstance(self.output_from_new_proc, str), "Parameter `output_from_new_proc` can only be `str` type." | |||
if self.output_from_new_proc not in {"all", "ignore", "only_error"}: | |||
os.makedirs(name=self.output_from_new_proc, exist_ok=True) | |||
self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) | |||
def setup(self): | |||
""" | |||
在主进程拉起其它子进程,将主进程作为rank 0 | |||
""" | |||
if self._has_setup: | |||
return | |||
self._has_setup = True | |||
# 如果用户需要使用多机模式,那么一定进入到这里; | |||
if self.is_pull_by_paddle_run: | |||
if self.outside_fleet: | |||
# 已经初始化了多机环境 | |||
self.set_from_fleet_environment() | |||
else: | |||
# 用户没有初始化多机环境 | |||
# TODO 绕一下 | |||
# dist.get_world_size() 只能在初始化之后进行调用; | |||
self.world_size = int(os.environ.get("PADDLE_TRAINERS_NUM")) | |||
self.global_rank = int(os.environ.get("PADDLE_TRAINER_ID")) | |||
reset_seed() | |||
logger.warning(f"\nworld size, global rank: {self.world_size}, {self.global_rank}\n") | |||
fleet.init(self.role_maker, self.is_collective, self.strategy) | |||
else: | |||
# 在用户只使用了一个分布式 trainer 的情况下 | |||
# 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False | |||
# parallel_device 是 list, | |||
# if self.local_rank == 0 and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
if not parallel_helper._is_parallel_ctx_initialized(): | |||
# 没有初始化分布式环境,且是主进程 | |||
self.init_fleet_and_set() | |||
# 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; | |||
else: | |||
# 已经设置过一次,保证参数必须是一样的 | |||
pre_gpus = os.environ[FASTNLP_DISTRIBUTED_CHECK] | |||
pre_gpus = [int (x) for x in pre_gpus.split(",")] | |||
if sorted(pre_gpus) != sorted(self.parallel_device): | |||
raise RuntimeError("Notice you are using `PaddleFleetDriver` after one instantiated `PaddleFleetDriver`, it is not" | |||
"allowed that your second `PaddleFleetDriver` has a new setting of parameters `parallel_device`.") | |||
if not self.outside_fleet: | |||
# self.model.to(self.model_device) | |||
self.configure_fleet() | |||
# 初始化 self._pids,从而使得每一个进程都能接受到 rank0 的 send 操作; | |||
# TODO 不用.to会怎么样? | |||
self._pids = [] | |||
dist.all_gather(self._pids, paddle.to_tensor(os.getpid(), dtype="int32")) | |||
# TODO LOCAL_WORLD_SIZE | |||
local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE")) if "LOCAL_WORLD_SIZE" in os.environ else None | |||
if local_world_size is None: | |||
local_world_size = paddle.to_tensor(self.local_rank, dtype="int32") | |||
dist.all_reduce(local_world_size, op=dist.ReduceOp.MAX) | |||
local_world_size = local_world_size.item() + 1 | |||
node_rank = self.global_rank // local_world_size | |||
self._pids = self._pids[node_rank*local_world_size: (node_rank+1)*local_world_size] | |||
self._pids = self.tensor_to_numeric(self._pids) | |||
def init_fleet_and_set(self): | |||
""" | |||
使用 FleetLauncher 拉起子进程 | |||
""" | |||
if self.local_rank == 0: | |||
# 是 rank0 的话,则拉起其它子进程 | |||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | |||
launcher.launch() | |||
# 设置参数和初始化分布式环境 | |||
reset_seed() | |||
fleet.init(self.role_maker, self.is_collective, self.strategy) | |||
self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) | |||
self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) | |||
# 正常情况下不会Assert出问题,但还是保险一下 | |||
assert self.global_rank is not None | |||
assert self.world_size is not None | |||
assert self.world_size == len(self.parallel_device) | |||
def set_from_fleet_environment(self): | |||
""" | |||
当用户使用了 `python -m paddle.distributed.launch xxx.py` 启动时,我们需要 | |||
根据 paddle 设置的环境变量来获得各种属性 | |||
""" | |||
self.world_size = dist.get_world_size() | |||
self.global_rank = dist.get_rank() | |||
def barrier(self): | |||
dist.barrier() | |||
def configure_fleet(self): | |||
if not self._configured and not isinstance(self.model, DataParallel): | |||
self.model = DataParallel( | |||
_FleetWrappingModel(self.model), | |||
**self._fleet_kwargs | |||
) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
self._configured = True | |||
@property | |||
def world_size(self) -> int: | |||
return self._world_size | |||
@world_size.setter | |||
def world_size(self, size: int) -> None: | |||
self._world_size = size | |||
@property | |||
def global_rank(self) -> int: | |||
return self._global_rank | |||
@global_rank.setter | |||
def global_rank(self, rank: int) -> None: | |||
self._global_rank = rank | |||
@property | |||
def local_rank(self) -> int: | |||
return int(os.getenv("PADDLE_RANK_IN_NODE", "0")) | |||
@property | |||
def model_device(self): | |||
# 我认为这里的两个 device 应该返回真实值,对 CUDA_VISIBLDE_DEIVCES的转换应该在相应的 to 函数完成 | |||
# 否则会造成用户的困惑 | |||
return self._model_device | |||
@property | |||
def data_device(self): | |||
if self.outside_fleet: | |||
return self._data_device | |||
return self.model_device | |||
def train_step(self, batch): | |||
return self._train_step(batch) | |||
def validate_step(self, batch): | |||
return self._validate_step(batch) | |||
def test_step(self, batch): | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Optional[Union[str, ReproducibleIterator]] = "dist", reproducible: bool = False): | |||
# 暂时不支持iterableDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler = dist_sampler | |||
return dataloader | |||
# paddle 的 BatchSampler 和 DataLoader 没有 shuffle 成员,只能根据 sampler 判断 | |||
# 但是其子类 DistributedBatchSampler 却有 shuffle 成员 | |||
# 因此用 type() 进行严格的判断 | |||
if type(dataloader.batch_sampler) == BatchSampler: | |||
shuffle = isinstance(dataloader.batch_sampler.sampler, RandomSampler) | |||
else: | |||
shuffle = dataloader.batch_sampler.shuffle | |||
# trainer, evaluator | |||
if dist_sampler is None: | |||
if reproducible: | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " | |||
"control.") | |||
else: | |||
return dataloader | |||
# trainer | |||
elif dist_sampler == "dist": | |||
# 如果用户的 trainer.use_dist_sampler 为 True,那么此时其是否进行断点重训,不影响这里的行为; | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
return dataloader | |||
else: | |||
sampler = RandomSampler( | |||
dataset=dataloader.dataset, | |||
shuffle=shuffle, | |||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||
) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank, | |||
pad=True | |||
) | |||
dataloader.batch_sampler.sampler = sampler | |||
return dataloader | |||
# evaluator | |||
elif dist_sampler == "unrepeatdist": | |||
sampler = UnrepeatedDistributedSampler( | |||
dataset=dataloader.dataset, | |||
shuffle=shuffle, | |||
seed=int(os.environ.get("FASTNLP_SEED", 0)) | |||
) | |||
sampler.set_distributed( | |||
num_replicas=self.world_size, | |||
rank=self.global_rank | |||
) | |||
dataloader.batch_sampler.sampler = sampler | |||
return dataloader | |||
else: | |||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def is_global_zero(self): | |||
return self.global_rank == 0 | |||
def get_no_sync_context(self): | |||
return self.model.no_sync | |||
def unwrap_model(self): | |||
_layers = self.model._layers | |||
if isinstance(_layers, _FleetWrappingModel): | |||
return _layers.model | |||
else: | |||
return _layers | |||
def get_local_rank(self) ->int: | |||
return self.local_rank | |||
def is_distributed(self): | |||
return True | |||
def move_data_to_device(self, batch: 'paddle.Tensor'): | |||
device = self.data_device | |||
# 因为设置了CUDA_VISIBLE_DEVICES,在子进程中可能会引起错误 | |||
if FASTNLP_DISTRIBUTED_CHECK in os.environ: | |||
device = get_device_from_visible(device) | |||
return paddle_move_data_to_device(batch, device) | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
""" | |||
paddle存在设置分布式optimizers的函数,返回值为fleet.meta_optimizers.HybridParallelOptimizer | |||
重写是为了防止单卡下也传入了分布式的优化器 | |||
""" | |||
DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") |
@@ -0,0 +1,176 @@ | |||
import os | |||
import sys | |||
import __main__ | |||
import tempfile | |||
import copy | |||
from typing import List | |||
from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.envs.env import ( | |||
FASTNLP_DISTRIBUTED_CHECK, | |||
FASTNLP_LOG_LEVEL, | |||
FASTNLP_GLOBAL_SEED, | |||
USER_CUDA_VISIBLE_DEVICES, | |||
) | |||
from .utils import ( | |||
find_free_ports, | |||
reset_seed, | |||
) | |||
# 记录各个进程信息 | |||
class SubTrainer(object): | |||
""" | |||
和fastnlp的Triainer没有关系,仅用于统计节点内不同训练的一些信息 | |||
""" | |||
def __init__(self, endpoint=None, rank=None): | |||
self.devices = [] | |||
self.endpoint = endpoint | |||
self.rank = rank | |||
class FleetLauncher: | |||
""" | |||
复原了 paddle 的 launch_collective 函数,将其集成到一个类里 | |||
仅支持单机多卡的启动 | |||
""" | |||
def __init__( | |||
self, | |||
devices: List[int], | |||
output_from_new_proc: str = "only_error" | |||
): | |||
self.devices = devices | |||
self.output_from_new_proc = output_from_new_proc | |||
self.setup() | |||
def setup(self): | |||
self.set_endpoints() | |||
self.sub_trainers = self.get_process_info() | |||
def launch(self) -> int: | |||
# 设置环境变量 | |||
self.global_envs = self.get_global_env() | |||
self.open_subprocess() | |||
reset_seed() | |||
def open_subprocess(self): | |||
if __main__.__spec__ is None: | |||
# Script called as `python a/b/c.py` | |||
# when user is using hydra find the absolute path | |||
path_lib = os.path.abspath | |||
# pull out the commands used to run the script and resolve the abs file path | |||
command = sys.argv | |||
try: | |||
full_path = path_lib(command[0]) | |||
except Exception: | |||
full_path = os.path.abspath(command[0]) | |||
command[0] = full_path | |||
# use the same python interpreter and actually running | |||
command = [sys.executable] + command | |||
else: # Script called as `python -m a.b.c` | |||
command = [sys.executable, "-m", __main__.__spec__._name] + sys.argv[1:] | |||
current_env = copy.copy(self.global_envs) | |||
for idx, t in enumerate(self.sub_trainers): | |||
proc_env = { | |||
# global_rank | |||
"PADDLE_TRAINER_ID": f"{t.rank}", | |||
"PADDLE_CURRENT_ENDPOINT": f"{t.endpoint}", | |||
# rank | |||
"PADDLE_RANK_IN_NODE": f"{idx}", | |||
"PADDLE_LOCAL_DEVICE_IDS": | |||
",".join([str(g) for g in t.devices]), | |||
} | |||
if len(t.devices) > 0: | |||
proc_env["FLAGS_selected_gpus"] = "%s" % ",".join( | |||
[str(g) for g in t.devices]) | |||
proc_env["FLAGS_selected_devices"] = "%s" % ",".join( | |||
[str(g) for g in t.devices]) | |||
current_env.update(proc_env) | |||
if os.environ.get(FASTNLP_GLOBAL_SEED) is None and FASTNLP_GLOBAL_SEED in current_env: | |||
del current_env[FASTNLP_GLOBAL_SEED] | |||
if idx != 0: | |||
# 子进程 | |||
if os.environ.get(FASTNLP_LOG_LEVEL, None) is None: | |||
current_env[FASTNLP_LOG_LEVEL] = "warning" | |||
proc = distributed_open_proc(self.output_from_new_proc, command, current_env, t.rank) | |||
else: | |||
# 更新当前的环境变量 | |||
os.environ.update(current_env) | |||
def get_global_env(self): | |||
global_envs = copy.copy(os.environ.copy()) | |||
self.gloo_rendezvous_dir = tempfile.mkdtemp() | |||
# launch中涉及的gloo环境 | |||
global_envs["PADDLE_WITH_GLOO"] = str(os.getenv("PADDLE_WITH_GLOO", "0")) | |||
global_envs["PADDLE_GLOO_RENDEZVOUS"] = "3" | |||
global_envs["PADDLE_GLOO_FS_PATH"] = self.gloo_rendezvous_dir | |||
global_envs["PADDLE_DISTRI_BACKEND"] = "nccl" | |||
# 通过FNLP初始化的标志 | |||
global_envs[FASTNLP_DISTRIBUTED_CHECK] = f"{','.join([str(g) for g in self.devices])}" | |||
# 统计全局信息 | |||
device_ids = [] | |||
for t in self.sub_trainers: | |||
device_ids.append([str(acc) for acc in t.devices]) | |||
world_device_ids = [':'.join(ele) for ele in device_ids] | |||
# 全局环境变量 | |||
global_envs.update({ | |||
# world_size | |||
"PADDLE_TRAINERS_NUM": f"{len(self.sub_trainers)}", | |||
"PADDLE_TRAINER_ENDPOINTS": ",".join(self.endpoints), | |||
"PADDLE_WORLD_DEVICE_IDS": ",".join(world_device_ids), | |||
}) | |||
return global_envs | |||
def set_endpoints(self): | |||
""" | |||
Reference to `get_cluster_from_args` | |||
""" | |||
self.node_ip = "127.0.0.1" | |||
free_ports = None | |||
if os.environ.get("FLAGS_START_PORT") is None: | |||
free_ports = find_free_ports(len(self.devices)) | |||
if free_ports is not None: | |||
free_ports = list(free_ports) | |||
else: | |||
start_port = int(os.getenv("FLAGS_START_PORT", "6070")) | |||
free_ports = [ | |||
x for x in range(start_port, start_port + len(self.devices)) | |||
] | |||
self.endpoints = ["%s:%d" % (self.node_ip, port) for port in free_ports] | |||
def get_process_info(self): | |||
""" | |||
Reference to `get_cluster` | |||
""" | |||
sub_trainers = [] | |||
assert len(self.endpoints) >= len( | |||
self.devices | |||
), "current trainer_endpoints size should be greater equal than acclerators size." | |||
for i in range(len(self.devices)): | |||
sub_trainer = SubTrainer(f"{self.endpoints[i]}", i) | |||
if isinstance(self.devices[i], (list, tuple)): | |||
sub_trainer.devices.extend(self.devices[i]) | |||
else: | |||
sub_trainer.devices.append(self.devices[i]) | |||
sub_trainers.append(sub_trainer) | |||
return sub_trainers |
@@ -0,0 +1,87 @@ | |||
import os | |||
from typing import Optional, List, Sequence, Union | |||
from .paddle_driver import PaddleDriver | |||
from .single_device import PaddleSingleDriver | |||
from .fleet import PaddleFleetDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[int]]], | |||
model: paddle.nn.Layer, **kwargs) -> PaddleDriver: | |||
r""" | |||
用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; | |||
注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; | |||
:param driver: 该参数的值应为以下之一:["paddle", "fleet"]; | |||
:param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; | |||
:param model: 训练或者评测的具体的模型; | |||
:return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 | |||
先后 driver 的次序的正确问题); | |||
""" | |||
if "PADDLE_TRAINERS_NUM" in os.environ and "PADDLE_RANK_IN_NODE" in os.environ and FASTNLP_DISTRIBUTED_CHECK not in os.environ: | |||
if device is not None: | |||
logger.warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | |||
"up your script. And we will directly get the local device via " | |||
"`f'gpu:{os.environ['FLAGS_selected_gpus']}')`.") | |||
device = [int(g) for g in os.environ["FLAGS_selected_gpus"].split(",")] | |||
return PaddleFleetDriver(model, f"gpu:{os.environ['PADDLE_RANK_IN_NODE']}", True, **kwargs) | |||
if driver not in {"paddle", "fleet"}: | |||
raise ValueError("Parameter `driver` can only be one of these values: ['paddle', 'fleet'].") | |||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | |||
user_visible_devices = os.getenv("USER_CUDA_VISIBLE_DEVICES") | |||
# 优先级 user > cuda | |||
# 判断单机情况 device 的合法性 | |||
# 分布式情况下通过 world_device 判断 | |||
if user_visible_devices is not None: | |||
_could_use_device_num = len(user_visible_devices.split(",")) | |||
elif cuda_visible_devices is not None: | |||
_could_use_device_num = len(cuda_visible_devices.split(",")) | |||
else: | |||
_could_use_device_num = paddle.device.cuda.device_count() | |||
if isinstance(device, int): | |||
if device < 0 and device != -1: | |||
raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") | |||
if device >= _could_use_device_num: | |||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
device = f"gpu:{device}" | |||
elif isinstance(device, Sequence) and not isinstance(device, str): | |||
device = list(set(device)) | |||
for each in device: | |||
if not isinstance(each, int): | |||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be 'int' type.") | |||
elif each < 0: | |||
raise ValueError("When parameter `device` is 'Sequence' type, the value in it should be bigger than 0.") | |||
if len(device) == 1: | |||
# 传入了 [1] 这样的,视为单卡。 | |||
device = device[0] | |||
elif device is not None and not isinstance(device, str): | |||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||
if driver == "paddle": | |||
if not isinstance(device, List): | |||
return PaddleSingleDriver(model, device, **kwargs) | |||
else: | |||
logger.warning("Notice you are using `paddle` driver but your chosen `device` are multi gpus, we will use" | |||
"`Fleetriver` by default. But if you mean using `PaddleFleetDriver`, you should choose parameter" | |||
"`driver` as `PaddleFleetDriver`.") | |||
return PaddleFleetDriver(model, device, **kwargs) | |||
elif driver == "fleet": | |||
if not isinstance(device, List): | |||
if device == "cpu": | |||
raise ValueError("You are using `fleet` driver, but your chosen `device` is 'cpu'.") | |||
logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" | |||
"still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " | |||
"choose `paddle` driver.") | |||
return PaddleFleetDriver(model, device, **kwargs) | |||
else: | |||
return PaddleFleetDriver(model, device, **kwargs) |
@@ -0,0 +1,315 @@ | |||
import os | |||
import random | |||
from typing import Union, Optional, Callable, Dict | |||
from functools import partial | |||
import numpy as np | |||
from .utils import _build_fp16_env | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.core.utils import apply_to_collection, paddle_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.io import DataLoader, IterableDataset | |||
from paddle.optimizer import Optimizer | |||
_reduces = { | |||
'max': paddle.max, | |||
'min': paddle.min, | |||
'mean': paddle.mean, | |||
'sum': paddle.sum | |||
} | |||
class PaddleDriver(Driver): | |||
r""" | |||
Paddle框架的Driver,包括实现单卡训练的`PaddleSingleDriver`和分布式训练的`PaddleFleetDriver`。 | |||
""" | |||
def __init__(self, model, fp16: Optional[bool] = False, **kwargs): | |||
if not isinstance(model, paddle.nn.Layer): | |||
raise ValueError(f"Parameter `model` can not be `{type(model)}` in `PaddleDriver`, it should be exactly " | |||
f"`paddle.nn.Layer` type.") | |||
super(PaddleDriver, self).__init__(model) | |||
self.fp16 = fp16 | |||
# scaler的参数 | |||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||
self.grad_scaler = _grad_scaler() | |||
def zero_grad(self, set_to_none: bool = False): | |||
r""" | |||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | |||
注意梯度累积不需要在这里实现,trainer 已经在内部实现了梯度累积; | |||
:param set_to_none: 用来判断是否需要将梯度直接置为 None;Paddle中这个参数无效。 | |||
""" | |||
# if set_to_none: | |||
# log.warning("Parameter `set_to_none` does nothing in paddle since grad cannot be set directly.") | |||
for optimizer in self.optimizers: | |||
optimizer.clear_grad() | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
r""" | |||
该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。 | |||
要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。 | |||
:param dataloader: 需要检测的输入的 `dataloader`; | |||
:param dataloader_name: | |||
:param is_train: | |||
""" | |||
if is_train: | |||
if not isinstance(dataloader, DataLoader): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'paddle.io.DataLoader' type, not {type(dataloader)}.") | |||
# TODO 我们先禁止 dataloader 的 dataset 是 IterableDataset 种类; | |||
if isinstance(dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
else: | |||
for each_dataloader in dataloader.values(): | |||
if not isinstance(each_dataloader, DataLoader): | |||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be 'paddle.io.DataLoader' " | |||
f"type, not {type(each_dataloader)}.") | |||
if isinstance(each_dataloader.dataset, IterableDataset): | |||
raise TypeError("`IterableDataset` is not allowed.") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
r""" | |||
对于用户传入 trainer 的每一个 optimizer检测其合法性,必须为`paddle.optimizer.Optimizer`类型。 | |||
:param optimizers: 需要检测的 `optimizers`; | |||
""" | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, Optimizer): | |||
raise ValueError(f"Each optimizer of parameter `optimizers` should be 'paddle.optimizer.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
def check_evaluator_mode(self, mode: str): | |||
r""" | |||
因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; | |||
因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 | |||
我们应当提醒用户这一行为; | |||
""" | |||
model = self.unwrap_model() | |||
if mode == "validate": | |||
if not hasattr(model, "validate_step"): | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Your model does not have 'validate_step' method but has 'test_step' method, but you" | |||
"are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" | |||
"'validate_step'.") | |||
else: | |||
if not hasattr(model, "test_step"): | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" | |||
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" | |||
"'test_step'.") | |||
@staticmethod | |||
def tensor_to_numeric(tensor, reduce=None): | |||
r""" | |||
将一个 `tensor` 对象(类型为 `paddle.Tensor` )转换为 python 的 `numeric` 对象;如果 tensor 只包含一个 | |||
元素则返回 float 或 int 。 | |||
:param tensor: 需要被转换的 `tensor` 对象 | |||
:param reduce: 可选 ['sum', 'max', 'mea', 'min'],如果不为 None 将使用该 reduce 方法来处理当前 tensor 再返回 | |||
float 或 int 对象。 | |||
:return: 转换后返回的结果 | |||
""" | |||
if tensor is None: | |||
return None | |||
def _translate(_data): | |||
# 如果只含有一个元素,则返回元素本身,而非list | |||
if _data.numel().item() == 1: | |||
return _data.item() | |||
if reduce is None: | |||
return _data.tolist() | |||
else: | |||
return _reduces[reduce](_data).item() | |||
return apply_to_collection( | |||
data=tensor, | |||
dtype=paddle.Tensor, | |||
function=_translate | |||
) | |||
def set_model_mode(self, mode: str): | |||
r""" | |||
设置模型为 `train` / `eval` 的模式;目的是为切换模型训练和推理(会关闭dropout等)模式; | |||
:param mode: 应为二者之一:["train", "eval"]; | |||
""" | |||
assert mode in {"train", "eval"} | |||
getattr(self.model, mode)() | |||
@rank_zero_call | |||
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable]=None, **kwargs): | |||
r""" | |||
保存模型的函数;注意函数 `save` 是用来进行断点重训的函数; | |||
如果 `model_save_fn` 是一个可调用的函数,那么我们会直接运行该函数; | |||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||
:param only_state_dict: 是否只保存模型的 `state_dict`;注意该参数仅当 `model_save_fn` 为 None 时有效; | |||
:param model_save_fn: 用户传入的用来代替该函数本身保存逻辑的函数;如果该参数不为 None,那么我们会调用 model_save_fn(path); | |||
""" | |||
if model_save_fn is not None: | |||
model_save_fn(filepath) | |||
else: | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
paddle.save(model.state_dict(), filepath) | |||
else: | |||
input_spec = kwargs.get("input_spec", None) | |||
if input_spec is None: | |||
raise Exception("To save the whole Paddle Layer, parameter 'input_spec' is needed.") | |||
paddle.jit.save(model, filepath, input_spec) | |||
@staticmethod | |||
@rank_zero_call | |||
def load_model(filepath: str, load_dict: bool = True): | |||
r""" | |||
加载模型的函数;注意函数 `load` 是用来进行断点重训的函数; | |||
:param filepath: 需要被加载的对象的文件位置(需要包括文件名); | |||
:param load_dict: 是否加载state_dict,默认为True。当用户在save_model时将only_state_dict设置为False时, | |||
即保存了整个模型时,这个参数必须也为False | |||
:return: 返回加载指定文件后的结果; | |||
""" | |||
if load_dict: | |||
return paddle.load(filepath) | |||
else: | |||
return paddle.jit.load(filepath) | |||
@rank_zero_call | |||
def save(self, folder, states: Dict): | |||
r""" | |||
断点重训的保存函数,该函数会负责保存模型和 optimizers 的 state_dict; | |||
需要注意 driver 应当是无状态的,即不管什么时候调用 driver 的接口函数,其返回的结果应该都是一样的;因此,断点重训不需要保存 driver | |||
本身自己的任何状态;而每一个 driver 实例需要在该函数中实现保存模型和 optimizers 的 state_dict 的逻辑;同时妥善存储传入的 | |||
states 中的内容(主要用于恢复 Trainer ,Callback 等) | |||
需要保证该函数只在 global rank 0 上运行 | |||
:param folder: 保存断点重训的状态的文件名; | |||
:param states: 由 trainer 传入的一个字典,其中已经包含了为了实现断点重训所需要保存的其它对象的状态,Driver 应该只需要保存 | |||
该对象即可, Driver 应该不需要理解该对象,同时在 driver.load() 的时候,需要将 states 返回回去,load()返回的值与这里的 | |||
传入的值保持一致。 | |||
""" | |||
# 1. 保存模型的状态; | |||
model = self.unwrap_model() | |||
model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} | |||
# 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; | |||
states["model_state_dict"] = model_state_dict | |||
# 2. 保存 optimizers 的状态; | |||
optimizers_state_dict = {} | |||
for i in range(len(self.optimizers)): | |||
optimizer: Optimizer = self.optimizers[i] | |||
optimizer_state = optimizer.state_dict() | |||
optimizer_state = {name: param.cpu().detach().clone() for name, param in optimizer_state.items()} | |||
optimizers_state_dict[f"optimizer{i}"] = optimizer_state # 注意这里没有使用 deepcopy,测试是不需要的; | |||
states["optimizers_state_dict"] = optimizers_state_dict | |||
paddle.save(states, folder) | |||
def load(self, filepath) -> Dict: | |||
r""" | |||
断点重训的加载函数,注意该函数会负责读取数据,并且恢复模型和 optimizers 的 state_dict 等; | |||
driver 实例需要在该函数中先加载模型和 optimizers 的 state_dict,然后将一个 state 字典返回给 trainer 。 | |||
因此 save 函数和 load 函数的接受和返回值应该是对应的; | |||
该函数需要在所有 rank 上执行。 | |||
:param filepath: 保存断点重训的状态的文件名; | |||
:return: 需要返回 save 函数输入的 states 内容; | |||
""" | |||
states = paddle.load(filepath) | |||
# 1. 加载 optimizers 的状态; | |||
optimizers_state_dict = states["optimizers_state_dict"] | |||
for i in range(len(self.optimizers)): | |||
optimizer: paddle.optimizer.Optimizer = self.optimizers[i] | |||
optimizer.set_state_dict(optimizers_state_dict[f"optimizer{i}"]) | |||
# 2. 加载模型状态; | |||
model = self.unwrap_model() | |||
model.load_dict(states["model_state_dict"]) | |||
self.barrier() | |||
return states | |||
def get_evaluate_context(self): | |||
r""" | |||
返回一个不计算梯度的环境用来对模型进行评测; | |||
:return: context 上下文对象 `paddle.no_grad`; | |||
""" | |||
return paddle.no_grad | |||
@staticmethod | |||
def move_model_to_device(model: 'paddle.nn.Layer', device: Union[str, int, 'paddle.CUDAPlace', 'paddle.CPUPlace']): | |||
r""" | |||
用来将模型转移到指定的 device 上; | |||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||
""" | |||
if device is not None: | |||
model.to(device) | |||
def move_data_to_device(self, batch: 'paddle.Tensor'): | |||
r""" | |||
将数据迁移到指定的机器上;batch 可能是 list 也可能 dict ,或其嵌套结构。 | |||
在 Paddle 中使用可能会引起因与设置的设备不一致而产生的问题,请注意。 | |||
:return: 将移动到指定机器上的 batch 对象返回; | |||
""" | |||
return paddle_move_data_to_device(batch, self.data_device) | |||
@staticmethod | |||
def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover | |||
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed | |||
with ``seed_everything(seed, workers=True)``. | |||
See also the PyTorch documentation on | |||
`randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. | |||
""" | |||
# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 | |||
global_rank = rank if rank is not None else rank_zero_call.rank | |||
# TODO gpu | |||
process_seed = paddle.fluid.core.default_cpu_generator().initial_seed() | |||
# back out the base seed so we can use all the bits | |||
base_seed = process_seed - worker_id | |||
ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) | |||
# use 128 bits (4 x 32-bit words) | |||
np.random.seed(ss.generate_state(4)) | |||
# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module | |||
paddle_ss, stdlib_ss = ss.spawn(2) | |||
paddle.seed(paddle_ss.generate_state(1, dtype=np.uint64)[0]) | |||
# use 128 bits expressed as an integer | |||
stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() | |||
random.seed(stdlib_seed) | |||
def set_deterministic_dataloader(self, dataloader): | |||
r""" | |||
为了确定性训练要对 dataloader 进行修改,保证在确定随机数种子后,每次重新训练得到的结果是一样的; | |||
作用是替换 datalaoder 的 `worker_init_fn`。 | |||
""" | |||
if int(os.environ.get(FASTNLP_SEED_WORKERS, 0)) and dataloader.worker_init_fn is None: | |||
dataloader.worker_init_fn = partial(self.worker_init_function, rank=self.global_rank) | |||
def set_sampler_epoch(self, dataloader: 'DataLoader', cur_epoch_idx): | |||
r""" | |||
对于分布式的 sampler,dataloader 需要在每一个 epoch 前设置随机数种子,来保证每一个进程上的 shuffle 是一样的; | |||
:param cur_epoch_idx: 当前是第几个 epoch; | |||
""" | |||
if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | |||
dataloader.batch_sampler.set_epoch(cur_epoch_idx) |
@@ -0,0 +1,161 @@ | |||
from typing import Optional, Dict, Union | |||
from .paddle_driver import PaddleDriver | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.core.utils import auto_param_call, get_paddle_gpu_str | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.fluid.reader import _DatasetKind | |||
__all__ = [ | |||
"PaddleSingleDriver", | |||
] | |||
class PaddleSingleDriver(PaddleDriver): | |||
def __init__(self, model, device: Optional[str], fp16: Optional[bool] = False, **kwargs): | |||
super(PaddleSingleDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
if device is None: | |||
raise ValueError("Parameter `device` can not be None in `PaddleSingleDriver`.") | |||
if isinstance(device, int): | |||
self.model_device = get_paddle_gpu_str(device) | |||
else: | |||
self.model_device = device | |||
self.local_rank = 0 | |||
self.global_rank = 0 | |||
self.world_size = 1 | |||
if isinstance(model, paddle.DataParallel): | |||
# 注意这里的 unwrap_model 调用的是具体子类的方法; | |||
model = self.unwrap_model() | |||
if hasattr(model, "train_step"): | |||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | |||
"implements the `train_step` method, which we can not call actually, we will " | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = self.model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | |||
"implements the `validate_step` method, which we can not call actually, we " | |||
"will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = self.model | |||
self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " | |||
"implements the `test_step` method, which we can not call actually, we will " | |||
"call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = self.model | |||
self._test_signature_fn = model.forward | |||
else: | |||
if hasattr(self.model, "train_step"): | |||
self._train_step = self.model.train_step | |||
self._train_signature_fn = None | |||
else: | |||
self._train_step = self.model | |||
# 输入的模型是 `DataParallel`,我们需要保证其 signature_fn 是正确的; | |||
model = self.unwrap_model() | |||
self._train_signature_fn = model.forward | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
self._validate_signature_fn = self.model.test_step | |||
else: | |||
self._validate_step = self.model | |||
model = self.unwrap_model() | |||
self._validate_signature_fn = model.forward | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
self._test_signature_fn = self.model.validate_step | |||
else: | |||
self._test_step = self.model | |||
model = self.unwrap_model() | |||
self._test_signature_fn = model.forward | |||
def setup(self): | |||
paddle.device.set_device(self.model_device) | |||
self.model.to(self.model_device) | |||
def train_step(self, batch) -> Dict: | |||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
def backward(self, loss): | |||
self.grad_scaler.scale(loss).backward() | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
self.grad_scaler.step(optimizer) | |||
self.grad_scaler.update() | |||
def validate_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
def replace_sampler(self, dataloader, dist_sampler: Union[str, ReproducibleBatchSampler, ReproducibleIterator], reproducible: bool = False): | |||
# 暂时不支持IteratorDataset | |||
assert dataloader.dataset_kind != _DatasetKind.ITER, \ | |||
"FastNLP does not support `IteratorDataset` now." | |||
if isinstance(dist_sampler, ReproducibleBatchSampler): | |||
dataloader.batch_sampler = dist_sampler | |||
return dataloader | |||
if isinstance(dist_sampler, ReproducibleIterator): | |||
dataloader.batch_sampler.sampler = dist_sampler | |||
return dataloader | |||
if reproducible: | |||
if isinstance(dataloader.batch_sampler.sampler, ReproducibleIterator): | |||
return dataloader | |||
elif isinstance(dataloader.batch_sampler, ReproducibleBatchSampler): | |||
return dataloader | |||
else: | |||
# TODO | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler=dataloader.batch_sampler, | |||
batch_size=dataloader.batch_sampler.batch_size, | |||
drop_last=dataloader.drop_last | |||
) | |||
dataloader.batch_sampler = batch_sampler | |||
return dataloader | |||
else: | |||
return dataloader | |||
def unwrap_model(self): | |||
if isinstance(self.model, paddle.DataParallel): | |||
return self.model._layers | |||
else: | |||
return self.model | |||
@property | |||
def data_device(self): | |||
""" | |||
单卡模式不支持 data_device; | |||
""" | |||
return self.model_device | |||
def is_distributed(self): | |||
return False |
@@ -0,0 +1,351 @@ | |||
import socket | |||
import os | |||
import struct | |||
import random | |||
import inspect | |||
import numpy as np | |||
from contextlib import ExitStack, closing | |||
from enum import IntEnum | |||
from typing import Dict, Optional, Union | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.core.utils import get_paddle_device_id, auto_param_call | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_SEED, FASTNLP_SEED_WORKERS, USER_CUDA_VISIBLE_DEVICES | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle import nn | |||
from paddle.nn import Layer | |||
from paddle.io import DataLoader, BatchSampler | |||
from paddle.amp import auto_cast, GradScaler | |||
else: | |||
from fastNLP.core.utils.dummy_class import DummyClass as Layer | |||
__all__ = [ | |||
"paddle_seed_everything", | |||
] | |||
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int: | |||
return random.randint(min_seed_value, max_seed_value) | |||
def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: | |||
max_seed_value = np.iinfo(np.uint32).max | |||
min_seed_value = np.iinfo(np.uint32).min | |||
if seed is None: | |||
env_seed = os.environ.get("GLOBAL_SEED") | |||
if env_seed is None: | |||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||
# rank_zero_warn(f"No seed found, seed set to {seed}") | |||
else: | |||
try: | |||
seed = int(env_seed) | |||
except ValueError: | |||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||
# rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") | |||
elif not isinstance(seed, int): | |||
seed = int(seed) | |||
if not (min_seed_value <= seed <= max_seed_value): | |||
logger.warning("Your seed value is two big or two small for numpy, we will choose a random seed for " | |||
"you.") | |||
# rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") | |||
seed = _select_seed_randomly(min_seed_value, max_seed_value) | |||
# using `log.info` instead of `rank_zero_info`, | |||
# so users can verify the seed is properly set in distributed training. | |||
# log.info(f"Global seed set to {seed}") | |||
os.environ[FASTNLP_GLOBAL_SEED] = str(seed) | |||
random.seed(seed) | |||
np.random.seed(seed) | |||
# paddle的seed函数会自行判断是否在gpu环境,如果在的话会设置gpu的种子 | |||
paddle.seed(seed) | |||
os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" | |||
return seed | |||
def reset_seed() -> None: | |||
""" | |||
fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 | |||
进行随机数的设置; | |||
""" | |||
seed = os.environ.get(FASTNLP_GLOBAL_SEED, None) | |||
workers = os.environ.get(FASTNLP_SEED_WORKERS, "0") | |||
if seed is not None: | |||
paddle_seed_everything(int(seed), workers=bool(int(workers))) | |||
class ForwardState(IntEnum): | |||
TRAIN = 0 | |||
VALIDATE = 1 | |||
TEST = 2 | |||
PREDICT = 3 | |||
_MODE_PARAMETER = "_forward_state" | |||
class _FleetWrappingModel(Layer): | |||
""" | |||
参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 | |||
pytorch相似的处理方式 | |||
""" | |||
def __init__(self, model: 'nn.Layer'): | |||
super(_FleetWrappingModel, self).__init__() | |||
self.model = model | |||
if isinstance(model, paddle.DataParallel): | |||
model = model._layers | |||
if hasattr(model, "train_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = self.model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = self.model | |||
self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = self.model | |||
self._test_signature_fn = model.forward | |||
else: | |||
if hasattr(model, "train_step"): | |||
self._train_step = model.train_step | |||
self._train_signature_fn = None | |||
else: | |||
self._train_step = model | |||
self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
self._validate_step = model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(model, "test_step"): | |||
self._validate_step = model.test_step | |||
self._validate_signature_fn = None | |||
else: | |||
self._validate_step = model | |||
self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
self._test_step = model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(model, "validate_step"): | |||
self._test_step = model.validate_step | |||
self._test_signature_fn = None | |||
else: | |||
self._test_step = model | |||
self._test_signature_fn = model.forward | |||
def forward(self, batch, **kwargs) -> Dict: | |||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||
if _forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
elif _forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
elif _forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
elif _forward_state == ForwardState.PREDICT: | |||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||
else: | |||
raise NotImplementedError("You should direct a concrete mode.") | |||
class DummyGradScaler: | |||
""" | |||
用于仿造的GradScaler对象,防止重复写大量的if判断 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def get_scale(self): | |||
return 1.0 | |||
def is_enabled(self): | |||
return False | |||
def scale(self, outputs): | |||
return outputs | |||
def step(self, optimizer, *args, **kwargs): | |||
optimizer.step(*args, **kwargs) | |||
def update(self, new_scale=None): | |||
pass | |||
def unscale_(self, optimizer): | |||
pass | |||
def load_state_dict(self, state_dict): | |||
pass | |||
def state_dict(self): | |||
return {} | |||
def _build_fp16_env(dummy=False): | |||
if dummy: | |||
auto_cast = ExitStack | |||
GradScaler = DummyGradScaler | |||
else: | |||
if not paddle.device.is_compiled_with_cuda(): | |||
raise RuntimeError("No cuda") | |||
if paddle.device.cuda.get_device_capability(0)[0] < 7: | |||
logger.warning( | |||
"NOTE: your device does NOT support faster training with fp16, " | |||
"please switch to FP32 which is likely to be faster" | |||
) | |||
return auto_cast, GradScaler | |||
def find_free_ports(num): | |||
def __free_port(): | |||
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: | |||
s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, | |||
struct.pack('ii', 1, 0)) | |||
s.bind(('', 0)) | |||
return s.getsockname()[1] | |||
port_set = set() | |||
step = 0 | |||
while True: | |||
port = __free_port() | |||
if port not in port_set: | |||
port_set.add(port) | |||
if len(port_set) >= num: | |||
return port_set | |||
step += 1 | |||
if step > 400: | |||
logger.error( | |||
"can't find avilable port and use the specified static port now!" | |||
) | |||
return None | |||
return None | |||
def get_host_name_ip(): | |||
try: | |||
host_name = socket.gethostname() | |||
host_ip = socket.gethostbyname(host_name) | |||
return host_name, host_ip | |||
except: | |||
return None | |||
def get_device_from_visible(device: Union[str, int]): | |||
""" | |||
在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 | |||
如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 | |||
:param devices:未转化的设备名 | |||
:return: 转化后的设备id | |||
""" | |||
if device == "cpu": | |||
return device | |||
cuda_visible_devices = os.getenv("CUDA_VISIBLE_DEVICES") | |||
idx = get_paddle_device_id(device) | |||
if cuda_visible_devices is None or cuda_visible_devices == "": | |||
# 这个判断一般不会发生,因为 fastnlp 会为 paddle 强行注入 CUDA_VISIBLE_DEVICES | |||
return idx | |||
else: | |||
# 利用 USER_CUDA_VISIBLDE_DEVICES 获取用户期望的设备 | |||
user_visiblde_devices = os.getenv(USER_CUDA_VISIBLE_DEVICES) | |||
if user_visiblde_devices is None or user_visiblde_devices != "": | |||
# 不为空,说明用户设置了 CUDA_VISIBLDE_DEVICES | |||
idx = user_visiblde_devices.split(",")[idx] | |||
else: | |||
idx = str(idx) | |||
cuda_visible_devices_list = cuda_visible_devices.split(',') | |||
assert idx in cuda_visible_devices_list, "Can't find "\ | |||
"your devices %s in CUDA_VISIBLE_DEVICES[%s]."\ | |||
% (idx, cuda_visible_devices) | |||
res = cuda_visible_devices_list.index(idx) | |||
return res | |||
def replace_sampler(dataloader: "DataLoader", sampler: "BatchSampler"): | |||
# 拿到实例属性; | |||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | |||
# 拿到 dataloader '__init__' 函数的默认函数签名; | |||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | |||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | |||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | |||
# 中寻找; | |||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||
if has_variadic_kwargs: | |||
init_params.update(dict(inspect.signature(DataLoader.__init__).parameters)) | |||
del init_params["self"] | |||
# 因为我们刚才可能用 DataLoader 的默认参数将用户定制的 dataloader 的参数覆盖掉了,因此需要重新弄一遍; | |||
non_default_params = {name for name, p in init_params.items() if | |||
name in instance_attrs and p.default != instance_attrs[name]} | |||
# add `dataset` as it might have been replaced with `*args` | |||
non_default_params.add("dataset") | |||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | |||
reconstruct_args.update({"batch_sampler": sampler, "shuffle": False, "drop_last": False, "batch_size": 1}) | |||
required_args = { | |||
p.name | |||
for p in init_params.values() | |||
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) | |||
and p.default is p.empty | |||
and p.name not in reconstruct_args | |||
} | |||
# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上; | |||
if required_args: | |||
required_args = sorted(required_args) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||
f"The missing attributes are {required_args}. " | |||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||
"manually add the `DistributedBatchSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||
) | |||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | |||
if not has_variadic_kwargs: | |||
# the dataloader signature does not allow keyword arguments that need to be passed | |||
missing_kwargs = reconstruct_args.keys() - init_params.keys() | |||
if missing_kwargs: | |||
missing_kwargs = sorted(missing_kwargs) | |||
dataloader_self_name = dataloader.__class__.__name__ | |||
raise Exception( | |||
f"Trying to inject `DistributedBatchSampler` into the `{dataloader_self_name}` instance. " | |||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||
f"The missing arguments are {missing_kwargs}. " | |||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||
"manually add the `DistributedBatchSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedBatchSampler(dataset))`." | |||
) | |||
return type(dataloader)(**reconstruct_args) |
@@ -0,0 +1,5 @@ | |||
__all__ = [ | |||
"TorchPaddleDriver", | |||
] | |||
from .torch_paddle_driver import TorchPaddleDriver |
@@ -0,0 +1,218 @@ | |||
from typing import Optional, Dict, Union, Callable | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.io import DataLoader as PaddleDataLoader | |||
from paddle.optimizer import Optimizer as PaddleOptimizer | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch.utils.data import DataLoader as TorchDataLoader | |||
from torch.optim import Optimizer as TorchOptimizer | |||
from fastNLP.core.drivers.driver import Driver | |||
from fastNLP.envs.distributed import rank_zero_call | |||
from fastNLP.core.utils.utils import auto_param_call, apply_to_collection | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.modules.mix_modules.mix_module import MixModule | |||
__all__ = [ | |||
"TorchPaddleDriver", | |||
] | |||
class TorchPaddleDriver(Driver): | |||
""" | |||
针对torch和paddle混合模型的driver | |||
由于是两种不同的框架不方便实现多卡,暂时先实现CPU和GPU单卡的功能 | |||
""" | |||
def __init__(self, model, device: Optional[str] = None, **kwargs): | |||
super(TorchPaddleDriver, self).__init__(model) | |||
self.model_device = device | |||
self.torch_non_blocking = kwargs.get("torch_non_blocking", None) | |||
self.paddle_blocking = kwargs.get("paddle_blocking", None) | |||
self._data_device = kwargs.get("_data_device", None) | |||
if isinstance(self._data_device, int): | |||
# 将data_device设置为cuda:x的字符串形式 | |||
if self._data_device < 0: | |||
raise ValueError("Parameter `_data_device` can not be smaller than 0.") | |||
_could_use_device_num = paddle.device.cuda.device_count() | |||
if self._data_device >= _could_use_device_num: | |||
raise ValueError("The gpu device that parameter `device` specifies is not existed.") | |||
self._data_device = f"cuda:{self._data_device}" | |||
elif self._data_device is not None: | |||
raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") | |||
if hasattr(self.model, "train_step"): | |||
self._train_step = self.model.train_step | |||
self._train_signature_fn = None | |||
else: | |||
self._train_step = self.model | |||
self._train_signature_fn = self.model.forward | |||
if hasattr(self.model, "validate_step"): | |||
self._validate_step = self.model.validate_step | |||
self._validate_signature_fn = None | |||
elif hasattr(self.model, "test_step"): | |||
self._validate_step = self.model.test_step | |||
self._validate_signature_fn = self.model.forward | |||
else: | |||
self._validate_step = self.model | |||
self._validate_signature_fn = self.model.forward | |||
if hasattr(self.model, "test_step"): | |||
self._test_step = self.model.test_step | |||
self._test_signature_fn = None | |||
elif hasattr(self.model, "validate_step"): | |||
self._test_step = self.model.validate_step | |||
self._test_signature_fn = self.model.forward | |||
else: | |||
self._test_step = self.model | |||
self._test_signature_fn = self.model.forward | |||
def setup(self): | |||
if self.model_device is not None: | |||
paddle.device.set_device(self.model_device.replace("cuda", "gpu")) | |||
self.model.to(self.model_device) | |||
@staticmethod | |||
def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | |||
if is_train: | |||
if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.") | |||
else: | |||
if not isinstance(dataloader, Dict): | |||
raise ValueError(f"Parameter `{dataloader_name}` should be 'Dict' type, not {type(dataloader)}.") | |||
else: | |||
for each_dataloader in dataloader.values(): | |||
if not isinstance(each_dataloader, (TorchDataLoader, PaddleDataLoader)): | |||
raise ValueError(f"Each dataloader of parameter `{dataloader_name}` should be " | |||
f"'torch.util.data.DataLoader' or `paddle.io.dataloader` " | |||
f"type, not {type(each_dataloader)}.") | |||
@staticmethod | |||
def _check_optimizer_legality(optimizers): | |||
for each_optimizer in optimizers: | |||
if not isinstance(each_optimizer, (TorchOptimizer, PaddleOptimizer)): | |||
raise ValueError(f"Each optimizers of parameter `optimizers` should be " | |||
f"'torch.optim.Optimizer' or 'paddle.optimizers.Optimizer' type, " | |||
f"not {type(each_optimizer)}.") | |||
def train_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._train_step, batch) | |||
else: | |||
return self._train_step(batch) | |||
def step(self): | |||
for optimizer in self.optimizers: | |||
optimizer.step() | |||
def backward(self, loss): | |||
loss.backward() | |||
def zero_grad(self): | |||
for optimizer in self.optimizers: | |||
if isinstance(optimizer, TorchOptimizer): | |||
optimizer.zero_grad() | |||
elif isinstance(optimizer, PaddleOptimizer): | |||
optimizer.clear_grad() | |||
else: | |||
raise ValueError("Unknown optimizers type.") | |||
def validate_step(self, batch): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._validate_step, batch) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._test_step, batch) | |||
else: | |||
return self._test_step(batch) | |||
def predict_step(self, batch): | |||
if isinstance(batch, Dict): | |||
return auto_param_call(self._predict_step, batch) | |||
else: | |||
return self._predict_step(batch) | |||
@rank_zero_call | |||
def save_model(self, filepath: str, only_state_dict: bool = True, model_save_fn: Optional[Callable] = None): | |||
r""" | |||
暂时不提供保存整个模型的方法 | |||
""" | |||
if only_state_dict == False: | |||
logger.warn("TorchPaddleModule only support saving state dicts now.") | |||
if model_save_fn is not None: | |||
model_save_fn(filepath) | |||
else: | |||
model = self.unwrap_model() | |||
self.move_model_to_device(model, "cpu") | |||
self.model.save(filepath) | |||
self.move_model_to_device(model, self.model_device) | |||
def load_model(self, filepath: str): | |||
""" | |||
加载模型的加载函数; | |||
:param filepath: 保存文件的文件位置(需要包括文件名); | |||
:return: | |||
""" | |||
return self.model.load(filepath) | |||
def save(self): | |||
... | |||
def load(self): | |||
... | |||
@staticmethod | |||
def move_model_to_device(model: MixModule, device: str): | |||
if device is not None: | |||
model.to(device) | |||
def unwrap_model(self): | |||
return self.model | |||
@staticmethod | |||
def tensor_to_numeric(tensor): | |||
if tensor is None: | |||
return None | |||
def _translate(_data): | |||
return _data.tolist() | |||
return apply_to_collection( | |||
data=tensor, | |||
dtype=(paddle.Tensor, torch.Tensor), | |||
function=_translate | |||
) | |||
def set_model_mode(self, mode: str): | |||
assert mode in {"train", "eval"} | |||
getattr(self.model, mode)() | |||
def get_model_device(self): | |||
return self.model_device | |||
@property | |||
def data_device(self): | |||
if self.model_device is not None: | |||
return self.model_device | |||
else: | |||
return self._data_device | |||
def set_model_mode(self, mode: str): | |||
assert mode in {"train", "eval"} | |||
getattr(self.model, mode)() | |||
def set_sampler_epoch(self, dataloader: Union['TorchDataLoader', 'PaddleDataLoader'], cur_epoch_idx): | |||
# 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | |||
return dataloader |
@@ -0,0 +1,4 @@ | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
pass |
@@ -0,0 +1,18 @@ | |||
__all__ = [ | |||
"Metric", | |||
"Accuracy", | |||
'Backend', | |||
'AutoBackend', | |||
'PaddleBackend', | |||
'TorchBackend', | |||
'SpanFPreRecMetric', | |||
'ClassifyFPreRecMetric', | |||
'func_post_proc' | |||
] | |||
from .metric import Metric | |||
from .accuracy import Accuracy | |||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | |||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | |||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | |||
from .utils import func_post_proc |
@@ -0,0 +1,75 @@ | |||
__all__ = [ | |||
'Accuracy' | |||
] | |||
from typing import Union | |||
import warnings | |||
import numpy as np | |||
from fastNLP.core.metrics.metric import Metric | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.utils.utils import seq_len_to_mask | |||
class Accuracy(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', | |||
aggregate_when_get_metric: bool = True): | |||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) | |||
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) | |||
def get_metric(self) -> dict: | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:return dict evaluate_result: {"acc": float} | |||
""" | |||
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)} | |||
return evaluate_result | |||
def update(self, pred, target, seq_len=None): | |||
r""" | |||
evaluate函数将针对一个批次的预测结果做评价指标的累计 | |||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | |||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | |||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||
如果mask也被传进来的话seq_len会被忽略. | |||
""" | |||
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if seq_len is not None: | |||
seq_len = self.tensor2numpy(seq_len) | |||
if seq_len is not None and target.ndim > 1: | |||
max_len = target.shape[1] | |||
masks = seq_len_to_mask(seq_len, max_len) | |||
else: | |||
masks = None | |||
if pred.ndim == target.ndim: | |||
if np.prod(pred.shape) != np.prod(target.shape): | |||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." | |||
f" while target have shape:{target.shape}, " | |||
f"pred have shape: {target.shape}") | |||
elif pred.ndim == target.ndim + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and target.ndim > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or " | |||
f"{pred.shape[:-1]}, got {target.shape}.") | |||
if masks is not None: | |||
self.total += masks.sum().item() | |||
self.correct += ((pred == target) * masks).sum().item() | |||
else: | |||
self.total += np.prod(list(pred.shape)).item() | |||
self.correct += (target == pred).sum().item() |
@@ -0,0 +1,12 @@ | |||
__all__ = [ | |||
'Backend', | |||
'AutoBackend', | |||
'TorchBackend', | |||
'PaddleBackend' | |||
] | |||
from .backend import Backend | |||
from .auto_backend import AutoBackend | |||
from .torch_backend.backend import TorchBackend | |||
from .paddle_backend.backend import PaddleBackend |
@@ -0,0 +1,75 @@ | |||
from typing import Union | |||
from .backend import Backend | |||
from .torch_backend.backend import TorchBackend | |||
from .paddle_backend.backend import PaddleBackend | |||
from .jittor_backend.backend import JittorBackend | |||
class AutoBackend(Backend): | |||
""" | |||
不需要初始化backend的AutoBackend,能够根据get_metric时候判断输入数据类型来选择backend是什么类型的 | |||
""" | |||
def __init__(self, backend: Union[str, Backend, None]): | |||
super(AutoBackend, self).__init__() | |||
if backend != 'auto': | |||
self._convert_backend(backend) | |||
def _convert_backend(self, backend): | |||
""" | |||
将AutoBackend转换为合适的Backend对象 | |||
""" | |||
if isinstance(backend, Backend): | |||
self.__class__ = backend.__class__ | |||
# 如果是str,直接选择就好了 | |||
elif backend == 'torch': | |||
self.__class__ = TorchBackend | |||
elif backend == 'paddle': | |||
self.__class__ = PaddleBackend | |||
elif backend == 'jittor': | |||
self.__class__ = JittorBackend | |||
elif backend is None: | |||
# 不用做任何事情就可以初始化了 | |||
pass | |||
else: | |||
raise RuntimeError(f"We did not support `{backend}` to be used as backend for now.") | |||
self._specified = True | |||
def choose_real_backend(self, args): | |||
assert not self.is_specified(), "This method should not be called after backend has been specified. " \ | |||
"This must be a bug, please report." | |||
types = [] | |||
for arg in args: | |||
types.append(str(type(arg))) | |||
torch_types = [] | |||
jittor_types = [] | |||
paddle_types = [] | |||
for type_name in types: | |||
if 'torch' in type_name: | |||
torch_types.append(type_name) | |||
if 'paddle' in type_name: | |||
paddle_types.append(type_name) | |||
if 'jittor' in type_name: | |||
jittor_types.append(type_name) | |||
# 根据 https://stackoverflow.com/a/3464154 ,可以通过这种方法实现切换成真实的 backend 上 | |||
if len(torch_types) > 0 and len(jittor_types) == 0 and len(paddle_types) == 0: | |||
backend = 'torch' | |||
elif len(torch_types) == 0 and len(jittor_types) > 0 and len(paddle_types) == 0: | |||
backend = 'jittor' | |||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) > 0: | |||
backend = 'paddle' | |||
elif len(torch_types) == 0 and len(jittor_types) == 0 and len(paddle_types) == 0: | |||
# 直接使用default的backend就好了 | |||
backend = None | |||
else: | |||
types = list(set(torch_types + jittor_types + paddle_types)) | |||
raise RuntimeError( | |||
f"Mixture of tensor type:{types} have been accept, please manually set backend instead of " | |||
f"using backend=auto.") | |||
self._convert_backend(backend) |
@@ -0,0 +1,75 @@ | |||
from ..utils import AggregateMethodError | |||
class Backend: | |||
""" | |||
Backend 及其子类的所有方法都必须是无状态的。 | |||
""" | |||
def __init__(self): | |||
self._specified = False | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果 | |||
""" | |||
if method is not None: | |||
return AggregateMethodError(should_have_aggregate_method=False, only_warn=True) | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
return value | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
return value | |||
def get_scalar(self, tensor) -> float: | |||
""" | |||
tensor的saclar值 | |||
:param tensor: | |||
:return: | |||
""" | |||
return tensor | |||
def is_specified(self) -> bool: | |||
""" | |||
判断是否是某种框架的backend | |||
:return: | |||
""" | |||
return self._specified | |||
def tensor2numpy(self, tensor): | |||
""" | |||
将tensor转为numpy | |||
:param tensor: | |||
:return: | |||
""" | |||
return tensor | |||
def move_tensor_to_device(self, tensor, device): | |||
""" | |||
""" | |||
return tensor | |||
def all_gather_object(self, obj, group=None): | |||
""" | |||
给定 obj 将各个 rank 上的 obj 汇总到每个 obj 上。返回一个 list 对象,里面依次为各个 rank 对应的 obj 。 | |||
:param obj: | |||
:param group: | |||
:return: | |||
""" | |||
raise NotImplementedError(f"all_gather_object() function is not implemented for {self.__class__.__name__}.") | |||
@@ -0,0 +1 @@ | |||
@@ -0,0 +1,72 @@ | |||
import numpy as np | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
from fastNLP.core.metrics.backend import Backend | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
class JittorBackend(Backend): | |||
def __init__(self): | |||
super(JittorBackend, self).__init__() | |||
self._specified = True | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果 | |||
""" | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
value = jittor.Var(value) | |||
return value | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
value = jittor.full_like(tensor, value) | |||
return value | |||
def get_scalar(self, tensor) -> float: | |||
""" | |||
tensor的saclar值 | |||
:param tensor: | |||
:return: | |||
""" | |||
return tensor.item() | |||
def is_specified(self) -> bool: | |||
""" | |||
判断是否是某种框架的backend | |||
:return: | |||
""" | |||
return self._specified | |||
def tensor2numpy(self, tensor): | |||
""" | |||
将tensor转为numpy | |||
:param tensor: | |||
:return: | |||
""" | |||
if isinstance(tensor, jittor.Var): | |||
return tensor.detach().numpy() | |||
elif isinstance(tensor, np.array): | |||
return tensor | |||
else: | |||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||
def move_tensor_to_device(self, tensor, device): | |||
""" | |||
jittor的没有转移设备的函数,因此该函数实际上无效 | |||
""" | |||
return tensor |
@@ -0,0 +1,5 @@ | |||
__all__ = [ | |||
'PaddleBackend' | |||
] | |||
from .backend import Backend as PaddleBackend |
@@ -0,0 +1,126 @@ | |||
from typing import List, Optional, Any | |||
import numpy as np | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.utils.paddle_utils import paddle_to | |||
from fastNLP.core.metrics.utils import AggregateMethodError | |||
from fastNLP.core.utils import is_in_paddle_dist | |||
from fastNLP.core.drivers.paddle_driver.utils import get_device_from_visible | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.fluid.dygraph import parallel_helper | |||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | |||
gathered_result = [paddle.zeros_like(result) for _ in range(world_size)] | |||
paddle.distributed.all_gather(gathered_result, result, group) | |||
return gathered_result | |||
class PaddleBackend(Backend): | |||
def __init__(self): | |||
super().__init__() | |||
self._specified = True | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果 | |||
""" | |||
if isinstance(tensor, paddle.Tensor): | |||
if parallel_helper._is_parallel_ctx_initialized(): | |||
if method is None: | |||
raise AggregateMethodError(should_have_aggregate_method=True) | |||
tensor = self._gather_all(tensor) | |||
if isinstance(tensor[0], paddle.Tensor): | |||
tensor = paddle.stack(tensor) | |||
# 第一步, aggregate结果 | |||
if method == 'sum': | |||
tensor = paddle.sum(tensor, dim=0) | |||
elif method == 'mean': | |||
tensor = paddle.mean(tensor, dim=0) | |||
elif method == 'max': | |||
tensor, _ = paddle.max(tensor, dim=0) | |||
elif method == 'min': | |||
tensor, _ = paddle.min(tensor, dim=0) | |||
else: | |||
raise AggregateMethodError(should_have_aggregate_method=False) | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
tensor = paddle.ones((1,)).fill_(value) | |||
return tensor | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
tensor.fill_(value) | |||
return tensor | |||
def get_scalar(self, tensor) -> float: | |||
return tensor.item() | |||
def tensor2numpy(self, tensor) -> np.array: | |||
if isinstance(tensor, paddle.Tensor): | |||
return tensor.cpu().detach().numpy() | |||
elif isinstance(tensor, np.array): | |||
return tensor | |||
else: | |||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||
@staticmethod | |||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||
""" | |||
聚合 group 中所有的 result;由于不同 group 中 result 大小不同,因此在适当的时候需要进行 padding | |||
""" | |||
# TODO check 正确性 | |||
if group is None: | |||
group = paddle.distributed.get_group(0) | |||
world_size = group.nranks | |||
paddle.distributed.barrier(group=group) | |||
# 张量为 标量的情况,简单地gather就好 | |||
if result.ndim == 0: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 获得 result 的 shape | |||
local_size = paddle.to_tensor(result.shape) | |||
# 将 group 中所有 result 的大小聚合在一起 | |||
local_sizes = [paddle.zeros_like(local_size) for _ in range(world_size)] | |||
paddle.distributed.all_gather(local_sizes, local_size, group=group) | |||
# 堆叠后,计算出 shape 每一维度的最大值 | |||
max_size = paddle.stack(local_sizes).max(axis=0).values | |||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||
# 如果所有的结果大小相同,那么可以直接聚合 | |||
if all_sizes_equal: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 否则,padding 与最大的张量对齐 | |||
pad_dims = [] | |||
pad_by = (max_size - local_size).detach().cpu() | |||
for val in reversed(pad_by): | |||
pad_dims.append(0) | |||
pad_dims.append(val.item()) | |||
result_padded = paddle.nn.functional.pad(result, pad_dims) | |||
# 重新进行聚合 | |||
gathered_result = [paddle.zeros_like(result_padded) for _ in range(world_size)] | |||
paddle.distributed.all_gather(gathered_result, result_padded, group) | |||
for idx, item_size in enumerate(local_sizes): | |||
slice_param = [slice(dim_size) for dim_size in item_size] | |||
gathered_result[idx] = gathered_result[idx][slice_param] | |||
return gathered_result | |||
def move_tensor_to_device(self, tensor, device): | |||
# TODO 如果在这里处理的话,会不会在别的地方引起bug? | |||
if is_in_paddle_dist(): | |||
device = get_device_from_visible(device) | |||
return paddle_to(tensor, device) | |||
@@ -0,0 +1,6 @@ | |||
__all__ = [ | |||
'TorchBackend' | |||
] | |||
from .backend import Backend as TorchBackend |
@@ -0,0 +1,154 @@ | |||
from typing import Any, List, Optional | |||
import numpy as np | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.metrics.utils import AggregateMethodError | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.distributed as dist | |||
import torch.nn.functional as F | |||
def _simple_gather_all_tensors(result, group: Any, world_size: int) -> List: | |||
gathered_result = [torch.zeros_like(result) for _ in range(world_size)] | |||
dist.all_gather(gathered_result, result, group) | |||
return gathered_result | |||
class TorchBackend(Backend): | |||
def __init__(self): | |||
super().__init__() | |||
self._specified = True | |||
def aggregate(self, tensor, method: str): | |||
""" | |||
聚集结果,并根据method计算后,返回结果。 | |||
""" | |||
if isinstance(tensor, torch.Tensor): | |||
if dist.is_initialized(): | |||
if method is None: | |||
raise AggregateMethodError(should_have_aggregate_method=True) | |||
tensor = self._gather_all(tensor) | |||
if isinstance(tensor[0], torch.Tensor): | |||
tensor = torch.stack(tensor) | |||
# 第一步, aggregate结果 | |||
if method == 'sum': | |||
tensor = torch.sum(tensor, dim=0) | |||
elif method == 'mean': | |||
tensor = torch.mean(tensor, dim=0) | |||
elif method == 'max': | |||
tensor, _ = torch.max(tensor, dim=0) | |||
elif method == 'min': | |||
tensor, _ = torch.min(tensor, dim=0) | |||
else: | |||
raise AggregateMethodError(should_have_aggregate_method=False) | |||
return tensor | |||
def create_tensor(self, value: float): | |||
""" | |||
创建tensor,并且填入value作为值 | |||
""" | |||
tensor = torch.ones(1).fill_(value) | |||
return tensor | |||
def fill_value(self, tensor, value: float): | |||
""" | |||
将tensor的值设置为value | |||
""" | |||
tensor.fill_(value) | |||
return tensor | |||
def get_scalar(self, tensor) -> float: | |||
return tensor.item() | |||
@staticmethod | |||
def _gather_all(result, group: Optional[Any] = None) -> List: | |||
"""Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. | |||
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case | |||
tensors are padded, gathered and then trimmed to secure equal workload for all processes. | |||
Args: | |||
result: the value to sync | |||
group: the process group to gather results from. Defaults to all processes (world) | |||
Return: | |||
gathered_result: list with size equal to the process group where | |||
gathered_result[i] corresponds to result tensor from process i | |||
""" | |||
if group is None: | |||
group = dist.group.WORLD | |||
# convert tensors to contiguous format | |||
result = result.contiguous() | |||
world_size = dist.get_world_size(group) | |||
dist.barrier(group=group) | |||
# if the tensor is scalar, things are easy | |||
if result.ndim == 0: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 1. Gather sizes of all tensors | |||
local_size = torch.tensor(result.shape, device=result.device) | |||
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] | |||
dist.all_gather(local_sizes, local_size, group=group) | |||
max_size = torch.stack(local_sizes).max(dim=0).values | |||
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) | |||
# 2. If shapes are all the same, then do a simple gather: | |||
if all_sizes_equal: | |||
return _simple_gather_all_tensors(result, group, world_size) | |||
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate | |||
pad_dims = [] | |||
pad_by = (max_size - local_size).detach().cpu() | |||
for val in reversed(pad_by): | |||
pad_dims.append(0) | |||
pad_dims.append(val.item()) | |||
result_padded = torch.nn.functional.pad(result, pad_dims) | |||
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] | |||
dist.all_gather(gathered_result, result_padded, group) | |||
for idx, item_size in enumerate(local_sizes): | |||
slice_param = [slice(dim_size) for dim_size in item_size] | |||
gathered_result[idx] = gathered_result[idx][slice_param] | |||
return gathered_result | |||
def tensor2numpy(self, tensor) -> np.array: | |||
""" | |||
将对应的tensor转为numpy对象 | |||
""" | |||
if isinstance(tensor, torch.Tensor): | |||
return tensor.cpu().detach().numpy() | |||
elif isinstance(tensor, np.ndarray): | |||
return tensor | |||
elif isinstance(tensor, (float, int)): | |||
return tensor | |||
else: | |||
raise ValueError(f"tensor: {tensor} can not convert to ndarray!") | |||
@staticmethod | |||
def is_distributed() -> bool: | |||
""" | |||
:return: | |||
""" | |||
return dist.is_available() and dist.is_initialized() | |||
def move_tensor_to_device(self, tensor, device): | |||
return tensor.to(device) | |||
def all_gather_object(self, obj, group=None) -> List: | |||
if self.is_distributed(): | |||
obj_list = fastnlp_torch_all_gather(obj, group=group) | |||
return obj_list | |||
return [obj] | |||
@@ -0,0 +1,142 @@ | |||
__all__ = [ | |||
'ClassifyFPreRecMetric' | |||
] | |||
from typing import Union, List | |||
from collections import defaultdict | |||
from functools import partial | |||
import warnings | |||
from .metric import Metric | |||
from .backend import Backend | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.utils.utils import seq_len_to_mask | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||
class ClassifyFPreRecMetric(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = False, | |||
tag_vocab: Vocabulary = None, encoding_type: str = None, ignore_labels: List[str] = None, | |||
only_gross: bool = True, f_type='micro', beta=1) -> None: | |||
super(ClassifyFPreRecMetric, self).__init__(backend=backend, | |||
aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
self.beta_square = self.beta ** 2 | |||
self.only_gross = only_gross | |||
self.tag_vocab = tag_vocab | |||
self._tp, self._fp, self._fn = defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')),\ | |||
defaultdict(partial(self.register_element, aggregate_method='sum')) | |||
def get_metric(self) -> dict: | |||
r""" | |||
get_metric函数将根据evaluate函数累计的评价指标统计量来计算最终的评价结果. | |||
:return dict evaluate_result: {"acc": float} | |||
""" | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._fn.keys()) | |||
tags.update(set(self._fp.keys())) | |||
tags.update(set(self._tp.keys())) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
for tag in tags: | |||
if self.tag_vocab is not None: | |||
tag_name = self.tag_vocab.to_word(tag) | |||
else: | |||
tag_name = int(tag) | |||
tp = self._tp[tag] | |||
fn = self._fn[tag] | |||
fp = self._fp[tag] | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum += rec | |||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag_name) | |||
pre_key = 'pre-{}'.format(tag_name) | |||
rec_key = 'rec-{}'.format(tag_name) | |||
evaluate_result[f_key] = f | |||
evaluate_result[pre_key] = pre | |||
evaluate_result[rec_key] = rec | |||
if self.f_type == 'macro': | |||
evaluate_result['f'] = f_sum / len(tags) | |||
evaluate_result['pre'] = pre_sum / len(tags) | |||
evaluate_result['rec'] = rec_sum / len(tags) | |||
if self.f_type == 'micro': | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(self._tp.values()), | |||
sum(self._fn.values()), | |||
sum(self._fp.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
for key, value in evaluate_result.items(): | |||
evaluate_result[key] = round(value, 6) | |||
return evaluate_result | |||
def update(self, pred, target, seq_len=None): | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if seq_len is not None: | |||
seq_len = self.tensor2numpy(seq_len) | |||
if seq_len is not None and target.ndim > 1: | |||
max_len = target.ndim[-1] | |||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | |||
else: | |||
masks = None | |||
if pred.ndim == target.ndim: | |||
if len(pred.flatten()) != len(target.flatten()): | |||
raise RuntimeError(f"when pred have same dimensions with target, they should have same element numbers." | |||
f" while target have element numbers:{len(pred.flatten())}, " | |||
f"pred have element numbers: {len(target.flatten())}") | |||
pass | |||
elif len(pred.ndim) == len(target.ndim) + 1: | |||
pred = pred.argmax(axis=-1) | |||
if seq_len is None and len(target.ndim) > 1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"when pred have " | |||
f"size:{pred.ndim}, target should have size: {pred.ndim} or " | |||
f"{pred.ndim[:-1]}, got {target.ndim}.") | |||
if masks is not None: | |||
target = target * masks | |||
pred = pred * masks | |||
target_idxes = set(target.reshape(-1).tolist()) | |||
for target_idx in target_idxes: | |||
self._tp[target_idx] += ((pred == target_idx) * (target != target_idx)).sum().item() | |||
self._fp[target_idx] += ((pred == target_idx) * (target == target_idx)).sum().item() | |||
self._fn[target_idx] += ((pred != target_idx) * (target != target_idx)).sum().item() | |||
@@ -0,0 +1,281 @@ | |||
__all__ = [ | |||
'Element' | |||
] | |||
import os | |||
from .backend import Backend, AutoBackend | |||
from fastNLP.core.log import logger | |||
from .utils import AggregateMethodError | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
class Element: | |||
def __init__(self, value: float, aggregate_method, backend: Backend, name=None): | |||
self.init_value = value | |||
self.aggregate_method = aggregate_method | |||
self.name = name | |||
if backend == 'auto': | |||
raise RuntimeError("You have to specify the backend.") | |||
elif isinstance(backend, AutoBackend): | |||
self.backend = backend | |||
else: | |||
self.backend = AutoBackend(backend) | |||
if self.backend.is_specified(): | |||
value = self.backend.create_tensor(self.init_value) | |||
else: | |||
value = None | |||
self._value = value | |||
self.device = None | |||
def aggregate(self): | |||
""" | |||
自动aggregate对应的元素 | |||
""" | |||
try: | |||
self._value = self.backend.aggregate(self._value, self.aggregate_method) | |||
except AggregateMethodError as e: | |||
msg = 'If you see this message, please report a bug.' | |||
if self.name and e.should_have_aggregate_method: | |||
msg = f"Element:{self.name} has no specified `aggregate_method`." | |||
elif e.should_have_aggregate_method: | |||
msg = "Element has no specified `aggregate_method`." | |||
elif self.name and not e.should_have_aggregate_method: | |||
msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ | |||
f'aggregate_method:{self.aggregate_method}.' | |||
elif not e.should_have_aggregate_method: | |||
msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ | |||
f'aggregate_method:{self.aggregate_method}.' | |||
if e.only_warn: | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
logger.warning(msg) | |||
self._value = self.backend.aggregate(self._value, method=None) | |||
else: | |||
raise RuntimeError(msg) | |||
def reset(self): | |||
if self.backend.is_specified(): | |||
self._value = self.backend.fill_value(self._value, self.init_value) | |||
@property | |||
def value(self): | |||
return self._value | |||
@value.setter | |||
def value(self, value): | |||
self._check_value_initialized() | |||
self._value = value | |||
@value.getter | |||
def value(self): | |||
self._check_value_initialized() | |||
return self._value | |||
def get_scalar(self) -> float: | |||
return self.backend.get_scalar(self._value) | |||
def fill_value(self, value): | |||
self._value = self.backend.fill_value(self._value, value) | |||
def to(self, device): | |||
# device这里如何处理呢? | |||
if self._value is not None: | |||
self._value = self.backend.move_tensor_to_device(self._value, device) | |||
self.device = device | |||
def _check_value_initialized(self): | |||
if self._value is None: | |||
assert self.backend.is_specified(), f"Backend is not specified, please specify backend in the Metric " \ | |||
f"initialization." | |||
self._value = self.backend.create_tensor(self.init_value) | |||
if self.device is not None: | |||
self.to(device=self.device) | |||
def _check_value_when_call(self): | |||
if self.value is None: | |||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
"element, or use it after it being used by the `Metric.compute()` method.") | |||
def __add__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value += other.value | |||
else: | |||
self.value += other | |||
return self | |||
def __radd__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value += other.value | |||
else: | |||
self.value += other | |||
return self | |||
def __sub__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value -= other.value | |||
else: | |||
self.value -= other | |||
return self | |||
def __rsub__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value -= other.value | |||
else: | |||
self.value -= other | |||
return self | |||
def __mul__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value *= other.value | |||
else: | |||
self.value *= other | |||
return self | |||
def __imul__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value *= other.value | |||
else: | |||
self.value *= other | |||
return self | |||
def __floordiv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value //= other.value | |||
else: | |||
self.value //= other | |||
return self | |||
def __rfloordiv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value //= other.value | |||
else: | |||
self.value //= other | |||
return self | |||
def __truediv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value /= other.value | |||
else: | |||
self.value /= other | |||
return self | |||
def __rtruediv__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value /= other.value | |||
else: | |||
self.value /= other | |||
return self | |||
def __mod__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value %= other.value | |||
else: | |||
self.value %= other | |||
return self | |||
def __rmod__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value /= other.value | |||
else: | |||
self.value /= other | |||
return self | |||
def __pow__(self, other, modulo=None): | |||
self._check_value_when_call() | |||
if modulo is None: | |||
if isinstance(other, Element): | |||
self.value **= other.value | |||
else: | |||
self.value **= other | |||
else: | |||
if isinstance(other, Element): | |||
self.value = pow(self.value, other.value, modulo) | |||
else: | |||
self.value = pow(self.value, other, modulo) | |||
return self | |||
def __rpow__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
self.value **= other.value | |||
else: | |||
self.value **= other | |||
return self | |||
def __lt__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value < other.value | |||
else: | |||
return self.value < other | |||
def __le__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value <= other.value | |||
else: | |||
return self.value <= other | |||
def __eq__(self, other): | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value == other.value | |||
else: | |||
return self.value == other | |||
def __ne__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value != other.value | |||
else: | |||
return self.value != other | |||
def __ge__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value >= other.value | |||
else: | |||
return self.value >= other | |||
def __gt__(self, other) -> bool: | |||
self._check_value_when_call() | |||
if isinstance(other, Element): | |||
return self.value > other.value | |||
else: | |||
return self.value > other | |||
def __str__(self): | |||
return str(self.value) | |||
def __repr__(self): | |||
return str(self.value) | |||
def __getattr__(self, item): | |||
""" | |||
为FDataLoader提供dataset的方法和属性,实现该方法后,用户可以在FDataLoader实例化后使用apply等dataset的方法 | |||
:param item: | |||
:return: | |||
""" | |||
try: | |||
if self._value is None: | |||
prefix = f'Element:`{self.name}`' if self.name else 'Element' | |||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | |||
"element, or use it after it being used by the `Metric.compute()` method.") | |||
return getattr(self._value, item) | |||
except AttributeError as e: | |||
raise e |
@@ -0,0 +1,184 @@ | |||
__all__ = [ | |||
'Metric' | |||
] | |||
from abc import abstractmethod | |||
from typing import Union | |||
import functools | |||
from contextlib import contextmanager | |||
import numpy as np | |||
from fastNLP.core.metrics.backend import Backend, AutoBackend | |||
from fastNLP.core.metrics.element import Element | |||
class Metric: | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True): | |||
""" | |||
:param str backend: 目前支持四种类型的backend, [torch, paddle, jittor, auto]。其中 auto 表示根据实际调用 Metric.update() | |||
函数时传入的参数决定具体的 backend ,大部分情况下直接使用 auto 即可。 | |||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, | |||
当 backend 不支持分布式时,该参数无意义。 | |||
""" | |||
self.backend = AutoBackend(backend) | |||
self._updated = False | |||
self.get_metric = self._sync_get_metric(self.get_metric) | |||
self.update = self._wrap_update(self.update) | |||
self.reset = self._wrap_auto_reset_elements(self.reset) | |||
self.aggregate_when_get_metric = aggregate_when_get_metric | |||
self._cannot_change_element = False | |||
self._elements = {} | |||
@property | |||
def elements(self) -> dict: | |||
return self._elements | |||
def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: | |||
""" | |||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | |||
tensor 直接进行加减乘除计算即可。 | |||
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 | |||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | |||
:param value: 初始化的值。在调用 Metric.reset() 方法时也将自动设置为该值 | |||
:param aggregate_method: 如何聚合多卡上的结果,如果为单卡执行,该值无意义。 | |||
:param backend: 使用的 backend 。Element 的类型会根据 backend 进行实际的初始化。例如 backend 为 torch 则该对象为 | |||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | |||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | |||
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 | |||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 | |||
到任何一种 tensor ,就默认使用 float 类型作为 element 。 | |||
:return: 注册的 Element 对象 | |||
""" | |||
if backend == 'auto': | |||
backend = self.backend | |||
else: | |||
backend = AutoBackend(backend) | |||
# 当name为None,默认为变量取得变量名 | |||
if name is None: | |||
name = f'ele_var_{len(self._elements)}' | |||
element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) | |||
self.elements[name] = element | |||
setattr(self, name, element) | |||
return element | |||
def reset(self): | |||
""" | |||
如果有非 element 的对象需要 reset 的时候,在本方法中写下非 element 的reset 方式。注册的 element 对象会自动 reset 为初始值。 | |||
""" | |||
pass | |||
def _wrap_auto_reset_elements(self, reset): | |||
@functools.wraps(reset) | |||
def _wrap_reset(*args, **kwargs): | |||
self._updated = False | |||
for ele in self.elements.values(): | |||
ele.reset() | |||
reset(*args, **kwargs) | |||
return _wrap_reset | |||
def _sync_get_metric(self, get_metric): | |||
@functools.wraps(get_metric) | |||
def _wrap_get_metric(*args, **kwargs): | |||
assert self._updated, f"You have to call `{self.__class__.__name__}` update() function before calling " \ | |||
f"get_metric()." | |||
with self.sync(recover=True, aggregate=self.aggregate_when_get_metric): | |||
results = get_metric(*args, **kwargs) | |||
return results | |||
return _wrap_get_metric | |||
def __setattr__(self, key, value): | |||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | |||
if key in self.elements and value is not self.elements[key]: | |||
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") | |||
object.__setattr__(self, key, value) | |||
def _wrap_update(self, update): | |||
@functools.wraps(update) | |||
def _wrap_update(*args, **kwargs): | |||
self.check_backend(*args, **kwargs) | |||
self._cannot_change_element = True | |||
self._updated = True | |||
return update(*args, **kwargs) | |||
return _wrap_update | |||
def check_backend(self, *args, **kwargs): | |||
if not self.backend.is_specified(): | |||
_args = [] | |||
for arg in args: | |||
_args.append(arg) | |||
for arg in kwargs.values(): | |||
_args.append(arg) | |||
self.backend.choose_real_backend(_args) | |||
@contextmanager | |||
def sync(self, recover=True, aggregate=False): | |||
""" | |||
在这个上下文下, metric 会自动先同步需要同步操作的 element 。当 recover 为 True 时,在退出环境的时候,会重新将 element 的 | |||
值恢复到计算前的值。 | |||
""" | |||
keep_value = {} | |||
if aggregate: | |||
for name, element in self.elements.items(): | |||
# 保存过去的值 | |||
keep_value[name] = element.get_scalar() | |||
# 聚合结果 | |||
element.aggregate() | |||
yield | |||
if recover and aggregate: | |||
for name, element in self.elements.items(): | |||
# 恢复结果 | |||
if name in keep_value: | |||
element.fill_value(value=keep_value.get(name)) | |||
@abstractmethod | |||
def update(self, *args, **kwargs): | |||
raise NotImplementedError() | |||
@abstractmethod | |||
def get_metric(self) -> dict: | |||
raise NotImplementedError() | |||
def set_auto_aggregate_when_get_metric(self, flag: bool): | |||
""" | |||
设置是否在 get_metric 的时候自动 aggregate | |||
""" | |||
self.aggregate_when_get_metric = flag | |||
def __getattr__(self, name: str) -> Element: | |||
if 'elements' in self.__dict__: | |||
elements = self.__dict__['elements'] | |||
if name in elements: | |||
return elements[name] | |||
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) | |||
def tensor2numpy(self, tensor) -> np.array: | |||
""" | |||
将tensor向量转为numpy类型变量 | |||
:param tensor: | |||
:return: | |||
""" | |||
return self.backend.tensor2numpy(tensor) | |||
def to(self, device): | |||
""" | |||
将所有的 element 变量移动到 device 设备上 | |||
:param device: | |||
:return: | |||
""" | |||
for element in self.elements.values(): | |||
element.to(device) |
@@ -0,0 +1,344 @@ | |||
__all__ = [ | |||
'SpanFPreRecMetric' | |||
] | |||
from typing import Union, List, Optional | |||
import warnings | |||
from collections import defaultdict | |||
from functools import partial | |||
from fastNLP.core.metrics.backend import Backend | |||
from fastNLP.core.metrics.metric import Metric | |||
from fastNLP.core.vocabulary import Vocabulary | |||
def _check_tag_vocab_and_encoding_type(tag_vocab: Union[Vocabulary, dict], encoding_type: str): | |||
r""" | |||
检查vocab中的tag是否与encoding_type是匹配的 | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:param encoding_type: bio, bmes, bioes, bmeso | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
tags = encoding_type | |||
for tag in tag_set: | |||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||
f"encoding_type." | |||
tags = tags.replace(tag, '') # 删除该值 | |||
if tags: # 如果不为空,说明出现了未使用的tag | |||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
"encoding_type.") | |||
def _get_encoding_type_from_tag_vocab(tag_vocab: Union[Vocabulary, dict]) -> str: | |||
r""" | |||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
bmes_tag_set = set('bmes') | |||
if tag_set == bmes_tag_set: | |||
return 'bmes' | |||
bio_tag_set = set('bio') | |||
if tag_set == bio_tag_set: | |||
return 'bio' | |||
bmeso_tag_set = set('bmeso') | |||
if tag_set == bmeso_tag_set: | |||
return 'bmeso' | |||
bioes_tag_set = set('bioes') | |||
if tag_set == bioes_tag_set: | |||
return 'bioes' | |||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support " | |||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
def _bmes_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | |||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | |||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bmes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bmes_tag, label = tag[:1], tag[2:] | |||
if bmes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def _bmeso_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'M-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bmes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bmes_tag, label = tag[:1], tag[2:] | |||
if bmes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bmes_tag in ('m', 'e') and prev_bmes_tag in ('b', 'm') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bmes_tag == 'o': | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bmes_tag = bmes_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def _bioes_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'E-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bioes_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bioes_tag, label = tag[:1], tag[2:] | |||
if bioes_tag in ('b', 's'): | |||
spans.append((label, [idx, idx])) | |||
elif bioes_tag in ('i', 'e') and prev_bioes_tag in ('b', 'i') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bioes_tag == 'o': | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bioes_tag = bioes_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) | |||
for span in spans | |||
if span[0] not in ignore_labels | |||
] | |||
def _bio_tag_to_spans(tags, ignore_labels=None): | |||
r""" | |||
给定一个tags的lis,比如['O', 'B-singer', 'I-singer', 'I-singer', 'O', 'O']。 | |||
返回[('singer', (1, 4))] (左闭右开区间) | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
:return: List[Tuple[str, List[int, int]]]. [(label,[start, end])] | |||
""" | |||
ignore_labels = set(ignore_labels) if ignore_labels else set() | |||
spans = [] | |||
prev_bio_tag = None | |||
for idx, tag in enumerate(tags): | |||
tag = tag.lower() | |||
bio_tag, label = tag[:1], tag[2:] | |||
if bio_tag == 'b': | |||
spans.append((label, [idx, idx])) | |||
elif bio_tag == 'i' and prev_bio_tag in ('b', 'i') and label == spans[-1][0]: | |||
spans[-1][1][1] = idx | |||
elif bio_tag == 'o': # o tag does not count | |||
pass | |||
else: | |||
spans.append((label, [idx, idx])) | |||
prev_bio_tag = bio_tag | |||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
def _compute_f_pre_rec(beta_square, tp, fn, fp): | |||
r""" | |||
:param tp: int, true positive | |||
:param fn: int, false negative | |||
:param fp: int, false positive | |||
:return: (f, pre, rec) | |||
""" | |||
pre = tp / (fp + tp + 1e-13) | |||
rec = tp / (fn + tp + 1e-13) | |||
f = (1 + beta_square) * pre * rec / (beta_square * pre + rec + 1e-13) | |||
return f, pre, rec | |||
class SpanFPreRecMetric(Metric): | |||
def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, | |||
encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', | |||
beta=1, aggregate_when_get_metric: bool = True,) -> None: | |||
super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if encoding_type: | |||
encoding_type = encoding_type.lower() | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
self.encoding_type = encoding_type | |||
else: | |||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = _bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
self.tag_to_span_func = _bio_tag_to_spans | |||
elif self.encoding_type == 'bmeso': | |||
self.tag_to_span_func = _bmeso_tag_to_spans | |||
elif self.encoding_type == 'bioes': | |||
self.tag_to_span_func = _bioes_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
self.beta = beta | |||
self.beta_square = self.beta ** 2 | |||
self.only_gross = only_gross | |||
self.tag_vocab = tag_vocab | |||
self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) | |||
def get_metric(self) -> dict: | |||
evaluate_result = {} | |||
if not self.only_gross or self.f_type == 'macro': | |||
tags = set(self._false_negatives.keys()) | |||
tags.update(set(self._false_positives.keys())) | |||
tags.update(set(self._true_positives.keys())) | |||
f_sum = 0 | |||
pre_sum = 0 | |||
rec_sum = 0 | |||
for tag in tags: | |||
tp = self._true_positives[tag].get_scalar() | |||
fn = self._false_negatives[tag].get_scalar() | |||
fp = self._false_positives[tag].get_scalar() | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum += rec | |||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag) | |||
pre_key = 'pre-{}'.format(tag) | |||
rec_key = 'rec-{}'.format(tag) | |||
evaluate_result[f_key] = f | |||
evaluate_result[pre_key] = pre | |||
evaluate_result[rec_key] = rec | |||
if self.f_type == 'macro': | |||
evaluate_result['f'] = f_sum / len(tags) | |||
evaluate_result['pre'] = pre_sum / len(tags) | |||
evaluate_result['rec'] = rec_sum / len(tags) | |||
if self.f_type == 'micro': | |||
f, pre, rec = _compute_f_pre_rec(self.beta_square, | |||
sum(val.get_scalar() for val in self._true_positives.values()), | |||
sum(val.get_scalar() for val in self._false_negatives.values()), | |||
sum(val.get_scalar() for val in self._false_positives.values())) | |||
evaluate_result['f'] = f | |||
evaluate_result['pre'] = pre | |||
evaluate_result['rec'] = rec | |||
for key, value in evaluate_result.items(): | |||
evaluate_result[key] = round(value, 6) | |||
return evaluate_result | |||
def update(self, pred, target, seq_len: Optional[List] = None) -> None: | |||
r"""update函数将针对一个批次的预测结果做评价指标的累计 | |||
:param pred: [batch, seq_len] 或者 [batch, seq_len, len(tag_vocab)], 预测的结果 | |||
:param target: [batch, seq_len], 真实值 | |||
:param seq_len: [batch] 文本长度标记 | |||
:return: | |||
""" | |||
pred = self.tensor2numpy(pred) | |||
target = self.tensor2numpy(target) | |||
if pred.ndim == target.ndim and target.ndim == 2: | |||
pass | |||
elif pred.ndim == target.ndim + 1 and target.ndim == 2: | |||
num_classes = pred.shape[-1] | |||
pred = pred.argmax(axis=-1) | |||
if (target >= num_classes).any(): | |||
raise ValueError("A gold label passed to SpanBasedF1Metric contains an " | |||
"id >= {}, the number of classes.".format(num_classes)) | |||
else: | |||
raise RuntimeError(f"when pred have size:{pred.ndim}, target should have size: {pred.ndim} or " | |||
f"{pred.shape[:-1]}, got {target.ndim}.") | |||
batch_size = pred.shape[0] | |||
pred = pred.tolist() | |||
target = target.tolist() | |||
for i in range(batch_size): | |||
pred_tags = pred[i][:int(seq_len[i])] | |||
gold_tags = target[i][:int(seq_len[i])] | |||
pred_str_tags = [self.tag_vocab.to_word(tag) for tag in pred_tags] | |||
gold_str_tags = [self.tag_vocab.to_word(tag) for tag in gold_tags] | |||
pred_spans = self.tag_to_span_func(pred_str_tags, ignore_labels=self.ignore_labels) | |||
gold_spans = self.tag_to_span_func(gold_str_tags, ignore_labels=self.ignore_labels) | |||
for span in pred_spans: | |||
if span in gold_spans: | |||
self._true_positives[span[0]] += 1 | |||
gold_spans.remove(span) | |||
else: | |||
self._false_positives[span[0]] += 1 | |||
for span in gold_spans: | |||
self._false_negatives[span[0]] += 1 |
@@ -0,0 +1,91 @@ | |||
__all__ = [ | |||
'func_post_proc' | |||
] | |||
from typing import Any | |||
from functools import wraps | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs.utils import _module_available | |||
_IS_TORCHMETRICS_AVAILABLE = _module_available('torchmetrics') | |||
if _IS_TORCHMETRICS_AVAILABLE: | |||
from torchmetrics import Metric as torchmetrics_Metric | |||
_IS_ALLENNLP_AVAILABLE = _module_available('allennlp') | |||
if _IS_ALLENNLP_AVAILABLE: | |||
from allennlp.training.metrics import Metric as allennlp_Metric | |||
if _NEED_IMPORT_PADDLE: | |||
from paddle.metric import Metric as paddle_Metric | |||
def _is_torchmetrics_metric(metric: Any) -> bool: | |||
""" | |||
检查输入的对象是否为torchmetrics对象 | |||
:param metric: | |||
:return: | |||
""" | |||
if _IS_TORCHMETRICS_AVAILABLE: | |||
return isinstance(metric, torchmetrics_Metric) | |||
else: | |||
return False | |||
def _is_allennlp_metric(metric: Any) -> bool: | |||
""" | |||
检查输入的对象是否为allennlp对象 | |||
:param metric: | |||
:return: | |||
""" | |||
if _IS_ALLENNLP_AVAILABLE: | |||
return isinstance(metric, allennlp_Metric) | |||
else: | |||
return False | |||
def _is_paddle_metric(metric: Any) -> bool: | |||
""" | |||
检查输入的对象是否为allennlp对象 | |||
:param metric: | |||
:return: | |||
""" | |||
if _NEED_IMPORT_PADDLE: | |||
return isinstance(metric, paddle_Metric) | |||
else: | |||
return False | |||
def func_post_proc(metric: 'Metric', fn: callable, method_name: str) -> 'Metric': | |||
""" | |||
将fn函数作用包裹在 metric 对象的 {method_name} 方法上,使得 metric.{method_name} 函数的返回结果先经过 fn 函数处理 | |||
后再返回。注意对 metric 的 {method_name} 函数的修改是 inplace 的。 | |||
:param metric: metric对象 | |||
:param fn: 作用于 metric 的 accumulate 方法的返回值 | |||
:param method_name: 一般来说,对于 | |||
:return: metric | |||
""" | |||
assert hasattr(metric, method_name) and callable(getattr(metric, method_name)), \ | |||
f"Parameter `metric` must have a {method_name} function." | |||
assert callable(fn), "Parameter `fn` must be callable." | |||
func = getattr(metric, method_name) | |||
@wraps(func) | |||
def wrap_method(*args, **kwargs): | |||
res = func(*args, **kwargs) | |||
return fn(res) | |||
wrap_method.__wrapped_by_func_post_proc__ = True | |||
setattr(metric, method_name, wrap_method) | |||
return metric | |||
class AggregateMethodError(BaseException): | |||
def __init__(self, should_have_aggregate_method, only_warn=False): | |||
super(AggregateMethodError, self).__init__(self) | |||
self.should_have_aggregate_method = should_have_aggregate_method | |||
self.only_warn = only_warn |
@@ -0,0 +1,43 @@ | |||
__all__ = [ | |||
'cache_results', | |||
'is_jittor_dataset', | |||
'jittor_collate_wraps', | |||
'paddle_to', | |||
'paddle_move_data_to_device', | |||
'get_paddle_device_id', | |||
'get_paddle_gpu_str', | |||
'is_in_paddle_dist', | |||
'is_in_fnlp_paddle_dist', | |||
'is_in_paddle_launch_dist', | |||
'f_rich_progress', | |||
'torch_paddle_move_data_to_device', | |||
'torch_move_data_to_device', | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
'match_and_substitute_params', | |||
'apply_to_collection', | |||
'nullcontext', | |||
'pretty_table_printer', | |||
'Option', | |||
'indice_collate_wrapper', | |||
'deprecated', | |||
'seq_len_to_mask', | |||
'synchronize_safe_rm', | |||
'synchronize_mkdir' | |||
] | |||
from .cache_results import cache_results | |||
from .jittor_utils import is_jittor_dataset, jittor_collate_wraps | |||
from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_device_id, get_paddle_gpu_str, is_in_paddle_dist, \ | |||
is_in_fnlp_paddle_dist, is_in_paddle_launch_dist | |||
from .rich_progress import f_rich_progress | |||
from .torch_paddle_utils import torch_paddle_move_data_to_device | |||
from .torch_utils import torch_move_data_to_device | |||
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \ | |||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir | |||
@@ -0,0 +1,310 @@ | |||
from datetime import datetime | |||
import hashlib | |||
import _pickle | |||
import functools | |||
import os | |||
from typing import Callable, List, Any, Optional | |||
import inspect | |||
import ast | |||
from collections import deque | |||
__all__ = [ | |||
'cache_results' | |||
] | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.core.log.highlighter import ColorHighlighter | |||
class FuncCallVisitor(ast.NodeVisitor): | |||
# credit to https://gist.github.com/jargnar/0946ab1d985e2b4ab776 | |||
def __init__(self): | |||
self._name = deque() | |||
@property | |||
def name(self): | |||
return '.'.join(self._name) | |||
@name.deleter | |||
def name(self): | |||
self._name.clear() | |||
def visit_Name(self, node): | |||
self._name.appendleft(node.id) | |||
def visit_Attribute(self, node): | |||
try: | |||
self._name.appendleft(node.attr) | |||
self._name.appendleft(node.value.id) | |||
except AttributeError: | |||
self.generic_visit(node) | |||
def get_func_calls(tree): | |||
func_calls = [] | |||
for node in ast.walk(tree): | |||
if isinstance(node, ast.Call): | |||
callvisitor = FuncCallVisitor() | |||
callvisitor.visit(node.func) | |||
func_calls.append(callvisitor.name) | |||
if isinstance(node, ast.FunctionDef): | |||
if not (node is tree): | |||
func_calls.extend(get_func_calls(node)) | |||
return func_calls | |||
def truncate_start_blanks(source:str)->str: | |||
""" | |||
将source中的每一行按照第一行的indent删掉多余的空格 | |||
:param source: | |||
:return: | |||
""" | |||
lines = source.split('\n') | |||
num_blank = 0 | |||
# get the top blank line | |||
for line in lines: | |||
if line: | |||
num_blank = len(line) - len(line.lstrip()) | |||
new_lines = [] | |||
for line in lines: | |||
i = -1 | |||
for i in range(min(len(line), num_blank)): | |||
if line[i] == ' ': | |||
continue | |||
else: | |||
break | |||
line = line[i:] | |||
new_lines.append(line) | |||
return '\n'.join(new_lines) | |||
def _get_func_and_its_called_func_source_code(func) -> List[str]: | |||
""" | |||
给定一个func,返回在这个函数里面用到的所有函数的源码。 | |||
:param callable func: | |||
:return: | |||
""" | |||
last_frame = inspect.currentframe().f_back.f_back.f_back | |||
last_frame_f_local = last_frame.f_locals | |||
last_frame_loc = {} | |||
if 'loc' in last_frame_f_local: | |||
last_frame_loc = last_frame_f_local['loc'] | |||
func_calls = list(set(get_func_calls(ast.parse(truncate_start_blanks(inspect.getsource(func)))))) | |||
func_calls.sort() | |||
sources = [] | |||
for _func_name in func_calls: | |||
try: | |||
if _func_name == 'cache_results': # ignore the decorator | |||
continue | |||
if '.' in _func_name: | |||
_funcs = _func_name.split('.') | |||
else: | |||
_funcs = [_func_name] | |||
if _funcs[0] in last_frame_f_local or _funcs[0] in last_frame_loc: | |||
tmp = _funcs.pop(0) | |||
variable = last_frame_f_local.get(tmp, last_frame_loc.get(tmp)) | |||
while len(_funcs) or variable is not None: | |||
if hasattr(variable, '__class__') and not inspect.isbuiltin(variable.__class__): | |||
try: | |||
sources.append(inspect.getsource(variable.__class__)) | |||
except TypeError: | |||
pass | |||
if callable(variable) or inspect.isclass(variable): | |||
sources.append(inspect.getsource(variable)) | |||
if len(_funcs): | |||
tmp = _funcs.pop(0) | |||
if hasattr(variable, tmp): | |||
variable = getattr(variable, tmp) | |||
else: | |||
break | |||
else: | |||
variable = None | |||
except: | |||
# some failure | |||
pass | |||
del last_frame # | |||
sources.append(inspect.getsource(func)) | |||
return sources | |||
def _prepare_cache_filepath(filepath:str): | |||
r""" | |||
检查filepath是否可以作为合理的cache文件. 如果可以的话,会自动创造路径 | |||
:param filepath: str. | |||
:return: None, if not, this function will raise error | |||
""" | |||
_cache_filepath = os.path.abspath(filepath) | |||
if os.path.isdir(_cache_filepath): | |||
raise RuntimeError("The cache_file_path must be a file, not a directory.") | |||
cache_dir = os.path.dirname(_cache_filepath) | |||
if not os.path.exists(cache_dir): | |||
os.makedirs(cache_dir, exist_ok=True) | |||
class Hasher: | |||
def __init__(self): | |||
self.m = hashlib.sha1() | |||
def update(self, value: Any) -> None: | |||
if isinstance(value, str): | |||
value = [value] | |||
for x in value: | |||
self.m.update(x.encode('utf8')) | |||
def hexdigest(self) -> str: | |||
return self.m.hexdigest() | |||
def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = None): | |||
if fn_kwargs is None: | |||
fn_kwargs = {} | |||
hasher = Hasher() | |||
try: | |||
sources = _get_func_and_its_called_func_source_code(fn) | |||
hasher.update(sources) | |||
except: | |||
return "can't be hashed" | |||
for key in sorted(fn_kwargs): | |||
hasher.update(key) | |||
try: | |||
hasher.update(fn_kwargs[key]) | |||
except: | |||
pass | |||
return hasher.hexdigest() | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||
r""" | |||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | |||
import time | |||
import numpy as np | |||
from fastNLP import cache_results | |||
@cache_results('cache.pkl') | |||
def process_data(): | |||
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | |||
time.sleep(1) | |||
return np.random.randint(10, size=(5,)) | |||
start_time = time.time() | |||
print("res =",process_data()) | |||
print(time.time() - start_time) | |||
start_time = time.time() | |||
print("res =",process_data()) | |||
print(time.time() - start_time) | |||
# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | |||
# Save cache to cache.pkl. | |||
# res = [5 4 9 1 8] | |||
# 1.0042750835418701 | |||
# Read cache from cache.pkl. | |||
# res = [5 4 9 1 8] | |||
# 0.0040721893310546875 | |||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: | |||
# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | |||
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' | |||
上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | |||
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | |||
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: | |||
process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 | |||
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | |||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | |||
函数调用的时候传入_cache_fp这个参数。 | |||
:param bool _refresh: 是否重新生成cache。 | |||
:param int _verbose: 是否打印cache的信息。 | |||
:param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 | |||
与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 | |||
该修改对结果有影响,但无法做出warning。 | |||
:return: | |||
""" | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'): | |||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
@functools.wraps(func) | |||
def wrapper(*args, **kwargs): | |||
fn_param = kwargs.copy() | |||
if args: | |||
params = [p.name for p in inspect.signature(func).parameters.values()] | |||
fn_param.update(zip(params, args)) | |||
if '_cache_fp' in kwargs: | |||
cache_filepath = kwargs.pop('_cache_fp') | |||
assert isinstance(cache_filepath, str), "_cache_fp can only be str." | |||
else: | |||
cache_filepath = _cache_fp | |||
if '_refresh' in kwargs: | |||
refresh = kwargs.pop('_refresh') | |||
assert isinstance(refresh, bool), "_refresh can only be bool." | |||
else: | |||
refresh = _refresh | |||
if '_verbose' in kwargs: | |||
verbose = kwargs.pop('_verbose') | |||
assert isinstance(verbose, int), "_verbose can only be integer." | |||
else: | |||
verbose = _verbose | |||
if '_check_hash' in kwargs: | |||
check_hash = kwargs.pop('_check_hash') | |||
else: | |||
check_hash = _check_hash | |||
refresh_flag = True | |||
new_hash_code = None | |||
if check_hash: | |||
new_hash_code = cal_fn_hash_code(func, fn_param) | |||
if cache_filepath is not None and refresh is False: | |||
# load data | |||
if os.path.exists(cache_filepath): | |||
cache_filepath = os.path.abspath(cache_filepath) | |||
with open(cache_filepath, 'rb') as f: | |||
results = _pickle.load(f) | |||
old_hash_code = results['hash'] | |||
save_time = results['save_time'] | |||
results = results['results'] | |||
if verbose == 1: | |||
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | |||
if check_hash and old_hash_code != new_hash_code: | |||
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | |||
f"difference may caused by the sourcecode change of the functions by this function.", | |||
extra={'highlighter': ColorHighlighter('red')}) | |||
refresh_flag = False | |||
if refresh_flag: | |||
if new_hash_code is None: | |||
new_hash_code = cal_fn_hash_code(func, fn_param) | |||
results = func(*args, **kwargs) | |||
if cache_filepath is not None: | |||
if results is None: | |||
raise RuntimeError("The return value is None. Cannot save None results.") | |||
cache_filepath = os.path.abspath(cache_filepath) | |||
_prepare_cache_filepath(cache_filepath) | |||
_dict = { | |||
'results': results, | |||
'hash': new_hash_code, | |||
'save_time': datetime.now(), | |||
} | |||
with open(cache_filepath, 'wb') as f: | |||
_pickle.dump(_dict, f) | |||
logger.info("Save cache to {}.".format(cache_filepath)) | |||
return results | |||
return wrapper | |||
return wrapper_ |
@@ -0,0 +1,4 @@ | |||
class DummyClass: | |||
pass |
@@ -0,0 +1,51 @@ | |||
__all__ = [ | |||
'is_jittor_dataset', | |||
'jittor_collate_wraps' | |||
] | |||
from collections.abc import Mapping, Callable | |||
from functools import wraps | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor as jt | |||
from fastNLP.core.dataset import Instance | |||
def is_jittor_dataset(dataset) -> bool: | |||
try: | |||
if isinstance(dataset, jt.dataset.Dataset): | |||
return True | |||
else: | |||
return False | |||
except BaseException: | |||
return False | |||
def jittor_collate_wraps(func, auto_collator: Callable): | |||
""" | |||
对jittor的collate_fn进行wrap封装, 如果数据集为mapping类型,那么采用auto_collator,否则还是采用jittor自带的collate_batch | |||
:param func: | |||
:param auto_collator: | |||
:return: | |||
""" | |||
@wraps(func) | |||
def wrapper(batch): | |||
if isinstance(batch[0], Instance): | |||
if auto_collator is not None: | |||
result = auto_collator(batch) | |||
else: | |||
raise ValueError(f"auto_collator is None, but batch exist fastnlp instance!") | |||
elif isinstance(batch[0], Mapping): | |||
if auto_collator is not None: | |||
result = auto_collator(batch) | |||
else: | |||
result = func(batch) | |||
else: | |||
result = func(batch) | |||
return result | |||
return wrapper |
@@ -0,0 +1,89 @@ | |||
__all__ = [ | |||
"paddle_to", | |||
"paddle_move_data_to_device", | |||
"get_paddle_gpu_str", | |||
"get_paddle_device_id", | |||
"is_in_paddle_dist", | |||
"is_in_fnlp_paddle_dist", | |||
"is_in_paddle_launch_dist", | |||
] | |||
import os | |||
from typing import Any, Optional, Union | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from .utils import apply_to_collection | |||
def paddle_to(data, device: Union[str, int]): | |||
if device == "cpu": | |||
return data.cpu() | |||
else: | |||
return data.cuda(get_paddle_device_id(device)) | |||
def get_paddle_gpu_str(device: Union[str, int]): | |||
""" | |||
获得 `gpu:x` 类型的设备名 | |||
""" | |||
if isinstance(device, str): | |||
return device.replace("cuda", "gpu") | |||
return f"gpu:{device}" | |||
def get_paddle_device_id(device: Union[str, int]): | |||
""" | |||
获得 gpu 的设备id,注意不要传入 `cpu` 。 | |||
""" | |||
if isinstance(device, int): | |||
return device | |||
if device == "cpu": | |||
raise ValueError("Cannot get device id from `cpu`.") | |||
return paddle.device._convert_to_place(device).get_device_id() | |||
def paddle_move_data_to_device(batch: Any, device: Optional[str] = None, | |||
data_device: Optional[str] = None) -> Any: | |||
r""" | |||
将数据集合传输到给定设备。只有paddle.Tensor对象会被传输到设备中,其余保持不变 | |||
:param batch: | |||
:param device: `cpu`, `gpu` or `gpu:x` | |||
:param data_device: | |||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
""" | |||
if device is None: | |||
if data_device is not None: | |||
device = data_device | |||
else: | |||
return batch | |||
def batch_to(data: Any) -> Any: | |||
return paddle_to(data, device) | |||
return apply_to_collection(batch, dtype=paddle.Tensor, function=batch_to) | |||
def is_in_paddle_dist(): | |||
""" | |||
判断是否处于分布式的进程下,使用 global_rank 和 selected_gpus 判断 | |||
""" | |||
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | |||
def is_in_fnlp_paddle_dist(): | |||
""" | |||
判断是否处于 FastNLP 拉起的分布式进程中 | |||
""" | |||
return FASTNLP_DISTRIBUTED_CHECK in os.environ | |||
def is_in_paddle_launch_dist(): | |||
""" | |||
判断是否处于 launch 启动的分布式进程中 | |||
""" | |||
return 'PADDLE_RANK_IN_NODE' in os.environ and \ | |||
'FLAGS_selected_gpus' in os.environ and \ | |||
FASTNLP_DISTRIBUTED_CHECK not in os.environ |
@@ -0,0 +1,214 @@ | |||
""" | |||
该文件用于为fastNLP提供一个统一的progress bar管理,通过共用一个Task对象,trainer中的progress bar和evaluation中的progress bar才能 | |||
不冲突 | |||
""" | |||
import sys | |||
from typing import Any, Union, Optional | |||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | |||
__all__ = [ | |||
'f_rich_progress' | |||
] | |||
from fastNLP.envs import get_global_rank | |||
class Singleton(type): | |||
_instances = {} | |||
def __call__(cls, *args, **kwargs): | |||
if cls not in cls._instances: | |||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||
return cls._instances[cls] | |||
# 如果不打印的时候,使得整个 progress 没有任何意义 | |||
class DummyFRichProgress: | |||
def __getattr__(self, item): | |||
return DummyFRichProgress() | |||
def __call__(self, *args, **kwargs): | |||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | |||
return None | |||
class FRichProgress(Progress, metaclass=Singleton): | |||
""" | |||
fastNLP 使用的 progress bar ,新增了 new_progress 函数,通过此函数即可定制 fastNLP 中所有 progress 的样式。 | |||
""" | |||
def new_progess(self, *columns: Union[str, ProgressColumn], | |||
console: Optional[Console] = None, | |||
auto_refresh: bool = True, | |||
refresh_per_second: float = 10, | |||
speed_estimate_period: float = 30.0, | |||
transient: bool = True, | |||
redirect_stdout: bool = True, | |||
redirect_stderr: bool = True, | |||
get_time: Optional[GetTimeCallable] = None, | |||
disable: bool = False, | |||
expand: bool = False): | |||
""" | |||
重新初始化一个rich bar。如果columns不传入,则继续使用之前的column内容。 | |||
:param progress: | |||
:return: | |||
""" | |||
for task_id in self.task_ids: # 首先移除已有的 | |||
self.remove_task(task_id) | |||
assert ( | |||
refresh_per_second is None or refresh_per_second > 0 | |||
), "refresh_per_second must be > 0" | |||
# stop previous columns | |||
self.stop() | |||
# do not change these variables | |||
# self._lock = RLock() | |||
# self._tasks: Dict[TaskID, Task] = {} | |||
# self._task_index: TaskID = TaskID(0) | |||
if len(columns) != 0: | |||
self.columns = columns | |||
self.speed_estimate_period = speed_estimate_period | |||
self.disable = disable | |||
self.expand = expand | |||
self.live = Live( | |||
console=console or get_console(), | |||
auto_refresh=auto_refresh, | |||
refresh_per_second=refresh_per_second, | |||
transient=transient, | |||
redirect_stdout=redirect_stdout, | |||
redirect_stderr=redirect_stderr, | |||
get_renderable=self.get_renderable, | |||
) | |||
self.get_time = get_time or self.console.get_time | |||
self.print = self.console.print | |||
self.log = self.console.log | |||
# start new | |||
self.start() | |||
return self | |||
def set_transient(self, transient: bool = True): | |||
""" | |||
设置是否在bar运行结束之后不关闭 | |||
:param transient: | |||
:return: | |||
""" | |||
self.new_progess(transient=transient) | |||
def set_disable(self, flag: bool = True): | |||
""" | |||
设置当前 progress bar 的状态,如果为 True ,则不会显示进度条了。 | |||
:param flag: | |||
:return: | |||
""" | |||
self.disable = flag | |||
def add_task( | |||
self, | |||
description: str, | |||
start: bool = True, | |||
total: float = 100.0, | |||
completed: int = 0, | |||
visible: bool = True, | |||
**fields: Any, | |||
) -> TaskID: | |||
if self.live._started is False: | |||
self.start() | |||
post_desc = fields.pop('post_desc', '') | |||
return super().add_task(description=description, | |||
start=start, | |||
total=total, | |||
completed=completed, | |||
visible=visible, | |||
post_desc=post_desc, | |||
**fields) | |||
def stop_task(self, task_id: TaskID) -> None: | |||
if task_id in self._tasks: | |||
super().stop_task(task_id) | |||
def remove_task(self, task_id: TaskID) -> None: | |||
if task_id in self._tasks: | |||
super().remove_task(task_id) | |||
def destroy_task(self, task_id: TaskID): | |||
if task_id in self._tasks: | |||
super().stop_task(task_id) | |||
super().remove_task(task_id) | |||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
f_rich_progress = FRichProgress().new_progess( | |||
"[progress.description]{task.description}", | |||
"[progress.percentage]{task.percentage:>3.0f}%", | |||
BarColumn(), | |||
TimeElapsedColumn(), | |||
"/", | |||
TimeRemainingColumn(), | |||
TextColumn("{task.fields[post_desc]}", justify="right"), | |||
transient=True, | |||
disable=False, | |||
speed_estimate_period=10 | |||
) | |||
else: | |||
f_rich_progress = DummyFRichProgress() | |||
if __name__ == '__main__': | |||
f = DummyFRichProgress() | |||
f.console.print('xxx') | |||
f.console.print.print('xxx') | |||
# 测试创建 | |||
import time | |||
n_steps = 10 | |||
task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||
for i in range(n_steps): | |||
f_rich_progress.update(task_id, description=f'test:{i}', advance=1, refresh=True) | |||
print(f"test:{i}") | |||
time.sleep(0.3) | |||
f_rich_progress.remove_task(task_id) | |||
# 测试一下 inner/outer | |||
n_steps = 5 | |||
f_rich_progress.start() | |||
outer_task_id = f_rich_progress.add_task(description='Outer:', total=n_steps) | |||
inner_task_id = f_rich_progress.add_task(description='Inner:', total=n_steps) | |||
for i in range(n_steps): | |||
f_rich_progress.reset(inner_task_id, total=n_steps) | |||
f_rich_progress.update(outer_task_id, description=f'Outer:{i}', advance=1, refresh=True) | |||
for j in range(n_steps): | |||
f_rich_progress.update(inner_task_id, description=f'Inner:{j}', advance=1, refresh=True, | |||
post_desc='Loss: 0.334332323') | |||
print(f"Outer:{i}, Inner:{j}") | |||
time.sleep(0.3) | |||
# 测试一下修改bar | |||
f_rich_progress = FRichProgress().new_progess( | |||
BarColumn(), | |||
"[progress.description]{task.description}", | |||
"[progress.percentage]{task.percentage:>3.0f}%", | |||
TimeElapsedColumn(), | |||
transient=True) | |||
n_steps = 10 | |||
task_id = f_rich_progress.add_task(description='test', total=n_steps) | |||
for i in range(n_steps): | |||
f_rich_progress.update(task_id, description=f'test:{i}', advance=1) | |||
print(f"test:{i}") | |||
time.sleep(0.3) | |||
f_rich_progress.remove_task(task_id) | |||
f_rich_progress.stop() |
@@ -0,0 +1,49 @@ | |||
from typing import Any, Optional | |||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
__all__ = [ | |||
"torch_paddle_move_data_to_device", | |||
] | |||
from .utils import apply_to_collection | |||
from .paddle_utils import paddle_to | |||
def torch_paddle_move_data_to_device(batch: Any, device: Optional[str] = None, non_blocking: Optional[bool] = True, | |||
data_device: Optional[str] = None) -> Any: | |||
r""" | |||
将数据集合传输到给定设备。只有paddle.Tensor和torch.Tensor对象会被传输到设备中,其余保持不变 | |||
:param batch: | |||
:param device: | |||
:param non_blocking: | |||
:param data_device: | |||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
""" | |||
if device is None: | |||
if data_device is not None: | |||
device = data_device | |||
else: | |||
return batch | |||
torch_device = device.replace("gpu", "cuda") | |||
paddle_device = device.replace("cuda", "gpu") | |||
def batch_to(data: Any) -> Any: | |||
if isinstance(data, torch.Tensor): | |||
data = data.to(torch_device, non_blocking=non_blocking) | |||
elif isinstance(data, paddle.Tensor): | |||
data = paddle_to(data, paddle_device) | |||
return data | |||
return apply_to_collection(batch, dtype=(paddle.Tensor, torch.Tensor), function=batch_to) |
@@ -0,0 +1,63 @@ | |||
from abc import ABC | |||
from typing import Any, Union, Optional | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
__all__ = [ | |||
'torch_move_data_to_device' | |||
] | |||
from .utils import apply_to_collection | |||
class TorchTransferableDataType(ABC): | |||
""" | |||
A custom type for data that can be moved to a torch device via `.to(...)`. | |||
Example: | |||
>>> isinstance(dict, TorchTransferableDataType) | |||
False | |||
>>> isinstance(torch.rand(2, 3), TorchTransferableDataType) | |||
True | |||
>>> class CustomObject: | |||
... def __init__(self): | |||
... self.x = torch.rand(2, 2) | |||
... def to(self, device): | |||
... self.x = self.x.to(device) | |||
... return self | |||
>>> isinstance(CustomObject(), TorchTransferableDataType) | |||
True | |||
""" | |||
@classmethod | |||
def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: | |||
if cls is TorchTransferableDataType: | |||
to = getattr(subclass, "to", None) | |||
return callable(to) | |||
return NotImplemented | |||
def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, | |||
non_blocking: Optional[bool] = True) -> Any: | |||
r""" | |||
将数据集合传输到给定设备。任何定义方法 “to(device)” 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||
:param batch: 应当迁移的数据; | |||
:param device: 数据应当迁移到的设备;当该参数的值为 None 时,表示迁移数据的操作由用户自己完成,我们不需要经管; | |||
:param non_blocking: pytorch 的迁移数据方法 `to` 的参数; | |||
:return: 相同的集合,但所有包含的张量都驻留在新设备上; | |||
""" | |||
if device is None: | |||
return batch | |||
def batch_to(data: Any) -> Any: | |||
kwargs = dict(non_blocking=non_blocking) if isinstance(data, torch.Tensor) else {} | |||
data_output = data.to(device, **kwargs) | |||
if data_output is not None: | |||
return data_output | |||
# user wrongly implemented the `TransferableDataType` and forgot to return `self`. | |||
return data | |||
dtype = TorchTransferableDataType | |||
return apply_to_collection(batch, dtype=dtype, function=batch_to) |
@@ -0,0 +1,591 @@ | |||
import inspect | |||
from inspect import Parameter | |||
import dataclasses | |||
import warnings | |||
from dataclasses import is_dataclass | |||
from copy import deepcopy | |||
from collections import defaultdict, OrderedDict | |||
from typing import Callable, List, Any, Dict, AnyStr, Union, Mapping, Sequence, Optional | |||
from typing import Tuple, Optional | |||
from time import sleep | |||
try: | |||
from typing import Literal, Final | |||
except ImportError: | |||
from typing_extensions import Literal, Final | |||
import os | |||
from contextlib import contextmanager | |||
from functools import wraps | |||
from prettytable import PrettyTable | |||
import numpy as np | |||
from pathlib import Path | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||
__all__ = [ | |||
'get_fn_arg_names', | |||
'check_fn_not_empty_params', | |||
'auto_param_call', | |||
'check_user_specific_params', | |||
'dataclass_to_dict', | |||
'match_and_substitute_params', | |||
'apply_to_collection', | |||
'nullcontext', | |||
'pretty_table_printer', | |||
'Option', | |||
'indice_collate_wrapper', | |||
'deprecated', | |||
'seq_len_to_mask', | |||
'synchronize_safe_rm', | |||
'synchronize_mkdir' | |||
] | |||
def get_fn_arg_names(fn: Callable) -> List[str]: | |||
r""" | |||
返回一个函数的所有参数的名字; | |||
:param fn: 需要查询的函数; | |||
:return: 一个列表,其中的元素则是查询函数的参数的字符串名字; | |||
""" | |||
return list(inspect.signature(fn).parameters) | |||
def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool: | |||
r""" | |||
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数; | |||
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可; | |||
:param fn: 传入的用以代替 Loop 中 'step' 函数的函数; | |||
:param param_num: 检测的函数的应当的没有默认值的参数的个数; | |||
:return: bool,表示传入的 `batch_step_fn` 是否正确; | |||
""" | |||
if fn is None: | |||
return True | |||
if not callable(fn): | |||
return False | |||
else: | |||
params = inspect.signature(fn).parameters | |||
not_default_params = {} | |||
for _name, _param in params.items(): | |||
if _param.default == Parameter.empty: | |||
not_default_params[_name] = _param | |||
return len(not_default_params) == param_num | |||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
r""" | |||
1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||
4.如果输入的函数是一个 `partial` 函数,情况同 '3.',即和默认参数的情况相同; | |||
:param fn: 用来进行实际计算的函数,其参数可以包含有默认值; | |||
:param args: 一系列的位置参数,应当为一系列的字典,我们需要从这些输入中提取 `fn` 计算所需要的实际参数; | |||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||
参数值后,再传给 `fn` 进行实际的运算; | |||
:param mapping: 一个字典,用来更改其前面的字典的键值; | |||
:return: 返回 `fn` 运行的结果; | |||
Examples: | |||
>>> # 1 | |||
>>> loss_fn = CrossEntropyLoss() # 如果其需要的参数为 def CrossEntropyLoss(y, pred); | |||
>>> batch = {"x": 20, "y": 1} | |||
>>> output = {"pred": 0} | |||
>>> acc = auto_param_call(loss_fn, batch, output) | |||
>>> # 2 | |||
>>> def test_fn(x, y, a, b=10): | |||
>>> return x + y + a + b | |||
>>> print(auto_param_call(test_fn, {"x": 10}, {"y": 20, "a": 30})) # res: 70 | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | |||
""" | |||
if signature_fn is not None: | |||
if not callable(signature_fn): | |||
raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | |||
_need_params = OrderedDict(inspect.signature(signature_fn).parameters) | |||
else: | |||
_need_params = OrderedDict(inspect.signature(fn).parameters) | |||
_kwargs = None | |||
for _name, _param in _need_params.items(): | |||
if _param.kind == Parameter.VAR_POSITIONAL: | |||
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.") | |||
if _param.kind == Parameter.VAR_KEYWORD: | |||
_kwargs = (_name, _param) | |||
if _kwargs is not None: | |||
_need_params.pop(_kwargs[0]) | |||
_default_params = {} | |||
for _name, _param in _need_params.items(): | |||
if _param.default != Parameter.empty: | |||
_default_params[_name] = _param.default | |||
if mapping is not None: | |||
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}." | |||
_has_params = {} | |||
duplicate_names = [] | |||
for arg in args: | |||
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type." | |||
for _name, _value in arg.items(): | |||
if mapping is not None and _name in mapping: | |||
_name = mapping[_name] | |||
if _name not in _has_params: | |||
if _kwargs is not None or _name in _need_params: | |||
_has_params[_name] = _value | |||
# 同一参数对象在两个输入的资源中都出现,造成混淆; | |||
elif _name in _need_params and not (_has_params[_name] is _value): | |||
duplicate_names.append(_name) | |||
if duplicate_names: | |||
raise ValueError(f"The following key present in several inputs:{duplicate_names}") | |||
# 将具有默认值但是没有被输入修改过的参数值传进去; | |||
for _name, _value in _default_params.items(): | |||
if _name not in _has_params: | |||
_has_params[_name] = _value | |||
if len(_has_params)<len(_need_params): | |||
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) | |||
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.") | |||
return fn(**_has_params) | |||
def check_user_specific_params(user_params: Dict, fn: Callable): | |||
""" | |||
该函数使用用户的输入来对指定函数的参数进行赋值; | |||
主要用于一些用户无法直接调用函数的情况; | |||
该函数主要的作用在于帮助检查用户对使用函数 fn 的参数输入是否有误; | |||
:param user_params: 用户指定的参数的值,应当是一个字典,其中 key 表示每一个参数的名字,value 为每一个参数应当的值; | |||
:param fn: 会被调用的函数; | |||
:return: 返回一个字典,其中为在之后调用函数 fn 时真正会被传进去的参数的值; | |||
""" | |||
fn_arg_names = get_fn_arg_names(fn) | |||
for arg_name, arg_value in user_params.items(): | |||
if arg_name not in fn_arg_names: | |||
logger.warning(f"Notice your specific parameter `{arg_name}` is not used by function `{fn.__name__}`.") | |||
return user_params | |||
def dataclass_to_dict(data: "dataclass") -> Dict: | |||
if not is_dataclass(data): | |||
raise TypeError(f"Parameter `data` can only be `dataclass` type instead of {type(data)}.") | |||
_dict = dict() | |||
for _key in data.__dataclass_fields__: | |||
_dict[_key] = getattr(data, _key) | |||
return _dict | |||
def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | |||
r""" | |||
用来实现将输入:batch,或者输出:outputs,通过 `mapping` 将键值进行更换的功能; | |||
该函数应用于 `input_mapping` 和 `output_mapping`; | |||
对于 `input_mapping`,该函数会在 `TrainBatchLoop` 中取完数据后立刻被调用; | |||
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; | |||
转换的逻辑按优先级依次为: | |||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; | |||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; | |||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; | |||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; | |||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 | |||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 | |||
:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 | |||
:param data: 需要被转换的对象; | |||
:return: 返回转换好的结果; | |||
""" | |||
if mapping is None: | |||
return data | |||
if callable(mapping): | |||
# 注意我们在 `Trainer.extract_loss_from_outputs` 函数里会检查 outputs 的输出,outputs 的类型目前只支持 `Dict` 和 `dataclass`; | |||
return mapping(data) | |||
if not isinstance(mapping, Dict): | |||
raise ValueError( | |||
f"Parameter `mapping` should be of type `Dict` or `Callable`, not `{type(mapping)}`. This is caused" | |||
f"by your `input_mapping` or `output_mapping` parameter in your `Trainer` or `Evaluator`.") | |||
if not isinstance(data, Dict) and not is_dataclass(data) and not isinstance(data, Sequence): | |||
raise ValueError("Parameter `data` should be type `Dict` or `dataclass` when the other parameter `mapping` is " | |||
"type `Dict`.") | |||
# 如果 `data` 是一个 dataclass,那么先将其转换为一个 `Dict`; | |||
if is_dataclass(data): | |||
data = dataclass_to_dict(data) | |||
# 如果 `data` 是一个 List,那么我们同样先将其转换为一个 `Dict`,为 {"_0": list[0], "_1": list[1], ...}; | |||
elif isinstance(data, Sequence): | |||
data = {"_" + str(i): data[i] for i in range(len(data))} | |||
_new_data = {} | |||
for _name, _value in data.items(): | |||
if _name in mapping: | |||
_new_data[mapping[_name]] = _value | |||
else: | |||
_new_data[_name] = _value | |||
return _new_data | |||
def _is_namedtuple(obj: object) -> bool: | |||
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 | |||
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") | |||
def _is_dataclass_instance(obj: object) -> bool: | |||
# https://docs.python.org/3/library/dataclasses.html#module-level-decorators-classes-and-functions | |||
return dataclasses.is_dataclass(obj) and not isinstance(obj, type) | |||
def apply_to_collection( | |||
data: Any, | |||
dtype: Union[type, Any, Tuple[Union[type, Any]]], | |||
function: Callable, | |||
*args: Any, | |||
wrong_dtype: Optional[Union[type, Tuple[type]]] = None, | |||
include_none: bool = True, | |||
**kwargs: Any, | |||
) -> Any: | |||
"""将函数 function 递归地在 data 中的元素执行,但是仅在满足元素为 dtype 时执行。 | |||
this function credit to: https://github.com/PyTorchLightning/pytorch-lightning | |||
Args: | |||
data: the collection to apply the function to | |||
dtype: the given function will be applied to all elements of this dtype | |||
function: the function to apply | |||
*args: positional arguments (will be forwarded to calls of ``function``) | |||
wrong_dtype: the given function won't be applied if this type is specified and the given collections | |||
is of the ``wrong_dtype`` even if it is of type ``dtype`` | |||
include_none: Whether to include an element if the output of ``function`` is ``None``. | |||
**kwargs: keyword arguments (will be forwarded to calls of ``function``) | |||
Returns: | |||
The resulting collection | |||
""" | |||
# Breaking condition | |||
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): | |||
return function(data, *args, **kwargs) | |||
elem_type = type(data) | |||
# Recursively apply to collection items | |||
if isinstance(data, Mapping): | |||
out = [] | |||
for k, v in data.items(): | |||
v = apply_to_collection( | |||
v, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||
) | |||
if include_none or v is not None: | |||
out.append((k, v)) | |||
if isinstance(data, defaultdict): | |||
return elem_type(data.default_factory, OrderedDict(out)) | |||
return elem_type(OrderedDict(out)) | |||
is_namedtuple = _is_namedtuple(data) | |||
is_sequence = isinstance(data, Sequence) and not isinstance(data, str) | |||
if is_namedtuple or is_sequence: | |||
out = [] | |||
for d in data: | |||
v = apply_to_collection( | |||
d, dtype, function, *args, wrong_dtype=wrong_dtype, include_none=include_none, **kwargs | |||
) | |||
if include_none or v is not None: | |||
out.append(v) | |||
return elem_type(*out) if is_namedtuple else elem_type(out) | |||
if _is_dataclass_instance(data): | |||
# make a deepcopy of the data, | |||
# but do not deepcopy mapped fields since the computation would | |||
# be wasted on values that likely get immediately overwritten | |||
fields = {} | |||
memo = {} | |||
for field in dataclasses.fields(data): | |||
field_value = getattr(data, field.name) | |||
fields[field.name] = (field_value, field.init) | |||
memo[id(field_value)] = field_value | |||
result = deepcopy(data, memo=memo) | |||
# apply function to each field | |||
for field_name, (field_value, field_init) in fields.items(): | |||
if field_init: | |||
v = apply_to_collection( | |||
field_value, | |||
dtype, | |||
function, | |||
*args, | |||
wrong_dtype=wrong_dtype, | |||
include_none=include_none, | |||
**kwargs, | |||
) | |||
if not field_init or (not include_none and v is None): # retain old value | |||
v = getattr(data, field_name) | |||
setattr(result, field_name, v) | |||
return result | |||
# data is neither of dtype, nor a collection | |||
return data | |||
@contextmanager | |||
def nullcontext(): | |||
r""" | |||
用来实现一个什么 dummy 的 context 上下文环境; | |||
""" | |||
yield | |||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
r""" | |||
:param string: 要被截断的字符串 | |||
:param c: 命令行列数 | |||
:param c_size: instance或dataset field数 | |||
:param title: 列名 | |||
:return: 对一个过长的列进行截断的结果 | |||
""" | |||
avg = max(int(c / c_size / 2), len(title)) | |||
string = str(string) | |||
res = "" | |||
counter = 0 | |||
for char in string: | |||
if ord(char) > 255: | |||
counter += 2 | |||
else: | |||
counter += 1 | |||
res += char | |||
if counter > avg: | |||
res = res + "..." | |||
break | |||
return res | |||
def _is_iterable(value): | |||
# 检查是否是iterable的, duck typing | |||
try: | |||
iter(value) | |||
return True | |||
except BaseException as e: | |||
return False | |||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
r""" | |||
:param dataset_or_ins: 传入一个dataSet或者instance | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||
+-----------+-----------+-----------------+ | |||
| field_1 | field_2 | field_3 | | |||
+-----------+-----------+-----------------+ | |||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||
+-----------+-----------+-----------------+ | |||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||
""" | |||
x = PrettyTable() | |||
try: | |||
sz = os.get_terminal_size() | |||
column = sz.columns | |||
row = sz.lines | |||
except OSError: | |||
column = 144 | |||
row = 11 | |||
if type(dataset_or_ins).__name__ == "DataSet": | |||
x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||
c_size = len(x.field_names) | |||
for ins in dataset_or_ins: | |||
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||
row -= 1 | |||
if row < 0: | |||
x.add_row(["..." for _ in range(c_size)]) | |||
break | |||
elif type(dataset_or_ins).__name__ == "Instance": | |||
x.field_names = list(dataset_or_ins.fields.keys()) | |||
c_size = len(x.field_names) | |||
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||
else: | |||
raise Exception("only accept DataSet and Instance") | |||
x.align = "l" | |||
return x | |||
class Option(dict): | |||
r"""a dict can treat keys as attributes""" | |||
def __getattr__(self, item): | |||
try: | |||
return self.__getitem__(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __setattr__(self, key, value): | |||
if key.startswith('__') and key.endswith('__'): | |||
raise AttributeError(key) | |||
self.__setitem__(key, value) | |||
def __delattr__(self, item): | |||
try: | |||
self.pop(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __getstate__(self): | |||
return self | |||
def __setstate__(self, state): | |||
self.update(state) | |||
def indice_collate_wrapper(func): | |||
""" | |||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | |||
:param func: 需要修饰的函数 | |||
:return: | |||
""" | |||
def wrapper(tuple_data): | |||
indice, ins_list = [], [] | |||
for idx, ins in tuple_data: | |||
indice.append(idx) | |||
ins_list.append(ins) | |||
return indice, func(ins_list) | |||
return wrapper | |||
_emitted_deprecation_warnings = set() | |||
def deprecated(help_message: Optional[str] = None): | |||
"""Decorator to mark a function as deprecated. | |||
Args: | |||
help_message (`Optional[str]`): An optional message to guide the user on how to | |||
switch to non-deprecated usage of the library. | |||
""" | |||
def decorator(deprecated_function: Callable): | |||
global _emitted_deprecation_warnings | |||
warning_msg = ( | |||
( | |||
f"{deprecated_function.__name__} is deprecated and will be removed " | |||
"in the next major version of datasets." | |||
) | |||
+ f" {help_message}" | |||
if help_message | |||
else "" | |||
) | |||
@wraps(deprecated_function) | |||
def wrapper(*args, **kwargs): | |||
func_hash = hash(deprecated_function) | |||
if func_hash not in _emitted_deprecation_warnings: | |||
warnings.warn(warning_msg, category=FutureWarning, stacklevel=2) | |||
_emitted_deprecation_warnings.add(func_hash) | |||
return deprecated_function(*args, **kwargs) | |||
wrapper._decorator_name_ = "deprecated" | |||
return wrapper | |||
return decorator | |||
def seq_len_to_mask(seq_len, max_len=None): | |||
r""" | |||
将一个表示sequence length的一维数组转换为二维的mask,不包含的位置为0。 | |||
转变 1-d seq_len到2-d mask. | |||
.. code-block:: | |||
>>> seq_len = torch.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len) | |||
>>> print(mask.size()) | |||
torch.Size([14, 15]) | |||
>>> seq_len = np.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len) | |||
>>> print(mask.shape) | |||
(14, 15) | |||
>>> seq_len = torch.arange(2, 16) | |||
>>> mask = seq_len_to_mask(seq_len, max_len=100) | |||
>>>print(mask.size()) | |||
torch.Size([14, 100]) | |||
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,) | |||
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有 | |||
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。 | |||
:return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8 | |||
""" | |||
if isinstance(seq_len, np.ndarray): | |||
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}." | |||
max_len = int(max_len) if max_len else int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
else: | |||
raise TypeError("Only support 1-d numpy.ndarray.") | |||
return mask | |||
def wait_to_success(fn, no=False): | |||
while True: | |||
sleep(0.01) | |||
if (no and not fn()) or (not no and fn()): | |||
break | |||
# 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||
# 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||
def synchronize_safe_rm(path: Optional[Union[str, Path]]): | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
if not path.exists(): | |||
return | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
_recursive_rm(path) | |||
wait_to_success(path.exists, no=True) | |||
def _recursive_rm(path: Path): | |||
if path.is_file() or path.is_symlink(): | |||
if path.exists(): | |||
try: | |||
path.unlink() | |||
except Exception: | |||
pass | |||
return | |||
for sub_path in list(path.iterdir()): | |||
_recursive_rm(sub_path) | |||
path.rmdir() | |||
def synchronize_mkdir(path: Optional[Union[str, Path]]): | |||
""" | |||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | |||
""" | |||
if path is None: | |||
return | |||
if isinstance(path, str): | |||
path = Path(path) | |||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||
path.mkdir(parents=True, exist_ok=True) | |||
wait_to_success(path.exists) | |||
@@ -0,0 +1,585 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"Vocabulary", | |||
"VocabularyOption", | |||
] | |||
from collections import Counter | |||
from functools import partial | |||
from functools import wraps | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.utils.utils import Option | |||
from fastNLP.core.utils.utils import _is_iterable | |||
import io | |||
class VocabularyOption(Option): | |||
def __init__(self, | |||
max_size=None, | |||
min_freq=None, | |||
padding='<pad>', | |||
unknown='<unk>'): | |||
super().__init__( | |||
max_size=max_size, | |||
min_freq=min_freq, | |||
padding=padding, | |||
unknown=unknown | |||
) | |||
def _check_build_vocab(func): | |||
r"""A decorator to make sure the indexing is built before used. | |||
""" | |||
@wraps(func) # to solve missing docstring | |||
def _wrapper(self, *args, **kwargs): | |||
if self._word2idx is None or self.rebuild is True: | |||
self.build_vocab() | |||
return func(self, *args, **kwargs) | |||
return _wrapper | |||
def _check_build_status(func): | |||
r"""A decorator to check whether the vocabulary updates after the last build. | |||
""" | |||
@wraps(func) # to solve missing docstring | |||
def _wrapper(self, *args, **kwargs): | |||
if self.rebuild is False: | |||
self.rebuild = True | |||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
self.max_size, func.__name__)) | |||
return func(self, *args, **kwargs) | |||
return _wrapper | |||
class Vocabulary(object): | |||
r""" | |||
用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | |||
vocab = Vocabulary() | |||
word_list = "this is a word list".split() | |||
vocab.update(word_list) | |||
vocab["word"] # str to int | |||
vocab.to_word(5) # int to str | |||
""" | |||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
r""" | |||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<pad>' | |||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<unk>' | |||
""" | |||
self.max_size = max_size | |||
self.min_freq = min_freq | |||
self.word_count = Counter() | |||
self.unknown = unknown | |||
self.padding = padding | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | |||
self._no_create_word = Counter() | |||
@property | |||
@_check_build_vocab | |||
def word2idx(self): | |||
return self._word2idx | |||
@word2idx.setter | |||
def word2idx(self, value): | |||
self._word2idx = value | |||
@property | |||
@_check_build_vocab | |||
def idx2word(self): | |||
return self._idx2word | |||
@idx2word.setter | |||
def idx2word(self, value): | |||
self._word2idx = value | |||
@_check_build_status | |||
def update(self, word_lst, no_create_entry=False): | |||
r"""依次增加序列中词在词典中的出现频率 | |||
:param list word_lst: a list of strings | |||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||
则这个词将认为是需要创建单独的vector的。 | |||
""" | |||
self._add_no_create_entry(word_lst, no_create_entry) | |||
self.word_count.update(word_lst) | |||
return self | |||
@_check_build_status | |||
def add(self, word, no_create_entry=False): | |||
r""" | |||
增加一个新词在词典中的出现频率 | |||
:param str word: 新词 | |||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||
则这个词将认为是需要创建单独的vector的。 | |||
""" | |||
self._add_no_create_entry(word, no_create_entry) | |||
self.word_count[word] += 1 | |||
return self | |||
def _add_no_create_entry(self, word, no_create_entry): | |||
r""" | |||
在新加入word时,检查_no_create_word的设置。 | |||
:param str List[str] word: | |||
:param bool no_create_entry: | |||
:return: | |||
""" | |||
if isinstance(word, str) or not _is_iterable(word): | |||
word = [word] | |||
for w in word: | |||
if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): | |||
self._no_create_word[w] += 1 | |||
elif not no_create_entry and w in self._no_create_word: | |||
self._no_create_word.pop(w) | |||
@_check_build_status | |||
def add_word(self, word, no_create_entry=False): | |||
r""" | |||
增加一个新词在词典中的出现频率 | |||
:param str word: 新词 | |||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||
则这个词将认为是需要创建单独的vector的。 | |||
""" | |||
self.add(word, no_create_entry=no_create_entry) | |||
@_check_build_status | |||
def add_word_lst(self, word_lst, no_create_entry=False): | |||
r""" | |||
依次增加序列中词在词典中的出现频率 | |||
:param list[str] word_lst: 词的序列 | |||
:param bool no_create_entry: 如果词语来自于非训练集建议设置为True。在使用fastNLP.TokenEmbedding加载预训练模型时,没有从预训练词表中找到这个词的处理方式。 | |||
如果为True,则不会有这个词语创建一个单独的entry,它将一直被指向unk的表示; 如果为False,则为这个词创建一个单独 | |||
的entry。如果这个word来自于dev或者test,一般设置为True,如果来自与train一般设置为False。以下两种情况: 如果新 | |||
加入一个word,且no_create_entry为True,但这个词之前已经在Vocabulary中且并不是no_create_entry的,则还是会为这 | |||
个词创建一个单独的vector; 如果no_create_entry为False,但这个词之前已经在Vocabulary中且并不是no_create_entry的, | |||
则这个词将认为是需要创建单独的vector的。 | |||
""" | |||
self.update(word_lst, no_create_entry=no_create_entry) | |||
return self | |||
def build_vocab(self): | |||
r""" | |||
根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | |||
但已经记录在词典中的词, 不会改变对应的 `int` | |||
""" | |||
if self._word2idx is None: | |||
self._word2idx = {} | |||
if self.padding is not None: | |||
self._word2idx[self.padding] = len(self._word2idx) | |||
if (self.unknown is not None) and (self.unknown != self.padding): | |||
self._word2idx[self.unknown] = len(self._word2idx) | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
if self.min_freq is not None: | |||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
if self._word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self._word2idx, words) | |||
start_idx = len(self._word2idx) | |||
self._word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
self.rebuild = False | |||
return self | |||
def build_reverse_vocab(self): | |||
r""" | |||
基于 `word to index` dict, 构建 `index to word` dict. | |||
""" | |||
self._idx2word = {i: w for w, i in self._word2idx.items()} | |||
return self | |||
@_check_build_vocab | |||
def __len__(self): | |||
return len(self._word2idx) | |||
@_check_build_vocab | |||
def __contains__(self, item): | |||
r""" | |||
检查词是否被记录 | |||
:param item: the word | |||
:return: True or False | |||
""" | |||
return item in self._word2idx | |||
def has_word(self, w): | |||
r""" | |||
检查词是否被记录:: | |||
has_abc = vocab.has_word('abc') | |||
# equals to | |||
has_abc = 'abc' in vocab | |||
:param item: the word | |||
:return: ``True`` or ``False`` | |||
""" | |||
return self.__contains__(w) | |||
@_check_build_vocab | |||
def __getitem__(self, w): | |||
r""" | |||
To support usage like:: | |||
vocab[w] | |||
""" | |||
if w in self._word2idx: | |||
return self._word2idx[w] | |||
if self.unknown is not None: | |||
return self._word2idx[self.unknown] | |||
else: | |||
raise ValueError("word `{}` not in vocabulary".format(w)) | |||
@_check_build_vocab | |||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
r""" | |||
将DataSet中对应field的词转为数字,Example:: | |||
# remember to use `field_name` | |||
vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | |||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||
:param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||
目前支持 ``str`` , ``List[str]`` | |||
:param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||
Default: ``None``. | |||
""" | |||
def index_instance(field): | |||
r""" | |||
有几种情况, str, 1d-list, 2d-list | |||
:param ins: | |||
:return: | |||
""" | |||
if isinstance(field, str) or not _is_iterable(field): | |||
return self.to_index(field) | |||
else: | |||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||
return [self.to_index(w) for w in field] | |||
else: | |||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
return [[self.to_index(c) for c in w] for w in field] | |||
new_field_name = new_field_name or field_name | |||
if type(new_field_name) == type(field_name): | |||
if isinstance(new_field_name, list): | |||
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | |||
"field_name." | |||
elif isinstance(new_field_name, str): | |||
field_name = [field_name] | |||
new_field_name = [new_field_name] | |||
else: | |||
raise TypeError("field_name and new_field_name can only be str or List[str].") | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
for f_n, n_f_n in zip(field_name, new_field_name): | |||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | |||
except Exception as e: | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
raise e | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
return self | |||
@property | |||
def _no_create_word_length(self): | |||
return len(self._no_create_word) | |||
def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None): | |||
r""" | |||
使用dataset的对应field中词构建词典:: | |||
# remember to use `field_name` | |||
vocab.from_dataset(train_data1, train_data2, field_name='words') | |||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||
:param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | |||
构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 | |||
: ``str`` , ``List[str]`` | |||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认), 建议直接将非训练数据都传入到这个参数。该选项用在接下来的模型会使用pretrain | |||
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | |||
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | |||
如果一个词出现在了train中,但是没在预训练模型中,embedding会为它用unk初始化,但它是单独的一个vector,如果 | |||
finetune embedding的话,这个词在更新之后可能会有更好的表示; 而如果这个词仅出现在了dev或test中,那么就不能为它们单独建立vector, | |||
而应该让它指向unk这个vector的值。所以只位于no_create_entry_dataset中的token,将首先从预训练的词表中寻找它的表示, | |||
如果找到了,就使用该表示; 如果没有找到,则认为该词的表示应该为unk的表示。 | |||
:return self: | |||
""" | |||
if isinstance(field_name, str): | |||
field_name = [field_name] | |||
elif not isinstance(field_name, list): | |||
raise TypeError('invalid argument field_name: {}'.format(field_name)) | |||
def construct_vocab(ins, no_create_entry=False): | |||
for fn in field_name: | |||
field = ins[fn] | |||
if isinstance(field, str) or not _is_iterable(field): | |||
self.add_word(field, no_create_entry=no_create_entry) | |||
else: | |||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||
for word in field: | |||
self.add_word(word, no_create_entry=no_create_entry) | |||
else: | |||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
for words in field: | |||
for word in words: | |||
self.add_word(word, no_create_entry=no_create_entry) | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
dataset.apply(construct_vocab) | |||
except BaseException as e: | |||
print("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
raise e | |||
else: | |||
raise TypeError("Only DataSet type is allowed.") | |||
if no_create_entry_dataset is not None: | |||
partial_construct_vocab = partial(construct_vocab, no_create_entry=True) | |||
if isinstance(no_create_entry_dataset, DataSet): | |||
no_create_entry_dataset.apply(partial_construct_vocab) | |||
elif isinstance(no_create_entry_dataset, list): | |||
for dataset in no_create_entry_dataset: | |||
if not isinstance(dataset, DataSet): | |||
raise TypeError("Only DataSet type is allowed.") | |||
dataset.apply(partial_construct_vocab) | |||
return self | |||
def _is_word_no_create_entry(self, word): | |||
r""" | |||
判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明 | |||
:param word: str | |||
:return: bool | |||
""" | |||
return word in self._no_create_word | |||
def to_index(self, w): | |||
r""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | |||
index = vocab.to_index('abc') | |||
# equals to | |||
index = vocab['abc'] | |||
:param str w: a word | |||
:return int index: the number | |||
""" | |||
return self.__getitem__(w) | |||
@property | |||
@_check_build_vocab | |||
def unknown_idx(self): | |||
r""" | |||
unknown 对应的数字. | |||
""" | |||
if self.unknown is None: | |||
return None | |||
return self._word2idx[self.unknown] | |||
@property | |||
@_check_build_vocab | |||
def padding_idx(self): | |||
r""" | |||
padding 对应的数字 | |||
""" | |||
if self.padding is None: | |||
return None | |||
return self._word2idx[self.padding] | |||
@_check_build_vocab | |||
def to_word(self, idx): | |||
r""" | |||
给定一个数字, 将其转为对应的词. | |||
:param int idx: the index | |||
:return str word: the word | |||
""" | |||
return self._idx2word[idx] | |||
def clear(self): | |||
r""" | |||
删除Vocabulary中的词表数据。相当于重新初始化一下。 | |||
:return: | |||
""" | |||
self.word_count.clear() | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
self._no_create_word.clear() | |||
return self | |||
def __getstate__(self): | |||
r"""Use to prepare data for pickle. | |||
""" | |||
len(self) # make sure vocab has been built | |||
state = self.__dict__.copy() | |||
# no need to pickle _idx2word as it can be constructed from _word2idx | |||
del state['_idx2word'] | |||
return state | |||
def __setstate__(self, state): | |||
r"""Use to restore state from pickle. | |||
""" | |||
self.__dict__.update(state) | |||
self.build_reverse_vocab() | |||
def __repr__(self): | |||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | |||
@_check_build_vocab | |||
def __iter__(self): | |||
# 依次(word1, 0), (word1, 1) | |||
for index in range(len(self._word2idx)): | |||
yield self.to_word(index), index | |||
def save(self, filepath): | |||
r""" | |||
:param str,io.StringIO filepath: Vocabulary的储存路径 | |||
:return: | |||
""" | |||
if isinstance(filepath, io.IOBase): | |||
assert filepath.writable() | |||
f = filepath | |||
elif isinstance(filepath, str): | |||
try: | |||
f = open(filepath, 'w', encoding='utf-8') | |||
except Exception as e: | |||
raise e | |||
else: | |||
raise TypeError("Illegal `path`.") | |||
f.write(f'max_size\t{self.max_size}\n') | |||
f.write(f'min_freq\t{self.min_freq}\n') | |||
f.write(f'unknown\t{self.unknown}\n') | |||
f.write(f'padding\t{self.padding}\n') | |||
f.write(f'rebuild\t{self.rebuild}\n') | |||
f.write('\n') | |||
# idx: 如果idx为-2, 说明还没有进行build; 如果idx为-1,说明该词未编入 | |||
# no_create_entry: 如果为1,说明该词是no_create_entry; 0 otherwise | |||
# word \t count \t idx \t no_create_entry \n | |||
idx = -2 | |||
for word, count in self.word_count.items(): | |||
if self._word2idx is not None: | |||
idx = self._word2idx.get(word, -1) | |||
is_no_create_entry = int(self._is_word_no_create_entry(word)) | |||
f.write(f'{word}\t{count}\t{idx}\t{is_no_create_entry}\n') | |||
if isinstance(filepath, str): # 如果是file的话就关闭 | |||
f.close() | |||
@staticmethod | |||
def load(filepath): | |||
r""" | |||
:param str,io.StringIO filepath: Vocabulary的读取路径 | |||
:return: Vocabulary | |||
""" | |||
if isinstance(filepath, io.IOBase): | |||
assert filepath.writable() | |||
f = filepath | |||
elif isinstance(filepath, str): | |||
try: | |||
f = open(filepath, 'r', encoding='utf-8') | |||
except Exception as e: | |||
raise e | |||
else: | |||
raise TypeError("Illegal `path`.") | |||
vocab = Vocabulary() | |||
for line in f: | |||
line = line.strip('\n') | |||
if line: | |||
name, value = line.split() | |||
if name in ('max_size', 'min_freq'): | |||
value = int(value) if value!='None' else None | |||
setattr(vocab, name, value) | |||
elif name in ('unknown', 'padding'): | |||
value = value if value!='None' else None | |||
setattr(vocab, name, value) | |||
elif name == 'rebuild': | |||
vocab.rebuild = True if value=='True' else False | |||
else: | |||
break | |||
word_counter = {} | |||
no_create_entry_counter = {} | |||
word2idx = {} | |||
for line in f: | |||
line = line.strip('\n') | |||
if line: | |||
parts = line.split('\t') | |||
word,count,idx,no_create_entry = parts[0], int(parts[1]), int(parts[2]), int(parts[3]) | |||
if idx >= 0: | |||
word2idx[word] = idx | |||
word_counter[word] = count | |||
if no_create_entry: | |||
no_create_entry_counter[word] = count | |||
word_counter = Counter(word_counter) | |||
no_create_entry_counter = Counter(no_create_entry_counter) | |||
if len(word2idx)>0: | |||
if vocab.padding: | |||
word2idx[vocab.padding] = 0 | |||
if vocab.unknown: | |||
word2idx[vocab.unknown] = 1 if vocab.padding else 0 | |||
idx2word = {value:key for key,value in word2idx.items()} | |||
vocab.word_count = word_counter | |||
vocab._no_create_word = no_create_entry_counter | |||
if word2idx: | |||
vocab._word2idx = word2idx | |||
vocab._idx2word = idx2word | |||
if isinstance(filepath, str): # 如果是file的话就关闭 | |||
f.close() | |||
return vocab |
@@ -0,0 +1,121 @@ | |||
r""" | |||
用于IO的模块, 具体包括: | |||
1. 用于读入 embedding 的 :mod:`EmbedLoader <fastNLP.io.embed_loader>` 类, | |||
2. 用于读入不同格式数据的 :mod:`Loader <fastNLP.io.loader>` 类 | |||
3. 用于处理读入数据的 :mod:`Pipe <fastNLP.io.pipe>` 类 | |||
4. 用于保存和载入模型的类, 参考 :mod:`model_io模块 <fastNLP.io.model_io>` | |||
这些类的使用方法如下: | |||
""" | |||
__all__ = [ | |||
'DataBundle', | |||
'EmbedLoader', | |||
'Loader', | |||
'CLSBaseLoader', | |||
'AGsNewsLoader', | |||
'DBPediaLoader', | |||
'YelpFullLoader', | |||
'YelpPolarityLoader', | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
"THUCNewsLoader", | |||
"WeiboSenti100kLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'Conll2003NERLoader', | |||
'OntoNotesNERLoader', | |||
'CTBLoader', | |||
"MsraNERLoader", | |||
"WeiboNERLoader", | |||
"PeopleDailyNERLoader", | |||
'CSVLoader', | |||
'JsonLoader', | |||
'CWSLoader', | |||
'MNLILoader', | |||
"QuoraLoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"CNXNLILoader", | |||
"BQCorpusLoader", | |||
"LCQMCLoader", | |||
"CMRC2018Loader", | |||
"Pipe", | |||
"CLSBasePipe", | |||
"AGsNewsPipe", | |||
"DBPediaPipe", | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"THUCNewsPipe", | |||
"WeiboSenti100kPipe", | |||
"Conll2003Pipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"PeopleDailyPipe", | |||
"WeiboNERPipe", | |||
"CWSPipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"WeiboNERPipe", | |||
"PeopleDailyPipe", | |||
"Conll2003Pipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"CNXNLIBertPipe", | |||
"BQCorpusBertPipe", | |||
"LCQMCBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
"LCQMCPipe", | |||
"CNXNLIPipe", | |||
"BQCorpusPipe", | |||
"RenamePipe", | |||
"GranularizePipe", | |||
"MachingTruncatePipe", | |||
"CMRC2018BertPipe", | |||
'ModelLoader', | |||
'ModelSaver', | |||
] | |||
from .data_bundle import DataBundle | |||
from .embed_loader import EmbedLoader | |||
from .loader import * | |||
from .model_io import ModelLoader, ModelSaver | |||
from .pipe import * |
@@ -0,0 +1,97 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CWSLoader" | |||
] | |||
import glob | |||
import os | |||
import random | |||
import shutil | |||
import time | |||
from .loader import Loader | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class CWSLoader(Loader): | |||
r""" | |||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||
Example:: | |||
上海 浦东 开发 与 法制 建设 同步 | |||
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) | |||
... | |||
该Loader读取后的DataSet具有如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
""" | |||
def __init__(self, dataset_name: str = None): | |||
r""" | |||
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||
""" | |||
super().__init__() | |||
datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} | |||
if dataset_name in datanames: | |||
self.dataset_name = datanames[dataset_name] | |||
else: | |||
self.dataset_name = None | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self, dev_ratio=0.1, re_download=False) -> str: | |||
r""" | |||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str | |||
""" | |||
if self.dataset_name is None: | |||
return '' | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
modify_time = 0 | |||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||
modify_time = os.stat(filepath).st_mtime | |||
break | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
try: | |||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, 'train.txt')) | |||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||
return data_dir |
@@ -0,0 +1,354 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
'DataBundle', | |||
] | |||
from typing import Union, List, Callable | |||
from ..core.dataset import DataSet | |||
from fastNLP.core.vocabulary import Vocabulary | |||
# from ..core._logger import _logger | |||
class DataBundle: | |||
r""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | |||
Loader的load函数生成,可以通过以下的方法获取里面的内容 | |||
Example:: | |||
data_bundle = YelpLoader().load({'train':'/path/to/train', 'dev': '/path/to/dev'}) | |||
train_vocabs = data_bundle.vocabs['train'] | |||
train_data = data_bundle.datasets['train'] | |||
dev_data = data_bundle.datasets['train'] | |||
""" | |||
def __init__(self, vocabs=None, datasets=None): | |||
r""" | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在 | |||
使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入。 | |||
""" | |||
self.vocabs = vocabs or {} | |||
self.datasets = datasets or {} | |||
def set_vocab(self, vocab: Vocabulary, field_name: str): | |||
r""" | |||
向DataBunlde中增加vocab | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
:param str field_name: 这个vocab对应的field名称 | |||
:return: self | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." | |||
self.vocabs[field_name] = vocab | |||
return self | |||
def set_dataset(self, dataset: DataSet, name: str): | |||
r""" | |||
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet | |||
:param str name: dataset的名称 | |||
:return: self | |||
""" | |||
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports." | |||
self.datasets[name] = dataset | |||
return self | |||
def get_dataset(self, name: str) -> DataSet: | |||
r""" | |||
获取名为name的dataset | |||
:param str name: dataset的名称,一般为'train', 'dev', 'test' | |||
:return: DataSet | |||
""" | |||
if name in self.datasets.keys(): | |||
return self.datasets[name] | |||
else: | |||
error_msg = f'DataBundle do NOT have DataSet named {name}. ' \ | |||
f'It should be one of {self.datasets.keys()}.' | |||
print(error_msg) | |||
raise KeyError(error_msg) | |||
def delete_dataset(self, name: str): | |||
r""" | |||
删除名为name的DataSet | |||
:param str name: | |||
:return: self | |||
""" | |||
self.datasets.pop(name, None) | |||
return self | |||
def get_vocab(self, field_name: str) -> Vocabulary: | |||
r""" | |||
获取field名为field_name对应的vocab | |||
:param str field_name: 名称 | |||
:return: Vocabulary | |||
""" | |||
if field_name in self.vocabs.keys(): | |||
return self.vocabs[field_name] | |||
else: | |||
error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ | |||
f'It should be one of {self.vocabs.keys()}.' | |||
print(error_msg) | |||
raise KeyError(error_msg) | |||
def delete_vocab(self, field_name: str): | |||
r""" | |||
删除vocab | |||
:param str field_name: | |||
:return: self | |||
""" | |||
self.vocabs.pop(field_name, None) | |||
return self | |||
@property | |||
def num_dataset(self): | |||
return len(self.datasets) | |||
@property | |||
def num_vocab(self): | |||
return len(self.vocabs) | |||
def copy_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True): | |||
r""" | |||
将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. | |||
:param str field_name: | |||
:param str new_field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.copy_field(field_name=field_name, new_field_name=new_field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def rename_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True, rename_vocab=True): | |||
r""" | |||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | |||
:param str field_name: | |||
:param str new_field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.rename_field(field_name=field_name, new_field_name=new_field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if rename_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs[new_field_name] = self.vocabs.pop(field_name) | |||
return self | |||
def delete_field(self, field_name: str, ignore_miss_dataset=True, delete_vocab=True): | |||
r""" | |||
将DataBundle中所有DataSet中名为field_name的field删除掉. | |||
:param str field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.delete_field(field_name=field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if delete_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs.pop(field_name) | |||
return self | |||
def iter_datasets(self) -> Union[str, DataSet]: | |||
r""" | |||
迭代data_bundle中的DataSet | |||
Example:: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
pass | |||
:return: | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
yield name, dataset | |||
def get_dataset_names(self) -> List[str]: | |||
r""" | |||
返回DataBundle中DataSet的名称 | |||
:return: | |||
""" | |||
return list(self.datasets.keys()) | |||
def get_vocab_names(self) -> List[str]: | |||
r""" | |||
返回DataBundle中Vocabulary的名称 | |||
:return: | |||
""" | |||
return list(self.vocabs.keys()) | |||
def iter_vocabs(self): | |||
r""" | |||
迭代data_bundle中的DataSet | |||
Example: | |||
for field_name, vocab in data_bundle.iter_vocabs(): | |||
pass | |||
:return: | |||
""" | |||
for field_name, vocab in self.vocabs.items(): | |||
yield field_name, vocab | |||
def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, | |||
ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :method:`~fastNLP.DataSet.apply_field` 方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param ignore_miss_dataset: | |||
:param num_proc: | |||
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
:param show_progress_bar 是否显示tqdm进度条 | |||
""" | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if dataset.has_field(field_name=field_name): | |||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | |||
progress_desc=progress_desc, show_progress_bar=show_progress_bar) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | |||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||
.. note:: | |||
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param show_progress_bar: 是否显示tqdm进度条 | |||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
:param num_proc: | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
""" | |||
res = {} | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if dataset.has_field(field_name=field_name): | |||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | |||
modify_fields=modify_fields, | |||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name} .") | |||
return res | |||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||
对DataBundle中所有的dataset使用apply方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param _apply_field: | |||
:param show_progress_bar: 是否显示tqd进度条 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
:param num_proc | |||
""" | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | |||
progress_desc=progress_desc, _apply_field=_apply_field) | |||
return self | |||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||
.. note:: | |||
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 | |||
``apply`` 区别的介绍。 | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param show_progress_bar: 是否显示tqd进度条 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
:param num_proc | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
""" | |||
res = {} | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | |||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | |||
return res | |||
def set_pad_val(self, *field_names, val=0) -> None: | |||
for _, ds in self.iter_datasets(): | |||
ds.set_pad_val(*field_names, val=val) | |||
def set_input(self, *field_names) -> None: | |||
for _, ds in self.iter_datasets(): | |||
ds.set_input(*field_names) | |||
def __repr__(self) -> str: | |||
_str = '' | |||
if len(self.datasets): | |||
_str += 'In total {} datasets:\n'.format(self.num_dataset) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
if len(self.vocabs): | |||
_str += 'In total {} vocabs:\n'.format(self.num_vocab) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
@@ -0,0 +1,188 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"EmbedLoader", | |||
"EmbeddingOption", | |||
] | |||
import logging | |||
import os | |||
import warnings | |||
import numpy as np | |||
from fastNLP.core.utils.utils import Option | |||
from fastNLP.core.vocabulary import Vocabulary | |||
class EmbeddingOption(Option): | |||
def __init__(self, | |||
embed_filepath=None, | |||
dtype=np.float32, | |||
normalize=True, | |||
error='ignore'): | |||
super().__init__( | |||
embed_filepath=embed_filepath, | |||
dtype=dtype, | |||
normalize=normalize, | |||
error=error | |||
) | |||
class EmbedLoader: | |||
r""" | |||
用于读取预训练的embedding, 读取结果可直接载入为模型参数。 | |||
""" | |||
def __init__(self): | |||
super(EmbedLoader, self).__init__() | |||
@staticmethod | |||
def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore', init_method=None): | |||
r""" | |||
从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 | |||
word2vec(第一行只有两个元素)还是glove格式的数据。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 | |||
没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 | |||
:param dtype: 读出的embedding的类型 | |||
:param str padding: 词表中padding的token | |||
:param str unknown: 词表中unknown的token | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 | |||
这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 | |||
:param callable init_method: 传入numpy.ndarray, 返回numpy.ndarray, 用以初始化embedding | |||
:return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." | |||
if not os.path.exists(embed_filepath): | |||
raise FileNotFoundError("`{}` does not exist.".format(embed_filepath)) | |||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
hit_flags = np.zeros(len(vocab), dtype=bool) | |||
line = f.readline().strip() | |||
parts = line.split() | |||
start_idx = 0 | |||
if len(parts) == 2: | |||
dim = int(parts[1]) | |||
start_idx += 1 | |||
else: | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
if init_method: | |||
matrix = init_method(matrix) | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
parts = line.strip().split() | |||
word = ''.join(parts[:-dim]) | |||
nums = parts[-dim:] | |||
# 对齐unk与pad | |||
if word == padding and vocab.padding is not None: | |||
word = vocab.padding | |||
elif word == unknown and vocab.unknown is not None: | |||
word = vocab.unknown | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
matrix[index] = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) | |||
hit_flags[index] = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
logging.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
total_hits = sum(hit_flags) | |||
logging.info("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab))) | |||
if init_method is None: | |||
found_vectors = matrix[hit_flags] | |||
if len(found_vectors) != 0: | |||
mean = np.mean(found_vectors, axis=0, keepdims=True) | |||
std = np.std(found_vectors, axis=0, keepdims=True) | |||
unfound_vec_num = len(vocab) - total_hits | |||
r_vecs = np.random.randn(unfound_vec_num, dim).astype(dtype) * std + mean | |||
matrix[hit_flags == False] = r_vecs | |||
if normalize: | |||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
return matrix | |||
@staticmethod | |||
def load_without_vocab(embed_filepath, dtype=np.float32, padding='<pad>', unknown='<unk>', normalize=True, | |||
error='ignore'): | |||
r""" | |||
从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 | |||
:param str embed_filepath: 预训练的embedding的路径。 | |||
:param dtype: 读出的embedding的类型 | |||
:param str padding: 词表中的padding的token. 并以此用做vocab的padding。 | |||
:param str unknown: 词表中的unknown的token. 并以此用做vocab的unknown。 | |||
:param bool normalize: 是否将每个vector归一化到norm为1 | |||
:param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 | |||
方在于词表有空行或者词表出现了维度不一致。 | |||
:return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 | |||
是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 | |||
""" | |||
vocab = Vocabulary(padding=padding, unknown=unknown) | |||
vec_dict = {} | |||
found_unknown = False | |||
found_pad = False | |||
with open(embed_filepath, 'r', encoding='utf-8') as f: | |||
line = f.readline() | |||
start = 1 | |||
dim = -1 | |||
if len(line.strip().split()) != 2: | |||
f.seek(0) | |||
start = 0 | |||
for idx, line in enumerate(f, start=start): | |||
try: | |||
parts = line.strip().split() | |||
if dim == -1: | |||
dim = len(parts) - 1 | |||
word = ''.join(parts[:-dim]) | |||
nums = parts[-dim:] | |||
vec = np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim) | |||
vec_dict[word] = vec | |||
vocab.add_word(word) | |||
if unknown is not None and unknown == word: | |||
found_unknown = True | |||
if padding is not None and padding == word: | |||
found_pad = True | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
pass | |||
else: | |||
logging.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
if dim == -1: | |||
raise RuntimeError("{} is an empty file.".format(embed_filepath)) | |||
matrix = np.random.randn(len(vocab), dim).astype(dtype) | |||
for key, vec in vec_dict.items(): | |||
index = vocab.to_index(key) | |||
matrix[index] = vec | |||
if ((unknown is not None) and (not found_unknown)) or ((padding is not None) and (not found_pad)): | |||
start_idx = 0 | |||
if padding is not None: | |||
start_idx += 1 | |||
if unknown is not None: | |||
start_idx += 1 | |||
mean = np.mean(matrix[start_idx:], axis=0, keepdims=True) | |||
std = np.std(matrix[start_idx:], axis=0, keepdims=True) | |||
if (unknown is not None) and (not found_unknown): | |||
matrix[start_idx - 1] = np.random.randn(1, dim).astype(dtype) * std + mean | |||
if (padding is not None) and (not found_pad): | |||
matrix[0] = np.random.randn(1, dim).astype(dtype) * std + mean | |||
if normalize: | |||
matrix /= np.linalg.norm(matrix, axis=1, keepdims=True) | |||
return matrix, vocab |
@@ -0,0 +1,136 @@ | |||
r"""undocumented | |||
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API | |||
""" | |||
__all__ = [] | |||
import json | |||
import csv | |||
# from ..core import log | |||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
r""" | |||
Construct a generator to read csv items. | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param headers: file's headers, if None, make file's first line as headers. default: None | |||
:param sep: separator for each column. default: ',' | |||
:param dropna: weather to ignore and drop invalid data, | |||
:if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, csv item) | |||
""" | |||
with open(path, 'r', encoding=encoding) as csv_file: | |||
f = csv.reader(csv_file, delimiter=sep) | |||
start_idx = 0 | |||
if headers is None: | |||
headers = next(f) | |||
start_idx += 1 | |||
elif not isinstance(headers, (list, tuple)): | |||
raise TypeError("headers should be list or tuple, not {}." \ | |||
.format(type(headers))) | |||
for line_idx, line in enumerate(f, start_idx): | |||
contents = line | |||
if len(contents) != len(headers): | |||
if dropna: | |||
continue | |||
else: | |||
if "" in headers: | |||
raise ValueError(("Line {} has {} parts, while header has {} parts.\n" + | |||
"Please check the empty parts or unnecessary '{}'s in header.") | |||
.format(line_idx, len(contents), len(headers), sep)) | |||
else: | |||
raise ValueError("Line {} has {} parts, while header has {} parts." \ | |||
.format(line_idx, len(contents), len(headers))) | |||
_dict = {} | |||
for header, content in zip(headers, contents): | |||
_dict[header] = content | |||
yield line_idx, _dict | |||
def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||
r""" | |||
Construct a generator to read json items. | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param fields: json object's fields that needed, if None, all fields are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
:if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, json item) | |||
""" | |||
if fields: | |||
fields = set(fields) | |||
with open(path, 'r', encoding=encoding) as f: | |||
for line_idx, line in enumerate(f): | |||
data = json.loads(line) | |||
if fields is None: | |||
yield line_idx, data | |||
continue | |||
_res = {} | |||
for k, v in data.items(): | |||
if k in fields: | |||
_res[k] = v | |||
if len(_res) < len(fields): | |||
if dropna: | |||
continue | |||
else: | |||
raise ValueError('invalid instance at line: {}'.format(line_idx)) | |||
yield line_idx, _res | |||
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
r""" | |||
Construct a generator to read conll items. | |||
:param path: file path | |||
:param encoding: file's encoding, default: utf-8 | |||
:param sep: seperator | |||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
:if False, raise ValueError when reading invalid data. default: True | |||
:return: generator, every time yield (line number, conll item) | |||
""" | |||
def parse_conll(sample): | |||
sample = list(map(list, zip(*sample))) | |||
sample = [sample[i] for i in indexes] | |||
for f in sample: | |||
if len(f) <= 0: | |||
raise ValueError('empty field') | |||
return sample | |||
with open(path, 'r', encoding=encoding) as f: | |||
sample = [] | |||
start = next(f).strip() | |||
if start != '': | |||
sample.append(start.split(sep)) if sep else sample.append(start.split()) | |||
for line_idx, line in enumerate(f, 1): | |||
line = line.strip() | |||
if line == '': | |||
if len(sample): | |||
try: | |||
res = parse_conll(sample) | |||
sample = [] | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
print('Invalid instance which ends at line: {} has been dropped.'.format(line_idx)) | |||
sample = [] | |||
continue | |||
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | |||
elif line.startswith('#'): | |||
continue | |||
else: | |||
sample.append(line.split(sep)) if sep else sample.append(line.split()) | |||
if len(sample) > 0: | |||
try: | |||
res = parse_conll(sample) | |||
yield line_idx, res | |||
except Exception as e: | |||
if dropna: | |||
return | |||
print('invalid instance ends at line: {}'.format(line_idx)) | |||
raise e |
@@ -0,0 +1,578 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"cached_path", | |||
"get_filepath", | |||
"get_cache_path", | |||
"split_filename_suffix", | |||
"get_from_cache", | |||
] | |||
import os | |||
import re | |||
import shutil | |||
import tempfile | |||
from pathlib import Path | |||
from urllib.parse import urlparse | |||
import requests | |||
from requests import HTTPError | |||
from fastNLP.core.log import logger | |||
from rich.progress import Progress, BarColumn, DownloadColumn, TimeRemainingColumn, TimeElapsedColumn | |||
PRETRAINED_BERT_MODEL_DIR = { | |||
'en': 'bert-base-cased.zip', | |||
'en-large-cased-wwm': 'bert-large-cased-wwm.zip', | |||
'en-large-uncased-wwm': 'bert-large-uncased-wwm.zip', | |||
'en-large-uncased': 'bert-large-uncased.zip', | |||
'en-large-cased': 'bert-large-cased.zip', | |||
'en-base-uncased': 'bert-base-uncased.zip', | |||
'en-base-cased': 'bert-base-cased.zip', | |||
'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip', | |||
'en-distilbert-base-uncased': 'distilbert-base-uncased.zip', | |||
'multi-base-cased': 'bert-base-multilingual-cased.zip', | |||
'multi-base-uncased': 'bert-base-multilingual-uncased.zip', | |||
'cn': 'bert-chinese-wwm.zip', | |||
'cn-base': 'bert-base-chinese.zip', | |||
'cn-wwm': 'bert-chinese-wwm.zip', | |||
'cn-wwm-ext': "bert-chinese-wwm-ext.zip" | |||
} | |||
PRETRAINED_GPT2_MODEL_DIR = { | |||
'en': 'gpt2.zip', | |||
'en-medium': 'gpt2-medium.zip', | |||
'en-large': 'gpt2-large.zip', | |||
'en-xl': 'gpt2-xl.zip' | |||
} | |||
PRETRAINED_ROBERTA_MODEL_DIR = { | |||
'en': 'roberta-base.zip', | |||
'en-large': 'roberta-large.zip' | |||
} | |||
PRETRAINED_ELMO_MODEL_DIR = { | |||
'en': 'elmo_en_Medium.zip', | |||
'en-small': "elmo_en_Small.zip", | |||
'en-original-5.5b': 'elmo_en_Original_5.5B.zip', | |||
'en-original': 'elmo_en_Original.zip', | |||
'en-medium': 'elmo_en_Medium.zip' | |||
} | |||
PRETRAIN_STATIC_FILES = { | |||
'en': 'glove.840B.300d.zip', | |||
'en-glove-6b-50d': 'glove.6B.50d.zip', | |||
'en-glove-6b-100d': 'glove.6B.100d.zip', | |||
'en-glove-6b-200d': 'glove.6B.200d.zip', | |||
'en-glove-6b-300d': 'glove.6B.300d.zip', | |||
'en-glove-42b-300d': 'glove.42B.300d.zip', | |||
'en-glove-840b-300d': 'glove.840B.300d.zip', | |||
'en-glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip', | |||
'en-glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip', | |||
'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip', | |||
'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip', | |||
'en-word2vec-300d': "GoogleNews-vectors-negative300.txt.gz", | |||
'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip", | |||
'en-fasttext-crawl': "crawl-300d-2M.vec.zip", | |||
'cn': "tencent_cn.zip", | |||
'cn-tencent': "tencent_cn.zip", | |||
'cn-fasttext': "cc.zh.300.vec.gz", | |||
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', | |||
'cn-char-fastnlp-100d': "cn_char_fastnlp_100d.zip", | |||
'cn-bi-fastnlp-100d': "cn_bi_fastnlp_100d.zip", | |||
"cn-tri-fastnlp-100d": "cn_tri_fastnlp_100d.zip" | |||
} | |||
DATASET_DIR = { | |||
# Classification, English | |||
'aclImdb': "imdb.zip", | |||
"yelp-review-full": "yelp_review_full.tar.gz", | |||
"yelp-review-polarity": "yelp_review_polarity.tar.gz", | |||
"sst-2": "SST-2.zip", | |||
"sst": "SST.zip", | |||
'mr': 'mr.zip', | |||
"R8": "R8.zip", | |||
"R52": "R52.zip", | |||
"20ng": "20ng.zip", | |||
"ohsumed": "ohsumed.zip", | |||
# Classification, Chinese | |||
"chn-senti-corp": "chn_senti_corp.zip", | |||
"weibo-senti-100k": "WeiboSenti100k.zip", | |||
"thuc-news": "THUCNews.zip", | |||
# Matching, English | |||
"mnli": "MNLI.zip", | |||
"snli": "SNLI.zip", | |||
"qnli": "QNLI.zip", | |||
"rte": "RTE.zip", | |||
# Matching, Chinese | |||
"cn-xnli": "XNLI.zip", | |||
# Sequence Labeling, Chinese | |||
"msra-ner": "MSRA_NER.zip", | |||
"peopledaily": "peopledaily.zip", | |||
"weibo-ner": "weibo_NER.zip", | |||
# Chinese Word Segmentation | |||
"cws-pku": 'cws_pku.zip', | |||
"cws-cityu": "cws_cityu.zip", | |||
"cws-as": 'cws_as.zip', | |||
"cws-msra": 'cws_msra.zip', | |||
# Summarization, English | |||
"ext-cnndm": "ext-cnndm.zip", | |||
# Question & answer, Chinese | |||
"cmrc2018": "cmrc2018.zip" | |||
} | |||
PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR, | |||
"bert": PRETRAINED_BERT_MODEL_DIR, | |||
"static": PRETRAIN_STATIC_FILES, | |||
'gpt2': PRETRAINED_GPT2_MODEL_DIR, | |||
'roberta': PRETRAINED_ROBERTA_MODEL_DIR} | |||
# 用于扩展fastNLP的下载 | |||
FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt' | |||
FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', | |||
'bert': 'fastnlp_bert_url.txt', | |||
'static': 'fastnlp_static_url.txt', | |||
'gpt2': 'fastnlp_gpt2_url.txt', | |||
'roberta': 'fastnlp_roberta_url.txt' | |||
} | |||
def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: | |||
r""" | |||
给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, | |||
1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir | |||
2. 如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} | |||
如果有该文件,就直接返回路径 | |||
如果没有该文件,则尝试用传入的url下载 | |||
或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 | |||
将文件放入到cache_dir中. | |||
:param str url_or_filename: 文件的下载url或者文件名称。 | |||
:param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径 | |||
:param str name: 中间一层的名称。如embedding, dataset | |||
:return: | |||
""" | |||
if cache_dir is None: | |||
data_cache = Path(get_cache_path()) | |||
else: | |||
data_cache = cache_dir | |||
if name: | |||
data_cache = os.path.join(data_cache, name) | |||
parsed = urlparse(url_or_filename) | |||
if parsed.scheme in ("http", "https"): | |||
# URL, so get it from the cache (downloading if necessary) | |||
return get_from_cache(url_or_filename, Path(data_cache)) | |||
elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists(): | |||
# File, and it exists. | |||
return Path(os.path.join(data_cache, url_or_filename)) | |||
elif parsed.scheme == "": | |||
# File, but it doesn't exist. | |||
raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache)) | |||
else: | |||
# Something unknown | |||
raise ValueError( | |||
"unable to parse {} as a URL or as a local path".format(url_or_filename) | |||
) | |||
def get_filepath(filepath): | |||
r""" | |||
如果filepath为文件夹, | |||
如果内含多个文件, 返回filepath | |||
如果只有一个文件, 返回filepath + filename | |||
如果filepath为文件 | |||
返回filepath | |||
:param str filepath: 路径 | |||
:return: | |||
""" | |||
if os.path.isdir(filepath): | |||
files = os.listdir(filepath) | |||
if len(files) == 1: | |||
return os.path.join(filepath, files[0]) | |||
else: | |||
return filepath | |||
elif os.path.isfile(filepath): | |||
return filepath | |||
else: | |||
raise FileNotFoundError(f"{filepath} is not a valid file or directory.") | |||
def get_cache_path(): | |||
r""" | |||
获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 | |||
:return str: 存放路径 | |||
""" | |||
if 'FASTNLP_CACHE_DIR' in os.environ: | |||
fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') | |||
if os.path.isdir(fastnlp_cache_dir): | |||
return fastnlp_cache_dir | |||
else: | |||
raise NotADirectoryError(f"{os.environ['FASTNLP_CACHE_DIR']} is not a directory.") | |||
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP")) | |||
return fastnlp_cache_dir | |||
def _get_base_url(name): | |||
r""" | |||
根据name返回下载的url地址。 | |||
:param str name: 支持dataset和embedding两种 | |||
:return: | |||
""" | |||
# 返回的URL结尾必须是/ | |||
environ_name = "FASTNLP_{}_URL".format(name.upper()) | |||
if environ_name in os.environ: | |||
url = os.environ[environ_name] | |||
if url.endswith('/'): | |||
return url | |||
else: | |||
return url + '/' | |||
else: | |||
URLS = { | |||
'embedding': "http://download.fastnlp.top/embedding/", | |||
"dataset": "http://download.fastnlp.top/dataset/" | |||
} | |||
if name.lower() not in URLS: | |||
raise KeyError(f"{name} is not recognized.") | |||
return URLS[name.lower()] | |||
def _get_embedding_url(embed_type, name): | |||
r""" | |||
给定embedding类似和名称,返回下载url | |||
:param str embed_type: 支持static, bert, elmo。即embedding的类型 | |||
:param str name: embedding的名称, 例如en, cn, based等 | |||
:return: str, 下载的url地址 | |||
""" | |||
# 从扩展中寻找下载的url | |||
_filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None) | |||
if _filename: | |||
url = _read_extend_url_file(_filename, name) | |||
if url: | |||
return url | |||
embed_map = PRETRAIN_MAP.get(embed_type, None) | |||
if embed_map: | |||
filename = embed_map.get(name, None) | |||
if filename: | |||
url = _get_base_url('embedding') + filename | |||
return url | |||
raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys()))) | |||
else: | |||
raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static, gpt2, roberta") | |||
def _read_extend_url_file(filename, name) -> str: | |||
r""" | |||
filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址 | |||
:param str filename: 在默认的路径下寻找file这个文件 | |||
:param str name: 需要寻找的资源的名称 | |||
:return: str,None | |||
""" | |||
cache_dir = get_cache_path() | |||
filepath = os.path.join(cache_dir, filename) | |||
if os.path.exists(filepath): | |||
with open(filepath, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
if len(parts) == 2: | |||
if name == parts[0]: | |||
return parts[1] | |||
return None | |||
def _get_dataset_url(name, dataset_dir: dict = None): | |||
r""" | |||
给定dataset的名称,返回下载url | |||
:param str name: 给定dataset的名称,比如imdb, sst-2等 | |||
:return: str | |||
""" | |||
# 从扩展中寻找下载的url | |||
url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name) | |||
if url: | |||
return url | |||
dataset_dir = DATASET_DIR if dataset_dir is None else dataset_dir | |||
filename = dataset_dir.get(name, None) | |||
if filename: | |||
url = _get_base_url('dataset') + filename | |||
return url | |||
else: | |||
raise KeyError(f"There is no {name}.") | |||
def split_filename_suffix(filepath): | |||
r""" | |||
给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 | |||
:param filepath: 文件路径 | |||
:return: filename, suffix | |||
""" | |||
filename = os.path.basename(filepath) | |||
if filename.endswith('.tar.gz'): | |||
return filename[:-7], '.tar.gz' | |||
return os.path.splitext(filename) | |||
def get_from_cache(url: str, cache_dir: Path = None) -> Path: | |||
r""" | |||
尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 | |||
文件解压,将解压后的文件全部放在cache_dir文件夹中。 | |||
如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 | |||
:param url: 资源的 url | |||
:param cache_dir: cache 目录 | |||
:return: 路径 | |||
""" | |||
cache_dir.mkdir(parents=True, exist_ok=True) | |||
filename = re.sub(r".+/", "", url) | |||
dir_name, suffix = split_filename_suffix(filename) | |||
# 寻找与它名字匹配的内容, 而不关心后缀 | |||
match_dir_name = match_file(dir_name, cache_dir) | |||
if match_dir_name: | |||
dir_name = match_dir_name | |||
cache_path = cache_dir / dir_name | |||
# get cache path to put the file | |||
if cache_path.exists(): | |||
return get_filepath(cache_path) | |||
# make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上 | |||
# response = requests.head(url, headers={"User-Agent": "fastNLP"}) | |||
# if response.status_code != 200: | |||
# raise IOError( | |||
# f"HEAD request failed for url {url} with status code {response.status_code}." | |||
# ) | |||
# add ETag to filename if it exists | |||
# etag = response.headers.get("ETag") | |||
if not cache_path.exists(): | |||
# Download to temporary file, then copy to cache dir once finished. | |||
# Otherwise you get corrupt cache entries if the download gets interrupted. | |||
# GET file object | |||
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"}) | |||
if req.status_code == 200: | |||
success = False | |||
fd, temp_filename = tempfile.mkstemp() | |||
uncompress_temp_dir = None | |||
try: | |||
content_length = req.headers.get("Content-Length") | |||
total = int(content_length) if content_length is not None else None | |||
# progress = tqdm(unit="B", total=total, unit_scale=1) | |||
progress = Progress( | |||
BarColumn(), | |||
TimeElapsedColumn(), | |||
"/", | |||
TimeRemainingColumn(), | |||
DownloadColumn() | |||
) | |||
task = progress.add_task(total=total, description='download') | |||
progress.start() | |||
logger.info("%s not found in cache, downloading to %s" % (url, temp_filename)) | |||
with open(temp_filename, "wb") as temp_file: | |||
for chunk in req.iter_content(chunk_size=1024 * 16): | |||
if chunk: # filter out keep-alive new chunks | |||
progress.update(task, advance=len(chunk)) | |||
temp_file.write(chunk) | |||
progress.stop() | |||
progress.remove_task(task) | |||
logger.info(f"Finish download from {url}") | |||
# 开始解压 | |||
if suffix in ('.zip', '.tar.gz', '.gz'): | |||
uncompress_temp_dir = tempfile.mkdtemp() | |||
logger.info(f"Start to uncompress file to {uncompress_temp_dir}") | |||
if suffix == '.zip': | |||
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
elif suffix == '.gz': | |||
ungzip_file(temp_filename, uncompress_temp_dir, dir_name) | |||
else: | |||
untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir)) | |||
filenames = os.listdir(uncompress_temp_dir) | |||
if len(filenames) == 1: | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])): | |||
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) | |||
cache_path.mkdir(parents=True, exist_ok=True) | |||
logger.info("Finish un-compressing file.") | |||
else: | |||
uncompress_temp_dir = temp_filename | |||
cache_path = str(cache_path) + suffix | |||
# 复制到指定的位置 | |||
logger.info(f"Copy file to {cache_path}") | |||
if os.path.isdir(uncompress_temp_dir): | |||
for filename in os.listdir(uncompress_temp_dir): | |||
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): | |||
shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path / filename) | |||
else: | |||
shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path / filename) | |||
else: | |||
shutil.copyfile(uncompress_temp_dir, cache_path) | |||
success = True | |||
except Exception as e: | |||
logger.info(e) | |||
raise e | |||
finally: | |||
if not success: | |||
if cache_path.exists(): | |||
if cache_path.is_file(): | |||
os.remove(cache_path) | |||
else: | |||
shutil.rmtree(cache_path) | |||
os.close(fd) | |||
os.remove(temp_filename) | |||
if uncompress_temp_dir is None: | |||
pass | |||
elif os.path.isdir(uncompress_temp_dir): | |||
shutil.rmtree(uncompress_temp_dir) | |||
elif os.path.isfile(uncompress_temp_dir): | |||
os.remove(uncompress_temp_dir) | |||
return get_filepath(cache_path) | |||
else: | |||
raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.") | |||
def unzip_file(file: Path, to: Path): | |||
# unpack and write out in CoNLL column-like format | |||
from zipfile import ZipFile | |||
with ZipFile(file, "r") as zipObj: | |||
# Extract all the contents of zip file in current directory | |||
zipObj.extractall(to) | |||
def untar_gz_file(file: Path, to: Path): | |||
import tarfile | |||
with tarfile.open(file, 'r:gz') as tar: | |||
tar.extractall(to) | |||
def ungzip_file(file: str, to: str, filename: str): | |||
import gzip | |||
g_file = gzip.GzipFile(file) | |||
with open(os.path.join(to, filename), 'wb+') as f: | |||
f.write(g_file.read()) | |||
g_file.close() | |||
def match_file(dir_name: str, cache_dir: Path) -> str: | |||
r""" | |||
匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。 | |||
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串 | |||
:param dir_name: 需要匹配的名称 | |||
:param cache_dir: 在该目录下找匹配dir_name是否存在 | |||
:return str: 做为匹配结果的字符串 | |||
""" | |||
files = os.listdir(cache_dir) | |||
matched_filenames = [] | |||
for file_name in files: | |||
if re.match(dir_name + '$', file_name) or re.match(dir_name + '\\..*', file_name): | |||
matched_filenames.append(file_name) | |||
if len(matched_filenames) == 0: | |||
return '' | |||
elif len(matched_filenames) == 1: | |||
return matched_filenames[-1] | |||
else: | |||
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.") | |||
def _get_bert_dir(model_dir_or_name: str = 'en-base-uncased'): | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
else: | |||
logger.info(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") | |||
raise ValueError(f"Cannot recognize BERT dir or name ``{model_dir_or_name}``.") | |||
return str(model_dir) | |||
def _get_gpt2_dir(model_dir_or_name: str = 'en'): | |||
if model_dir_or_name.lower() in PRETRAINED_GPT2_MODEL_DIR: | |||
model_url = _get_embedding_url('gpt2', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
else: | |||
logger.info(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") | |||
raise ValueError(f"Cannot recognize GPT2 dir or name ``{model_dir_or_name}``.") | |||
return str(model_dir) | |||
def _get_roberta_dir(model_dir_or_name: str = 'en'): | |||
if model_dir_or_name.lower() in PRETRAINED_ROBERTA_MODEL_DIR: | |||
model_url = _get_embedding_url('roberta', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
else: | |||
logger.info(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.") | |||
raise ValueError(f"Cannot recognize RoBERTa dir or name ``{model_dir_or_name}``.") | |||
return str(model_dir) | |||
def _get_file_name_base_on_postfix(dir_path, postfix): | |||
r""" | |||
在dir_path中寻找后缀为postfix的文件. | |||
:param dir_path: str, 文件夹 | |||
:param postfix: 形如".bin", ".json"等 | |||
:return: str,文件的路径 | |||
""" | |||
files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path)))) | |||
if len(files) == 0: | |||
raise FileNotFoundError(f"There is no file endswith {postfix} file in {dir_path}") | |||
elif len(files) > 1: | |||
raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}") | |||
return os.path.join(dir_path, files[0]) |
@@ -0,0 +1,107 @@ | |||
r""" | |||
Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 | |||
三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, | |||
读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; | |||
``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: | |||
0.传入None | |||
将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 | |||
1.传入一个文件的 path | |||
返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.get_dataset('train')`` 获取 | |||
2.传入一个文件夹目录 | |||
将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为:: | |||
| | |||
+-train.txt | |||
+-dev.txt | |||
+-test.txt | |||
+-other.txt | |||
在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , | |||
``data_bundle.get_dataset('dev')`` , | |||
``data_bundle.get_dataset('test')`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为:: | |||
| | |||
+-train.txt | |||
+-dev.txt | |||
在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , | |||
``data_bundle.get_dataset('dev')`` 获取对应的 dataset。 | |||
3.传入一个字典 | |||
字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径:: | |||
paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'} | |||
在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , ``data_bundle.get_dataset('dev')`` , | |||
``data_bundle.get_dataset('test')`` 来获取对应的 `dataset` | |||
fastNLP 目前提供了如下的 Loader | |||
""" | |||
__all__ = [ | |||
'Loader', | |||
'CLSBaseLoader', | |||
'YelpFullLoader', | |||
'YelpPolarityLoader', | |||
'AGsNewsLoader', | |||
'DBPediaLoader', | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
"THUCNewsLoader", | |||
"WeiboSenti100kLoader", | |||
"MRLoader", | |||
"R8Loader", "R52Loader", "OhsumedLoader", "NG20Loader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'Conll2003NERLoader', | |||
'OntoNotesNERLoader', | |||
'CTBLoader', | |||
"MsraNERLoader", | |||
"PeopleDailyNERLoader", | |||
"WeiboNERLoader", | |||
'CSVLoader', | |||
'JsonLoader', | |||
'CWSLoader', | |||
'MNLILoader', | |||
"QuoraLoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"CNXNLILoader", | |||
"BQCorpusLoader", | |||
"LCQMCLoader", | |||
"CoReferenceLoader", | |||
"CMRC2018Loader" | |||
] | |||
from .classification import CLSBaseLoader, YelpFullLoader, YelpPolarityLoader, AGsNewsLoader, IMDBLoader, \ | |||
SSTLoader, SST2Loader, DBPediaLoader, \ | |||
ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, \ | |||
MRLoader, R8Loader, R52Loader, OhsumedLoader, NG20Loader | |||
from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader | |||
from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader | |||
from .coreference import CoReferenceLoader | |||
from .csv import CSVLoader | |||
from .cws import CWSLoader | |||
from .json import JsonLoader | |||
from .loader import Loader | |||
from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, CNXNLILoader, BQCorpusLoader, \ | |||
LCQMCLoader | |||
from .qa import CMRC2018Loader | |||
@@ -0,0 +1,647 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CLSBaseLoader", | |||
"YelpFullLoader", | |||
"YelpPolarityLoader", | |||
"AGsNewsLoader", | |||
"DBPediaLoader", | |||
"IMDBLoader", | |||
"SSTLoader", | |||
"SST2Loader", | |||
"ChnSentiCorpLoader", | |||
"THUCNewsLoader", | |||
"WeiboSenti100kLoader", | |||
"MRLoader", | |||
"R8Loader", | |||
"R52Loader", | |||
"OhsumedLoader", | |||
"NG20Loader", | |||
] | |||
import glob | |||
import os | |||
import random | |||
import shutil | |||
import time | |||
import warnings | |||
from .loader import Loader | |||
from fastNLP.core.dataset import Instance, DataSet | |||
# from ...core._logger import log | |||
class CLSBaseLoader(Loader): | |||
r""" | |||
文本分类Loader的一个基类 | |||
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 | |||
Example:: | |||
"1","I got 'new' tires from the..." | |||
"1","Don't waste your time..." | |||
读取的DataSet将具备以下的数据结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"I got 'new' tires from them and... ", "1" | |||
"Don't waste your time. We had two...", "1" | |||
"...", "..." | |||
""" | |||
def __init__(self, sep=',', has_header=False): | |||
super().__init__() | |||
self.sep = sep | |||
self.has_header = has_header | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
try: | |||
with open(path, 'r', encoding='utf-8') as f: | |||
read_header = self.has_header | |||
for line in f: | |||
if read_header: | |||
read_header = False | |||
continue | |||
line = line.strip() | |||
sep_index = line.index(self.sep) | |||
target = line[:sep_index] | |||
raw_words = line[sep_index + 1:] | |||
if target.startswith("\""): | |||
target = target[1:] | |||
if target.endswith("\""): | |||
target = target[:-1] | |||
if raw_words.endswith("\""): | |||
raw_words = raw_words[:-1] | |||
if raw_words.startswith('"'): | |||
raw_words = raw_words[1:] | |||
raw_words = raw_words.replace('""', '"') # 替换双引号 | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words, target=target)) | |||
except Exception as e: | |||
print(f'Load file `{path}` failed for `{e}`') | |||
return ds | |||
def _split_dev(dataset_name, data_dir, dev_ratio=0.0, re_download=False, suffix='csv'): | |||
if dev_ratio == 0.0: | |||
return data_dir | |||
modify_time = 0 | |||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||
modify_time = os.stat(filepath).st_mtime | |||
break | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = Loader()._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, f'dev.{suffix}')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
try: | |||
with open(os.path.join(data_dir, f'train.{suffix}'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, f'middle_file.{suffix}'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, f'dev.{suffix}'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, f'train.{suffix}')) | |||
os.renames(os.path.join(data_dir, f'middle_file.{suffix}'), os.path.join(data_dir, f'train.{suffix}')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, f'middle_file.{suffix}')): | |||
os.remove(os.path.join(data_dir, f'middle_file.{suffix}')) | |||
return data_dir | |||
class AGsNewsLoader(CLSBaseLoader): | |||
def download(self): | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
in Neural Information Processing Systems 28 (NIPS 2015) | |||
:return: str, 数据集的目录地址 | |||
""" | |||
return self._get_dataset_path(dataset_name='ag-news') | |||
class DBPediaLoader(CLSBaseLoader): | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
in Neural Information Processing Systems 28 (NIPS 2015) | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'dbpedia' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir | |||
class IMDBLoader(CLSBaseLoader): | |||
r""" | |||
原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 | |||
Example:: | |||
neg Alan Rickman & Emma... | |||
neg I have seen this... | |||
IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 | |||
读取的DataSet具备以下的结构: | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"Alan Rickman & Emma... ", "neg" | |||
"I have seen this... ", "neg" | |||
"...", "..." | |||
""" | |||
def __init__(self): | |||
super().__init__(sep='\t') | |||
def download(self, dev_ratio: float = 0.0, re_download=False): | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
http://www.aclweb.org/anthology/P11-1015 | |||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后不从train中切分dev | |||
:param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'aclImdb' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='txt') | |||
return data_dir | |||
class SSTLoader(Loader): | |||
r""" | |||
原始数据中内容应该为: | |||
Example:: | |||
(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)... | |||
(3 (3 (2 If) (3 (2 you) (3 (2 sometimes)... | |||
读取之后的DataSet具有以下的结构 | |||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||
:header: "raw_words" | |||
"(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." | |||
"(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." | |||
"..." | |||
raw_words列是str。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
r""" | |||
从path读取SST文件 | |||
:param str path: 文件路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self): | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf | |||
:return: str, 数据集的目录地址 | |||
""" | |||
output_dir = self._get_dataset_path(dataset_name='sst') | |||
return output_dir | |||
class YelpFullLoader(CLSBaseLoader): | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
in Neural Information Processing Systems 28 (NIPS 2015) | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'yelp-review-full' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir | |||
class YelpPolarityLoader(CLSBaseLoader): | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False): | |||
r""" | |||
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 | |||
Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances | |||
in Neural Information Processing Systems 28 (NIPS 2015) | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = 'yelp-review-polarity' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir | |||
class SST2Loader(Loader): | |||
r""" | |||
原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label | |||
Example:: | |||
sentence label | |||
it 's a charming and often affecting journey . 1 | |||
unflinchingly bleak and desperate 0 | |||
读取之后DataSet将如下所示 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"it 's a charming and often affecting journey .", "1" | |||
"unflinchingly bleak and desperate", "0" | |||
"..." | |||
test的DataSet没有target列。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
r"""从path读取SST2文件 | |||
:param str path: 数据路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if 'test' in os.path.split(path)[1]: | |||
warnings.warn("SST2's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
sep_index = line.index('\t') | |||
raw_words = line[sep_index + 1:] | |||
index = int(line[: sep_index]) | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words, index=index)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
raw_words = line[:-2] | |||
target = line[-1] | |||
if raw_words: | |||
ds.append(Instance(raw_words=raw_words, target=target)) | |||
return ds | |||
def download(self): | |||
r""" | |||
自动下载数据集,如果你使用了该数据集,请引用以下的文章 | |||
https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path(dataset_name='sst-2') | |||
return output_dir | |||
class ChnSentiCorpLoader(Loader): | |||
r""" | |||
支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 | |||
一个制表符之后认为是句子 | |||
Example:: | |||
label text_a | |||
1 基金痛所有投资项目一样,必须先要有所了解... | |||
1 系统很好装,LED屏是不错,就是16比9的比例... | |||
读取后的DataSet具有以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"基金痛所有投资项目一样,必须先要有所了解...", "1" | |||
"系统很好装,LED屏是不错,就是16比9的比例...", "1" | |||
"..." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
r""" | |||
从path中读取数据 | |||
:param path: | |||
:return: | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() | |||
for line in f: | |||
line = line.strip() | |||
tab_index = line.index('\t') | |||
if tab_index != -1: | |||
target = line[:tab_index] | |||
raw_chars = line[tab_index + 1:] | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self) -> str: | |||
r""" | |||
自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 | |||
https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('chn-senti-corp') | |||
return output_dir | |||
class THUCNewsLoader(Loader): | |||
r""" | |||
数据集简介:document-level分类任务,新闻10分类 | |||
原始数据内容为:每行一个sample,第一个 "\\t" 之前为target,第一个 "\\t" 之后为raw_words | |||
Example:: | |||
体育 调查-您如何评价热火客场胜绿军总分3-1夺赛点?... | |||
读取后的Dataset将具有以下数据结构: | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"调查-您如何评价热火客场胜绿军总分3-1夺赛点?...", "体育" | |||
"...", "..." | |||
""" | |||
def __init__(self): | |||
super(THUCNewsLoader, self).__init__() | |||
def _load(self, path: str = None): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
sep_index = line.index('\t') | |||
raw_chars = line[sep_index + 1:] | |||
target = line[:sep_index] | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self) -> str: | |||
r""" | |||
自动下载数据,该数据取自 | |||
http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('thuc-news') | |||
return output_dir | |||
class WeiboSenti100kLoader(Loader): | |||
r""" | |||
别名: | |||
数据集简介:微博sentiment classification,二分类 | |||
Example:: | |||
label text | |||
1 多谢小莲,好运满满[爱你] | |||
1 能在他乡遇老友真不赖,哈哈,珠儿,我也要用... | |||
读取后的Dataset将具有以下数据结构: | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"多谢小莲,好运满满[爱你]", "1" | |||
"能在他乡遇老友真不赖,哈哈,珠儿,我也要用...", "1" | |||
"...", "..." | |||
""" | |||
def __init__(self): | |||
super(WeiboSenti100kLoader, self).__init__() | |||
def _load(self, path: str = None): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
next(f) | |||
for line in f: | |||
line = line.strip() | |||
target = line[0] | |||
raw_chars = line[1:] | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
return ds | |||
def download(self) -> str: | |||
r""" | |||
自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ | |||
在 https://arxiv.org/abs/1906.08101 有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('weibo-senti-100k') | |||
return output_dir | |||
class MRLoader(CLSBaseLoader): | |||
def __init__(self): | |||
super(MRLoader, self).__init__() | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: | |||
r""" | |||
自动下载数据集 | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = r'mr' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir | |||
class R8Loader(CLSBaseLoader): | |||
def __init__(self): | |||
super(R8Loader, self).__init__() | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: | |||
r""" | |||
自动下载数据集 | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = r'R8' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir | |||
class R52Loader(CLSBaseLoader): | |||
def __init__(self): | |||
super(R52Loader, self).__init__() | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: | |||
r""" | |||
自动下载数据集 | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = r'R52' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir | |||
class NG20Loader(CLSBaseLoader): | |||
def __init__(self): | |||
super(NG20Loader, self).__init__() | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: | |||
r""" | |||
自动下载数据集 | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = r'20ng' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir | |||
class OhsumedLoader(CLSBaseLoader): | |||
def __init__(self): | |||
super(OhsumedLoader, self).__init__() | |||
def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: | |||
r""" | |||
自动下载数据集 | |||
如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 | |||
下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
""" | |||
dataset_name = r'ohsumed' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
data_dir = _split_dev(dataset_name=dataset_name, | |||
data_dir=data_dir, | |||
dev_ratio=dev_ratio, | |||
re_download=re_download, | |||
suffix='csv') | |||
return data_dir |
@@ -0,0 +1,542 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ConllLoader", | |||
"Conll2003Loader", | |||
"Conll2003NERLoader", | |||
"OntoNotesNERLoader", | |||
"CTBLoader", | |||
"CNNERLoader", | |||
"MsraNERLoader", | |||
"WeiboNERLoader", | |||
"PeopleDailyNERLoader" | |||
] | |||
import glob | |||
import os | |||
import random | |||
import shutil | |||
import time | |||
from .loader import Loader | |||
from ..file_reader import _read_conll | |||
# from ...core.const import Const | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class ConllLoader(Loader): | |||
r""" | |||
ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: | |||
Example:: | |||
# 文件中的内容 | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列 | |||
dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列 | |||
dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll') | |||
# 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field | |||
dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') | |||
ConllLoader返回的DataSet的field由传入的headers确定。 | |||
数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 | |||
""" | |||
def __init__(self, headers, sep=None, indexes=None, dropna=True): | |||
r""" | |||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param list sep: 指定分隔符,默认为制表符 | |||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
""" | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
raise TypeError( | |||
'invalid headers: {}, should be list of strings'.format(headers)) | |||
self.headers = headers | |||
self.dropna = dropna | |||
self.sep=sep | |||
if indexes is None: | |||
self.indexes = list(range(len(self.headers))) | |||
else: | |||
if len(indexes) != len(headers): | |||
raise ValueError | |||
self.indexes = indexes | |||
def _load(self, path): | |||
r""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
class Conll2003Loader(ConllLoader): | |||
r""" | |||
用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 | |||
Example:: | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
返回的DataSet的内容为 | |||
.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。 | |||
:header: "raw_words", "pos", "chunk", "ner" | |||
"[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]", "[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'raw_words', 'pos', 'chunk', 'ner', | |||
] | |||
super(Conll2003Loader, self).__init__(headers=headers) | |||
def _load(self, path): | |||
r""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
doc_start = False | |||
for i, h in enumerate(self.headers): | |||
field = data[i] | |||
if str(field[0]).startswith('-DOCSTART-'): | |||
doc_start = True | |||
break | |||
if doc_start: | |||
continue | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
def download(self, output_dir=None): | |||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||
class Conll2003NERLoader(ConllLoader): | |||
r""" | |||
用于读取conll2003任务的NER数据。每一行有4列内容,空行意味着隔开两个句子 | |||
支持读取的内容如下 | |||
Example:: | |||
Nadim NNP B-NP B-PER | |||
Ladki NNP I-NP I-PER | |||
AL-AIN NNP B-NP B-LOC | |||
United NNP B-NP B-LOC | |||
Arab NNP I-NP I-LOC | |||
Emirates NNPS I-NP I-LOC | |||
1996-12-06 CD I-NP O | |||
... | |||
返回的DataSet的内容为 | |||
.. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码 | |||
:header: "raw_words", "target" | |||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'raw_words', 'target', | |||
] | |||
super().__init__(headers=headers, indexes=[0, 3]) | |||
def _load(self, path): | |||
r""" | |||
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 | |||
:param str path: 文件的路径 | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
doc_start = False | |||
for i, h in enumerate(self.headers): | |||
field = data[i] | |||
if str(field[0]).startswith('-DOCSTART-'): | |||
doc_start = True | |||
break | |||
if doc_start: | |||
continue | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
if len(ds) == 0: | |||
raise RuntimeError("No data found {}.".format(path)) | |||
return ds | |||
def download(self): | |||
raise RuntimeError("conll2003 cannot be downloaded automatically.") | |||
class OntoNotesNERLoader(ConllLoader): | |||
r""" | |||
用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 | |||
https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 | |||
读取的数据格式为: | |||
Example:: | |||
bc/msnbc/00/msnbc_0000 0 0 Hi UH (TOP(FRAG(INTJ*) - - - Dan_Abrams * - | |||
bc/msnbc/00/msnbc_0000 0 1 everyone NN (NP*) - - - Dan_Abrams * - | |||
... | |||
返回的DataSet的内容为 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"['Hi', 'everyone', '.']", "['O', 'O', 'O']" | |||
"['first', 'up', 'on', 'the', 'docket']", "['O', 'O', 'O', 'O', 'O']" | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__(headers=['raw_words', 'target'], indexes=[3, 10]) | |||
def _load(self, path: str): | |||
dataset = super()._load(path) | |||
def convert_to_bio(tags): | |||
bio_tags = [] | |||
flag = None | |||
for tag in tags: | |||
label = tag.strip("()*") | |||
if '(' in tag: | |||
bio_label = 'B-' + label | |||
flag = label | |||
elif flag: | |||
bio_label = 'I-' + flag | |||
else: | |||
bio_label = 'O' | |||
if ')' in tag: | |||
flag = None | |||
bio_tags.append(bio_label) | |||
return bio_tags | |||
def convert_word(words): | |||
converted_words = [] | |||
for word in words: | |||
word = word.replace('/.', '.') # 有些结尾的.是/.形式的 | |||
if not word.startswith('-'): | |||
converted_words.append(word) | |||
continue | |||
# 以下是由于这些符号被转义了,再转回来 | |||
tfrs = {'-LRB-': '(', | |||
'-RRB-': ')', | |||
'-LSB-': '[', | |||
'-RSB-': ']', | |||
'-LCB-': '{', | |||
'-RCB-': '}' | |||
} | |||
if word in tfrs: | |||
converted_words.append(tfrs[word]) | |||
else: | |||
converted_words.append(word) | |||
return converted_words | |||
dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words') | |||
dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target') | |||
return dataset | |||
def download(self): | |||
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " | |||
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") | |||
class CTBLoader(Loader): | |||
r""" | |||
支持加载的数据应该具备以下格式, 其中第二列为词语,第四列为pos tag,第七列为依赖树的head,第八列为依赖树的label | |||
Example:: | |||
1 印度 _ NR NR _ 3 nn _ _ | |||
2 海军 _ NN NN _ 3 nn _ _ | |||
3 参谋长 _ NN NN _ 5 nsubjpass _ _ | |||
4 被 _ SB SB _ 5 pass _ _ | |||
5 解职 _ VV VV _ 0 root _ _ | |||
1 新华社 _ NR NR _ 7 dep _ _ | |||
2 新德里 _ NR NR _ 7 dep _ _ | |||
3 12月 _ NT NT _ 7 dep _ _ | |||
... | |||
读取之后DataSet具备的格式为 | |||
.. csv-table:: | |||
:header: "raw_words", "pos", "dep_head", "dep_label" | |||
"[印度, 海军, ...]", "[NR, NN, SB, ...]", "[3, 3, ...]", "[nn, nn, ...]" | |||
"[新华社, 新德里, ...]", "[NR, NR, NT, ...]", "[7, 7, 7, ...]", "[dep, dep, dep, ...]" | |||
"[...]", "[...]", "[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
headers = [ | |||
'raw_words', 'pos', 'dep_head', 'dep_label', | |||
] | |||
indexes = [ | |||
1, 3, 6, 7, | |||
] | |||
self.loader = ConllLoader(headers=headers, indexes=indexes) | |||
def _load(self, path: str): | |||
dataset = self.loader._load(path) | |||
return dataset | |||
def download(self): | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://catalog.ldc.upenn.edu/LDC2013T21 | |||
:return: | |||
""" | |||
raise RuntimeError("CTB cannot be downloaded automatically.") | |||
class CNNERLoader(Loader): | |||
def _load(self, path: str): | |||
r""" | |||
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample | |||
Example:: | |||
我 O | |||
们 O | |||
变 O | |||
而 O | |||
以 O | |||
书 O | |||
会 O | |||
... | |||
:param str path: 文件路径 | |||
:return: DataSet,包含raw_words列和target列 | |||
""" | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
raw_chars = [] | |||
target = [] | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split() | |||
if len(parts) == 1: # 网上下载的数据有一些列少tag,默认补充O | |||
parts.append('O') | |||
raw_chars.append(parts[0]) | |||
target.append(parts[1]) | |||
else: | |||
if raw_chars: | |||
ds.append(Instance(raw_chars=raw_chars, target=target)) | |||
raw_chars = [] | |||
target = [] | |||
return ds | |||
class MsraNERLoader(CNNERLoader): | |||
r""" | |||
读取MSRA-NER数据,数据中的格式应该类似与下列的内容 | |||
Example:: | |||
把 O | |||
欧 B-LOC | |||
美 B-LOC | |||
、 O | |||
港 B-LOC | |||
台 B-LOC | |||
流 O | |||
行 O | |||
的 O | |||
食 O | |||
... | |||
读取后的DataSet包含以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"['把', '欧'] ", "['O', 'B-LOC']" | |||
"['美', '、']", "['B-LOC', 'O']" | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: | |||
r""" | |||
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language | |||
Processing Bakeoff: Word Segmentation and Named Entity Recognition. | |||
根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.conll, test.conll, | |||
dev.conll三个文件。 | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str, 数据集的目录地址 | |||
:return: | |||
""" | |||
dataset_name = 'msra-ner' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
modify_time = 0 | |||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||
modify_time = os.stat(filepath).st_mtime | |||
break | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.conll')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
try: | |||
with open(os.path.join(data_dir, 'train.conll'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.conll'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.conll'), 'w', encoding='utf-8') as f2: | |||
lines = [] # 一个sample包含很多行 | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
lines.append(line) | |||
else: | |||
if random.random() < dev_ratio: | |||
f2.write('\n'.join(lines) + '\n\n') | |||
else: | |||
f1.write('\n'.join(lines) + '\n\n') | |||
lines.clear() | |||
os.remove(os.path.join(data_dir, 'train.conll')) | |||
os.renames(os.path.join(data_dir, 'middle_file.conll'), os.path.join(data_dir, 'train.conll')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): | |||
os.remove(os.path.join(data_dir, 'middle_file.conll')) | |||
return data_dir | |||
class WeiboNERLoader(CNNERLoader): | |||
r""" | |||
读取WeiboNER数据,数据中的格式应该类似与下列的内容 | |||
Example:: | |||
老 B-PER.NOM | |||
百 I-PER.NOM | |||
姓 I-PER.NOM | |||
心 O | |||
... | |||
读取后的DataSet包含以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"['老', '百', '姓']", "['B-PER.NOM', 'I-PER.NOM', 'I-PER.NOM']" | |||
"['心']", "['O']" | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def download(self) -> str: | |||
r""" | |||
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for | |||
Chinese Social Media with Jointly Trained Embeddings. | |||
:return: str | |||
""" | |||
dataset_name = 'weibo-ner' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
return data_dir | |||
class PeopleDailyNERLoader(CNNERLoader): | |||
r""" | |||
支持加载的数据格式如下 | |||
Example:: | |||
中 B-ORG | |||
共 I-ORG | |||
中 I-ORG | |||
央 I-ORG | |||
致 O | |||
中 B-ORG | |||
... | |||
读取后的DataSet包含以下的field | |||
.. csv-table:: target列是基于BIO的编码方式 | |||
:header: "raw_chars", "target" | |||
"['中', '共', '中', '央']", "['B-ORG', 'I-ORG', 'I-ORG', 'I-ORG']" | |||
"[...]", "[...]" | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def download(self) -> str: | |||
dataset_name = 'peopledaily' | |||
data_dir = self._get_dataset_path(dataset_name=dataset_name) | |||
return data_dir |
@@ -0,0 +1,64 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CoReferenceLoader", | |||
] | |||
from ...core.dataset import DataSet | |||
from ..file_reader import _read_json | |||
from fastNLP.core.dataset import Instance | |||
# from ...core.const import Const | |||
from .json import JsonLoader | |||
class CoReferenceLoader(JsonLoader): | |||
r""" | |||
原始数据中内容应该为, 每一行为一个json对象,其中doc_key包含文章的种类信息,speakers包含每句话的说话者信息,cluster是指向现实中同一个事物的聚集,sentences是文本信息内容。 | |||
Example:: | |||
{"doc_key": "bc/cctv/00/cctv_0000_0", | |||
"speakers": [["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"], ["Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1", "Speaker#1"]], | |||
"clusters": [[[70, 70], [485, 486], [500, 500], [73, 73], [55, 55], [153, 154], [366, 366]]], | |||
"sentences": [["In", "the", "summer", "of", "2005", ",", "a", "picture", "that", "people", "have", "long", "been", "looking", "forward", "to", "started", "emerging", "with", "frequency", "in", "various", "major", "Hong", "Kong", "media", "."], ["With", "their", "unique", "charm", ",", "these", "well", "-", "known", "cartoon", "images", "once", "again", "caused", "Hong", "Kong", "to", "be", "a", "focus", "of", "worldwide", "attention", "."]] | |||
} | |||
读取预处理好的Conll2012数据,数据结构如下: | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "raw_words3", "raw_words4" | |||
"bc/cctv/00/cctv_0000_0", "[['Speaker#1', 'Speaker#1', 'Speaker#1...", "[[[70, 70], [485, 486], [500, 500], [7...", "[['In', 'the', 'summer', 'of', '2005',..." | |||
"...", "...", "...", "..." | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super().__init__(fields, dropna) | |||
self.fields = {"doc_key": "raw_words1", "speakers": "raw_words2", "clusters": "raw_words3", | |||
"sentences": "raw_words4"} | |||
def _load(self, path): | |||
r""" | |||
加载数据 | |||
:param path: 数据文件路径,文件为json | |||
:return: | |||
""" | |||
dataset = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
if self.fields: | |||
ins = {self.fields[k]: v for k, v in d.items()} | |||
else: | |||
ins = d | |||
dataset.append(Instance(**ins)) | |||
return dataset | |||
def download(self): | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://www.aclweb.org/anthology/W12-4501 | |||
:return: | |||
""" | |||
raise RuntimeError("CoReference cannot be downloaded automatically.") |
@@ -0,0 +1,38 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CSVLoader", | |||
] | |||
from .loader import Loader | |||
from ..file_reader import _read_csv | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class CSVLoader(Loader): | |||
r""" | |||
读取CSV格式的数据集, 返回 ``DataSet`` 。 | |||
""" | |||
def __init__(self, headers=None, sep=",", dropna=False): | |||
r""" | |||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
super().__init__() | |||
self.headers = headers | |||
self.sep = sep | |||
self.dropna = dropna | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_csv(path, headers=self.headers, | |||
sep=self.sep, dropna=self.dropna): | |||
ds.append(Instance(**data)) | |||
return ds | |||
@@ -0,0 +1,97 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CWSLoader" | |||
] | |||
import glob | |||
import os | |||
import random | |||
import shutil | |||
import time | |||
from .loader import Loader | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class CWSLoader(Loader): | |||
r""" | |||
CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: | |||
Example:: | |||
上海 浦东 开发 与 法制 建设 同步 | |||
新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) | |||
... | |||
该Loader读取后的DataSet具有如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
""" | |||
def __init__(self, dataset_name: str = None): | |||
r""" | |||
:param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None | |||
""" | |||
super().__init__() | |||
datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} | |||
if dataset_name in datanames: | |||
self.dataset_name = datanames[dataset_name] | |||
else: | |||
self.dataset_name = None | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
ds.append(Instance(raw_words=line)) | |||
return ds | |||
def download(self, dev_ratio=0.1, re_download=False) -> str: | |||
r""" | |||
如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, | |||
2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 | |||
:param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 | |||
:param bool re_download: 是否重新下载数据,以重新切分数据。 | |||
:return: str | |||
""" | |||
if self.dataset_name is None: | |||
return '' | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
modify_time = 0 | |||
for filepath in glob.glob(os.path.join(data_dir, '*')): | |||
modify_time = os.stat(filepath).st_mtime | |||
break | |||
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 | |||
shutil.rmtree(data_dir) | |||
data_dir = self._get_dataset_path(dataset_name=self.dataset_name) | |||
if not os.path.exists(os.path.join(data_dir, 'dev.txt')): | |||
if dev_ratio > 0: | |||
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." | |||
try: | |||
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \ | |||
open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \ | |||
open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2: | |||
for line in f: | |||
if random.random() < dev_ratio: | |||
f2.write(line) | |||
else: | |||
f1.write(line) | |||
os.remove(os.path.join(data_dir, 'train.txt')) | |||
os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt')) | |||
finally: | |||
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): | |||
os.remove(os.path.join(data_dir, 'middle_file.txt')) | |||
return data_dir |
@@ -0,0 +1,45 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"JsonLoader" | |||
] | |||
from .loader import Loader | |||
from ..file_reader import _read_json | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class JsonLoader(Loader): | |||
r""" | |||
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` | |||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , | |||
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 | |||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||
Default: ``False`` | |||
""" | |||
def __init__(self, fields=None, dropna=False): | |||
super(JsonLoader, self).__init__() | |||
self.dropna = dropna | |||
self.fields = None | |||
self.fields_list = None | |||
if fields: | |||
self.fields = {} | |||
for k, v in fields.items(): | |||
self.fields[k] = k if v is None else v | |||
self.fields_list = list(self.fields.keys()) | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||
if self.fields: | |||
ins = {self.fields[k]: v for k, v in d.items()} | |||
else: | |||
ins = d | |||
ds.append(Instance(**ins)) | |||
return ds |
@@ -0,0 +1,94 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"Loader" | |||
] | |||
from typing import Union, Dict | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.io.file_utils import _get_dataset_url, get_cache_path, cached_path | |||
from fastNLP.io.utils import check_loader_paths | |||
from fastNLP.core.dataset import DataSet | |||
class Loader: | |||
r""" | |||
各种数据 Loader 的基类,提供了 API 的参考. | |||
Loader支持以下的三个函数 | |||
- download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。 | |||
- _load() 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` 。返回的DataSet的内容可以通过每个Loader的文档判断出。 | |||
- load() 函数:将文件分别读取为DataSet,然后将多个DataSet放入到一个DataBundle中并返回 | |||
""" | |||
def __init__(self): | |||
pass | |||
def _load(self, path: str) -> DataSet: | |||
r""" | |||
给定一个路径,返回读取的DataSet。 | |||
:param str path: 路径 | |||
:return: DataSet | |||
""" | |||
raise NotImplementedError | |||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
r""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
:param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式: | |||
0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 | |||
1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错:: | |||
data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train | |||
# dev、 test等有所变化,可以通过以下的方式取出DataSet | |||
tr_data = data_bundle.get_dataset('train') | |||
te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 | |||
2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: | |||
paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} | |||
data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" | |||
dev_data = data_bundle.get_dataset('dev') | |||
3.传入文件路径:: | |||
data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' | |||
tr_data = data_bundle.get_dataset('train') # 取出DataSet | |||
:return: 返回的 :class:`~fastNLP.io.DataBundle` | |||
""" | |||
if paths is None: | |||
paths = self.download() | |||
paths = check_loader_paths(paths) | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self) -> str: | |||
r""" | |||
自动下载该数据集 | |||
:return: 下载后解压目录 | |||
""" | |||
raise NotImplementedError(f"{self.__class__} cannot download data automatically.") | |||
@staticmethod | |||
def _get_dataset_path(dataset_name): | |||
r""" | |||
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存(如果支持的话) | |||
:param str dataset_name: 数据集的名称 | |||
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 | |||
""" | |||
default_cache_path = get_cache_path() | |||
url = _get_dataset_url(dataset_name) | |||
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') | |||
return output_dir |
@@ -0,0 +1,577 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MNLILoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"QuoraLoader", | |||
"BQCorpusLoader", | |||
"CNXNLILoader", | |||
"LCQMCLoader" | |||
] | |||
import os | |||
import warnings | |||
from typing import Union, Dict | |||
from .csv import CSVLoader | |||
from .json import JsonLoader | |||
from .loader import Loader | |||
from fastNLP.io.data_bundle import DataBundle | |||
from ..utils import check_loader_paths | |||
# from ...core.const import Const | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class MNLILoader(Loader): | |||
r""" | |||
读取的数据格式为: | |||
Example:: | |||
index promptID pairID genre sentence1_binary_parse sentence2_binary_parse sentence1_parse sentence2_parse sentence1 sentence2 label1 gold_label | |||
0 31193 31193n government ( ( Conceptually ( cream skimming ) ) ... | |||
1 101457 101457e telephone ( you ( ( know ( during ( ( ( the season ) and ) ( i guess ) ) )... | |||
... | |||
读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没 | |||
有target列。 | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"Conceptually cream ...", "Product and geography...", "neutral" | |||
"you know during the ...", "You lose the things to the...", "entailment" | |||
"...", "...", "..." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): | |||
warnings.warn("MNLI's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[8] | |||
raw_words2 = parts[9] | |||
idx = int(parts[0]) | |||
if raw_words1 and raw_words2: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, index=idx)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[8] | |||
raw_words2 = parts[9] | |||
target = parts[-1] | |||
idx = int(parts[0]) | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target, index=idx)) | |||
return ds | |||
def load(self, paths: str = None): | |||
r""" | |||
:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, | |||
test_mismatched.tsv, train.tsv文件夹 | |||
:return: DataBundle | |||
""" | |||
if paths: | |||
paths = os.path.abspath(os.path.expanduser(paths)) | |||
else: | |||
paths = self.download() | |||
if not os.path.isdir(paths): | |||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||
files = {'dev_matched': "dev_matched.tsv", | |||
"dev_mismatched": "dev_mismatched.tsv", | |||
"test_matched": "test_matched.tsv", | |||
"test_mismatched": "test_mismatched.tsv", | |||
"train": 'train.tsv'} | |||
datasets = {} | |||
for name, filename in files.items(): | |||
filepath = os.path.join(paths, filename) | |||
if not os.path.isfile(filepath): | |||
if 'test' not in name: | |||
raise FileNotFoundError(f"{name} not found in directory {filepath}.") | |||
datasets[name] = self._load(filepath) | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
r""" | |||
如果你使用了这个数据,请引用 | |||
https://www.nyu.edu/projects/bowman/multinli/paper.pdf | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('mnli') | |||
return output_dir | |||
class SNLILoader(JsonLoader): | |||
r""" | |||
文件每一行是一个sample,每一行都为一个json对象,其数据格式为: | |||
Example:: | |||
{"annotator_labels": ["neutral", "entailment", "neutral", "neutral", "neutral"], "captionID": "4705552913.jpg#2", | |||
"gold_label": "neutral", "pairID": "4705552913.jpg#2r1n", | |||
"sentence1": "Two women are embracing while holding to go packages.", | |||
"sentence1_binary_parse": "( ( Two women ) ( ( are ( embracing ( while ( holding ( to ( go packages ) ) ) ) ) ) . ) )", | |||
"sentence1_parse": "(ROOT (S (NP (CD Two) (NNS women)) (VP (VBP are) (VP (VBG embracing) (SBAR (IN while) (S (NP (VBG holding)) (VP (TO to) (VP (VB go) (NP (NNS packages)))))))) (. .)))", | |||
"sentence2": "The sisters are hugging goodbye while holding to go packages after just eating lunch.", | |||
"sentence2_binary_parse": "( ( The sisters ) ( ( are ( ( hugging goodbye ) ( while ( holding ( to ( ( go packages ) ( after ( just ( eating lunch ) ) ) ) ) ) ) ) ) . ) )", | |||
"sentence2_parse": "(ROOT (S (NP (DT The) (NNS sisters)) (VP (VBP are) (VP (VBG hugging) (NP (UH goodbye)) (PP (IN while) (S (VP (VBG holding) (S (VP (TO to) (VP (VB go) (NP (NNS packages)) (PP (IN after) (S (ADVP (RB just)) (VP (VBG eating) (NP (NN lunch))))))))))))) (. .)))" | |||
} | |||
读取之后的DataSet中的field情况为 | |||
.. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field | |||
:header: "target", "raw_words1", "raw_words2", | |||
"neutral ", "Two women are embracing while holding..", "The sisters are hugging goodbye..." | |||
"entailment", "Two women are embracing while holding...", "Two woman are holding packages." | |||
"...", "...", "..." | |||
""" | |||
def __init__(self): | |||
super().__init__(fields={ | |||
'sentence1': 'raw_words1', | |||
'sentence2': 'raw_words2', | |||
'gold_label': 'target', | |||
}) | |||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
r""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据Loader初始化时传入的field决定。 | |||
:param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl | |||
和snli_1.0_test.jsonl三个文件。 | |||
:return: 返回的 :class:`~fastNLP.io.DataBundle` | |||
""" | |||
_paths = {} | |||
if paths is None: | |||
paths = self.download() | |||
if paths: | |||
if os.path.isdir(paths): | |||
if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')): | |||
raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}") | |||
_paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl') | |||
for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']: | |||
filepath = os.path.join(paths, filename) | |||
_paths[filename.split('_')[-1].split('.')[0]] = filepath | |||
paths = _paths | |||
else: | |||
raise NotADirectoryError(f"{paths} is not a valid directory.") | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
r""" | |||
如果您的文章使用了这份数据,请引用 | |||
http://nlp.stanford.edu/pubs/snli_paper.pdf | |||
:return: str | |||
""" | |||
return self._get_dataset_path('snli') | |||
class QNLILoader(JsonLoader): | |||
r""" | |||
第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、问题、句子和标签构成(以制表符分割),数据结构如下: | |||
Example:: | |||
index question sentence label | |||
0 What came into force after the new constitution was herald? As of that day, the new constitution heralding the Second Republic came into force. entailment | |||
QNLI数据集的Loader, | |||
加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"What came into force after the new...", "As of that day...", "entailment" | |||
"...","." | |||
test数据集没有target列 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
warnings.warn("QNLI's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
if raw_words1 and raw_words2: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
target = parts[-1] | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
r""" | |||
如果您的实验使用到了该数据,请引用 | |||
https://arxiv.org/pdf/1809.05053.pdf | |||
:return: | |||
""" | |||
return self._get_dataset_path('qnli') | |||
class RTELoader(Loader): | |||
r""" | |||
第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、句子1、句子2和标签构成(以制表符分割),数据结构如下: | |||
Example:: | |||
index sentence1 sentence2 label | |||
0 Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation. Christopher Reeve had an accident. not_entailment | |||
RTE数据的loader | |||
加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" | |||
"...","..." | |||
test数据集没有target列 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
f.readline() # 跳过header | |||
if path.endswith("test.tsv"): | |||
warnings.warn("RTE's test file has no target.") | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
if raw_words1 and raw_words2: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2)) | |||
else: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
target = parts[-1] | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
r""" | |||
如果您的实验使用到了该数据,请引用GLUE Benchmark | |||
https://openreview.net/pdf?id=rJ4km2R5t7 | |||
:return: | |||
""" | |||
return self._get_dataset_path('rte') | |||
class QuoraLoader(Loader): | |||
r""" | |||
Quora matching任务的数据集Loader | |||
支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 | |||
Example:: | |||
1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970 | |||
0 Is honey a viable alternative to sugar for diabetics ? How would you compare the United States ' euthanasia laws to Denmark ? 90348 | |||
... | |||
加载的DataSet将具备以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"How do I get funding for my web based...", "How do I get seed funding...","1" | |||
"Is honey a viable alternative ...", "How would you compare the United...","0" | |||
"...","...","..." | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
if line: | |||
parts = line.split('\t') | |||
raw_words1 = parts[1] | |||
raw_words2 = parts[2] | |||
target = parts[0] | |||
if raw_words1 and raw_words2 and target: | |||
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) | |||
return ds | |||
def download(self): | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://www.kaggle.com/c/quora-question-pairs/data | |||
:return: | |||
""" | |||
raise RuntimeError("Quora cannot be downloaded automatically.") | |||
class CNXNLILoader(Loader): | |||
r""" | |||
数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize | |||
原始数据数据为: | |||
Example:: | |||
premise hypo label | |||
我们 家里 有 一个 但 我 没 找到 我 可以 用 的 时间 我们 家里 有 一个 但 我 从来 没有 时间 使用 它 . entailment | |||
dev和test中的数据为csv或json格式,包括十多个field,这里只取与以上三个field中的数据 | |||
读取后的Dataset将具有以下数据结构: | |||
.. csv-table:: | |||
:header: "raw_chars1", "raw_chars2", "target" | |||
"我们 家里 有 一个 但 我 没 找到 我 可以 用 的 时间", "我们 家里 有 一个 但 我 从来 没有 时间 使用 它 .", "0" | |||
"...", "...", "..." | |||
""" | |||
def __init__(self): | |||
super(CNXNLILoader, self).__init__() | |||
def _load(self, path: str = None): | |||
ds_all = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
head_name_list = f.readline().strip().split('\t') | |||
sentence1_index = head_name_list.index('sentence1') | |||
sentence2_index = head_name_list.index('sentence2') | |||
gold_label_index = head_name_list.index('gold_label') | |||
language_index = head_name_list.index(('language')) | |||
for line in f: | |||
line = line.strip() | |||
raw_instance = line.split('\t') | |||
sentence1 = raw_instance[sentence1_index] | |||
sentence2 = raw_instance[sentence2_index] | |||
gold_label = raw_instance[gold_label_index] | |||
language = raw_instance[language_index] | |||
if sentence1: | |||
ds_all.append(Instance(sentence1=sentence1, sentence2=sentence2, gold_label=gold_label, language=language)) | |||
ds_zh = DataSet() | |||
for i in ds_all: | |||
if i['language'] == 'zh': | |||
ds_zh.append(Instance(raw_chars1=i['sentence1'], raw_chars2=i['sentence2'], target=i['gold_label'])) | |||
return ds_zh | |||
def _load_train(self, path: str = None): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
next(f) | |||
for line in f: | |||
raw_instance = line.strip().split('\t') | |||
premise = "".join(raw_instance[0].split())# 把已经分好词的premise和hypo强制还原为character segmentation | |||
hypo = "".join(raw_instance[1].split()) | |||
label = "".join(raw_instance[-1].split()) | |||
if premise: | |||
ds.append(Instance(premise=premise, hypo=hypo, label=label)) | |||
ds.rename_field('label', 'target') | |||
ds.rename_field('premise', 'raw_chars1') | |||
ds.rename_field('hypo', 'raw_chars2') | |||
ds.apply(lambda i: "".join(i['raw_chars1'].split()), new_field_name='raw_chars1') | |||
ds.apply(lambda i: "".join(i['raw_chars2'].split()), new_field_name='raw_chars2') | |||
return ds | |||
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: | |||
if paths is None: | |||
paths = self.download() | |||
paths = check_loader_paths(paths) | |||
datasets = {} | |||
for name, path in paths.items(): | |||
if name == 'train': | |||
datasets[name] = self._load_train(path) | |||
else: | |||
datasets[name] = self._load(path) | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self) -> str: | |||
r""" | |||
自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 | |||
在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf | |||
https://arxiv.org/pdf/1809.05053.pdf 有使用 | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('cn-xnli') | |||
return output_dir | |||
class BQCorpusLoader(Loader): | |||
r""" | |||
别名: | |||
数据集简介:句子对二分类任务(判断是否具有相同的语义) | |||
原始数据结构为: | |||
Example:: | |||
sentence1,sentence2,label | |||
综合评分不足什么原因,综合评估的依据,0 | |||
什么时候我能使用微粒贷,你就赶快给我开通就行了,0 | |||
读取后的Dataset将具有以下数据结构: | |||
.. csv-table:: | |||
:header: "raw_chars1", "raw_chars2", "target" | |||
"综合评分不足什么原因", "综合评估的依据", "0" | |||
"什么时候我能使用微粒贷", "你就赶快给我开通就行了", "0" | |||
"...", "...", "..." | |||
""" | |||
def __init__(self): | |||
super(BQCorpusLoader, self).__init__() | |||
def _load(self, path: str = None): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
next(f) | |||
for line in f: | |||
line = line.strip() | |||
target = line[-1] | |||
sep_index = line.index(',') | |||
raw_chars1 = line[:sep_index] | |||
raw_chars2 = line[sep_index + 1:] | |||
if raw_chars1: | |||
ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | |||
return ds | |||
def download(self): | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://github.com/ymcui/Chinese-BERT-wwm | |||
:return: | |||
""" | |||
raise RuntimeError("BQCorpus cannot be downloaded automatically.") | |||
class LCQMCLoader(Loader): | |||
r""" | |||
数据集简介:句对匹配(question matching) | |||
原始数据为: | |||
Example:: | |||
喜欢打篮球的男生喜欢什么样的女生 爱打篮球的男生喜欢什么样的女生 1 | |||
你帮我设计小说的封面吧 谁能帮我给小说设计个封面? 0 | |||
读取后的Dataset将具有以下的数据结构 | |||
.. csv-table:: | |||
:header: "raw_chars1", "raw_chars2", "target" | |||
"喜欢打篮球的男生喜欢什么样的女生", "爱打篮球的男生喜欢什么样的女生", "1" | |||
"你帮我设计小说的封面吧", "妇可以戴耳机听音乐吗?", "0" | |||
"...", "...", "..." | |||
""" | |||
def __init__(self): | |||
super(LCQMCLoader, self).__init__() | |||
def _load(self, path: str = None): | |||
ds = DataSet() | |||
with open(path, 'r', encoding='utf-8') as f: | |||
for line in f: | |||
line = line.strip() | |||
line_segments = line.split('\t') | |||
assert len(line_segments) == 3 | |||
target = line_segments[-1] | |||
raw_chars1 = line_segments[0] | |||
raw_chars2 = line_segments[1] | |||
if raw_chars1: | |||
ds.append(Instance(raw_chars1=raw_chars1, raw_chars2=raw_chars2, target=target)) | |||
return ds | |||
def download(self): | |||
r""" | |||
由于版权限制,不能提供自动下载功能。可参考 | |||
https://github.com/ymcui/Chinese-BERT-wwm | |||
:return: | |||
""" | |||
raise RuntimeError("LCQMC cannot be downloaded automatically.") | |||
@@ -0,0 +1,74 @@ | |||
r""" | |||
该文件中的Loader主要用于读取问答式任务的数据 | |||
""" | |||
from .loader import Loader | |||
import json | |||
from fastNLP.core.dataset import DataSet, Instance | |||
__all__ = ['CMRC2018Loader'] | |||
class CMRC2018Loader(Loader): | |||
r""" | |||
请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 | |||
读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 | |||
.. csv-table:: | |||
:header:"title", "context", "question", "answers", "answer_starts", "id" | |||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" | |||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" | |||
"...", "...", "...","...", ".", "..." | |||
其中title是文本的标题,多条记录可能是相同的title;id是该问题的id,具备唯一性 | |||
验证集DataSet将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案) | |||
.. csv-table:: | |||
:header: "title", "context", "question", "answers", "answer_starts", "id" | |||
"战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "《战国无双3》是由哪两个公司合作开发的?", "['光荣和ω-force', '光荣和ω-force', '光荣和ω-force']", "[30, 30, 30]", "DEV_0_QUERY_0" | |||
"战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "男女主角亦有专属声优这一模式是由谁改编的?", "['村雨城', '村雨城', '任天堂游戏谜之村雨城']", "[226, 226, 219]", "DEV_0_QUERY_1" | |||
"...", "...", "...","...", ".", "..." | |||
其中answer_starts是从0开始的index。例如"我来自a复旦大学?",其中"复"的开始index为4。另外"Russell评价说"中的说的index为9, 因为 | |||
英文和数字都直接按照character计量的。 | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
def _load(self, path: str) -> DataSet: | |||
with open(path, 'r', encoding='utf-8') as f: | |||
data = json.load(f)['data'] | |||
ds = DataSet() | |||
for entry in data: | |||
title = entry['title'] | |||
para = entry['paragraphs'][0] | |||
context = para['context'] | |||
qas = para['qas'] | |||
for qa in qas: | |||
question = qa['question'] | |||
ans = qa['answers'] | |||
answers = [] | |||
answer_starts = [] | |||
id = qa['id'] | |||
for an in ans: | |||
answers.append(an['text']) | |||
answer_starts.append(an['answer_start']) | |||
ds.append(Instance(title=title, context=context, question=question, answers=answers, | |||
answer_starts=answer_starts,id=id)) | |||
return ds | |||
def download(self) -> str: | |||
r""" | |||
如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('cmrc2018') | |||
return output_dir | |||
@@ -0,0 +1,63 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"ExtCNNDMLoader" | |||
] | |||
import os | |||
from typing import Union, Dict | |||
from ..data_bundle import DataBundle | |||
from ..utils import check_loader_paths | |||
from .json import JsonLoader | |||
class ExtCNNDMLoader(JsonLoader): | |||
r""" | |||
读取之后的DataSet中的field情况为 | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication" | |||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||
["..."], ["..."], [], "cnndm" | |||
""" | |||
def __init__(self, fields=None): | |||
fields = fields or {"text": None, "summary": None, "label": None, "publication": None} | |||
super(ExtCNNDMLoader, self).__init__(fields=fields) | |||
def load(self, paths: Union[str, Dict[str, str]] = None): | |||
r""" | |||
从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 | |||
读取的field根据ExtCNNDMLoader初始化时传入的headers决定。 | |||
:param str paths: 传入一个目录, 将在该目录下寻找train.label.jsonl, dev.label.jsonl | |||
test.label.jsonl三个文件(该目录还应该需要有一个名字为vocab的文件,在 :class:`~fastNLP.io.ExtCNNDMPipe` | |||
当中需要用到)。 | |||
:return: 返回 :class:`~fastNLP.io.DataBundle` | |||
""" | |||
if paths is None: | |||
paths = self.download() | |||
paths = check_loader_paths(paths) | |||
if ('train' in paths) and ('test' not in paths): | |||
paths['test'] = paths['train'] | |||
paths.pop('train') | |||
datasets = {name: self._load(path) for name, path in paths.items()} | |||
data_bundle = DataBundle(datasets=datasets) | |||
return data_bundle | |||
def download(self): | |||
r""" | |||
如果你使用了这个数据,请引用 | |||
https://arxiv.org/pdf/1506.03340.pdf | |||
:return: | |||
""" | |||
output_dir = self._get_dataset_path('ext-cnndm') | |||
return output_dir |
@@ -0,0 +1,71 @@ | |||
r""" | |||
用于载入和保存模型 | |||
""" | |||
__all__ = [ | |||
"ModelLoader", | |||
"ModelSaver" | |||
] | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
class ModelLoader: | |||
r""" | |||
用于读取模型 | |||
""" | |||
def __init__(self): | |||
super(ModelLoader, self).__init__() | |||
@staticmethod | |||
def load_pytorch(empty_model, model_path): | |||
r""" | |||
从 ".pkl" 文件读取 PyTorch 模型 | |||
:param empty_model: 初始化参数的 PyTorch 模型 | |||
:param str model_path: 模型保存的路径 | |||
""" | |||
empty_model.load_state_dict(torch.load(model_path)) | |||
@staticmethod | |||
def load_pytorch_model(model_path): | |||
r""" | |||
读取整个模型 | |||
:param str model_path: 模型保存的路径 | |||
""" | |||
return torch.load(model_path) | |||
class ModelSaver(object): | |||
r""" | |||
用于保存模型 | |||
Example:: | |||
saver = ModelSaver("./save/model_ckpt_100.pkl") | |||
saver.save_pytorch(model) | |||
""" | |||
def __init__(self, save_path): | |||
r""" | |||
:param save_path: 模型保存的路径 | |||
""" | |||
self.save_path = save_path | |||
def save_pytorch(self, model, param_only=True): | |||
r""" | |||
把 PyTorch 模型存入 ".pkl" 文件 | |||
:param model: PyTorch 模型 | |||
:param bool param_only: 是否只保存模型的参数(否则保存整个模型) | |||
""" | |||
if param_only is True: | |||
torch.save(model.state_dict(), self.save_path) | |||
else: | |||
torch.save(model, self.save_path) |
@@ -0,0 +1,80 @@ | |||
r""" | |||
Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。 | |||
``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回; | |||
``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。 | |||
``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet` | |||
一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外, | |||
`data_bundle` 还会包含将field转为index时所建立的词表。 | |||
""" | |||
__all__ = [ | |||
"Pipe", | |||
"CWSPipe", | |||
"CLSBasePipe", | |||
"AGsNewsPipe", | |||
"DBPediaPipe", | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"THUCNewsPipe", | |||
"WeiboSenti100kPipe", | |||
"MRPipe", "R52Pipe", "R8Pipe", "OhsumedPipe", "NG20Pipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"WeiboNERPipe", | |||
"PeopleDailyPipe", | |||
"Conll2003Pipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"CNXNLIBertPipe", | |||
"BQCorpusBertPipe", | |||
"LCQMCBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
"LCQMCPipe", | |||
"CNXNLIPipe", | |||
"BQCorpusPipe", | |||
"RenamePipe", | |||
"GranularizePipe", | |||
"MachingTruncatePipe", | |||
"CoReferencePipe", | |||
"CMRC2018BertPipe", | |||
"R52PmiGraphPipe", | |||
"R8PmiGraphPipe", | |||
"OhsumedPmiGraphPipe", | |||
"NG20PmiGraphPipe", | |||
"MRPmiGraphPipe" | |||
] | |||
from .classification import CLSBasePipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, THUCNewsPipe, \ | |||
WeiboSenti100kPipe, AGsNewsPipe, DBPediaPipe, MRPipe, R8Pipe, R52Pipe, OhsumedPipe, NG20Pipe | |||
from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe | |||
from .conll import Conll2003Pipe | |||
from .coreference import CoReferencePipe | |||
from .cws import CWSPipe | |||
from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ | |||
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ | |||
LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe | |||
from .pipe import Pipe | |||
from .qa import CMRC2018BertPipe | |||
from .construct_graph import MRPmiGraphPipe, R8PmiGraphPipe, R52PmiGraphPipe, NG20PmiGraphPipe, OhsumedPmiGraphPipe |
@@ -0,0 +1,939 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CLSBasePipe", | |||
"AGsNewsPipe", | |||
"DBPediaPipe", | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
'IMDBPipe', | |||
"ChnSentiCorpPipe", | |||
"THUCNewsPipe", | |||
"WeiboSenti100kPipe", | |||
"MRPipe", "R8Pipe", "R52Pipe", "OhsumedPipe", "NG20Pipe" | |||
] | |||
import re | |||
import warnings | |||
try: | |||
from nltk import Tree | |||
except: | |||
# only nltk in some versions can run | |||
pass | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer, _indexize, _add_words_field, _add_chars_field, _granularize | |||
from ..data_bundle import DataBundle | |||
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader | |||
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \ | |||
AGsNewsLoader, DBPediaLoader, MRLoader, R52Loader, R8Loader, OhsumedLoader, NG20Loader | |||
# from ...core._logger import log | |||
# from ...core.const import Const | |||
from fastNLP.core.dataset import DataSet, Instance | |||
class CLSBasePipe(Pipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy', lang='en'): | |||
super().__init__() | |||
self.lower = lower | |||
self.tokenizer = get_tokenizer(tokenizer, lang=lang) | |||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | |||
r""" | |||
将DataBundle中的数据进行tokenize | |||
:param DataBundle data_bundle: | |||
:param str field_name: | |||
:param str new_field_name: | |||
:return: 传入的DataBundle对象 | |||
""" | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"I got 'new' tires from them and... ", "1" | |||
"Don't waste your time. We had two...", "1" | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
# 复制一列words | |||
data_bundle = _add_words_field(data_bundle, lower=self.lower) | |||
# 进行tokenize | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='words') | |||
# 建立词表并index | |||
data_bundle = _indexize(data_bundle=data_bundle) | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len('words') | |||
data_bundle.set_input('words', 'seq_len', 'target') | |||
return data_bundle | |||
def process_from_file(self, paths) -> DataBundle: | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
:param paths: | |||
:return: DataBundle | |||
""" | |||
raise NotImplementedError | |||
class YelpFullPipe(CLSBasePipe): | |||
r""" | |||
处理YelpFull的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||
"...", ., "[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 | |||
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||
self.granularity = granularity | |||
if granularity == 2: | |||
self.tag_map = {"1": "negative", "2": "negative", "4": "positive", "5": "positive"} | |||
elif granularity == 3: | |||
self.tag_map = {"1": "negative", "2": "negative", "3": "medium", "4": "positive", "5": "positive"} | |||
else: | |||
self.tag_map = None | |||
def process(self, data_bundle): | |||
r""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"I got 'new' tires from them and... ", "1" | |||
"Don't waste your time. We had two...", "1" | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
if self.tag_map is not None: | |||
data_bundle = _granularize(data_bundle, self.tag_map) | |||
data_bundle = super().process(data_bundle) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: | |||
:return: DataBundle | |||
""" | |||
data_bundle = YelpFullLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class YelpPolarityPipe(CLSBasePipe): | |||
r""" | |||
处理YelpPolarity的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||
"...", ., "[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param str paths: | |||
:return: DataBundle | |||
""" | |||
data_bundle = YelpPolarityLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class AGsNewsPipe(CLSBasePipe): | |||
r""" | |||
处理AG's News的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||
"...", ., "[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param str paths: | |||
:return: DataBundle | |||
""" | |||
data_bundle = AGsNewsLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class DBPediaPipe(CLSBasePipe): | |||
r""" | |||
处理DBPedia的数据, 处理之后DataSet中的内容如下 | |||
.. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 | |||
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 | |||
"...", ., "[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param str paths: | |||
:return: DataBundle | |||
""" | |||
data_bundle = DBPediaLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class SSTPipe(CLSBasePipe): | |||
r""" | |||
经过该Pipe之后,DataSet中具备的field如下所示 | |||
.. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field | |||
:header: "raw_words", "words", "target", "seq_len" | |||
"It 's a lovely film with lovely perfor...", 1, "[187, 6, 5, 132, 120, 70, 132, 188, 25...", 13 | |||
"No one goes unindicted here , which is...", 0, "[191, 126, 192, 193, 194, 4, 195, 17, ...", 13 | |||
"...", ., "[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): | |||
r""" | |||
:param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` | |||
:param bool train_subtree: 是否将train集通过子树扩展数据。 | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将 | |||
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.subtree = subtree | |||
self.train_tree = train_subtree | |||
self.lower = lower | |||
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." | |||
self.granularity = granularity | |||
if granularity == 2: | |||
self.tag_map = {"0": "negative", "1": "negative", "3": "positive", "4": "positive"} | |||
elif granularity == 3: | |||
self.tag_map = {"0": "negative", "1": "negative", "2": "medium", "3": "positive", "4": "positive"} | |||
else: | |||
self.tag_map = None | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 | |||
.. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field | |||
:header: "raw_words" | |||
"(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." | |||
"(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." | |||
"..." | |||
:param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 | |||
:return: | |||
""" | |||
# 先取出subtree | |||
for name in list(data_bundle.datasets.keys()): | |||
dataset = data_bundle.get_dataset(name) | |||
ds = DataSet() | |||
use_subtree = self.subtree or (name == 'train' and self.train_tree) | |||
for ins in dataset: | |||
raw_words = ins['raw_words'] | |||
tree = Tree.fromstring(raw_words) | |||
if use_subtree: | |||
for t in tree.subtrees(): | |||
raw_words = " ".join(t.leaves()) | |||
instance = Instance(raw_words=raw_words, target=t.label()) | |||
ds.append(instance) | |||
else: | |||
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) | |||
ds.append(instance) | |||
data_bundle.set_dataset(ds, name) | |||
# 根据granularity设置tag | |||
data_bundle = _granularize(data_bundle, tag_map=self.tag_map) | |||
data_bundle = super().process(data_bundle) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
data_bundle = SSTLoader().load(paths) | |||
return self.process(data_bundle=data_bundle) | |||
class SST2Pipe(CLSBasePipe): | |||
r""" | |||
加载SST2的数据, 处理完成之后DataSet将拥有以下的field | |||
.. csv-table:: | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"it 's a charming and often affecting j... ", 1, "[19, 9, 6, 111, 5, 112, 113, 114, 3]", 9 | |||
"unflinchingly bleak and desperate", 0, "[115, 116, 5, 117]", 4 | |||
"...", "...", ., . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower=False, tokenizer='spacy'): | |||
r""" | |||
:param bool lower: 是否对输入进行小写化。 | |||
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 | |||
""" | |||
super().__init__(lower=lower, tokenizer=tokenizer, lang='en') | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 | |||
:return: DataBundle | |||
""" | |||
data_bundle = SST2Loader().load(paths) | |||
return self.process(data_bundle) | |||
class IMDBPipe(CLSBasePipe): | |||
r""" | |||
经过本Pipe处理后DataSet将如下 | |||
.. csv-table:: 输出DataSet的field | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"Bromwell High is a cartoon ... ", 0, "[3, 5, 6, 9, ...]", 20 | |||
"Story of a man who has ...", 1, "[20, 43, 9, 10, ...]", 31 | |||
"...", ., "[...]", . | |||
其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; | |||
words列被设置为input; target列被设置为target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | False | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 | |||
.. csv-table:: 输入DataSet的field | |||
:header: "raw_words", "target" | |||
"Bromwell High is a cartoon ... ", "pos" | |||
"Story of a man who has ...", "neg" | |||
"...", "..." | |||
:param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, | |||
target列应该为str。 | |||
:return: DataBundle | |||
""" | |||
# 替换<br /> | |||
def replace_br(raw_words): | |||
raw_words = raw_words.replace("<br />", ' ') | |||
return raw_words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(replace_br, field_name='raw_words', new_field_name='raw_words') | |||
data_bundle = super().process(data_bundle) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = IMDBLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class ChnSentiCorpPipe(Pipe): | |||
r""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", 1, "[2, 3, 4, 5, ...]", 31 | |||
"<荐书> 推荐所有喜欢<红楼>...", 1, "[10, 21, ....]", 25 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
r""" | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
super().__init__() | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def _tokenize(self, data_bundle): | |||
r""" | |||
将DataSet中的"复旦大学"拆分为["复", "旦", "大", "学"]. 未来可以通过扩展这个函数实现分词。 | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle.apply_field(list, field_name='chars', new_field_name='chars') | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
可以处理的DataSet应该具备以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"這間酒店環境和服務態度亦算不錯,但房間空間太小~~", "1" | |||
"<荐书> 推荐所有喜欢<红楼>...", "1" | |||
"..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
_add_chars_field(data_bundle, lower=False) | |||
data_bundle = self._tokenize(data_bundle) | |||
input_field_names = ['chars'] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names, 'target') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target'] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len('chars') | |||
data_bundle.set_input(*input_fields, *target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = ChnSentiCorpLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class THUCNewsPipe(CLSBasePipe): | |||
r""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", 0, "[409, 1197, 2146, 213, ...]", 746 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
super().__init__() | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def _chracter_split(self, sent): | |||
return list(sent) | |||
# return [w for w in sent] | |||
def _raw_split(self, sent): | |||
return sent.split() | |||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
可处理的DataSet应具备如下的field | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道 ... ", "体育" | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
# 根据granularity设置tag | |||
tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9} | |||
data_bundle = _granularize(data_bundle=data_bundle, tag_map=tag_map) | |||
# clean,lower | |||
# CWS(tokenize) | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') | |||
input_field_names = ['chars'] | |||
# n-grams | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') | |||
# add length | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target'] | |||
data_bundle.set_input(*input_fields, *target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
data_loader = THUCNewsLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None | |||
data_bundle = data_loader.load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class WeiboSenti100kPipe(CLSBasePipe): | |||
r""" | |||
处理之后的DataSet有以下的结构 | |||
.. csv-table:: | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", 0, "[0, 690, 18, ...]", 56 | |||
"..." | |||
其中chars, seq_len是input,target是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | False | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
def __init__(self, bigrams=False, trigrams=False): | |||
super().__init__() | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def _chracter_split(self, sent): | |||
return list(sent) | |||
def _tokenize(self, data_bundle, field_name='words', new_field_name=None): | |||
new_field_name = new_field_name or field_name | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self._chracter_split, field_name=field_name, new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
可处理的DataSet应具备以下的field | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", "0" | |||
"...", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
# clean,lower | |||
# CWS(tokenize) | |||
data_bundle = self._tokenize(data_bundle=data_bundle, field_name='raw_chars', new_field_name='chars') | |||
input_field_names = ['chars'] | |||
# n-grams | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
data_bundle = _indexize(data_bundle=data_bundle, input_field_names='chars') | |||
# add length | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len(field_name='chars', new_field_name='seq_len') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target'] | |||
data_bundle.set_input(*input_fields, *target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
data_loader = WeiboSenti100kLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None | |||
data_bundle = data_loader.load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class MRPipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = MRLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class R8Pipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = R8Loader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class R52Pipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = R52Loader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class OhsumedPipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = OhsumedLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class NG20Pipe(CLSBasePipe): | |||
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): | |||
r""" | |||
:param bool lower: 是否将words列的数据小写。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__(tokenizer=tokenizer, lang='en') | |||
self.lower = lower | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = NG20Loader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle |
@@ -0,0 +1,427 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"Conll2003NERPipe", | |||
"Conll2003Pipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"PeopleDailyPipe", | |||
"WeiboNERPipe" | |||
] | |||
from .pipe import Pipe | |||
from .utils import _add_chars_field | |||
from .utils import _indexize, _add_words_field | |||
from .utils import iob2, iob2bioes | |||
from fastNLP.io.data_bundle import DataBundle | |||
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader | |||
from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader | |||
# from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
class _NERPipe(Pipe): | |||
r""" | |||
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||
Vocabulary转换为index。 | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', lower: bool = False): | |||
r""" | |||
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 | |||
""" | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
elif encoding_type == 'bioes': | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
else: | |||
raise ValueError("encoding_type only supports `bio` and `bioes`.") | |||
self.lower = lower | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
r""" | |||
支持的DataSet的field为 | |||
.. csv-table:: | |||
:header: "raw_words", "target" | |||
"[Nadim, Ladki]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]" | |||
"[...]", "[...]" | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]在传入DataBundle基础上原位修改。 | |||
:return DataBundle: | |||
""" | |||
# 转换tag | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# index | |||
_indexize(data_bundle) | |||
input_fields = ['target', 'words', 'seq_len'] | |||
target_fields = ['target', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('words') | |||
data_bundle.set_input(*input_fields, *target_fields) | |||
return data_bundle | |||
class Conll2003NERPipe(_NERPipe): | |||
r""" | |||
Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 | |||
Vocabulary转换为index。 | |||
经过该Pipe过后,DataSet中的内容如下所示 | |||
.. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4,...]", "[4, 5, 6,...]", 6 | |||
"[...]", "[...]", "[...]", . | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths) -> DataBundle: | |||
r""" | |||
:param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 | |||
:return: DataBundle | |||
""" | |||
# 读取数据 | |||
data_bundle = Conll2003NERLoader().load(paths) | |||
data_bundle = self.process(data_bundle) | |||
return data_bundle | |||
class Conll2003Pipe(Pipe): | |||
r""" | |||
经过该Pipe后,DataSet中的内容如下 | |||
.. csv-table:: | |||
:header: "raw_words" , "pos", "chunk", "ner", "words", "seq_len" | |||
"[Nadim, Ladki]", "[0, 0]", "[1, 2]", "[1, 2]", "[2, 3]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", "[4, 5, 6,...]", 6 | |||
"[...]", "[...]", "[...]", "[...]", "[...]", . | |||
其中words, seq_len是input; pos, chunk, ner, seq_len是target | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||
| field_names | raw_words | pos | chunk | ner | words | seq_len | | |||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||
| is_input | False | False | False | False | True | True | | |||
| is_target | False | True | True | True | False | True | | |||
| ignore_type | | False | False | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | 0 | 0 | | |||
+-------------+-----------+-------+-------+-------+-------+---------+ | |||
""" | |||
def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False): | |||
r""" | |||
:param str chunk_encoding_type: 支持bioes, bio。 | |||
:param str ner_encoding_type: 支持bioes, bio。 | |||
:param bool lower: 是否将words列小写化后再建立词表 | |||
""" | |||
if chunk_encoding_type == 'bio': | |||
self.chunk_convert_tag = iob2 | |||
elif chunk_encoding_type == 'bioes': | |||
self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
else: | |||
raise ValueError("chunk_encoding_type only supports `bio` and `bioes`.") | |||
if ner_encoding_type == 'bio': | |||
self.ner_convert_tag = iob2 | |||
elif ner_encoding_type == 'bioes': | |||
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) | |||
else: | |||
raise ValueError("ner_encoding_type only supports `bio` and `bioes`.") | |||
self.lower = lower | |||
def process(self, data_bundle) -> DataBundle: | |||
r""" | |||
输入的DataSet应该类似于如下的形式 | |||
.. csv-table:: | |||
:header: "raw_words", "pos", "chunk", "ner" | |||
"[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" | |||
"[AL-AIN, United, Arab, ...]", "[NNP, NNP...]", "[B-NP, B-NP, ...]", "[B-LOC, B-LOC,...]" | |||
"[...]", "[...]", "[...]", "[...]", . | |||
:param data_bundle: | |||
:return: 传入的DataBundle | |||
""" | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.drop(lambda x: "-DOCSTART-" in x['raw_words']) | |||
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') | |||
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') | |||
_add_words_field(data_bundle, lower=self.lower) | |||
# index | |||
_indexize(data_bundle, input_field_names='words', target_field_names=['pos', 'ner']) | |||
# chunk中存在一些tag只在dev中出现,没在train中 | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') | |||
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') | |||
data_bundle.set_vocab(tgt_vocab, 'chunk') | |||
input_fields = ['words', 'seq_len'] | |||
target_fields = ['pos', 'ner', 'chunk', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('words') | |||
data_bundle.set_input(*input_fields, *target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths): | |||
r""" | |||
:param paths: | |||
:return: | |||
""" | |||
data_bundle = ConllLoader(headers=['raw_words', 'pos', 'chunk', 'ner']).load(paths) | |||
return self.process(data_bundle) | |||
class OntoNotesNERPipe(_NERPipe): | |||
r""" | |||
处理OntoNotes的NER数据,处理之后DataSet中的field情况为 | |||
.. csv-table:: | |||
:header: "raw_words", "target", "words", "seq_len" | |||
"[Nadim, Ladki]", "[1, 2]", "[2, 3]", 2 | |||
"[AL-AIN, United, Arab, ...]", "[3, 4]", "[4, 5, 6,...]", 6 | |||
"[...]", "[...]", "[...]", . | |||
raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_words | target | words | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths): | |||
data_bundle = OntoNotesNERLoader().load(paths) | |||
return self.process(data_bundle) | |||
class _CNNERPipe(Pipe): | |||
r""" | |||
中文NER任务的处理Pipe, 该Pipe会(1)复制raw_chars列,并命名为chars; (2)在chars, target列建立词表 | |||
(创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将chars,target列根据相应的 | |||
Vocabulary转换为index。 | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。 | |||
""" | |||
def __init__(self, encoding_type: str = 'bio', bigrams=False, trigrams=False): | |||
r""" | |||
:param str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 | |||
:param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 | |||
设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('bigrams')获取. | |||
:param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 | |||
data_bundle.get_vocab('trigrams')获取. | |||
""" | |||
if encoding_type == 'bio': | |||
self.convert_tag = iob2 | |||
elif encoding_type == 'bioes': | |||
self.convert_tag = lambda words: iob2bioes(iob2(words)) | |||
else: | |||
raise ValueError("encoding_type only supports `bio` and `bioes`.") | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
r""" | |||
支持的DataSet的field为 | |||
.. csv-table:: | |||
:header: "raw_chars", "target" | |||
"[相, 比, 之, 下,...]", "[O, O, O, O, ...]" | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]" | |||
"[...]", "[...]" | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int], | |||
是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
:param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。在传入DataBundle基础上原位修改。 | |||
:return: DataBundle | |||
""" | |||
# 转换tag | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(self.convert_tag, field_name='target', new_field_name='target') | |||
_add_chars_field(data_bundle, lower=False) | |||
input_field_names = ['chars'] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
# index | |||
_indexize(data_bundle, input_field_names, 'target') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('chars') | |||
data_bundle.set_input(*input_fields, *target_fields) | |||
return data_bundle | |||
class MsraNERPipe(_CNNERPipe): | |||
r""" | |||
处理MSRA-NER的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 | |||
"[...]", "[...]", "[...]", . | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = MsraNERLoader().load(paths) | |||
return self.process(data_bundle) | |||
class PeopleDailyPipe(_CNNERPipe): | |||
r""" | |||
处理people daily的ner的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
:header: "raw_chars", "target", "chars", "seq_len" | |||
"[相, 比, 之, 下,...]", "[0, 0, 0, 0, ...]", "[2, 3, 4, 5, ...]", 11 | |||
"[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 | |||
"[...]", "[...]", "[...]", . | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = PeopleDailyNERLoader().load(paths) | |||
return self.process(data_bundle) | |||
class WeiboNERPipe(_CNNERPipe): | |||
r""" | |||
处理weibo的ner的数据,处理之后的DataSet的field情况为 | |||
.. csv-table:: | |||
:header: "raw_chars", "chars", "target", "seq_len" | |||
"['老', '百', '姓']", "[4, 3, 3]", "[38, 39, 40]", 3 | |||
"['心']", "[0]", "[41]", 1 | |||
"[...]", "[...]", "[...]", . | |||
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 | |||
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = WeiboNERLoader().load(paths) | |||
return self.process(data_bundle) |
@@ -0,0 +1,286 @@ | |||
__all__ = [ | |||
'MRPmiGraphPipe', | |||
'R8PmiGraphPipe', | |||
'R52PmiGraphPipe', | |||
'OhsumedPmiGraphPipe', | |||
'NG20PmiGraphPipe' | |||
] | |||
try: | |||
import networkx as nx | |||
from sklearn.feature_extraction.text import CountVectorizer | |||
from sklearn.feature_extraction.text import TfidfTransformer | |||
from sklearn.pipeline import Pipeline | |||
except: | |||
pass | |||
from collections import defaultdict | |||
import itertools | |||
import math | |||
import numpy as np | |||
from ..data_bundle import DataBundle | |||
# from ...core.const import Const | |||
from ..loader.classification import MRLoader, OhsumedLoader, R52Loader, R8Loader, NG20Loader | |||
from fastNLP.core.utils import f_rich_progress | |||
def _get_windows(content_lst: list, window_size: int): | |||
r""" | |||
滑动窗口处理文本,获取词频和共现词语的词频 | |||
:param content_lst: | |||
:param window_size: | |||
:return: 词频,共现词频,窗口化后文本段的数量 | |||
""" | |||
word_window_freq = defaultdict(int) # w(i) 单词在窗口单位内出现的次数 | |||
word_pair_count = defaultdict(int) # w(i, j) | |||
windows_len = 0 | |||
task_id = f_rich_progress.add_task(description="Split by window", total=len(content_lst)) | |||
for words in content_lst: | |||
windows = list() | |||
if isinstance(words, str): | |||
words = words.split() | |||
length = len(words) | |||
if length <= window_size: | |||
windows.append(words) | |||
else: | |||
for j in range(length - window_size + 1): | |||
window = words[j: j + window_size] | |||
windows.append(list(set(window))) | |||
for window in windows: | |||
for word in window: | |||
word_window_freq[word] += 1 | |||
for word_pair in itertools.combinations(window, 2): | |||
word_pair_count[word_pair] += 1 | |||
windows_len += len(windows) | |||
f_rich_progress.update(task_id, advance=1) | |||
f_rich_progress.destroy_task(task_id) | |||
return word_window_freq, word_pair_count, windows_len | |||
def _cal_pmi(W_ij, W, word_freq_i, word_freq_j): | |||
r""" | |||
params: w_ij:为词语i,j的共现词频 | |||
w:文本数量 | |||
word_freq_i: 词语i的词频 | |||
word_freq_j: 词语j的词频 | |||
return: 词语i,j的tfidf值 | |||
""" | |||
p_i = word_freq_i / W | |||
p_j = word_freq_j / W | |||
p_i_j = W_ij / W | |||
pmi = math.log(p_i_j / (p_i * p_j)) | |||
return pmi | |||
def _count_pmi(windows_len, word_pair_count, word_window_freq, threshold): | |||
r""" | |||
params: windows_len: 文本段数量 | |||
word_pair_count: 词共现频率字典 | |||
word_window_freq: 词频率字典 | |||
threshold: 阈值 | |||
return 词语pmi的list列表,其中元素为[word1, word2, pmi] | |||
""" | |||
word_pmi_lst = list() | |||
task_id = f_rich_progress.add_task(description="Calculate pmi between words", total=len(word_pair_count)) | |||
for word_pair, W_i_j in word_pair_count.items(): | |||
word_freq_1 = word_window_freq[word_pair[0]] | |||
word_freq_2 = word_window_freq[word_pair[1]] | |||
pmi = _cal_pmi(W_i_j, windows_len, word_freq_1, word_freq_2) | |||
if pmi <= threshold: | |||
continue | |||
word_pmi_lst.append([word_pair[0], word_pair[1], pmi]) | |||
f_rich_progress.update(task_id, advance=1) | |||
f_rich_progress.destory_task(task_id) | |||
return word_pmi_lst | |||
class GraphBuilderBase: | |||
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): | |||
self.graph = nx.Graph() | |||
self.word2id = dict() | |||
self.graph_type = graph_type | |||
self.window_size = widow_size | |||
self.doc_node_num = 0 | |||
self.tr_doc_index = None | |||
self.te_doc_index = None | |||
self.dev_doc_index = None | |||
self.doc = None | |||
self.threshold = threshold | |||
def _get_doc_edge(self, data_bundle: DataBundle): | |||
r""" | |||
对输入的DataBundle进行处理,然后生成文档-单词的tfidf值 | |||
:param: data_bundle中的文本若为英文,形式为[ 'This is the first document.'],若为中文则为['他 喜欢 吃 苹果'] | |||
: return 返回带有具有tfidf边文档-单词稀疏矩阵 | |||
""" | |||
tr_doc = list(data_bundle.get_dataset("train").get_field('raw_words')) | |||
val_doc = list(data_bundle.get_dataset("dev").get_field('raw_words')) | |||
te_doc = list(data_bundle.get_dataset("test").get_field('raw_words')) | |||
doc = tr_doc + val_doc + te_doc | |||
self.doc = doc | |||
self.tr_doc_index = [ind for ind in range(len(tr_doc))] | |||
self.dev_doc_index = [ind + len(tr_doc) for ind in range(len(val_doc))] | |||
self.te_doc_index = [ind + len(tr_doc) + len(val_doc) for ind in range(len(te_doc))] | |||
text_tfidf = Pipeline([('count', CountVectorizer(token_pattern=r'\S+', min_df=1, max_df=1.0)), | |||
('tfidf', | |||
TfidfTransformer(norm=None, use_idf=True, smooth_idf=False, sublinear_tf=False))]) | |||
tfidf_vec = text_tfidf.fit_transform(doc) | |||
self.doc_node_num = tfidf_vec.shape[0] | |||
vocab_lst = text_tfidf['count'].get_feature_names() | |||
for ind, word in enumerate(vocab_lst): | |||
self.word2id[word] = ind | |||
for ind, row in enumerate(tfidf_vec): | |||
for col_index, value in zip(row.indices, row.data): | |||
self.graph.add_edge(ind, self.doc_node_num + col_index, weight=value) | |||
return nx.to_scipy_sparse_matrix(self.graph) | |||
def _get_word_edge(self): | |||
word_window_freq, word_pair_count, windows_len = _get_windows(self.doc, self.window_size) | |||
pmi_edge_lst = _count_pmi(windows_len, word_pair_count, word_window_freq, self.threshold) | |||
for edge_item in pmi_edge_lst: | |||
word_indx1 = self.doc_node_num + self.word2id[edge_item[0]] | |||
word_indx2 = self.doc_node_num + self.word2id[edge_item[1]] | |||
if word_indx1 == word_indx2: | |||
continue | |||
self.graph.add_edge(word_indx1, word_indx2, weight=edge_item[2]) | |||
def build_graph(self, data_bundle: DataBundle): | |||
r""" | |||
对输入的DataBundle进行处理,然后返回该scipy_sparse_matrix类型的邻接矩阵。 | |||
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 | |||
:return: | |||
""" | |||
raise NotImplementedError | |||
def build_graph_from_file(self, path: str): | |||
r""" | |||
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
:param path: | |||
:return: scipy_sparse_matrix | |||
""" | |||
raise NotImplementedError | |||
class MRPmiGraphPipe(GraphBuilderBase): | |||
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): | |||
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) | |||
def build_graph(self, data_bundle: DataBundle): | |||
r""" | |||
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. | |||
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. | |||
""" | |||
self._get_doc_edge(data_bundle) | |||
self._get_word_edge() | |||
return nx.to_scipy_sparse_matrix(self.graph, | |||
nodelist=list(range(self.graph.number_of_nodes())), | |||
weight='weight', dtype=np.float32, format='csr'), ( | |||
self.tr_doc_index, self.dev_doc_index, self.te_doc_index) | |||
def build_graph_from_file(self, path: str): | |||
data_bundle = MRLoader().load(path) | |||
return self.build_graph(data_bundle) | |||
class R8PmiGraphPipe(GraphBuilderBase): | |||
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): | |||
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) | |||
def build_graph(self, data_bundle: DataBundle): | |||
r""" | |||
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. | |||
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. | |||
""" | |||
self._get_doc_edge(data_bundle) | |||
self._get_word_edge() | |||
return nx.to_scipy_sparse_matrix(self.graph, | |||
nodelist=list(range(self.graph.number_of_nodes())), | |||
weight='weight', dtype=np.float32, format='csr'), ( | |||
self.tr_doc_index, self.dev_doc_index, self.te_doc_index) | |||
def build_graph_from_file(self, path: str): | |||
data_bundle = R8Loader().load(path) | |||
return self.build_graph(data_bundle) | |||
class R52PmiGraphPipe(GraphBuilderBase): | |||
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): | |||
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) | |||
def build_graph(self, data_bundle: DataBundle): | |||
r""" | |||
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. | |||
return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index. | |||
""" | |||
self._get_doc_edge(data_bundle) | |||
self._get_word_edge() | |||
return nx.to_scipy_sparse_matrix(self.graph, | |||
nodelist=list(range(self.graph.number_of_nodes())), | |||
weight='weight', dtype=np.float32, format='csr'), ( | |||
self.tr_doc_index, self.dev_doc_index, self.te_doc_index) | |||
def build_graph_from_file(self, path: str): | |||
data_bundle = R52Loader().load(path) | |||
return self.build_graph(data_bundle) | |||
class OhsumedPmiGraphPipe(GraphBuilderBase): | |||
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): | |||
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) | |||
def build_graph(self, data_bundle: DataBundle): | |||
r""" | |||
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. | |||
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. | |||
""" | |||
self._get_doc_edge(data_bundle) | |||
self._get_word_edge() | |||
return nx.to_scipy_sparse_matrix(self.graph, | |||
nodelist=list(range(self.graph.number_of_nodes())), | |||
weight='weight', dtype=np.float32, format='csr'), ( | |||
self.tr_doc_index, self.dev_doc_index, self.te_doc_index) | |||
def build_graph_from_file(self, path: str): | |||
data_bundle = OhsumedLoader().load(path) | |||
return self.build_graph(data_bundle) | |||
class NG20PmiGraphPipe(GraphBuilderBase): | |||
def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): | |||
super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) | |||
def build_graph(self, data_bundle: DataBundle): | |||
r""" | |||
params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. | |||
return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. | |||
""" | |||
self._get_doc_edge(data_bundle) | |||
self._get_word_edge() | |||
return nx.to_scipy_sparse_matrix(self.graph, | |||
nodelist=list(range(self.graph.number_of_nodes())), | |||
weight='weight', dtype=np.float32, format='csr'), ( | |||
self.tr_doc_index, self.dev_doc_index, self.te_doc_index) | |||
def build_graph_from_file(self, path: str): | |||
r""" | |||
param: path->数据集的路径. | |||
return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. | |||
""" | |||
data_bundle = NG20Loader().load(path) | |||
return self.build_graph(data_bundle) |
@@ -0,0 +1,186 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CoReferencePipe" | |||
] | |||
import collections | |||
import numpy as np | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from .pipe import Pipe | |||
from ..data_bundle import DataBundle | |||
from ..loader.coreference import CoReferenceLoader | |||
# from ...core.const import Const | |||
class CoReferencePipe(Pipe): | |||
r""" | |||
对Coreference resolution问题进行处理,得到文章种类/说话者/字符级信息/序列长度。 | |||
处理完成后数据包含文章类别、speaker信息、句子信息、句子对应的index、char、句子长度、target: | |||
.. csv-table:: | |||
:header: "words1", "words2","words3","words4","chars","seq_len","target" | |||
"bc", "[[0,0],[1,1]]","[['I','am'],[]]","[[1,2],[]]","[[[1],[2,3]],[]]","[2,3]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||
"[...]", "[...]","[...]","[...]","[...]","[...]","[...]" | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+--------+-------+---------+ | |||
| field_names | raw_chars | target | chars | seq_len | | |||
+-------------+-----------+--------+-------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | True | False | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+--------+-------+---------+ | |||
""" | |||
def __init__(self, config): | |||
super().__init__() | |||
self.config = config | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
对load进来的数据进一步处理原始数据包含:raw_key,raw_speaker,raw_words,raw_clusters | |||
.. csv-table:: | |||
:header: "raw_key", "raw_speaker","raw_words","raw_clusters" | |||
"bc/cctv/00/cctv_0000_0", "[[Speaker#1, Speaker#1],[]]","[['I','am'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||
"bc/cctv/00/cctv_0000_1", "[['Speaker#1', 'peaker#1'],[]]","[['He','is'],[]]","[[[2,3],[6,7]],[[10,12],[20,22]]]" | |||
"[...]", "[...]","[...]","[...]" | |||
:param data_bundle: | |||
:return: | |||
""" | |||
genres = {g: i for i, g in enumerate(["bc", "bn", "mz", "nw", "pt", "tc", "wb"])} | |||
vocab = Vocabulary().from_dataset(*data_bundle.datasets.values(), field_name='raw_words4') | |||
vocab.build_vocab() | |||
word2id = vocab.word2idx | |||
data_bundle.set_vocab(vocab, 'words1') | |||
if self.config.char_path: | |||
char_dict = get_char_dict(self.config.char_path) | |||
else: | |||
char_set = set() | |||
for i, w in enumerate(word2id): | |||
if i < 2: | |||
continue | |||
for c in w: | |||
char_set.add(c) | |||
char_dict = collections.defaultdict(int) | |||
char_dict.update({c: i for i, c in enumerate(char_set)}) | |||
for name, ds in data_bundle.iter_datasets(): | |||
# genre | |||
ds.apply(lambda x: genres[x['raw_words1'][:2]], new_field_name='words1') | |||
# speaker_ids_np | |||
ds.apply(lambda x: speaker2numpy(x['raw_words2'], self.config.max_sentences, is_train=name == 'train'), | |||
new_field_name='words2') | |||
# sentences | |||
ds.rename_field('raw_words4', 'words3') | |||
# doc_np | |||
ds.apply(lambda x: doc2numpy(x['words3'], word2id, char_dict, max(self.config.filter), | |||
self.config.max_sentences, is_train=name == 'train')[0], | |||
new_field_name='words4') | |||
# char_index | |||
ds.apply(lambda x: doc2numpy(x['words3'], word2id, char_dict, max(self.config.filter), | |||
self.config.max_sentences, is_train=name == 'train')[1], | |||
new_field_name='chars') | |||
# seq len | |||
ds.apply(lambda x: doc2numpy(x['words3'], word2id, char_dict, max(self.config.filter), | |||
self.config.max_sentences, is_train=name == 'train')[2], | |||
new_field_name='seq_len') | |||
# clusters | |||
ds.rename_field('raw_words3', 'target') | |||
ds.set_input('words1', 'words2', 'words3', 'words4', 'chars', 'seq_len', 'target') | |||
return data_bundle | |||
def process_from_file(self, paths): | |||
bundle = CoReferenceLoader().load(paths) | |||
return self.process(bundle) | |||
# helper | |||
def doc2numpy(doc, word2id, chardict, max_filter, max_sentences, is_train): | |||
docvec, char_index, length, max_len = _doc2vec(doc, word2id, chardict, max_filter, max_sentences, is_train) | |||
assert max(length) == max_len | |||
assert char_index.shape[0] == len(length) | |||
assert char_index.shape[1] == max_len | |||
doc_np = np.zeros((len(docvec), max_len), int) | |||
for i in range(len(docvec)): | |||
for j in range(len(docvec[i])): | |||
doc_np[i][j] = docvec[i][j] | |||
return doc_np, char_index, length | |||
def _doc2vec(doc, word2id, char_dict, max_filter, max_sentences, is_train): | |||
max_len = 0 | |||
max_word_length = 0 | |||
docvex = [] | |||
length = [] | |||
if is_train: | |||
sent_num = min(max_sentences, len(doc)) | |||
else: | |||
sent_num = len(doc) | |||
for i in range(sent_num): | |||
sent = doc[i] | |||
length.append(len(sent)) | |||
if (len(sent) > max_len): | |||
max_len = len(sent) | |||
sent_vec = [] | |||
for j, word in enumerate(sent): | |||
if len(word) > max_word_length: | |||
max_word_length = len(word) | |||
if word in word2id: | |||
sent_vec.append(word2id[word]) | |||
else: | |||
sent_vec.append(word2id["UNK"]) | |||
docvex.append(sent_vec) | |||
char_index = np.zeros((sent_num, max_len, max_word_length), dtype=int) | |||
for i in range(sent_num): | |||
sent = doc[i] | |||
for j, word in enumerate(sent): | |||
char_index[i, j, :len(word)] = [char_dict[c] for c in word] | |||
return docvex, char_index, length, max_len | |||
def speaker2numpy(speakers_raw, max_sentences, is_train): | |||
if is_train and len(speakers_raw) > max_sentences: | |||
speakers_raw = speakers_raw[0:max_sentences] | |||
speakers = flatten(speakers_raw) | |||
speaker_dict = {s: i for i, s in enumerate(set(speakers))} | |||
speaker_ids = np.array([speaker_dict[s] for s in speakers]) | |||
return speaker_ids | |||
# 展平 | |||
def flatten(l): | |||
return [item for sublist in l for item in sublist] | |||
def get_char_dict(path): | |||
vocab = ["<UNK>"] | |||
with open(path) as f: | |||
vocab.extend(c.strip() for c in f.readlines()) | |||
char_dict = collections.defaultdict(int) | |||
char_dict.update({c: i for i, c in enumerate(vocab)}) | |||
return char_dict |
@@ -0,0 +1,282 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"CWSPipe" | |||
] | |||
import re | |||
from itertools import chain | |||
from .pipe import Pipe | |||
from .utils import _indexize | |||
from fastNLP.io.data_bundle import DataBundle | |||
from fastNLP.io.loader import CWSLoader | |||
# from ...core.const import Const | |||
def _word_lens_to_bmes(word_lens): | |||
r""" | |||
:param list word_lens: List[int], 每个词语的长度 | |||
:return: List[str], BMES的序列 | |||
""" | |||
tags = [] | |||
for word_len in word_lens: | |||
if word_len == 1: | |||
tags.append('S') | |||
else: | |||
tags.append('B') | |||
tags.extend(['M'] * (word_len - 2)) | |||
tags.append('E') | |||
return tags | |||
def _word_lens_to_segapp(word_lens): | |||
r""" | |||
:param list word_lens: List[int], 每个词语的长度 | |||
:return: List[str], BMES的序列 | |||
""" | |||
tags = [] | |||
for word_len in word_lens: | |||
if word_len == 1: | |||
tags.append('SEG') | |||
else: | |||
tags.extend(['APP'] * (word_len - 1)) | |||
tags.append('SEG') | |||
return tags | |||
def _alpha_span_to_special_tag(span): | |||
r""" | |||
将span替换成特殊的字符 | |||
:param str span: | |||
:return: | |||
""" | |||
if 'oo' == span.lower(): # speical case when represent 2OO8 | |||
return span | |||
if len(span) == 1: | |||
return span | |||
else: | |||
return '<ENG>' | |||
def _find_and_replace_alpha_spans(line): | |||
r""" | |||
传入原始句子,替换其中的字母为特殊标记 | |||
:param str line:原始数据 | |||
:return: str | |||
""" | |||
new_line = '' | |||
pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%,.。!<-“])' | |||
prev_end = 0 | |||
for match in re.finditer(pattern, line): | |||
start, end = match.span() | |||
span = line[start:end] | |||
new_line += line[prev_end:start] + _alpha_span_to_special_tag(span) | |||
prev_end = end | |||
new_line += line[prev_end:] | |||
return new_line | |||
def _digit_span_to_special_tag(span): | |||
r""" | |||
:param str span: 需要替换的str | |||
:return: | |||
""" | |||
if span[0] == '0' and len(span) > 2: | |||
return '<NUM>' | |||
decimal_point_count = 0 # one might have more than one decimal pointers | |||
for idx, char in enumerate(span): | |||
if char == '.' or char == '﹒' or char == '·': | |||
decimal_point_count += 1 | |||
if span[-1] == '.' or span[-1] == '﹒' or span[ | |||
-1] == '·': # last digit being decimal point means this is not a number | |||
if decimal_point_count == 1: | |||
return span | |||
else: | |||
return '<UNKDGT>' | |||
if decimal_point_count == 1: | |||
return '<DEC>' | |||
elif decimal_point_count > 1: | |||
return '<UNKDGT>' | |||
else: | |||
return '<NUM>' | |||
def _find_and_replace_digit_spans(line): | |||
r""" | |||
only consider words start with number, contains '.', characters. | |||
If ends with space, will be processed | |||
If ends with Chinese character, will be processed | |||
If ends with or contains english char, not handled. | |||
floats are replaced by <DEC> | |||
otherwise unkdgt | |||
""" | |||
new_line = '' | |||
pattern = r'\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%%,。!<-“])' | |||
prev_end = 0 | |||
for match in re.finditer(pattern, line): | |||
start, end = match.span() | |||
span = line[start:end] | |||
new_line += line[prev_end:start] + _digit_span_to_special_tag(span) | |||
prev_end = end | |||
new_line += line[prev_end:] | |||
return new_line | |||
class CWSPipe(Pipe): | |||
r""" | |||
对CWS数据进行预处理, 处理之后的数据,具备以下的结构 | |||
.. csv-table:: | |||
:header: "raw_words", "chars", "target", "seq_len" | |||
"共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", 13 | |||
"2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", 20 | |||
"...", "[...]","[...]", . | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+-----------+-------+--------+---------+ | |||
| field_names | raw_words | chars | target | seq_len | | |||
+-------------+-----------+-------+--------+---------+ | |||
| is_input | False | True | True | True | | |||
| is_target | False | False | True | True | | |||
| ignore_type | | False | False | False | | |||
| pad_value | | 0 | 0 | 0 | | |||
+-------------+-----------+-------+--------+---------+ | |||
""" | |||
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): | |||
r""" | |||
:param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None | |||
:param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp | |||
的tag为[seg, app, seg, app, app, app, seg, ...] | |||
:param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 | |||
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] | |||
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] | |||
""" | |||
if encoding_type == 'bmes': | |||
self.word_lens_to_tags = _word_lens_to_bmes | |||
else: | |||
self.word_lens_to_tags = _word_lens_to_segapp | |||
self.dataset_name = dataset_name | |||
self.bigrams = bigrams | |||
self.trigrams = trigrams | |||
self.replace_num_alpha = replace_num_alpha | |||
def _tokenize(self, data_bundle): | |||
r""" | |||
将data_bundle中的'chars'列切分成一个一个的word. | |||
例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ] | |||
:param data_bundle: | |||
:return: | |||
""" | |||
def split_word_into_chars(raw_chars): | |||
words = raw_chars.split() | |||
chars = [] | |||
for word in words: | |||
char = [] | |||
subchar = [] | |||
for c in word: | |||
if c == '<': | |||
if subchar: | |||
char.extend(subchar) | |||
subchar = [] | |||
subchar.append(c) | |||
continue | |||
if c == '>' and len(subchar)>0 and subchar[0] == '<': | |||
subchar.append(c) | |||
char.append(''.join(subchar)) | |||
subchar = [] | |||
continue | |||
if subchar: | |||
subchar.append(c) | |||
else: | |||
char.append(c) | |||
char.extend(subchar) | |||
chars.append(char) | |||
return chars | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(split_word_into_chars, field_name='chars', | |||
new_field_name='chars') | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
r""" | |||
可以处理的DataSet需要包含raw_words列 | |||
.. csv-table:: | |||
:header: "raw_words" | |||
"上海 浦东 开发 与 法制 建设 同步" | |||
"新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" | |||
"..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle.copy_field('raw_words', 'chars') | |||
if self.replace_num_alpha: | |||
data_bundle.apply_field(_find_and_replace_alpha_spans, 'chars', 'chars') | |||
data_bundle.apply_field(_find_and_replace_digit_spans, 'chars', 'chars') | |||
self._tokenize(data_bundle) | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name='chars', | |||
new_field_name='target') | |||
dataset.apply_field(lambda chars: list(chain(*chars)), field_name='chars', | |||
new_field_name='chars') | |||
input_field_names = ['chars'] | |||
if self.bigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])], | |||
field_name='chars', new_field_name='bigrams') | |||
input_field_names.append('bigrams') | |||
if self.trigrams: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in | |||
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)], | |||
field_name='chars', new_field_name='trigrams') | |||
input_field_names.append('trigrams') | |||
_indexize(data_bundle, input_field_names, 'target') | |||
input_fields = ['target', 'seq_len'] + input_field_names | |||
target_fields = ['target', 'seq_len'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('chars') | |||
data_bundle.set_input(*input_fields, *target_fields) | |||
return data_bundle | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
r""" | |||
:param str paths: | |||
:return: | |||
""" | |||
if self.dataset_name is None and paths is None: | |||
raise RuntimeError( | |||
"You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.") | |||
if self.dataset_name is not None and paths is not None: | |||
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") | |||
data_bundle = CWSLoader(self.dataset_name).load(paths) | |||
return self.process(data_bundle) |
@@ -0,0 +1,545 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"CNXNLIBertPipe", | |||
"BQCorpusBertPipe", | |||
"LCQMCBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
"LCQMCPipe", | |||
"CNXNLIPipe", | |||
"BQCorpusPipe", | |||
"RenamePipe", | |||
"GranularizePipe", | |||
"MachingTruncatePipe", | |||
] | |||
import warnings | |||
from .pipe import Pipe | |||
from .utils import get_tokenizer | |||
from ..data_bundle import DataBundle | |||
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, \ | |||
LCQMCLoader | |||
# from ...core._logger import log | |||
# from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
class MatchingBertPipe(Pipe): | |||
r""" | |||
Matching任务的Bert pipe,输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target", "words", "seq_len" | |||
"The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", 10 | |||
"This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", 5 | |||
"...", "...", ., "[...]", . | |||
words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 | |||
words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, | |||
如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+------------+------------+--------+-------+---------+ | |||
| field_names | raw_words1 | raw_words2 | target | words | seq_len | | |||
+-------------+------------+------------+--------+-------+---------+ | |||
| is_input | False | False | False | True | True | | |||
| is_target | False | False | True | False | False | | |||
| ignore_type | | | False | False | False | | |||
| pad_value | | | 0 | 0 | 0 | | |||
+-------------+------------+------------+--------+-------+---------+ | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
r""" | |||
:param bool lower: 是否将word小写化。 | |||
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 | |||
""" | |||
super().__init__() | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
r""" | |||
:param DataBundle data_bundle: DataBundle. | |||
:param list field_names: List[str], 需要tokenize的field名称 | |||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | |||
:return: 输入的DataBundle对象 | |||
""" | |||
for name, dataset in data_bundle.iter_datasets(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
r""" | |||
输入的data_bundle中的dataset需要具有以下结构: | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" | |||
"...","..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field('target'): | |||
dataset.drop(lambda x: x['target'] == '-') | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.copy_field('raw_words1', 'words1', ) | |||
dataset.copy_field('raw_words2', 'words2', ) | |||
if self.lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset['words1'].lower() | |||
dataset['words2'].lower() | |||
data_bundle = self._tokenize(data_bundle, ['words1', 'words2'], | |||
['words1', 'words2']) | |||
# concat两个words | |||
def concat(ins): | |||
words0 = ins['words1'] | |||
words1 = ins['words2'] | |||
words = words0 + ['[SEP]'] + words1 | |||
return words | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.apply(concat, new_field_name='words') | |||
dataset.delete_field('words1') | |||
dataset.delete_field('words2') | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name='words', | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name='words') | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||
field_name='target', | |||
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() | |||
if ('train' not in name) and (ds.has_field('target'))] | |||
) | |||
if len(target_vocab._no_create_word) > 0: | |||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | |||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | |||
f"data set but not in train data set!." | |||
warnings.warn(warn_msg) | |||
print(warn_msg) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field('target')] | |||
target_vocab.index_dataset(*has_target_datasets, field_name='target') | |||
data_bundle.set_vocab(word_vocab, 'words') | |||
data_bundle.set_vocab(target_vocab, 'target') | |||
input_fields = ['words', 'seq_len'] | |||
target_fields = ['target'] | |||
for name, dataset in data_bundle.iter_datasets(): | |||
dataset.add_seq_len('words') | |||
dataset.set_input(*input_fields) | |||
for fields in target_fields: | |||
if dataset.has_field(fields): | |||
dataset.set_input(fields) | |||
return data_bundle | |||
class RTEBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = RTELoader().load(paths) | |||
return self.process(data_bundle) | |||
class SNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = SNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class QuoraBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths): | |||
data_bundle = QuoraLoader().load(paths) | |||
return self.process(data_bundle) | |||
class QNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = QNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class MNLIBertPipe(MatchingBertPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = MNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class MatchingPipe(Pipe): | |||
r""" | |||
Matching任务的Pipe。输出的DataSet将包含以下的field | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target", "words1", "words2", "seq_len1", "seq_len2" | |||
"The new rights are...", "Everyone really likes..", 1, "[2, 3, 4, 5, ...]", "[10, 20, 6]", 10, 13 | |||
"This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", "[2, 7, ...]", 6, 7 | |||
"...", "...", ., "[...]", "[...]", ., . | |||
words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target | |||
和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 | |||
的形参名进行传参)。 | |||
dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: | |||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||
| field_names | raw_words1 | raw_words2 | target | words1 | words2 | seq_len1 | seq_len2 | | |||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||
| is_input | False | False | False | True | True | True | True | | |||
| is_target | False | False | True | False | False | False | False | | |||
| ignore_type | | | False | False | False | False | False | | |||
| pad_value | | | 0 | 0 | 0 | 0 | 0 | | |||
+-------------+------------+------------+--------+--------+--------+----------+----------+ | |||
""" | |||
def __init__(self, lower=False, tokenizer: str = 'raw'): | |||
r""" | |||
:param bool lower: 是否将所有raw_words转为小写。 | |||
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 | |||
""" | |||
super().__init__() | |||
self.lower = bool(lower) | |||
self.tokenizer = get_tokenizer(tokenize_method=tokenizer) | |||
def _tokenize(self, data_bundle, field_names, new_field_names): | |||
r""" | |||
:param ~fastNLP.DataBundle data_bundle: DataBundle. | |||
:param list field_names: List[str], 需要tokenize的field名称 | |||
:param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。 | |||
:return: 输入的DataBundle对象 | |||
""" | |||
for name, dataset in data_bundle.iter_datasets(): | |||
for field_name, new_field_name in zip(field_names, new_field_names): | |||
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, | |||
new_field_name=new_field_name) | |||
return data_bundle | |||
def process(self, data_bundle): | |||
r""" | |||
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 | |||
.. csv-table:: | |||
:header: "raw_words1", "raw_words2", "target" | |||
"The new rights are...", "Everyone really likes..", "entailment" | |||
"This site includes a...", "The Government Executive...", "not_entailment" | |||
"...", "..." | |||
:param ~fastNLP.DataBundle data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 | |||
:return: data_bundle | |||
""" | |||
data_bundle = self._tokenize(data_bundle, ['raw_words1', 'raw_words2'], | |||
['words1', 'words2']) | |||
for dataset in data_bundle.datasets.values(): | |||
if dataset.has_field('target'): | |||
dataset.drop(lambda x: x['target'] == '-') | |||
if self.lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset['words1'].lower() | |||
dataset['words2'].lower() | |||
word_vocab = Vocabulary() | |||
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], | |||
field_name=['words1', 'words2'], | |||
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if | |||
'train' not in name]) | |||
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=['words1', 'words2']) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||
field_name='target', | |||
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() | |||
if ('train' not in name) and (ds.has_field('target'))] | |||
) | |||
if len(target_vocab._no_create_word) > 0: | |||
warn_msg = f"There are {len(target_vocab._no_create_word)} target labels" \ | |||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | |||
f"data set but not in train data set!." | |||
warnings.warn(warn_msg) | |||
print(warn_msg) | |||
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if | |||
dataset.has_field('target')] | |||
target_vocab.index_dataset(*has_target_datasets, field_name='target') | |||
data_bundle.set_vocab(word_vocab, 'words1') | |||
data_bundle.set_vocab(target_vocab, 'target') | |||
input_fields = ['words1', 'words2', 'seq_len1', 'seq_len2'] | |||
target_fields = ['target'] | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.add_seq_len('words1', 'seq_len1') | |||
dataset.add_seq_len('words2', 'seq_len2') | |||
dataset.set_input(*input_fields) | |||
for fields in target_fields: | |||
if dataset.has_field(fields): | |||
dataset.set_input(fields) | |||
return data_bundle | |||
class RTEPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = RTELoader().load(paths) | |||
return self.process(data_bundle) | |||
class SNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = SNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class QuoraPipe(MatchingPipe): | |||
def process_from_file(self, paths): | |||
data_bundle = QuoraLoader().load(paths) | |||
return self.process(data_bundle) | |||
class QNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = QNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class MNLIPipe(MatchingPipe): | |||
def process_from_file(self, paths=None): | |||
data_bundle = MNLILoader().load(paths) | |||
return self.process(data_bundle) | |||
class LCQMCPipe(MatchingPipe): | |||
def __init__(self, tokenizer='cn=char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def process_from_file(self, paths=None): | |||
data_bundle = LCQMCLoader().load(paths) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
return data_bundle | |||
class CNXNLIPipe(MatchingPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def process_from_file(self, paths=None): | |||
data_bundle = CNXNLILoader().load(paths) | |||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) # 使中文数据的field | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
return data_bundle | |||
class BQCorpusPipe(MatchingPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def process_from_file(self, paths=None): | |||
data_bundle = BQCorpusLoader().load(paths) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = RenamePipe().process(data_bundle) | |||
return data_bundle | |||
class RenamePipe(Pipe): | |||
def __init__(self, task='cn-nli'): | |||
super().__init__() | |||
self.task = task | |||
def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset | |||
if (self.task == 'cn-nli'): | |||
for name, dataset in data_bundle.datasets.items(): | |||
if (dataset.has_field('raw_chars1')): | |||
dataset.rename_field('raw_chars1', 'raw_words1') # RAW_CHARS->RAW_WORDS | |||
dataset.rename_field('raw_chars2', 'raw_words2') | |||
elif (dataset.has_field('words1')): | |||
dataset.rename_field('words1', 'chars1') # WORDS->CHARS | |||
dataset.rename_field('words2', 'chars2') | |||
dataset.rename_field('raw_words1', 'raw_chars1') | |||
dataset.rename_field('raw_words2', 'raw_chars2') | |||
else: | |||
raise RuntimeError( | |||
"field name of dataset is not qualified. It should have ether RAW_CHARS or WORDS") | |||
elif (self.task == 'cn-nli-bert'): | |||
for name, dataset in data_bundle.datasets.items(): | |||
if (dataset.has_field('raw_chars1')): | |||
dataset.rename_field('raw_chars1', 'raw_words1') # RAW_CHARS->RAW_WORDS | |||
dataset.rename_field('raw_chars2', 'raw_words2') | |||
elif (dataset.has_field('raw_words1')): | |||
dataset.rename_field('raw_words1', 'raw_chars1') | |||
dataset.rename_field('raw_words2', 'raw_chars2') | |||
dataset.rename_field('words', 'chars') | |||
else: | |||
raise RuntimeError( | |||
"field name of dataset is not qualified. It should have ether RAW_CHARS or RAW_WORDS" | |||
) | |||
else: | |||
raise RuntimeError( | |||
"Only support task='cn-nli' or 'cn-nli-bert'" | |||
) | |||
return data_bundle | |||
class GranularizePipe(Pipe): | |||
def __init__(self, task=None): | |||
super().__init__() | |||
self.task = task | |||
def _granularize(self, data_bundle, tag_map): | |||
r""" | |||
该函数对data_bundle中'target'列中的内容进行转换。 | |||
:param data_bundle: | |||
:param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, | |||
且将"1"认为是第0类。 | |||
:return: 传入的data_bundle | |||
""" | |||
for name in list(data_bundle.datasets.keys()): | |||
dataset = data_bundle.get_dataset(name) | |||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', | |||
new_field_name='target') | |||
dataset.drop(lambda ins: ins['target'] == -100) | |||
data_bundle.set_dataset(dataset, name) | |||
return data_bundle | |||
def process(self, data_bundle: DataBundle): | |||
task_tag_dict = { | |||
'XNLI': {'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} | |||
} | |||
if self.task in task_tag_dict: | |||
data_bundle = self._granularize(data_bundle=data_bundle, tag_map=task_tag_dict[self.task]) | |||
else: | |||
raise RuntimeError(f"Only support {task_tag_dict.keys()} task_tag_map.") | |||
return data_bundle | |||
class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len | |||
def __init__(self): | |||
super().__init__() | |||
def process(self, data_bundle: DataBundle): | |||
for name, dataset in data_bundle.datasets.items(): | |||
pass | |||
return None | |||
class LCQMCBertPipe(MatchingBertPipe): | |||
def __init__(self, tokenizer='cn=char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def process_from_file(self, paths=None): | |||
data_bundle = LCQMCLoader().load(paths) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = TruncateBertPipe(task='cn').process(data_bundle) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
return data_bundle | |||
class BQCorpusBertPipe(MatchingBertPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def process_from_file(self, paths=None): | |||
data_bundle = BQCorpusLoader().load(paths) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = TruncateBertPipe(task='cn').process(data_bundle) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
return data_bundle | |||
class CNXNLIBertPipe(MatchingBertPipe): | |||
def __init__(self, tokenizer='cn-char'): | |||
super().__init__(tokenizer=tokenizer) | |||
def process_from_file(self, paths=None): | |||
data_bundle = CNXNLILoader().load(paths) | |||
data_bundle = GranularizePipe(task='XNLI').process(data_bundle) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
data_bundle = self.process(data_bundle) | |||
data_bundle = TruncateBertPipe(task='cn').process(data_bundle) | |||
data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) | |||
return data_bundle | |||
class TruncateBertPipe(Pipe): | |||
def __init__(self, task='cn'): | |||
super().__init__() | |||
self.task = task | |||
def _truncate(self, sentence_index:list, sep_index_vocab): | |||
# 根据[SEP]在vocab中的index,找到[SEP]在dataset的field['words']中的index | |||
sep_index_words = sentence_index.index(sep_index_vocab) | |||
words_before_sep = sentence_index[:sep_index_words] | |||
words_after_sep = sentence_index[sep_index_words:] # 注意此部分包括了[SEP] | |||
if self.task == 'cn': | |||
# 中文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过250 | |||
words_before_sep = words_before_sep[:250] | |||
words_after_sep = words_after_sep[:250] | |||
elif self.task == 'en': | |||
# 英文任务将Instance['words']中在[SEP]前后的文本分别截至长度不超过215 | |||
words_before_sep = words_before_sep[:215] | |||
words_after_sep = words_after_sep[:215] | |||
else: | |||
raise RuntimeError("Only support 'cn' or 'en' task.") | |||
return words_before_sep + words_after_sep | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
for name in data_bundle.datasets.keys(): | |||
dataset = data_bundle.get_dataset(name) | |||
sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') | |||
dataset.apply_field(lambda sent_index: self._truncate(sentence_index=sent_index, sep_index_vocab=sep_index_vocab), field_name='words', new_field_name='words') | |||
# truncate之后需要更新seq_len | |||
dataset.add_seq_len(field_name='words') | |||
return data_bundle | |||
@@ -0,0 +1,41 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"Pipe", | |||
] | |||
from fastNLP.io.data_bundle import DataBundle | |||
class Pipe: | |||
r""" | |||
Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe | |||
文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 | |||
一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或raw_chars进行tokenize以切分成不同的词或字; | |||
(2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target列建立词表并将target列转为index; | |||
Pipe中提供了两个方法 | |||
-process()函数,输入为DataBundle | |||
-process_from_file()函数,输入为对应Loader的load函数可接受的类型。 | |||
""" | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
r""" | |||
对输入的DataBundle进行处理,然后返回该DataBundle。 | |||
:param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 | |||
:return: DataBundle | |||
""" | |||
raise NotImplementedError | |||
def process_from_file(self, paths: str) -> DataBundle: | |||
r""" | |||
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` | |||
:param str paths: | |||
:return: DataBundle | |||
""" | |||
raise NotImplementedError |
@@ -0,0 +1,144 @@ | |||
r""" | |||
本文件中的Pipe主要用于处理问答任务的数据。 | |||
""" | |||
from copy import deepcopy | |||
from .pipe import Pipe | |||
from fastNLP.io.data_bundle import DataBundle | |||
from ..loader.qa import CMRC2018Loader | |||
from .utils import get_tokenizer | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.vocabulary import Vocabulary | |||
__all__ = ['CMRC2018BertPipe'] | |||
def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): | |||
r""" | |||
处理data_bundle中的DataSet,将context与question按照character进行tokenize,然后使用[SEP]将两者连接起来。 | |||
会新增field: context_len(int), raw_words(list[str]), target_start(int), target_end(int)其中target_start | |||
与target_end是与raw_chars等长的。其中target_start和target_end是前闭后闭的区间。 | |||
:param DataBundle data_bundle: 类似["a", "b", "[SEP]", "c", ] | |||
:return: | |||
""" | |||
tokenizer = get_tokenizer('cn-char', lang='cn') | |||
for name in list(data_bundle.datasets.keys()): | |||
ds = data_bundle.get_dataset(name) | |||
data_bundle.delete_dataset(name) | |||
new_ds = DataSet() | |||
for ins in ds: | |||
new_ins = deepcopy(ins) | |||
context = ins['context'] | |||
question = ins['question'] | |||
cnt_lst = tokenizer(context) | |||
q_lst = tokenizer(question) | |||
answer_start = -1 | |||
if len(cnt_lst) + len(q_lst) + 3 > max_len: # 预留开头的[CLS]和[SEP]和中间的[sep] | |||
if 'answer_starts' in ins and 'answers' in ins: | |||
answer_start = int(ins['answer_starts'][0]) | |||
answer = ins['answers'][0] | |||
answer_end = answer_start + len(answer) | |||
if answer_end > max_len - 3 - len(q_lst): | |||
span_start = answer_end + 3 + len(q_lst) - max_len | |||
span_end = answer_end | |||
else: | |||
span_start = 0 | |||
span_end = max_len - 3 - len(q_lst) | |||
cnt_lst = cnt_lst[span_start:span_end] | |||
answer_start = int(ins['answer_starts'][0]) | |||
answer_start -= span_start | |||
answer_end = answer_start + len(ins['answers'][0]) | |||
else: | |||
cnt_lst = cnt_lst[:max_len - len(q_lst) - 3] | |||
else: | |||
if 'answer_starts' in ins and 'answers' in ins: | |||
answer_start = int(ins['answer_starts'][0]) | |||
answer_end = answer_start + len(ins['answers'][0]) | |||
tokens = cnt_lst + ['[SEP]'] + q_lst | |||
new_ins['context_len'] = len(cnt_lst) | |||
new_ins[concat_field_name] = tokens | |||
if answer_start != -1: | |||
new_ins['target_start'] = answer_start | |||
new_ins['target_end'] = answer_end - 1 | |||
new_ds.append(new_ins) | |||
data_bundle.set_dataset(new_ds, name) | |||
return data_bundle | |||
class CMRC2018BertPipe(Pipe): | |||
r""" | |||
处理之后的DataSet将新增以下的field(传入的field仍然保留) | |||
.. csv-table:: | |||
:header: "context_len", "raw_chars", "target_start", "target_end", "chars" | |||
492, ['范', '廷', '颂... ], 30, 34, "[21, 25, ...]" | |||
491, ['范', '廷', '颂... ], 41, 61, "[21, 25, ...]" | |||
".", "...", "...","...", "..." | |||
raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index | |||
(闭区间);context_len指示的是words列中context的长度。 | |||
其中各列的meta信息如下: | |||
.. code:: | |||
+-------------+-------------+-----------+--------------+------------+-------+---------+ | |||
| field_names | context_len | raw_chars | target_start | target_end | chars | answers | | |||
+-------------+-------------+-----------+--------------+------------+-------+---------| | |||
| is_input | False | False | False | False | True | False | | |||
| is_target | True | True | True | True | False | True | | |||
| ignore_type | False | True | False | False | False | True | | |||
| pad_value | 0 | 0 | 0 | 0 | 0 | 0 | | |||
+-------------+-------------+-----------+--------------+------------+-------+---------+ | |||
""" | |||
def __init__(self, max_len=510): | |||
super().__init__() | |||
self.max_len = max_len | |||
def process(self, data_bundle: DataBundle) -> DataBundle: | |||
r""" | |||
传入的DataSet应该具备以下的field | |||
.. csv-table:: | |||
:header:"title", "context", "question", "answers", "answer_starts", "id" | |||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" | |||
"范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" | |||
"...", "...", "...","...", ".", "..." | |||
:param data_bundle: | |||
:return: | |||
""" | |||
data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') | |||
src_vocab = Vocabulary() | |||
src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||
field_name='raw_chars', | |||
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() | |||
if 'train' not in name] | |||
) | |||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name='raw_chars', new_field_name='chars') | |||
data_bundle.set_vocab(src_vocab, 'chars') | |||
data_bundle.set_input('chars', 'raw_chars', 'answers', 'target_start', 'target_end', 'context_len') | |||
return data_bundle | |||
def process_from_file(self, paths=None) -> DataBundle: | |||
data_bundle = CMRC2018Loader().load(paths) | |||
return self.process(data_bundle) |
@@ -0,0 +1,196 @@ | |||
r"""undocumented""" | |||
import os | |||
import numpy as np | |||
from .pipe import Pipe | |||
from .utils import _drop_empty_instance | |||
from ..loader.summarization import ExtCNNDMLoader | |||
from ..data_bundle import DataBundle | |||
# from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
# from ...core._logger import log | |||
WORD_PAD = "[PAD]" | |||
WORD_UNK = "[UNK]" | |||
DOMAIN_UNK = "X" | |||
TAG_UNK = "X" | |||
class ExtCNNDMPipe(Pipe): | |||
r""" | |||
对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" | |||
""" | |||
def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False): | |||
r""" | |||
:param vocab_size: int, 词表大小 | |||
:param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 | |||
:param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 | |||
:param vocab_path: str, 外部词表路径 | |||
:param domain: bool, 是否需要建立domain词表 | |||
""" | |||
self.vocab_size = vocab_size | |||
self.vocab_path = vocab_path | |||
self.sent_max_len = sent_max_len | |||
self.doc_max_timesteps = doc_max_timesteps | |||
self.domain = domain | |||
def process(self, data_bundle: DataBundle): | |||
r""" | |||
传入的DataSet应该具备如下的结构 | |||
.. csv-table:: | |||
:header: "text", "summary", "label", "publication" | |||
["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" | |||
["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" | |||
["..."], ["..."], [], "cnndm" | |||
:param data_bundle: | |||
:return: 处理得到的数据包括 | |||
.. csv-table:: | |||
:header: "text_wd", "words", "seq_len", "target" | |||
[["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] | |||
[["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0] | |||
[[""],...,[""]], [[],...,[]], [], [] | |||
""" | |||
if self.vocab_path is None: | |||
error_msg = 'vocab file is not defined!' | |||
print(error_msg) | |||
raise RuntimeError(error_msg) | |||
data_bundle.apply(lambda x: _lower_text(x['text']), new_field_name='text') | |||
data_bundle.apply(lambda x: _lower_text(x['summary']), new_field_name='summary') | |||
data_bundle.apply(lambda x: _split_list(x['text']), new_field_name='text_wd') | |||
data_bundle.apply(lambda x: _convert_label(x["label"], len(x["text"])), new_field_name='target') | |||
data_bundle.apply(lambda x: _pad_sent(x["text_wd"], self.sent_max_len), new_field_name='words') | |||
# db.apply(lambda x: _token_mask(x["text_wd"], self.sent_max_len), new_field_name="pad_token_mask") | |||
# pad document | |||
data_bundle.apply(lambda x: _pad_doc(x['words'], self.sent_max_len, self.doc_max_timesteps), new_field_name='words') | |||
data_bundle.apply(lambda x: _sent_mask(x['words'], self.doc_max_timesteps), new_field_name='seq_len') | |||
data_bundle.apply(lambda x: _pad_label(x['target'], self.doc_max_timesteps), new_field_name='target') | |||
data_bundle = _drop_empty_instance(data_bundle, "label") | |||
# set input and target | |||
data_bundle.set_input('words', 'seq_len', 'target', 'seq_len') | |||
# print("[INFO] Load existing vocab from %s!" % self.vocab_path) | |||
word_list = [] | |||
with open(self.vocab_path, 'r', encoding='utf8') as vocab_f: | |||
cnt = 2 # pad and unk | |||
for line in vocab_f: | |||
pieces = line.split("\t") | |||
word_list.append(pieces[0]) | |||
cnt += 1 | |||
if cnt > self.vocab_size: | |||
break | |||
vocabs = Vocabulary(max_size=self.vocab_size, padding=WORD_PAD, unknown=WORD_UNK) | |||
vocabs.add_word_lst(word_list) | |||
vocabs.build_vocab() | |||
data_bundle.set_vocab(vocabs, "vocab") | |||
if self.domain is True: | |||
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK) | |||
domaindict.from_dataset(data_bundle.get_dataset("train"), field_name="publication") | |||
data_bundle.set_vocab(domaindict, "domain") | |||
return data_bundle | |||
def process_from_file(self, paths=None): | |||
r""" | |||
:param paths: dict or string | |||
:return: DataBundle | |||
""" | |||
loader = ExtCNNDMLoader() | |||
if self.vocab_path is None: | |||
if paths is None: | |||
paths = loader.download() | |||
if not os.path.isdir(paths): | |||
error_msg = 'vocab file is not defined!' | |||
print(error_msg) | |||
raise RuntimeError(error_msg) | |||
self.vocab_path = os.path.join(paths, 'vocab') | |||
db = loader.load(paths=paths) | |||
db = self.process(db) | |||
for ds in db.datasets.values(): | |||
db.get_vocab("vocab").index_dataset(ds, field_name='words', new_field_name='words') | |||
return db | |||
def _lower_text(text_list): | |||
return [text.lower() for text in text_list] | |||
def _split_list(text_list): | |||
return [text.split() for text in text_list] | |||
def _convert_label(label, sent_len): | |||
np_label = np.zeros(sent_len, dtype=int) | |||
if label != []: | |||
np_label[np.array(label)] = 1 | |||
return np_label.tolist() | |||
def _pad_sent(text_wd, sent_max_len): | |||
pad_text_wd = [] | |||
for sent_wd in text_wd: | |||
if len(sent_wd) < sent_max_len: | |||
pad_num = sent_max_len - len(sent_wd) | |||
sent_wd.extend([WORD_PAD] * pad_num) | |||
else: | |||
sent_wd = sent_wd[:sent_max_len] | |||
pad_text_wd.append(sent_wd) | |||
return pad_text_wd | |||
def _token_mask(text_wd, sent_max_len): | |||
token_mask_list = [] | |||
for sent_wd in text_wd: | |||
token_num = len(sent_wd) | |||
if token_num < sent_max_len: | |||
mask = [1] * token_num + [0] * (sent_max_len - token_num) | |||
else: | |||
mask = [1] * sent_max_len | |||
token_mask_list.append(mask) | |||
return token_mask_list | |||
def _pad_label(label, doc_max_timesteps): | |||
text_len = len(label) | |||
if text_len < doc_max_timesteps: | |||
pad_label = label + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_label = label[:doc_max_timesteps] | |||
return pad_label | |||
def _pad_doc(text_wd, sent_max_len, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
padding = [WORD_PAD] * sent_max_len | |||
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len) | |||
else: | |||
pad_text = text_wd[:doc_max_timesteps] | |||
return pad_text | |||
def _sent_mask(text_wd, doc_max_timesteps): | |||
text_len = len(text_wd) | |||
if text_len < doc_max_timesteps: | |||
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len) | |||
else: | |||
sent_mask = [1] * doc_max_timesteps | |||
return sent_mask | |||
@@ -0,0 +1,224 @@ | |||
r"""undocumented""" | |||
__all__ = [ | |||
"iob2", | |||
"iob2bioes", | |||
"get_tokenizer", | |||
] | |||
from typing import List | |||
import warnings | |||
# from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
# from ...core._logger import log | |||
from pkg_resources import parse_version | |||
def iob2(tags: List[str]) -> List[str]: | |||
r""" | |||
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 | |||
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format | |||
:param tags: 需要转换的tags | |||
""" | |||
for i, tag in enumerate(tags): | |||
if tag == "O": | |||
continue | |||
split = tag.split("-") | |||
if len(split) != 2 or split[0] not in ["I", "B"]: | |||
raise TypeError("The encoding schema is not a valid IOB type.") | |||
if split[0] == "B": | |||
continue | |||
elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
elif tags[i - 1][1:] == tag[1:]: | |||
continue | |||
else: # conversion IOB1 to IOB2 | |||
tags[i] = "B" + tag[1:] | |||
return tags | |||
def iob2bioes(tags: List[str]) -> List[str]: | |||
r""" | |||
将iob的tag转换为bioes编码 | |||
:param tags: | |||
:return: | |||
""" | |||
new_tags = [] | |||
for i, tag in enumerate(tags): | |||
if tag == 'O': | |||
new_tags.append(tag) | |||
else: | |||
split = tag.split('-')[0] | |||
if split == 'B': | |||
if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('B-', 'S-')) | |||
elif split == 'I': | |||
if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I': | |||
new_tags.append(tag) | |||
else: | |||
new_tags.append(tag.replace('I-', 'E-')) | |||
else: | |||
raise TypeError("Invalid IOB format.") | |||
return new_tags | |||
def get_tokenizer(tokenize_method: str, lang='en'): | |||
r""" | |||
:param str tokenize_method: 获取tokenzier方法 | |||
:param str lang: 语言,当前仅支持en | |||
:return: 返回tokenize函数 | |||
""" | |||
tokenizer_dict = { | |||
'spacy': None, | |||
'raw': _raw_split, | |||
'cn-char': _cn_char_split, | |||
} | |||
if tokenize_method == 'spacy': | |||
import spacy | |||
spacy.prefer_gpu() | |||
if lang != 'en': | |||
raise RuntimeError("Spacy only supports en right right.") | |||
if parse_version(spacy.__version__) >= parse_version('3.0'): | |||
en = spacy.load('en_core_web_sm') | |||
else: | |||
en = spacy.load(lang) | |||
tokenizer = lambda x: [w.text for w in en.tokenizer(x)] | |||
elif tokenize_method in tokenizer_dict: | |||
tokenizer = tokenizer_dict[tokenize_method] | |||
else: | |||
raise RuntimeError(f"Only support {tokenizer_dict.keys()} tokenizer.") | |||
return tokenizer | |||
def _cn_char_split(sent): | |||
return [chars for chars in sent] | |||
def _raw_split(sent): | |||
return sent.split() | |||
def _indexize(data_bundle, input_field_names='words', target_field_names='target'): | |||
r""" | |||
在dataset中的field_name列建立词表,'target'列建立词表,并把词表加入到data_bundle中。 | |||
:param ~fastNLP.DataBundle data_bundle: | |||
:param: str,list input_field_names: | |||
:param: str,list target_field_names: 这一列的vocabulary没有unknown和padding | |||
:return: | |||
""" | |||
if isinstance(input_field_names, str): | |||
input_field_names = [input_field_names] | |||
if isinstance(target_field_names, str): | |||
target_field_names = [target_field_names] | |||
for input_field_name in input_field_names: | |||
src_vocab = Vocabulary() | |||
src_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||
field_name=input_field_name, | |||
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() | |||
if ('train' not in name) and (ds.has_field(input_field_name))] | |||
) | |||
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) | |||
data_bundle.set_vocab(src_vocab, input_field_name) | |||
for target_field_name in target_field_names: | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) | |||
tgt_vocab.from_dataset(*[ds for name, ds in data_bundle.iter_datasets() if 'train' in name], | |||
field_name=target_field_name, | |||
no_create_entry_dataset=[ds for name, ds in data_bundle.iter_datasets() | |||
if ('train' not in name) and (ds.has_field(target_field_name))] | |||
) | |||
if len(tgt_vocab._no_create_word) > 0: | |||
warn_msg = f"There are {len(tgt_vocab._no_create_word)} `{target_field_name}` labels" \ | |||
f" in {[name for name in data_bundle.datasets.keys() if 'train' not in name]} " \ | |||
f"data set but not in train data set!.\n" \ | |||
f"These label(s) are {tgt_vocab._no_create_word}" | |||
warnings.warn(warn_msg) | |||
# log.warning(warn_msg) | |||
tgt_vocab.index_dataset(*[ds for ds in data_bundle.datasets.values() if ds.has_field(target_field_name)], field_name=target_field_name) | |||
data_bundle.set_vocab(tgt_vocab, target_field_name) | |||
return data_bundle | |||
def _add_words_field(data_bundle, lower=False): | |||
r""" | |||
给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化 | |||
:param data_bundle: | |||
:param bool lower:是否要小写化 | |||
:return: 传入的DataBundle | |||
""" | |||
data_bundle.copy_field(field_name='raw_words', new_field_name='words', ignore_miss_dataset=True) | |||
if lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset['words'].lower() | |||
return data_bundle | |||
def _add_chars_field(data_bundle, lower=False): | |||
r""" | |||
给data_bundle中的dataset中复制一列chars. 并根据lower参数判断是否需要小写化 | |||
:param data_bundle: | |||
:param bool lower:是否要小写化 | |||
:return: 传入的DataBundle | |||
""" | |||
data_bundle.copy_field(field_name='raw_chars', new_field_name='chars', ignore_miss_dataset=True) | |||
if lower: | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset['chars'].lower() | |||
return data_bundle | |||
def _drop_empty_instance(data_bundle, field_name): | |||
r""" | |||
删除data_bundle的DataSet中存在的某个field为空的情况 | |||
:param ~fastNLP.DataBundle data_bundle: | |||
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 | |||
:return: 传入的DataBundle | |||
""" | |||
def empty_instance(ins): | |||
if field_name: | |||
field_value = ins[field_name] | |||
if field_value in ((), {}, [], ''): | |||
return True | |||
return False | |||
for _, field_value in ins.items(): | |||
if field_value in ((), {}, [], ''): | |||
return True | |||
return False | |||
for name, dataset in data_bundle.datasets.items(): | |||
dataset.drop(empty_instance) | |||
return data_bundle | |||
def _granularize(data_bundle, tag_map): | |||
r""" | |||
该函数对data_bundle中'target'列中的内容进行转换。 | |||
:param data_bundle: | |||
:param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance, | |||
且将"1"认为是第0类。 | |||
:return: 传入的data_bundle | |||
""" | |||
if tag_map is None: | |||
return data_bundle | |||
for name in list(data_bundle.datasets.keys()): | |||
dataset = data_bundle.get_dataset(name) | |||
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name='target', | |||
new_field_name='target') | |||
dataset.drop(lambda ins: ins['target'] == -100) | |||
data_bundle.set_dataset(dataset, name) | |||
return data_bundle |
@@ -0,0 +1,82 @@ | |||
r""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"check_loader_paths" | |||
] | |||
import os | |||
from pathlib import Path | |||
from typing import Union, Dict | |||
# from ..core import log | |||
def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||
r""" | |||
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: | |||
{ | |||
'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 | |||
'test': 'xxx' # 可能有,也可能没有 | |||
... | |||
} | |||
如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 | |||
:param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 | |||
中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 | |||
:return: | |||
""" | |||
if isinstance(paths, (str, Path)): | |||
paths = os.path.abspath(os.path.expanduser(paths)) | |||
if os.path.isfile(paths): | |||
return {'train': paths} | |||
elif os.path.isdir(paths): | |||
filenames = os.listdir(paths) | |||
filenames.sort() | |||
files = {} | |||
for filename in filenames: | |||
path_pair = None | |||
if 'train' in filename: | |||
path_pair = ('train', filename) | |||
if 'dev' in filename: | |||
if path_pair: | |||
raise Exception( | |||
"Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||
path_pair = ('dev', filename) | |||
if 'test' in filename: | |||
if path_pair: | |||
raise Exception( | |||
"Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0])) | |||
path_pair = ('test', filename) | |||
if path_pair: | |||
if path_pair[0] in files: | |||
raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the " | |||
f"filepath for `{path_pair[0]}`.") | |||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | |||
if 'train' not in files: | |||
raise KeyError(f"There is no train file in {paths}.") | |||
return files | |||
else: | |||
raise FileNotFoundError(f"{paths} is not a valid file path.") | |||
elif isinstance(paths, dict): | |||
if paths: | |||
if 'train' not in paths: | |||
raise KeyError("You have to include `train` in your dict.") | |||
for key, value in paths.items(): | |||
if isinstance(key, str) and isinstance(value, str): | |||
value = os.path.abspath(os.path.expanduser(value)) | |||
if not os.path.exists(value): | |||
raise TypeError(f"{value} is not a valid path.") | |||
paths[key] = value | |||
else: | |||
raise TypeError("All keys and values in paths should be str.") | |||
return paths | |||
else: | |||
raise ValueError("Empty paths is not allowed.") | |||
else: | |||
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") |
@@ -0,0 +1,9 @@ | |||
__all__ = [ | |||
"MixModule", | |||
"torch2paddle", | |||
"paddle2torch", | |||
"torch2jittor", | |||
"jittor2torch", | |||
] | |||
from .mix_modules import MixModule, torch2paddle, paddle2torch, torch2jittor, jittor2torch |
@@ -0,0 +1,10 @@ | |||
__all__ = [ | |||
"MixModule", | |||
"torch2paddle", | |||
"paddle2torch", | |||
"torch2jittor", | |||
"jittor2torch", | |||
] | |||
from .mix_module import MixModule | |||
from .utils import * |
@@ -0,0 +1,306 @@ | |||
import os | |||
import io | |||
import pickle | |||
from typing import Dict | |||
from collections import OrderedDict | |||
import numpy as np | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | |||
from fastNLP.core.utils.paddle_utils import paddle_to | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
from paddle.nn import Layer as PaddleLayer | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch.nn import Module as TorchModule, Parameter as TorchParameter | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
__all__ = [ | |||
"MixModule", | |||
] | |||
class MixModule: | |||
""" | |||
TODO: 支持不同的混合方式;添加state_dict的支持;如果参数里有List of Tensors该怎么处理; | |||
是否需要仿照Module那样在初始化的时候给各种模型分类 | |||
可以同时使用Torch和Paddle框架的混合模型 | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
pass | |||
def __call__(self, *args, **kwargs): | |||
return self.forward(*args, **kwargs) | |||
def named_parameters(self, prefix='', recurse: bool=True, backend=None): | |||
""" | |||
返回模型的名字和参数 | |||
:param prefix: 输出时在参数名前加上的前缀 | |||
:param recurse: 是否递归地输出参数 | |||
:param backend: `backend`=`None`时,将所有模型和张量的参数返回; | |||
`backend`=`torch`时,返回`torch`的参数; | |||
`backend`=`paddle`时,返回`paddle`的参数。 | |||
""" | |||
if backend is None: | |||
generator = self.attributes(TorchModule, TorchParameter, PaddleLayer) | |||
elif backend == "torch": | |||
generator = self.attributes(TorchModule, TorchParameter) | |||
elif backend == "paddle": | |||
generator = self.attributes(PaddleLayer) | |||
else: | |||
raise ValueError("Unknown backend parameter.") | |||
for name, value in generator: | |||
name = prefix + ('.' if prefix else '') + name | |||
if isinstance(value, TorchParameter): | |||
# 非Module/Layer类型,直接输出名字和值 | |||
yield name, value | |||
elif recurse: | |||
# 递归地调用named_parameters | |||
for name_r, value_r in value.named_parameters(name, recurse): | |||
yield name_r, value_r | |||
def parameters(self, recurse: bool = True, backend: str = None): | |||
""" | |||
返回模型的参数 | |||
:param recurse: | |||
:param backend: `backend`=`None`时,将所有模型和张量的参数返回; | |||
`backend`=`torch`时,返回`torch`的参数; | |||
`backend`=`paddle`时,返回`paddle`的参数。 | |||
""" | |||
for name, value in self.named_parameters(recurse=recurse, backend=backend): | |||
yield value | |||
def forward(self, *args, **kwargs): | |||
raise NotImplementedError | |||
def train_step(self, batch): | |||
raise NotImplementedError | |||
def test_step(self, batch): | |||
raise NotImplementedError | |||
def validate_step(self, batch): | |||
raise NotImplementedError | |||
def train(self): | |||
for name, value in self.attributes(TorchModule, PaddleLayer): | |||
value.train() | |||
def eval(self): | |||
for name, value in self.attributes(TorchModule, PaddleLayer): | |||
value.eval() | |||
def to(self, device): | |||
""" | |||
:param device: 设备名 | |||
""" | |||
# 有jittor的话 warning | |||
if device == "cpu": | |||
paddle_device = device | |||
elif device.startswith("cuda"): | |||
paddle_device = device.replace("cuda", "gpu") | |||
elif device.startswith("gpu"): | |||
paddle_device = device | |||
device = device.replace("gpu", "cuda") | |||
else: | |||
raise ValueError("Device value error") | |||
for name, value in self.attributes(TorchModule): | |||
# torch的to函数不影响Tensor | |||
vars(self)[name] = value.to(device) | |||
for name, value in self.attributes(TorchParameter): | |||
# Parameter在经过to函数后会变成Tensor类型 | |||
vars(self)[name] = TorchParameter(value.to(device), requires_grad=value.requires_grad) | |||
for name, value in self.attributes(PaddleLayer): | |||
vars(self)[name] = value.to(paddle_device) | |||
for name, value in self.attributes(paddle.Tensor): | |||
# paddle的to函数会影响到Tensor | |||
vars(self)[name] = paddle_to(value, paddle_device) | |||
return self | |||
def state_dict(self, backend: str = None) -> Dict: | |||
""" | |||
返回模型的state_dict。 | |||
NOTE: torch的destination参数会在将来删除,因此不提供destination参数 | |||
:param backend: `backend`=`None`时,将所有模型和张量的state dict返回; | |||
`backend`=`torch`时,返回`torch`的state dict; | |||
`backend`=`paddle`时,返回`paddle`的state dict。 | |||
""" | |||
if backend is None: | |||
generator = self.attributes(TorchModule, TorchParameter, PaddleLayer) | |||
elif backend == "torch": | |||
generator = self.attributes(TorchModule, TorchParameter) | |||
elif backend == "paddle": | |||
generator = self.attributes(PaddleLayer) | |||
else: | |||
raise ValueError(f"Unknown backend {backend}.") | |||
destination = OrderedDict() | |||
for name, value in generator: | |||
if value is None: | |||
continue | |||
if isinstance(value, TorchParameter): | |||
destination[name] = value | |||
else: | |||
# 不同框架state_dict函数的参数名和顺序不同 | |||
if isinstance(value, PaddleLayer): | |||
kwargs = { | |||
"structured_name_prefix": name + ".", | |||
} | |||
elif isinstance(value, TorchModule): | |||
kwargs = { | |||
"prefix": name + ".", | |||
} | |||
else: | |||
raise ValueError(f"Unknown item type {type(value)}") | |||
destination.update(value.state_dict(**kwargs)) | |||
return destination | |||
def save_state_dict_to_file(self, path: str): | |||
""" | |||
保存模型的state dict到path | |||
""" | |||
# TODO 设备限制 | |||
filename = os.path.basename(path) | |||
if filename == "": | |||
raise ValueError("Received empty filename.") | |||
dirname = os.path.dirname(path) | |||
if dirname and not os.path.exists(dirname): | |||
os.makedirs(dirname) | |||
protocol = 4 | |||
saved = {} | |||
paddle_dict = self.state_dict(backend="paddle") | |||
torch_dict = self.state_dict(backend="torch") | |||
# 保存paddle部分 | |||
# 调用paddle保存时的处理函数 | |||
paddle_saved_obj = paddle.framework.io._build_saved_state_dict(paddle_dict) | |||
paddle_saved_obj = paddle.fluid.io._unpack_saved_dict(paddle_saved_obj, protocol) | |||
# 将返回的dict保存 | |||
saved["paddle"] = paddle_saved_obj | |||
# 保存torch部分 | |||
buffer = io.BytesIO() | |||
torch.save(torch_dict, buffer) | |||
saved["torch"] = buffer.getvalue() | |||
# 保存 | |||
with open(path, "wb") as f: | |||
pickle.dump(saved, f, protocol) | |||
def load_state_dict_from_file(self, path: str): | |||
""" | |||
从 `path` 中加载保存的state dict | |||
""" | |||
state_dict = {} | |||
with open(path, "rb") as f: | |||
loaded = pickle.load(f) | |||
# 加载paddle的数据 | |||
paddle_loaded_obj = loaded["paddle"] | |||
paddle_load_result = paddle.fluid.io._pack_loaded_dict(paddle_loaded_obj) | |||
if "StructuredToParameterName@@" in paddle_load_result: | |||
for key in paddle_load_result["StructuredToParameterName@@"]: | |||
if isinstance(paddle_load_result[key], np.ndarray): | |||
paddle_load_result[key] = paddle.to_tensor(paddle_load_result[key]) | |||
state_dict.update(paddle_load_result) | |||
# 加载torch的数据 | |||
torch_loaded_obj = loaded["torch"] | |||
torch_bytes = io.BytesIO(torch_loaded_obj) | |||
torch_load_result = torch.load(torch_bytes) | |||
state_dict.update(torch_load_result) | |||
self.load_state_dict(state_dict) | |||
def load_state_dict(self, state_dict): | |||
""" | |||
从state dict中加载数据 | |||
""" | |||
missing_keys = [] | |||
unexpected_keys = [] | |||
error_msgs = [] | |||
new_state = {} | |||
local_state = self.state_dict() | |||
# 对字典内容按前缀进行归类 | |||
for key, value in state_dict.items(): | |||
splited = key.split(".", 1) | |||
if len(splited) == 1: | |||
# 没有前缀,实际上只有torch.nn.Parameter会进入这种情况 | |||
new_state[key] = value | |||
else: | |||
prefix, name = splited | |||
if prefix not in new_state: | |||
new_state[prefix] = {} | |||
new_state[prefix][name] = value | |||
for key, param in self.attributes(TorchModule, TorchParameter, PaddleLayer): | |||
if key in new_state: | |||
# 在传入的字典中找到了对应的值 | |||
input_param = new_state[key] | |||
if not isinstance(input_param, dict): | |||
# 且不是字典,即上述没有前缀的情况 | |||
# 按照torch.nn.Module._load_from_state_dict进行赋值 | |||
if not torch.overrides.is_tensor_like(input_param): | |||
error_msgs.append('While copying the parameter named "{}", ' | |||
'expected torch.Tensor or Tensor-like object from checkpoint but ' | |||
'received {}' | |||
.format(key, type(input_param))) | |||
continue | |||
# This is used to avoid copying uninitialized parameters into | |||
# non-lazy modules, since they dont have the hook to do the checks | |||
# in such case, it will error when accessing the .shape attribute. | |||
is_param_lazy = torch.nn.parameter.is_lazy(param) | |||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ | |||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1: | |||
input_param = input_param[0] | |||
if not is_param_lazy and input_param.shape != param.shape: | |||
# local shape should match the one in checkpoint | |||
error_msgs.append('size mismatch for {}: copying a param with shape {} from checkpoint, ' | |||
'the shape in current model is {}.' | |||
.format(key, input_param.shape, param.shape)) | |||
continue | |||
try: | |||
with torch.no_grad(): | |||
param.copy_(input_param) | |||
except Exception as ex: | |||
error_msgs.append('While copying the parameter named "{}", ' | |||
'whose dimensions in the model are {} and ' | |||
'whose dimensions in the checkpoint are {}, ' | |||
'an exception occurred : {}.' | |||
.format(key, param.size(), input_param.size(), ex.args)) | |||
else: | |||
# 否则在子模块中 | |||
if isinstance(param, TorchModule): | |||
# torch模块 | |||
# 由于paddle没有提供类似strict的参数,因此也不对torch作要求 | |||
param.load_state_dict(input_param, strict=False) | |||
elif isinstance(param, PaddleLayer): | |||
# paddle模块 | |||
param.load_dict(input_param) | |||
else: | |||
missing_keys.append(key) | |||
if len(error_msgs) > 0: | |||
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( | |||
self.__class__.__name__, "\n\t".join(error_msgs))) | |||
def attributes(self, *types): | |||
""" | |||
查找对应类型的成员 | |||
""" | |||
for name, value in vars(self).items(): | |||
if isinstance(value, types): | |||
yield name, value |
@@ -0,0 +1,229 @@ | |||
import warnings | |||
import os | |||
from typing import Any, Optional, Union | |||
import numpy as np | |||
from fastNLP.core.utils.utils import apply_to_collection | |||
from fastNLP.core.utils.paddle_utils import paddle_to | |||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
if _NEED_IMPORT_JITTOR: | |||
import jittor | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
__all__ = [ | |||
"paddle2torch", | |||
"torch2paddle", | |||
"jittor2torch", | |||
"torch2jittor", | |||
] | |||
def _paddle2torch(paddle_tensor: 'paddle.Tensor', target_device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor': | |||
""" | |||
将paddle tensor转换为torch tensor,并且能够保留梯度进行反向传播 | |||
:param paddle_tensor: 要转换的paddle张量 | |||
:param target_device: 是否将转换后的张量迁移到特定设备上,输入为`None`时,和输入的张量相同。 | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 转换后的torch张量 | |||
""" | |||
no_gradient = paddle_tensor.stop_gradient if no_gradient is None else no_gradient | |||
paddle_numpy = paddle_tensor.numpy() | |||
if not np.issubdtype(paddle_numpy.dtype, np.inexact): | |||
no_gradient = True | |||
if target_device is None: | |||
if paddle_tensor.place.is_gpu_place(): | |||
# paddlepaddle有两种Place,对应不同的device id获取方式 | |||
if hasattr(paddle_tensor.place, "gpu_device_id"): | |||
# paddle.fluid.core_avx.Place | |||
# 在gpu环境下创建张量的话,张量的place是这一类型 | |||
target_device = f"cuda:{paddle_tensor.place.gpu_device_id()}" | |||
else: | |||
# paddle.CUDAPlace | |||
target_device = f"cuda:{paddle_tensor.place.get_device_id()}" | |||
else: | |||
# TODO: 可能需要支持xpu等设备 | |||
target_device = "cpu" | |||
if not no_gradient: | |||
# 保持梯度,并保持反向传播 | |||
# torch.tensor会保留numpy数组的类型 | |||
torch_tensor = torch.tensor(paddle_numpy, requires_grad=True, device=target_device) | |||
hook = torch_tensor.register_hook( | |||
lambda grad: paddle.autograd.backward(paddle_tensor, paddle.to_tensor(grad.cpu().numpy())) | |||
) | |||
else: | |||
# 不保留梯度 | |||
torch_tensor = torch.tensor(paddle_numpy, requires_grad=False, device=target_device) | |||
return torch_tensor | |||
def _torch2paddle(torch_tensor: 'torch.Tensor', target_device: str = None, no_gradient: bool = None) -> 'paddle.Tensor': | |||
""" | |||
将torch tensor转换为paddle tensor,并且能够保留梯度进行反向传播。 | |||
:param torch_tensor: 要转换的torch张量 | |||
:param target_device: 是否将转换后的张量迁移到特定设备上,输入为`None`时,和输入的张量相同。 | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 转换后的paddle张量 | |||
""" | |||
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient | |||
if target_device is None: | |||
if torch_tensor.is_cuda: | |||
target_device = f"gpu:{torch_tensor.device.index}" | |||
else: | |||
target_device = "cpu" | |||
if not no_gradient: | |||
# 保持梯度并保持反向传播 | |||
# paddle的stop_gradient和torch的requires_grad表现是相反的 | |||
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False) | |||
hook = paddle_tensor.register_hook( | |||
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) | |||
) | |||
else: | |||
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True) | |||
paddle_tensor = paddle_to(paddle_tensor, target_device) | |||
return paddle_tensor | |||
def _jittor2torch(jittor_var: 'jittor.Var', target_device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor': | |||
""" | |||
将jittor Var转换为torch tensor,并且能够保留梯度进行反向传播 | |||
:param jittor_var: 要转换的jittor变量 | |||
:param target_device: 是否将转换后的张量迁移到特定设备上,输入为`None`时,根据jittor.flags.use_cuda决定。 | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 转换后的torch张量 | |||
""" | |||
# TODO: warning:无法保留梯度 | |||
# jittor的grad可以通过callback进行传递 | |||
# 如果outputs有_grad键,可以实现求导 | |||
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient | |||
if no_gradient == False: | |||
warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.") | |||
jittor_numpy = jittor_var.numpy() | |||
if not np.issubdtype(jittor_numpy.dtype, np.inexact): | |||
no_gradient = True | |||
if target_device is None: | |||
# jittor的设备分配是自动的 | |||
# 根据use_cuda判断 | |||
if jittor.flags.use_cuda: | |||
target_device = "cuda:0" | |||
else: | |||
target_device = "cpu" | |||
torch_tensor = torch.tensor(jittor_numpy, requires_grad=not no_gradient, device=target_device) | |||
return torch_tensor | |||
def _torch2jittor(torch_tensor: 'torch.Tensor', no_gradient: bool = None) -> 'jittor.Var': | |||
""" | |||
将torch tensor转换为jittor Var,并且能够保留梯度进行反向传播 | |||
:param torch_tensor: 要转换的torch张量 | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 转换后的jittor变量 | |||
""" | |||
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient | |||
if not no_gradient: | |||
# 保持梯度并保持反向传播 | |||
jittor_var = jittor.Var(torch_tensor.detach().numpy()) | |||
jittor_var.requires_grad = True | |||
hook = jittor_var.register_hook( | |||
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) | |||
) | |||
else: | |||
jittor_var = jittor.Var(torch_tensor.detach().numpy()) | |||
jittor_var.requires_grad = False | |||
return jittor_var | |||
def torch2paddle(torch_in: Any, target_device: str = None, no_gradient: bool = None) -> Any: | |||
""" | |||
递归地将输入中包含的torch张量转换为paddle张量 | |||
:param torch_in: 要转换的包含torch.Tensor类型的变量 | |||
:param target_device: 是否将转换后的张量迁移到特定设备上, | |||
输入为`None`时,和输入的张量相同, | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 将所有torch.Tensor转换为paddle.Tensor的张量 | |||
""" | |||
return apply_to_collection( | |||
torch_in, | |||
dtype=torch.Tensor, | |||
function=_torch2paddle, | |||
target_device=target_device, | |||
no_gradient=no_gradient, | |||
) | |||
def paddle2torch(paddle_in: Any, target_device: str = None, no_gradient: bool = None) -> Any: | |||
""" | |||
递归地将输入中包含的paddle张量转换为torch张量 | |||
:param torch_in: 要转换的包含paddle.Tensor类型的变量 | |||
:param target_device: 是否将转换后的张量迁移到特定设备上, | |||
输入为`None`时,和输入的张量相同, | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 将所有paddle.Tensor转换为torch.Tensor后的变量 | |||
""" | |||
return apply_to_collection( | |||
paddle_in, | |||
dtype=paddle.Tensor, | |||
function=_paddle2torch, | |||
target_device=target_device, | |||
no_gradient=no_gradient, | |||
) | |||
def jittor2torch(jittor_in: Any, target_device: str = None, no_gradient: bool = None) -> Any: | |||
""" | |||
递归地将输入中包含的jittor变量转换为torch张量 | |||
:param jittor_in: 要转换的jittor变量 | |||
:param target_device: 是否将转换后的张量迁移到特定设备上,输入为`None`时,默认为cuda:0。 | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 转换后的torch张量 | |||
""" | |||
return apply_to_collection( | |||
jittor_in, | |||
dtype=jittor.Var, | |||
function=_jittor2torch, | |||
target_device=target_device, | |||
no_gradient=no_gradient, | |||
) | |||
def torch2jittor(torch_in: Any, no_gradient: bool = None) -> Any: | |||
""" | |||
递归地将输入中包含的torch张量转换为jittor变量 | |||
:param torch_tensor: 要转换的torch张量 | |||
:param no_gradient: 是否保留原张量的梯度。为`None`时,新的张量与输入张量保持一致; | |||
为`True`时,全部不保留梯度;为`False`时,全部保留梯度。 | |||
:return: 转换后的jittor变量 | |||
""" | |||
return apply_to_collection( | |||
torch_in, | |||
dtype=torch.Tensor, | |||
function=_torch2jittor, | |||
no_gradient=no_gradient, | |||
) |
@@ -0,0 +1,81 @@ | |||
import pytest | |||
from fastNLP.core.collators import AutoCollator | |||
from fastNLP.core.collators.collator import _MultiCollator | |||
from fastNLP.core.dataset import DataSet | |||
class TestCollator: | |||
@pytest.mark.parametrize('as_numpy', [True, False]) | |||
def test_auto_collator(self, as_numpy): | |||
""" | |||
测试auto_collator的auto_pad功能 | |||
:param as_numpy: | |||
:return: | |||
""" | |||
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, | |||
'y': [0, 1, 1, 0] * 100}) | |||
collator = AutoCollator(as_numpy=as_numpy) | |||
collator.set_input('x', 'y') | |||
bucket_data = [] | |||
data = [] | |||
for i in range(len(dataset)): | |||
data.append(dataset[i]) | |||
if len(data) == 40: | |||
bucket_data.append(data) | |||
data = [] | |||
results = [] | |||
for bucket in bucket_data: | |||
res = collator(bucket) | |||
assert res['x'].shape == (40, 5) | |||
assert res['y'].shape == (40,) | |||
results.append(res) | |||
def test_auto_collator_v1(self): | |||
""" | |||
测试auto_collator的set_pad_val和set_pad_val功能 | |||
:return: | |||
""" | |||
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, | |||
'y': [0, 1, 1, 0] * 100}) | |||
collator = AutoCollator(as_numpy=False) | |||
collator.set_input('x') | |||
collator.set_pad_val('x', val=-1) | |||
collator.set_as_numpy(True) | |||
bucket_data = [] | |||
data = [] | |||
for i in range(len(dataset)): | |||
data.append(dataset[i]) | |||
if len(data) == 40: | |||
bucket_data.append(data) | |||
data = [] | |||
for bucket in bucket_data: | |||
res = collator(bucket) | |||
print(res) | |||
def test_multicollator(self): | |||
""" | |||
测试multicollator功能 | |||
:return: | |||
""" | |||
dataset = DataSet({'x': [[1, 2], [0, 1, 2, 3], [3], [9, 0, 10, 1, 5]] * 100, | |||
'y': [0, 1, 1, 0] * 100}) | |||
collator = AutoCollator(as_numpy=False) | |||
multi_collator = _MultiCollator(collator) | |||
multi_collator.set_as_numpy(as_numpy=True) | |||
multi_collator.set_pad_val('x', val=-1) | |||
multi_collator.set_input('x') | |||
bucket_data = [] | |||
data = [] | |||
for i in range(len(dataset)): | |||
data.append(dataset[i]) | |||
if len(data) == 40: | |||
bucket_data.append(data) | |||
data = [] | |||
for bucket in bucket_data: | |||
res = multi_collator(bucket) | |||
print(res) |
@@ -0,0 +1,80 @@ | |||
import pytest | |||
from jittor.dataset import Dataset | |||
import jittor | |||
import numpy as np | |||
from datasets import Dataset as HfDataset | |||
from datasets import load_dataset | |||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | |||
from fastNLP.core.dataset import DataSet as Fdataset | |||
class MyDataset(Dataset): | |||
def __init__(self, data_len=1000): | |||
super(MyDataset, self).__init__() | |||
self.data = [jittor.ones((3, 4)) for _ in range(data_len)] | |||
self.set_attrs(total_len=data_len) | |||
self.dataset_len = data_len | |||
def __getitem__(self, item): | |||
return self.data[item] | |||
# return {'x': [[1, 0], [2, 0, 1]]} | |||
# return np.random.randn(3, 10) | |||
# def __len__(self): | |||
# return self.dataset_len | |||
class TestJittor: | |||
def test_v1(self): | |||
""" | |||
测试jittor类型的dataset使用fdl | |||
:return: | |||
""" | |||
dataset = MyDataset() | |||
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | |||
jtl.set_pad_val('x', 'y') | |||
jtl.set_input('x') | |||
for batch in jtl: | |||
print(batch) | |||
print(jtl.get_batch_indices()) | |||
def test_v2(self): | |||
""" | |||
测试fastnlp的dataset | |||
:return: | |||
""" | |||
dataset = Fdataset({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | |||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | |||
jtl.set_pad_val('x', val=-1) | |||
jtl.set_input('x', 'y') | |||
for batch in jtl: | |||
assert batch['x'].size() == (16, 4) | |||
def test_v3(self): | |||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | |||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | |||
jtl.set_input('x', 'y') | |||
for batch in jtl: | |||
print(batch) | |||
def test_v4(self): | |||
dataset = MyDataset() | |||
dl = JittorDataLoader(dataset, batch_size=4, num_workers=2) | |||
print(len(dl)) | |||
for idx, batch in enumerate(dl): | |||
print(batch.shape, idx) | |||
for idx, batch in enumerate(dl): | |||
print(batch.shape, idx) | |||
def test_v5(self): | |||
dataset = MyDataset() | |||
dataset.set_attrs(batch_size=4, num_workers=2) | |||
for idx, batch in enumerate(dataset): | |||
print(idx, batch.shape) | |||
for idx, batch in enumerate(dataset): | |||
print(idx, batch.shape) |
@@ -0,0 +1,53 @@ | |||
import unittest | |||
from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader | |||
from fastNLP.core.dataset import DataSet | |||
from paddle.io import Dataset, DataLoader | |||
import numpy as np | |||
import paddle | |||
class RandomDataset(Dataset): | |||
def __getitem__(self, idx): | |||
image = np.random.random((10, 5)).astype('float32') | |||
return {'image': paddle.Tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]} | |||
def __len__(self): | |||
return 10 | |||
class TestPaddle(unittest.TestCase): | |||
def test_init(self): | |||
# ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | |||
ds = RandomDataset() | |||
fdl = PaddleDataLoader(ds, batch_size=2) | |||
# fdl = DataLoader(ds, batch_size=2, shuffle=True) | |||
for batch in fdl: | |||
print(batch) | |||
# print(fdl.get_batch_indices()) | |||
def test_fdl_batch_indices(self): | |||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | |||
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | |||
fdl.set_input("x", "y") | |||
for batch in fdl: | |||
assert len(fdl.get_batch_indices()) == 4 | |||
print(batch) | |||
print(fdl.get_batch_indices()) | |||
def test_set_inputs_and_set_pad_val(self): | |||
ds = RandomDataset() | |||
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | |||
fdl.set_input('image', 'label') | |||
fdl.set_pad_val('label', val=-1) | |||
for batch in fdl: | |||
assert batch['image'].shape == [2, 10, 5] | |||
print(batch) | |||
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | |||
fdl1.set_input('image', 'label') | |||
fdl1.set_pad_val('image', val=None) | |||
for batch in fdl1: | |||
assert batch['image'].shape == [4, 10, 5] | |||
print(batch) |
@@ -0,0 +1,96 @@ | |||
import unittest | |||
from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.io.data_bundle import DataBundle | |||
class TestFdl(unittest.TestCase): | |||
def test_init_v1(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) | |||
# for batch in fdl: | |||
# print(batch) | |||
fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) | |||
# for batch in fdl1: | |||
# print(batch) | |||
def test_set_padding(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
ds.set_pad_val("x", val=-1) | |||
fdl = FDataLoader(ds, batch_size=3) | |||
fdl.set_input("x", "y") | |||
for batch in fdl: | |||
print(batch) | |||
fdl.set_pad_val("x", val=-2) | |||
for batch in fdl: | |||
print(batch) | |||
def test_add_collator(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
def collate_fn(ins_list): | |||
_dict = {"Y": []} | |||
for ins in ins_list: | |||
_dict["Y"].append(ins['y']) | |||
return _dict | |||
fdl = FDataLoader(ds, batch_size=3, as_numpy=True) | |||
fdl.set_input("x", "y") | |||
fdl.add_collator(collate_fn) | |||
for batch in fdl: | |||
print(batch) | |||
def test_get_batch_indices(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
fdl = FDataLoader(ds, batch_size=3, shuffle=True) | |||
fdl.set_input("y", "x") | |||
for batch in fdl: | |||
print(fdl.get_batch_indices()) | |||
def test_other_dataset(self): | |||
import numpy as np | |||
class _DataSet: | |||
def __init__(self): | |||
pass | |||
def __getitem__(self, item): | |||
return np.random.randn(5), [[1, 2], [2, 3, 4]] | |||
def __len__(self): | |||
return 10 | |||
def __getattribute__(self, item): | |||
return object.__getattribute__(self, item) | |||
dataset = _DataSet() | |||
dl = FDataLoader(dataset, batch_size=2, shuffle=True) | |||
# dl.set_inputs('data', 'labels') | |||
# dl.set_pad_val('labels', val=None) | |||
for batch in dl: | |||
print(batch) | |||
print(dl.get_batch_indices()) | |||
def test_prepare_dataloader(self): | |||
ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) | |||
assert isinstance(dl, FDataLoader) | |||
ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) | |||
dbl = DataBundle(datasets={'train': ds, 'val': ds1}) | |||
dl_bundle = prepare_dataloader(dbl) | |||
assert isinstance(dl_bundle['train'], FDataLoader) | |||
assert isinstance(dl_bundle['val'], FDataLoader) | |||
ds_dict = {'train_1': ds, 'val': ds1} | |||
dl_dict = prepare_dataloader(ds_dict) | |||
assert isinstance(dl_dict['train_1'], FDataLoader) | |||
assert isinstance(dl_dict['val'], FDataLoader) | |||
sequence = [ds, ds1] | |||
seq_ds = prepare_dataloader(sequence) | |||
assert isinstance(seq_ds[0], FDataLoader) | |||
assert isinstance(seq_ds[1], FDataLoader) |
@@ -0,0 +1,516 @@ | |||
import os | |||
import unittest | |||
import numpy as np | |||
from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException | |||
class TestDataSetInit(unittest.TestCase): | |||
"""初始化DataSet的办法有以下几种: | |||
1) 用dict: | |||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||
2) 用list of Instance: | |||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||
只接受纯list或者最外层ndarray | |||
""" | |||
def test_init_v1(self): | |||
# 一维list | |||
ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) | |||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
def test_init_v2(self): | |||
# 用dict | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) | |||
self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) | |||
self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) | |||
def test_init_assert(self): | |||
with self.assertRaises(AssertionError): | |||
_ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) | |||
with self.assertRaises(AssertionError): | |||
_ = DataSet([[1, 2, 3, 4]] * 10) | |||
with self.assertRaises(ValueError): | |||
_ = DataSet(0.00001) | |||
class TestDataSetMethods(unittest.TestCase): | |||
def test_append(self): | |||
dd = DataSet() | |||
for _ in range(3): | |||
dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) | |||
self.assertEqual(len(dd), 3) | |||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) | |||
self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) | |||
def test_add_field(self): | |||
dd = DataSet() | |||
dd.add_field("x", [[1, 2, 3]] * 10) | |||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||
dd.add_field("z", [[5, 6]] * 10) | |||
self.assertEqual(len(dd), 10) | |||
self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) | |||
self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) | |||
self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) | |||
with self.assertRaises(RuntimeError): | |||
dd.add_field("??", [[1, 2]] * 40) | |||
def test_delete_field(self): | |||
dd = DataSet() | |||
dd.add_field("x", [[1, 2, 3]] * 10) | |||
dd.add_field("y", [[1, 2, 3, 4]] * 10) | |||
dd.delete_field("x") | |||
self.assertFalse("x" in dd.field_arrays) | |||
self.assertTrue("y" in dd.field_arrays) | |||
def test_delete_instance(self): | |||
dd = DataSet() | |||
old_length = 2 | |||
dd.add_field("x", [[1, 2, 3]] * old_length) | |||
dd.add_field("y", [[1, 2, 3, 4]] * old_length) | |||
dd.delete_instance(0) | |||
self.assertEqual(len(dd), old_length - 1) | |||
dd.delete_instance(0) | |||
self.assertEqual(len(dd), old_length - 2) | |||
def test_getitem(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ins_1, ins_0 = ds[0], ds[1] | |||
self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance)) | |||
self.assertEqual(ins_1["x"], [1, 2, 3, 4]) | |||
self.assertEqual(ins_1["y"], [5, 6]) | |||
self.assertEqual(ins_0["x"], [1, 2, 3, 4]) | |||
self.assertEqual(ins_0["y"], [5, 6]) | |||
sub_ds = ds[:10] | |||
self.assertTrue(isinstance(sub_ds, DataSet)) | |||
self.assertEqual(len(sub_ds), 10) | |||
sub_ds_1 = ds[[10, 0, 2, 3]] | |||
self.assertTrue(isinstance(sub_ds_1, DataSet)) | |||
self.assertEqual(len(sub_ds_1), 4) | |||
field_array = ds['x'] | |||
self.assertTrue(isinstance(field_array, FieldArray)) | |||
self.assertEqual(len(field_array), 40) | |||
def test_get_item_error(self): | |||
with self.assertRaises(RuntimeError): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
_ = ds[40:] | |||
with self.assertRaises(KeyError): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
_ = ds["kom"] | |||
def test_len_(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
self.assertEqual(len(ds), 40) | |||
ds = DataSet() | |||
self.assertEqual(len(ds), 0) | |||
def test_add_fieldarray(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40)) | |||
self.assertEqual(ds['z'].content, [[7, 8]]*40) | |||
with self.assertRaises(RuntimeError): | |||
ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10)) | |||
with self.assertRaises(TypeError): | |||
ds.add_fieldarray('z', [1, 2, 4]) | |||
def test_copy_field(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ds.copy_field('x', 'z') | |||
self.assertEqual(ds['x'].content, ds['z'].content) | |||
def test_has_field(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
self.assertTrue(ds.has_field('x')) | |||
self.assertFalse(ds.has_field('z')) | |||
def test_get_field(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
with self.assertRaises(KeyError): | |||
ds.get_field('z') | |||
x_array = ds.get_field('x') | |||
self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40) | |||
def test_get_all_fields(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
field_arrays = ds.get_all_fields() | |||
self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40) | |||
self.assertEqual(field_arrays['y'], [[5, 6]] * 40) | |||
def test_get_field_names(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
field_names = ds.get_field_names() | |||
self.assertTrue('x' in field_names) | |||
self.assertTrue('y' in field_names) | |||
def test_apply(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000}) | |||
ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx') | |||
self.assertTrue("rx" in ds.field_arrays) | |||
self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) | |||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | |||
self.assertEqual(ds.field_arrays["y"].content[0], 2) | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") | |||
self.assertTrue(isinstance(res, list) and len(res) > 0) | |||
self.assertTrue(res[0], 4) | |||
ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k") | |||
# expect no exception raised | |||
def test_apply_progress_bar(self): | |||
import time | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 400, "y": [[5, 6]] * 400}) | |||
def do_nothing(ins): | |||
time.sleep(0.01) | |||
ds.apply(do_nothing, show_progress_bar=True, num_proc=0) | |||
ds.apply_field(do_nothing, field_name='x', show_progress_bar=True) | |||
def test_apply_cannot_modify_instance(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
def modify_inplace(instance): | |||
instance['words'] = 1 | |||
ds.apply(modify_inplace) | |||
# with self.assertRaises(TypeError): | |||
# ds.apply(modify_inplace) | |||
def test_apply_more(self): | |||
T = DataSet({"a": [1, 2, 3], "b": [2, 4, 5]}) | |||
func_1 = lambda x: {"c": x["a"] * 2, "d": x["a"] ** 2} | |||
func_2 = lambda x: {"c": x * 3, "d": x ** 3} | |||
def func_err_1(x): | |||
if x["a"] == 1: | |||
return {"e": x["a"] * 2, "f": x["a"] ** 2} | |||
else: | |||
return {"e": x["a"] * 2} | |||
def func_err_2(x): | |||
if x == 1: | |||
return {"e": x * 2, "f": x ** 2} | |||
else: | |||
return {"e": x * 2} | |||
T.apply_more(func_1) | |||
# print(T['c'][0, 1, 2]) | |||
self.assertEqual(list(T["c"].content), [2, 4, 6]) | |||
self.assertEqual(list(T["d"].content), [1, 4, 9]) | |||
res = T.apply_field_more(func_2, "a", modify_fields=False) | |||
self.assertEqual(list(T["c"].content), [2, 4, 6]) | |||
self.assertEqual(list(T["d"].content), [1, 4, 9]) | |||
self.assertEqual(list(res["c"]), [3, 6, 9]) | |||
self.assertEqual(list(res["d"]), [1, 8, 27]) | |||
with self.assertRaises(ApplyResultException) as e: | |||
T.apply_more(func_err_1) | |||
print(e) | |||
with self.assertRaises(ApplyResultException) as e: | |||
T.apply_field_more(func_err_2, "a") | |||
print(e) | |||
def test_drop(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) | |||
ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) | |||
self.assertEqual(len(ds), 20) | |||
def test_contains(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
self.assertTrue("x" in ds) | |||
self.assertTrue("y" in ds) | |||
self.assertFalse("z" in ds) | |||
def test_rename_field(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
ds.rename_field("x", "xx") | |||
self.assertTrue("xx" in ds) | |||
self.assertFalse("x" in ds) | |||
with self.assertRaises(KeyError): | |||
ds.rename_field("yyy", "oo") | |||
def test_split(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
d1, d2 = ds.split(0.1) | |||
self.assertEqual(len(d1), len(ds)*0.9) | |||
self.assertEqual(len(d2), len(ds)*0.1) | |||
def test_add_field_v2(self): | |||
ds = DataSet({"x": [3, 4]}) | |||
ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']]) | |||
# ds.apply(lambda x:[x['x']]*3, new_field_name='y') | |||
print(ds) | |||
def test_save_load(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
ds.save("./my_ds.pkl") | |||
self.assertTrue(os.path.exists("./my_ds.pkl")) | |||
ds_1 = DataSet.load("./my_ds.pkl") | |||
os.remove("my_ds.pkl") | |||
def test_add_null(self): | |||
ds = DataSet() | |||
with self.assertRaises(RuntimeError) as RE: | |||
ds.add_field('test', []) | |||
def test_concat(self): | |||
""" | |||
测试两个dataset能否正确concat | |||
""" | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for _ in range(10)], "y": [[5, 6] for _ in range(10)]}) | |||
ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]}) | |||
ds3 = ds1.concat(ds2) | |||
self.assertEqual(len(ds3), 20) | |||
self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4]) | |||
self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1]) | |||
ds2[0]['x'][0] = 100 | |||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||
ds3[10]['x'][0] = -100 | |||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||
# 测试inplace | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"x": [[4, 3, 2, 1] for i in range(10)], "y": [[6, 5] for i in range(10)]}) | |||
ds3 = ds1.concat(ds2, inplace=True) | |||
ds2[0]['x'][0] = 100 | |||
self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 | |||
ds3[10]['x'][0] = -100 | |||
self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 | |||
ds3[0]['x'][0] = 100 | |||
self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了 | |||
# 测试mapping | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) | |||
ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'}) | |||
self.assertEqual(len(ds3), 20) | |||
# 测试忽略掉多余的 | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)], 'Z': [0] * 10}) | |||
ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'}) | |||
# 测试报错 | |||
ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) | |||
ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) | |||
with self.assertRaises(RuntimeError): | |||
ds3 = ds1.concat(ds2, field_mapping={'X': 'x'}) | |||
def test_instance_field_disappear_bug(self): | |||
data = DataSet({'raw_chars': [[0, 1], [2]], 'target': [0, 1]}) | |||
data.copy_field(field_name='raw_chars', new_field_name='chars') | |||
_data = data[:1] | |||
for field_name in ['raw_chars', 'target', 'chars']: | |||
self.assertTrue(_data.has_field(field_name)) | |||
def test_from_pandas(self): | |||
import pandas as pd | |||
df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
ds = DataSet.from_pandas(df) | |||
print(ds) | |||
self.assertEqual(ds['x'].content, [1, 2, 3]) | |||
self.assertEqual(ds['y'].content, [4, 5, 6]) | |||
def test_to_pandas(self): | |||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
df = ds.to_pandas() | |||
def test_to_csv(self): | |||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
ds.to_csv("1.csv") | |||
self.assertTrue(os.path.exists("1.csv")) | |||
os.remove("1.csv") | |||
def test_add_collate_fn(self): | |||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
def collate_fn(item): | |||
return item | |||
ds.add_collate_fn(collate_fn) | |||
self.assertEqual(len(ds.collate_fns.collators), 2) | |||
def test_get_collator(self): | |||
from typing import Callable | |||
ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) | |||
collate_fn = ds.get_collator() | |||
self.assertEqual(isinstance(collate_fn, Callable), True) | |||
def test_add_seq_len(self): | |||
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) | |||
ds.add_seq_len('x') | |||
print(ds) | |||
def test_set_target(self): | |||
ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) | |||
ds.set_target('x') | |||
class TestFieldArrayInit(unittest.TestCase): | |||
""" | |||
1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: | |||
1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) | |||
1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])}) | |||
1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]}) | |||
2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray; | |||
然后后面的样本使用FieldArray.append进行添加。 | |||
2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])]) | |||
2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))]) | |||
2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])]) | |||
2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))]) | |||
""" | |||
def test_init_v1(self): | |||
# 二维list | |||
fa = FieldArray("x", [[1, 2], [3, 4]] * 5) | |||
def test_init_v2(self): | |||
# 二维array | |||
fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5)) | |||
def test_init_v3(self): | |||
# 三维list | |||
fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]) | |||
def test_init_v4(self): | |||
# 一维list | |||
val = [1, 2, 3, 4] | |||
fa = FieldArray("x", [val]) | |||
fa.append(val) | |||
def test_init_v5(self): | |||
# 一维array | |||
val = np.array([1, 2, 3, 4]) | |||
fa = FieldArray("x", [val]) | |||
fa.append(val) | |||
def test_init_v6(self): | |||
# 二维array | |||
val = [[1, 2], [3, 4]] | |||
fa = FieldArray("x", [val]) | |||
fa.append(val) | |||
def test_init_v7(self): | |||
# list of array | |||
fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])]) | |||
def test_init_v8(self): | |||
# 二维list | |||
val = np.array([[1, 2], [3, 4]]) | |||
fa = FieldArray("x", [val]) | |||
fa.append(val) | |||
class TestFieldArray(unittest.TestCase): | |||
def test_main(self): | |||
fa = FieldArray("x", [1, 2, 3, 4, 5]) | |||
self.assertEqual(len(fa), 5) | |||
fa.append(6) | |||
self.assertEqual(len(fa), 6) | |||
self.assertEqual(fa[-1], 6) | |||
self.assertEqual(fa[0], 1) | |||
fa[-1] = 60 | |||
self.assertEqual(fa[-1], 60) | |||
self.assertEqual(fa.get(0), 1) | |||
self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) | |||
self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) | |||
def test_getitem_v1(self): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) | |||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
ans = fa[[0, 1]] | |||
self.assertTrue(isinstance(ans, np.ndarray)) | |||
self.assertTrue(isinstance(ans[0], np.ndarray)) | |||
self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) | |||
self.assertEqual(ans.dtype, np.float64) | |||
def test_getitem_v2(self): | |||
x = np.random.rand(10, 5) | |||
fa = FieldArray("my_field", x) | |||
indices = [0, 1, 3, 4, 6] | |||
for a, b in zip(fa[indices], x[indices]): | |||
self.assertListEqual(a.tolist(), b.tolist()) | |||
def test_append(self): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) | |||
fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) | |||
self.assertEqual(len(fa), 3) | |||
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) | |||
def test_pop(self): | |||
fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) | |||
fa.pop(0) | |||
self.assertEqual(len(fa), 1) | |||
self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0]) | |||
fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5] | |||
self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) | |||
class TestCase(unittest.TestCase): | |||
def test_init(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||
self.assertTrue(isinstance(ins.fields, dict)) | |||
self.assertEqual(ins.fields, fields) | |||
ins = Instance(**fields) | |||
self.assertEqual(ins.fields, fields) | |||
def test_add_field(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
ins = Instance(**fields) | |||
ins.add_field("z", [1, 1, 1]) | |||
fields.update({"z": [1, 1, 1]}) | |||
self.assertEqual(ins.fields, fields) | |||
def test_get_item(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
ins = Instance(**fields) | |||
self.assertEqual(ins["x"], [1, 2, 3]) | |||
self.assertEqual(ins["y"], [4, 5, 6]) | |||
self.assertEqual(ins["z"], [1, 1, 1]) | |||
def test_repr(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
ins = Instance(**fields) | |||
# simple print, that is enough. | |||
print(ins) |