|
|
@@ -26,20 +26,11 @@ class _JittorDataset(Dataset): |
|
|
|
def __init__(self, dataset) -> None: |
|
|
|
super(_JittorDataset, self).__init__() |
|
|
|
self.dataset = dataset |
|
|
|
self.total_len = len(dataset) |
|
|
|
|
|
|
|
def __getitem__(self, item): |
|
|
|
return (item, self.dataset[item]) |
|
|
|
|
|
|
|
def __len__(self) -> int: |
|
|
|
return len(self.dataset) |
|
|
|
|
|
|
|
# def __getattr__(self, item): |
|
|
|
# # jittor的Dataset没有的方法而用户的dataset存在且实现了getattribute方法,此时用户可以调用 |
|
|
|
# try: |
|
|
|
# self.dataset.__getattribute__(item) |
|
|
|
# except Exception as e: |
|
|
|
# raise e |
|
|
|
|
|
|
|
|
|
|
|
class JittorDataLoader: |
|
|
|
""" |
|
|
@@ -62,13 +53,17 @@ class JittorDataLoader: |
|
|
|
:param keep_numpy_array: |
|
|
|
:param endless: |
|
|
|
:param collate_fn: 对取得到的数据进行打包的callable函数 |
|
|
|
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 |
|
|
|
""" |
|
|
|
# TODO 验证支持replacesampler (以后完成) |
|
|
|
# 将内部dataset批次设置为1 |
|
|
|
if isinstance(dataset, Dataset): |
|
|
|
dataset.set_attrs(batch_size=1) |
|
|
|
|
|
|
|
# FastNLP Datset, collate_fn not None |
|
|
|
if isinstance(dataset, FDataSet) and collate_fn is None: |
|
|
|
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") |
|
|
|
|
|
|
|
# 将所有dataset转为jittor类型的dataset |
|
|
|
if not isinstance(dataset, _JittorDataset): |
|
|
|
self.dataset = _JittorDataset(dataset) |
|
|
|
|
|
|
@@ -82,17 +77,13 @@ class JittorDataLoader: |
|
|
|
else: |
|
|
|
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") |
|
|
|
elif isinstance(collate_fn, Callable): |
|
|
|
if collate_fn is not collate_batch: |
|
|
|
self.collate_fn = collate_fn |
|
|
|
self.collate_fn = collate_fn |
|
|
|
else: |
|
|
|
self.collate_fn = collate_batch |
|
|
|
|
|
|
|
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, |
|
|
|
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, |
|
|
|
keep_numpy_array=keep_numpy_array, endless=endless) |
|
|
|
# 将内部dataset批次设置为1 |
|
|
|
if isinstance(self.dataset.dataset, Dataset): |
|
|
|
self.dataset.dataset.set_attrs(batch_size=1) |
|
|
|
|
|
|
|
self.cur_batch_indices = None |
|
|
|
|
|
|
@@ -105,12 +96,10 @@ class JittorDataLoader: |
|
|
|
yield data |
|
|
|
|
|
|
|
def __len__(self): |
|
|
|
if self.dataset.drop_last: |
|
|
|
return len(self.dataset) // self.dataset.batch_size |
|
|
|
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 |
|
|
|
return len(self.dataset) |
|
|
|
|
|
|
|
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, |
|
|
|
pad_fn: Callable = None) -> "JittorDataLoader": |
|
|
|
pad_fn: Callable = None) -> Collator: |
|
|
|
""" |
|
|
|
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 |
|
|
|
|
|
|
@@ -129,14 +118,27 @@ class JittorDataLoader: |
|
|
|
形式,输出将被直接作为结果输出。 |
|
|
|
:return: 返回 Collator 自身 |
|
|
|
""" |
|
|
|
if isinstance(self.collate_fn, Collator): |
|
|
|
self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, |
|
|
|
backend=backend) |
|
|
|
return self |
|
|
|
collator = self._get_collator() |
|
|
|
if isinstance(collator, Collator): |
|
|
|
collator.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) |
|
|
|
return collator |
|
|
|
else: |
|
|
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") |
|
|
|
|
|
|
|
def set_ignore(self, *field_names) -> "JittorDataLoader": |
|
|
|
def _get_collator(self): |
|
|
|
""" |
|
|
|
如果 collate_fn 是 Collator 对象,得到该对象。如果没有的话,返回 None |
|
|
|
|
|
|
|
:return: |
|
|
|
""" |
|
|
|
collator = None |
|
|
|
if hasattr(self.collate_fn, '__wrapped__') and isinstance(self.collate_fn.__wrapped__, Collator): |
|
|
|
collator = self.collate_fn.__wrapped__ |
|
|
|
elif isinstance(self.collate_fn, Collator): |
|
|
|
collator = self.collate_fn |
|
|
|
return collator |
|
|
|
|
|
|
|
def set_ignore(self, *field_names) -> Collator: |
|
|
|
""" |
|
|
|
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 |
|
|
|
Example:: |
|
|
@@ -147,9 +149,10 @@ class JittorDataLoader: |
|
|
|
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 |
|
|
|
:return: 返回 Collator 自身 |
|
|
|
""" |
|
|
|
if isinstance(self.collate_fn, Collator): |
|
|
|
self.collate_fn.set_ignore(*field_names) |
|
|
|
return self |
|
|
|
collator = self._get_collator() |
|
|
|
if isinstance(collator, Collator): |
|
|
|
collator.set_ignore(*field_names) |
|
|
|
return collator |
|
|
|
else: |
|
|
|
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") |
|
|
|
|
|
|
|