@@ -9,6 +9,7 @@ from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPad | |||||
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | ||||
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | ||||
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | ||||
from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder | |||||
from .exceptions import * | from .exceptions import * | ||||
@@ -91,6 +92,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | |||||
return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | ||||
@@ -103,6 +106,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | |||||
return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | ||||
@@ -116,6 +121,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | ||||
elif backend == 'jittor': | |||||
return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | ||||
@@ -0,0 +1,195 @@ | |||||
__all__ = [ | |||||
'JittorNumberPadder', | |||||
'JittorSequencePadder', | |||||
'JittorTensorPadder' | |||||
] | |||||
from inspect import isclass | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
numpy_to_jittor_dtype_dict = { | |||||
np.bool_: 'bool', | |||||
np.uint8: 'uint8', | |||||
np.int8: "int8", | |||||
np.int16: "int16", | |||||
np.int32: "int32", | |||||
np.int64: "int64", | |||||
np.float16: "float16", | |||||
np.float32: 'float32', | |||||
np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||||
} | |||||
# number_to_jittor_dtype_dict = { | |||||
# float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64 | |||||
# int: 'int64', | |||||
# bool: 'bool' | |||||
# } | |||||
from .padder import Padder | |||||
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||||
from .exceptions import * | |||||
def is_jittor_tensor(dtype): | |||||
if not isclass(dtype) and isinstance(dtype, jittor.jittor_core.Var): | |||||
return True | |||||
return False | |||||
def is_jittor_dtype_str(dtype): | |||||
try: | |||||
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | |||||
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | |||||
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8', | |||||
u'int16', u'int32', u'int64', u'uint8'}: | |||||
return True | |||||
except: | |||||
pass | |||||
return False | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
if not (ele_dtype is None or ( | |||||
is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"or numpy numbers or jittor.Var but get `{ele_dtype}`.") | |||||
if dtype is not None: | |||||
if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or jittor.dtype but get `{dtype}`.") | |||||
# dtype = number_to_jittor_dtype_dict.get(dtype, dtype) | |||||
else: | |||||
# if (is_number(ele_dtype) or is_jittor_tensor(ele_dtype)): | |||||
# # ele_dtype = number_to_jittor_dtype_dict.get(ele_dtype, ele_dtype) | |||||
# dtype = ele_dtype | |||||
# elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
# dtype = numpy_to_jittor_dtype_dict.get(ele_dtype.type) | |||||
if is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_jittor_dtype_dict.get(ele_dtype) | |||||
else: | |||||
dtype = ele_dtype | |||||
return dtype | |||||
class JittorNumberPadder(Padder): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 jittor.Var([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
return jittor.Var(np.array(batch_field, dtype=dtype)) | |||||
class JittorSequencePadder(Padder): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 jittor.Var([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||||
return tensor | |||||
class JittorTensorPadder(Padder): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
目前支持 [jittor.Var([3, 2], jittor.Var([1])] 类似的。若内部元素不为 jittor.Var ,则必须含有 tolist() 方法。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
try: | |||||
if not isinstance(batch_field[0], jittor.Var): | |||||
batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field] | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a jittor.Var (it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
shapes = [field.shape for field in batch_field] | |||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||||
# if dtype is not None: | |||||
# tensor = jittor.full(max_shape, pad_val, dtype=dtype) | |||||
# else: | |||||
tensor = jittor.full(max_shape, pad_val, dtype=dtype) | |||||
for i, field in enumerate(batch_field): | |||||
slices = (i,) + tuple(slice(0, s) for s in shapes[i]) | |||||
tensor[slices] = field | |||||
return tensor | |||||
def fill_tensor(batch_field, padded_batch, dtype): | |||||
""" | |||||
将 batch_field 中的值填入到 tensor 中。 | |||||
:param batch_field: 需要填充进入 array 中的内容 | |||||
:param padded_batch: 待填充的 tensor | |||||
:param dtype: 数据的类别 | |||||
:return: | |||||
""" | |||||
if padded_batch.ndim == 2: | |||||
for i, content_i in enumerate(batch_field): | |||||
padded_batch[i, :len(content_i)] = jittor.Var(np.array(content_i, dtype=dtype)) | |||||
elif padded_batch.ndim == 3: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype)) | |||||
elif padded_batch.ndim == 4: | |||||
try: # 应该是图像,所以直接应该就 ok 了。 | |||||
padded_batch = np.array(batch_field) | |||||
except: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
for k, content_iii in enumerate(content_ii): | |||||
padded_batch[i, j, k, :len(content_iii)] = jittor.Var(np.array(content_iii, dtype=dtype)) | |||||
elif padded_batch.ndim == 1: | |||||
padded_batch[:] = jittor.Var(np.array(batch_field, dtype=dtype)) | |||||
else: | |||||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
"report.") | |||||
return padded_batch | |||||
def get_padded_jittor_tensor(batch_field, dtype=None, pad_val=0): | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> jittor.LongTensor([[1, 2], [3, 0]]) | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
shapes = get_shape(batch_field) | |||||
tensor = jittor.full(shapes, pad_val, dtype=dtype) | |||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||||
return tensor |
@@ -64,38 +64,40 @@ class JittorDataLoader: | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | :param collate_fn: 对取得到的数据进行打包的callable函数 | ||||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | :param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | ||||
""" | """ | ||||
# TODO 支持fastnlp dataset | |||||
# TODO 验证支持replacesampler (以后完成) | # TODO 验证支持replacesampler (以后完成) | ||||
# 是否为 jittor 类型的 dataset | |||||
# FastNLP Datset, collate_fn not None | |||||
if isinstance(dataset, FDataSet) and collate_fn is None: | |||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||||
if not isinstance(dataset, _JittorDataset): | |||||
self.dataset = _JittorDataset(dataset) | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == "auto": | if collate_fn == "auto": | ||||
if isinstance(dataset, FDataSet): | |||||
self._collate_fn = dataset.collator | |||||
self._collate_fn.set_backend(backend="jittor") | |||||
if isinstance(self.dataset.dataset, FDataSet): | |||||
self.collate_fn = self.dataset.dataset.collator | |||||
self.collate_fn.set_backend(backend="jittor") | |||||
else: | else: | ||||
self._collate_fn = Collator(backend="jittor") | |||||
self.collate_fn = Collator(backend="jittor") | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
elif isinstance(collate_fn, Callable): | elif isinstance(collate_fn, Callable): | ||||
if collate_fn is not collate_batch: | if collate_fn is not collate_batch: | ||||
self._collate_fn = collate_fn | |||||
self.collate_fn = collate_fn | |||||
else: | else: | ||||
self._collate_fn = collate_batch | |||||
self.dataset = _JittorDataset(dataset) | |||||
self.collate_fn = collate_batch | |||||
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | ||||
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | ||||
keep_numpy_array=keep_numpy_array, endless=endless) | keep_numpy_array=keep_numpy_array, endless=endless) | ||||
# 将内部dataset批次设置为1 | |||||
if isinstance(self.dataset.dataset, Dataset): | if isinstance(self.dataset.dataset, Dataset): | ||||
self.dataset.dataset.set_attrs(batch_size=1) | self.dataset.dataset.set_attrs(batch_size=1) | ||||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | |||||
# self._collate_fn = _collate_fn | |||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
def __iter__(self): | def __iter__(self): | ||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | ||||
self.collate_fn = self._collate_fn | |||||
if self.cur_batch_indices is None: | if self.cur_batch_indices is None: | ||||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) | self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) | ||||
for indices, data in self.dataset.__iter__(): | for indices, data in self.dataset.__iter__(): | ||||
@@ -107,8 +109,8 @@ class JittorDataLoader: | |||||
return len(self.dataset) // self.dataset.batch_size | return len(self.dataset) // self.dataset.batch_size | ||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | ||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> "JittorDataLoader": | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -127,13 +129,14 @@ class JittorDataLoader: | |||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
if isinstance(self.collate_fn, Collator): | |||||
self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, | |||||
backend=backend) | |||||
return self | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | ||||
def set_ignore(self, *field_names) -> Collator: | |||||
def set_ignore(self, *field_names) -> "JittorDataLoader": | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Ex:: | Ex:: | ||||
@@ -144,9 +147,9 @@ class JittorDataLoader: | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_ignore(*field_names) | |||||
return self._collate_fn | |||||
if isinstance(self.collate_fn, Collator): | |||||
self.collate_fn.set_ignore(*field_names) | |||||
return self | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | ||||
@@ -158,5 +161,6 @@ class JittorDataLoader: | |||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_jittor_dataloader(): | def prepare_jittor_dataloader(): | ||||
... | ... |
@@ -189,7 +189,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | |||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=train_batch_size, | batch_sampler=batch_sampler, batch_size=train_batch_size, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
@@ -199,7 +199,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
persistent_workers=persistent_workers) | persistent_workers=persistent_workers) | ||||
else: | else: | ||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | |||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, | batch_sampler=batch_sampler, batch_size=non_train_batch_size, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
@@ -1,7 +1,6 @@ | |||||
import pytest | import pytest | ||||
import numpy as np | import numpy as np | ||||
from datasets import Dataset as HfDataset | from datasets import Dataset as HfDataset | ||||
from datasets import load_dataset | |||||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | ||||
from fastNLP.core.dataset import DataSet as Fdataset | from fastNLP.core.dataset import DataSet as Fdataset | ||||
@@ -23,16 +22,12 @@ class MyDataset(Dataset): | |||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return self.data[item] | return self.data[item] | ||||
# return {'x': [[1, 0], [2, 0, 1]]} | |||||
# return np.random.randn(3, 10) | |||||
# def __len__(self): | |||||
# return self.dataset_len | |||||
@pytest.mark.jittor | @pytest.mark.jittor | ||||
class TestJittor: | class TestJittor: | ||||
def test_v1(self): | |||||
def test_jittor_dataset(self): | |||||
""" | """ | ||||
测试jittor类型的dataset使用fdl | 测试jittor类型的dataset使用fdl | ||||
@@ -40,13 +35,13 @@ class TestJittor: | |||||
""" | """ | ||||
dataset = MyDataset() | dataset = MyDataset() | ||||
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | ||||
# jtl.set_pad_val('x', 'y') | |||||
# jtl.set_input('x') | |||||
for batch in jtl: | for batch in jtl: | ||||
print(batch) | |||||
print(jtl.get_batch_indices()) | |||||
assert batch.size() == [4, 3, 4] | |||||
jtl1 = JittorDataLoader(dataset, keep_numpy_array=False, batch_size=4, num_workers=2) | |||||
for batch in jtl1: | |||||
assert batch.size() == [4, 3, 4] | |||||
def test_v2(self): | |||||
def test_fastnlp_Dataset(self): | |||||
""" | """ | ||||
测试fastnlp的dataset | 测试fastnlp的dataset | ||||
@@ -56,26 +51,27 @@ class TestJittor: | |||||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | ||||
jtl.set_pad("x", -1) | jtl.set_pad("x", -1) | ||||
jtl.set_ignore("y") | jtl.set_ignore("y") | ||||
# jtl.set_pad_val('x', val=-1) | |||||
# jtl.set_input('x', 'y') | |||||
for batch in jtl: | for batch in jtl: | ||||
assert batch['x'].size() == (16, 4) | assert batch['x'].size() == (16, 4) | ||||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2) | |||||
def test_v3(self): | |||||
def test_huggingface_datasets(self): | |||||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | ||||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | ||||
# jtl.set_input('x', 'y') | |||||
for batch in jtl: | for batch in jtl: | ||||
print(batch) | |||||
assert batch['x'].size() == [4, 4] | |||||
assert len(batch['y']) == 4 | |||||
def test_v4(self): | |||||
def test_num_workers(self): | |||||
dataset = MyDataset() | dataset = MyDataset() | ||||
dl = JittorDataLoader(dataset, batch_size=4, num_workers=2) | dl = JittorDataLoader(dataset, batch_size=4, num_workers=2) | ||||
print(len(dl)) | |||||
for idx, batch in enumerate(dl): | for idx, batch in enumerate(dl): | ||||
print(batch.shape, idx) | |||||
assert batch.shape == [4, 3, 4] | |||||
for idx, batch in enumerate(dl): | for idx, batch in enumerate(dl): | ||||
print(batch.shape, idx) | |||||
assert batch.shape == [4, 3, 4] | |||||
def test_v5(self): | def test_v5(self): | ||||
dataset = MyDataset() | dataset = MyDataset() | ||||
@@ -18,7 +18,7 @@ class RandomDataset(Dataset): | |||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
image = np.random.random((10, 5)).astype('float32') | image = np.random.random((10, 5)).astype('float32') | ||||
return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]} | |||||
return {'image': paddle.to_tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]} | |||||
def __len__(self): | def __len__(self): | ||||
return 10 | return 10 | ||||
@@ -39,10 +39,16 @@ class TestPaddle: | |||||
def test_fdl_fastnlp_dataset(self): | def test_fdl_fastnlp_dataset(self): | ||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ||||
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | |||||
fdl = PaddleDataLoader(ds, batch_size=3, shuffle=False, drop_last=True) | |||||
fdl.set_ignore('y') | |||||
fdl.set_pad('x', -1) | |||||
for batch in fdl: | for batch in fdl: | ||||
assert len(fdl.get_batch_indices()) == 4 | |||||
print(fdl.get_batch_indices()) | |||||
assert len(fdl.get_batch_indices()) == 3 | |||||
assert 'y' not in batch | |||||
assert batch['x'].shape == [3, 3] | |||||
with pytest.raises(ValueError): | |||||
PaddleDataLoader(ds, batch_size=3, collate_fn=None) | |||||
def test_set_inputs_and_set_pad_val(self): | def test_set_inputs_and_set_pad_val(self): | ||||
logger.setLevel("DEBUG") | logger.setLevel("DEBUG") | ||||
@@ -50,11 +56,8 @@ class TestPaddle: | |||||
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | ||||
fdl.set_pad('label', -1) | fdl.set_pad('label', -1) | ||||
for batch in fdl: | for batch in fdl: | ||||
print(batch['image']) | |||||
assert batch['image'].shape == [2, 10, 5] | assert batch['image'].shape == [2, 10, 5] | ||||
print(batch) | |||||
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | ||||
fdl1.set_ignore('label') | fdl1.set_ignore('label') | ||||
for batch in fdl1: | for batch in fdl1: | ||||
assert batch['image'].shape == [4, 10, 5] | assert batch['image'].shape == [4, 10, 5] | ||||
print(batch) |