Browse Source

[unstable] change Batch to torch's DataLoader

tags/v0.4.10
yunfan 5 years ago
parent
commit
7564818f4b
9 changed files with 146 additions and 199 deletions
  1. +5
    -1
      fastNLP/__init__.py
  2. +1
    -1
      fastNLP/core/__init__.py
  3. +90
    -169
      fastNLP/core/batch.py
  4. +6
    -1
      fastNLP/core/field.py
  5. +2
    -3
      fastNLP/core/predictor.py
  6. +11
    -3
      fastNLP/core/tester.py
  7. +19
    -9
      fastNLP/core/trainer.py
  8. +2
    -2
      fastNLP/modules/encoder/embedding.py
  9. +10
    -10
      test/core/test_batch.py

+ 5
- 1
fastNLP/__init__.py View File

@@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的
__all__ = [ __all__ = [
"Instance", "Instance",
"FieldArray", "FieldArray",
"Batch",

"DataSetIter",
"BatchIter",
"TorchLoaderIter",

"Vocabulary", "Vocabulary",
"DataSet", "DataSet",
"Const", "Const",


+ 1
- 1
fastNLP/core/__init__.py View File

@@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa
介绍core 的子模块的分工,好像必要性不大 介绍core 的子模块的分工,好像必要性不大
""" """
from .batch import Batch
from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC
from .const import Const from .const import Const
from .dataset import DataSet from .dataset import DataSet


+ 90
- 169
fastNLP/core/batch.py View File

@@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。


""" """
__all__ = [ __all__ = [
"Batch"
"BatchIter",
"DataSetIter",
"TorchLoaderIter",
] ]


import atexit import atexit
@@ -15,7 +17,7 @@ import torch.multiprocessing as mp
import torch.utils.data import torch.utils.data
from numbers import Number from numbers import Number


from .sampler import RandomSampler
from .sampler import SequentialSampler
from .dataset import DataSet from .dataset import DataSet


_python_is_exit = False _python_is_exit = False
@@ -28,14 +30,18 @@ def _set_python_is_exit():


atexit.register(_set_python_is_exit) atexit.register(_set_python_is_exit)



class DataSetGetter: class DataSetGetter:
def __init__(self, dataset: DataSet, as_numpy=False): def __init__(self, dataset: DataSet, as_numpy=False):
self.dataset = dataset self.dataset = dataset
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input}
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target}
self.as_numpy = as_numpy self.as_numpy = as_numpy
self.idx_list = list(range(len(dataset)))


def __getitem__(self, idx: int): def __getitem__(self, idx: int):
# mapping idx to sampled idx
idx = self.idx_list[idx]
inputs = {n:f.get(idx) for n, f in self.inputs.items()} inputs = {n:f.get(idx) for n, f in self.inputs.items()}
targets = {n:f.get(idx) for n, f in self.targets.items()} targets = {n:f.get(idx) for n, f in self.targets.items()}
return idx, inputs, targets return idx, inputs, targets
@@ -60,9 +66,9 @@ class DataSetGetter:
if f.padder is None: if f.padder is None:
batch_dict[n] = np.array(vlist) batch_dict[n] = np.array(vlist)
else: else:
data = f.padder(vlist, field_name=n, field_ele_dtype=f.dtype)
data = f.pad(vlist)
if not self.as_numpy: if not self.as_numpy:
data = _to_tensor(data, f.dtype)
data, flag = _to_tensor(data, f.dtype)
batch_dict[n] = data batch_dict[n] = data
return batch_dict return batch_dict


@@ -70,24 +76,40 @@ class DataSetGetter:
pad_batch(batch_x, self.inputs), pad_batch(batch_x, self.inputs),
pad_batch(batch_y, self.targets)) pad_batch(batch_y, self.targets))


def set_idx_list(self, idx_list):
if len(idx_list) != len(self.idx_list):
raise ValueError
self.idx_list = idx_list


class Batch:
def __init__(self, dataset, batch_size, sampler=None, buffer_size=0, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, **kwargs):


dataset_getter = DataSetGetter(dataset, as_numpy)
self.buffer_size = buffer_size
class SamplerAdapter(torch.utils.data.Sampler):
def __init__(self, sampler, dataset):
self.sampler = sampler
self.dataset = dataset

def __iter__(self):
return iter(self.sampler(self.dataset))


class BatchIter:
def __init__(self):
self.dataiter = None
self.num_batches = None
self.cur_batch_indices = None self.cur_batch_indices = None
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
shuffle = isinstance(sampler, RandomSampler)
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset_getter, batch_size=batch_size, shuffle=shuffle,
collate_fn=dataset_getter.collate_fn,
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
self.batch_size = None

def init_iter(self):
pass

@staticmethod
def get_num_batches(num_samples, batch_size, drop_last):
num_batches = num_samples // batch_size
if not drop_last and (num_samples % batch_size > 0):
num_batches += 1
return num_batches


def __iter__(self): def __iter__(self):
self.init_iter()
for indices, batch_x, batch_y in self.dataiter: for indices, batch_x, batch_y in self.dataiter:
self.cur_batch_indices = indices self.cur_batch_indices = indices
yield batch_x, batch_y yield batch_x, batch_y
@@ -98,163 +120,62 @@ class Batch:
def __len__(self): def __len__(self):
return self.num_batches return self.num_batches


@property
def dataset(self):
return self.dataiter.dataset


class Batch1(object):
"""
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch`

Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出,
组成 `x` 和 `y`::

batch = Batch(data_set, batch_size=16, sampler=SequentialSampler())
num_batch = len(batch)
for batch_x, batch_y in batch:
# do stuff ...

:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集
:param int batch_size: 取出的batch大小
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`.
Default: ``None``
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`.
Default: ``False``
:param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch.
Default: ``False``
"""
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False):
self.dataset = dataset

class DataSetIter(BatchIter):
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None):
super().__init__()
assert isinstance(dataset, DataSet)
dataset = DataSetGetter(dataset, as_numpy)
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
self.dataiter = torch.utils.data.DataLoader(
dataset=dataset, batch_size=batch_size, sampler=sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last)
self.batch_size = batch_size self.batch_size = batch_size
if sampler is None:
sampler = RandomSampler()
self.sampler = sampler
self.as_numpy = as_numpy
self.idx_list = None
self.curidx = 0
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0)
self.cur_batch_indices = None
self.prefetch = prefetch
self.lengths = 0
def fetch_one(self):
if self.curidx >= len(self.idx_list):
return None
else:
endidx = min(self.curidx + self.batch_size, len(self.idx_list))
batch_x, batch_y = {}, {}
indices = self.idx_list[self.curidx:endidx]
self.cur_batch_indices = indices
for field_name, field in self.dataset.get_all_fields().items():
if field.is_target or field.is_input:
batch = field.get(indices)
if not self.as_numpy and \
field.dtype is not None and \
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor):
batch = _to_tensor(batch)
if field.is_target:
batch_y[field_name] = batch
if field.is_input:
batch_x[field_name] = batch
self.curidx = endidx
return batch_x, batch_y
def __iter__(self):
"""
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process
:return:
"""
if self.prefetch:
return self._run_batch_iter(self)
def batch_iter():
self.init_iter()
while 1:
res = self.fetch_one()
if res is None:
break
yield res
return batch_iter()
def init_iter(self):
self.idx_list = self.sampler(self.dataset)
self.curidx = 0
self.lengths = self.dataset.get_length()
def __len__(self):
return self.num_batches
def get_batch_indices(self):
"""
取得当前batch在DataSet中所在的index下标序列


:return list(int) indexes: 下标序列
"""
return self.cur_batch_indices
@staticmethod
def _run_fetch(batch, q):
try:
global _python_is_exit
batch.init_iter()
# print('start fetch')
while 1:
res = batch.fetch_one()
# print('fetch one')
while 1:
try:
q.put(res, timeout=3)
break
except Full:
if _python_is_exit:
return
if res is None:
# print('fetch done, waiting processing')
break
# print('fetch exit')
except Exception as e:
q.put(e)
finally:
q.join()
@staticmethod
def _run_batch_iter(batch):
q = mp.JoinableQueue(maxsize=10)
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q))
fetch_p.daemon = True
fetch_p.start()
# print('fork fetch process')
while 1:
try:
res = q.get(timeout=1)
q.task_done()
# print('get fetched')
if res is None:
break
elif isinstance(res, Exception):
raise res
yield res
except Empty as e:
if fetch_p.is_alive():
continue
else:
break
fetch_p.terminate()
fetch_p.join()
# print('iter done')

class TorchLoaderIter(BatchIter):
def __init__(self, dataset):
super().__init__()
assert isinstance(dataset, torch.utils.data.DataLoader)
self.dataiter = dataset
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last)
self.batch_size = dataset.batch_size


class OnlineDataGettter:
# TODO
pass




def _to_tensor(batch):
class OnlineDataIter(BatchIter):
# TODO
def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False,
num_workers=0, pin_memory=False, drop_last=False,
timeout=0, worker_init_fn=None, **kwargs):
super().__init__()


def _to_tensor(batch, field_dtype):
try: try:
if issubclass(batch.dtype.type, np.floating):
batch = torch.as_tensor(batch).float() # 默认使用float32
if field_dtype is not None \
and issubclass(field_dtype, Number) \
and not isinstance(batch, torch.Tensor):
if issubclass(batch.dtype.type, np.floating):
new_batch = torch.as_tensor(batch).float() # 默认使用float32
else:
new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制
return new_batch, True
else: else:
batch = torch.as_tensor(batch) # 复用内存地址,避免复制
return batch, False
except: except:
pass
return batch
return batch, False

+ 6
- 1
fastNLP/core/field.py View File

@@ -176,7 +176,12 @@ class FieldArray:
if self.padder is None or pad is False: if self.padder is None or pad is False:
return np.array(contents) return np.array(contents)
else: else:
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
return self.pad(contents)

def pad(self, contents):
if self.padder is None:
raise RuntimeError
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)


def set_padder(self, padder): def set_padder(self, padder):
""" """


+ 2
- 3
fastNLP/core/predictor.py View File

@@ -6,7 +6,7 @@ from collections import defaultdict


import torch import torch


from . import Batch
from . import DataSetIter
from . import DataSet from . import DataSet
from . import SequentialSampler from . import SequentialSampler
from .utils import _build_args from .utils import _build_args
@@ -44,8 +44,7 @@ class Predictor(object):


self.network.eval() self.network.eval()
batch_output = defaultdict(list) batch_output = defaultdict(list)
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False,
prefetch=False)
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)


if hasattr(self.network, "predict"): if hasattr(self.network, "predict"):
predict_func = self.network.predict predict_func = self.network.predict


+ 11
- 3
fastNLP/core/tester.py View File

@@ -37,7 +37,7 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn


from .batch import Batch
from .batch import BatchIter, DataSetIter
from .dataset import DataSet from .dataset import DataSet
from .metrics import _prepare_metrics from .metrics import _prepare_metrics
from .sampler import SequentialSampler from .sampler import SequentialSampler
@@ -82,7 +82,7 @@ class Tester(object):
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 :param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
""" """
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1):
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1):
super(Tester, self).__init__() super(Tester, self).__init__()
if not isinstance(data, DataSet): if not isinstance(data, DataSet):
@@ -96,6 +96,14 @@ class Tester(object):
self._model = _move_model_to_device(model, device=device) self._model = _move_model_to_device(model, device=device)
self.batch_size = batch_size self.batch_size = batch_size
self.verbose = verbose self.verbose = verbose

if isinstance(data, DataSet):
self.data_iterator = DataSetIter(
dataset=data, batch_size=batch_size, num_workers=num_workers)
elif isinstance(data, BatchIter):
self.data_iterator = data
else:
raise TypeError("data type {} not support".format(type(data)))
# 如果是DataParallel将没有办法使用predict方法 # 如果是DataParallel将没有办法使用predict方法
if isinstance(self._model, nn.DataParallel): if isinstance(self._model, nn.DataParallel):
@@ -124,7 +132,7 @@ class Tester(object):
self._model_device = _get_model_device(self._model) self._model_device = _get_model_device(self._model)
network = self._model network = self._model
self._mode(network, is_test=True) self._mode(network, is_test=True)
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False)
data_iterator = self.data_iterator
eval_results = {} eval_results = {}
try: try:
with torch.no_grad(): with torch.no_grad():


+ 19
- 9
fastNLP/core/trainer.py View File

@@ -312,7 +312,7 @@ try:
except: except:
from .utils import _pseudo_tqdm as tqdm from .utils import _pseudo_tqdm as tqdm


from .batch import Batch
from .batch import DataSetIter, BatchIter
from .callback import CallbackManager, CallbackException from .callback import CallbackManager, CallbackException
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
@@ -394,7 +394,7 @@ class Trainer(object):
""" """
def __init__(self, train_data, model, optimizer=None, loss=None, def __init__(self, train_data, model, optimizer=None, loss=None,
batch_size=32, sampler=None, update_every=1,
batch_size=32, sampler=None, update_every=1, num_workers=0,
n_epochs=10, print_every=5, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
validate_every=-1, save_path=None, validate_every=-1, save_path=None,
@@ -439,9 +439,19 @@ class Trainer(object):
# sampler check # sampler check
if sampler is not None and not isinstance(sampler, Sampler): if sampler is not None and not isinstance(sampler, Sampler):
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))

if isinstance(train_data, DataSet):
self.data_iterator = DataSetIter(
dataset=train_data, batch_size=batch_size, num_workers=num_workers)
elif isinstance(train_data, BatchIter):
self.data_iterator = train_data
else:
raise TypeError("train_data type {} not support".format(type(train_data)))
if check_code_level > -1:
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter):
# TODO 考虑不同的dataset类型怎么check
_check_code(data_iterator=self.data_iterator,
model=model, losser=losser, metrics=metrics, dev_data=dev_data,
metric_key=metric_key, check_level=check_code_level, metric_key=metric_key, check_level=check_code_level,
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码
@@ -493,7 +503,7 @@ class Trainer(object):
self.callback_manager = CallbackManager(env={"trainer": self}, self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks) callbacks=callbacks)
def train(self, load_best_model=True, on_exception='auto'): def train(self, load_best_model=True, on_exception='auto'):
""" """
使用该函数使Trainer开始训练。 使用该函数使Trainer开始训练。
@@ -572,8 +582,7 @@ class Trainer(object):
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar self.pbar = pbar
avg_loss = 0 avg_loss = 0
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
prefetch=self.prefetch)
data_iterator = self.data_iterator
self.batch_per_epoch = data_iterator.num_batches self.batch_per_epoch = data_iterator.num_batches
for epoch in range(1, self.n_epochs + 1): for epoch in range(1, self.n_epochs + 1):
self.epoch = epoch self.epoch = epoch
@@ -786,13 +795,14 @@ def _get_value_info(_dict):
return strs return strs




def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
def _check_code(data_iterator, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
dev_data=None, metric_key=None, dev_data=None, metric_key=None,
check_level=0): check_level=0):
# check get_loss 方法 # check get_loss 方法
model_devcie = model.parameters().__next__().device model_devcie = model.parameters().__next__().device
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler())
batch = data_iterator
dataset = data_iterator.dataset
for batch_count, (batch_x, batch_y) in enumerate(batch): for batch_count, (batch_x, batch_y) in enumerate(batch):
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) _move_dict_value_to_device(batch_x, batch_y, device=model_devcie)
# forward check # forward check


+ 2
- 2
fastNLP/modules/encoder/embedding.py View File

@@ -15,7 +15,7 @@ from ...io.file_utils import cached_path, _get_base_url
from ._bert import _WordBertModel from ._bert import _WordBertModel
from typing import List from typing import List


from ... import DataSet, Batch, SequentialSampler
from ... import DataSet, DataSetIter, SequentialSampler
from ...core.utils import _move_model_to_device, _get_model_device from ...core.utils import _move_model_to_device, _get_model_device




@@ -226,7 +226,7 @@ class ContextualEmbedding(TokenEmbedding):
with torch.no_grad(): with torch.no_grad():
for index, dataset in enumerate(datasets): for index, dataset in enumerate(datasets):
try: try:
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), prefetch=False)
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
words = batch_x['words'].to(device) words = batch_x['words'].to(device)
words_list = words.tolist() words_list = words.tolist()


+ 10
- 10
test/core/test_batch.py View File

@@ -3,7 +3,7 @@ import unittest
import numpy as np import numpy as np
import torch import torch


from fastNLP import Batch
from fastNLP import DataSetIter
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
from fastNLP import SequentialSampler from fastNLP import SequentialSampler
@@ -57,7 +57,7 @@ class TestCase1(unittest.TestCase):
dataset = construct_dataset( dataset = construct_dataset(
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
dataset.set_target() dataset.set_target()
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
batch = DataSetIter(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
cnt = 0 cnt = 0
for _, _ in batch: for _, _ in batch:
@@ -68,7 +68,7 @@ class TestCase1(unittest.TestCase):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.set_input("x") ds.set_input("x")
ds.set_target("y") ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
for x, y in iter: for x, y in iter:
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray))
self.assertEqual(len(x["x"]), 4) self.assertEqual(len(x["x"]), 4)
@@ -81,7 +81,7 @@ class TestCase1(unittest.TestCase):
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
ds.set_input("x") ds.set_input("x")
ds.set_target("y") ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
for x, y in iter: for x, y in iter:
self.assertEqual(x["x"].shape, (4, 4)) self.assertEqual(x["x"].shape, (4, 4))
self.assertEqual(y["y"].shape, (4, 4)) self.assertEqual(y["y"].shape, (4, 4))
@@ -91,7 +91,7 @@ class TestCase1(unittest.TestCase):
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
ds.set_input("x") ds.set_input("x")
ds.set_target("y") ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
for x, y in iter: for x, y in iter:
self.assertEqual(x["x"].shape, (4, 4)) self.assertEqual(x["x"].shape, (4, 4))
self.assertEqual(y["y"].shape, (4, 4)) self.assertEqual(y["y"].shape, (4, 4))
@@ -101,7 +101,7 @@ class TestCase1(unittest.TestCase):
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
ds.set_input("x") ds.set_input("x")
ds.set_target("y") ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter: for x, y in iter:
self.assertTrue(isinstance(x["x"], torch.Tensor)) self.assertTrue(isinstance(x["x"], torch.Tensor))
self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertEqual(tuple(x["x"].shape), (4, 4))
@@ -113,7 +113,7 @@ class TestCase1(unittest.TestCase):
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
ds.set_input("x") ds.set_input("x")
ds.set_target("y") ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter: for x, y in iter:
self.assertTrue(isinstance(x["x"], torch.Tensor)) self.assertTrue(isinstance(x["x"], torch.Tensor))
self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertEqual(tuple(x["x"].shape), (4, 4))
@@ -125,7 +125,7 @@ class TestCase1(unittest.TestCase):
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)])
ds.set_input("x") ds.set_input("x")
ds.set_target("y") ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter: for x, y in iter:
self.assertTrue(isinstance(x["x"], torch.Tensor)) self.assertTrue(isinstance(x["x"], torch.Tensor))
self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertEqual(tuple(x["x"].shape), (4, 4))
@@ -137,7 +137,7 @@ class TestCase1(unittest.TestCase):
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)])
ds.set_input("x") ds.set_input("x")
ds.set_target("y") ds.set_target("y")
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter: for x, y in iter:
print(x, y) print(x, y)
@@ -146,7 +146,7 @@ class TestCase1(unittest.TestCase):
num_samples = 1000 num_samples = 1000
dataset = generate_fake_dataset(num_samples) dataset = generate_fake_dataset(num_samples)
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
pass pass


Loading…
Cancel
Save