@@ -12,7 +12,11 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||
__all__ = [ | |||
"Instance", | |||
"FieldArray", | |||
"Batch", | |||
"DataSetIter", | |||
"BatchIter", | |||
"TorchLoaderIter", | |||
"Vocabulary", | |||
"DataSet", | |||
"Const", | |||
@@ -14,7 +14,7 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
介绍core 的子模块的分工,好像必要性不大 | |||
""" | |||
from .batch import Batch | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||
from .const import Const | |||
from .dataset import DataSet | |||
@@ -3,7 +3,9 @@ batch 模块实现了 fastNLP 所需的 Batch 类。 | |||
""" | |||
__all__ = [ | |||
"Batch" | |||
"BatchIter", | |||
"DataSetIter", | |||
"TorchLoaderIter", | |||
] | |||
import atexit | |||
@@ -15,7 +17,7 @@ import torch.multiprocessing as mp | |||
import torch.utils.data | |||
from numbers import Number | |||
from .sampler import RandomSampler | |||
from .sampler import SequentialSampler | |||
from .dataset import DataSet | |||
_python_is_exit = False | |||
@@ -28,14 +30,18 @@ def _set_python_is_exit(): | |||
atexit.register(_set_python_is_exit) | |||
class DataSetGetter: | |||
def __init__(self, dataset: DataSet, as_numpy=False): | |||
self.dataset = dataset | |||
self.inputs = {n: f for n, f in dataset.get_all_fields().items() if f.is_input} | |||
self.targets = {n: f for n, f in dataset.get_all_fields().items() if f.is_target} | |||
self.as_numpy = as_numpy | |||
self.idx_list = list(range(len(dataset))) | |||
def __getitem__(self, idx: int): | |||
# mapping idx to sampled idx | |||
idx = self.idx_list[idx] | |||
inputs = {n:f.get(idx) for n, f in self.inputs.items()} | |||
targets = {n:f.get(idx) for n, f in self.targets.items()} | |||
return idx, inputs, targets | |||
@@ -60,9 +66,9 @@ class DataSetGetter: | |||
if f.padder is None: | |||
batch_dict[n] = np.array(vlist) | |||
else: | |||
data = f.padder(vlist, field_name=n, field_ele_dtype=f.dtype) | |||
data = f.pad(vlist) | |||
if not self.as_numpy: | |||
data = _to_tensor(data, f.dtype) | |||
data, flag = _to_tensor(data, f.dtype) | |||
batch_dict[n] = data | |||
return batch_dict | |||
@@ -70,24 +76,40 @@ class DataSetGetter: | |||
pad_batch(batch_x, self.inputs), | |||
pad_batch(batch_y, self.targets)) | |||
def set_idx_list(self, idx_list): | |||
if len(idx_list) != len(self.idx_list): | |||
raise ValueError | |||
self.idx_list = idx_list | |||
class Batch: | |||
def __init__(self, dataset, batch_size, sampler=None, buffer_size=0, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, **kwargs): | |||
dataset_getter = DataSetGetter(dataset, as_numpy) | |||
self.buffer_size = buffer_size | |||
class SamplerAdapter(torch.utils.data.Sampler): | |||
def __init__(self, sampler, dataset): | |||
self.sampler = sampler | |||
self.dataset = dataset | |||
def __iter__(self): | |||
return iter(self.sampler(self.dataset)) | |||
class BatchIter: | |||
def __init__(self): | |||
self.dataiter = None | |||
self.num_batches = None | |||
self.cur_batch_indices = None | |||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||
shuffle = isinstance(sampler, RandomSampler) | |||
self.dataiter = torch.utils.data.DataLoader( | |||
dataset=dataset_getter, batch_size=batch_size, shuffle=shuffle, | |||
collate_fn=dataset_getter.collate_fn, | |||
num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
self.batch_size = None | |||
def init_iter(self): | |||
pass | |||
@staticmethod | |||
def get_num_batches(num_samples, batch_size, drop_last): | |||
num_batches = num_samples // batch_size | |||
if not drop_last and (num_samples % batch_size > 0): | |||
num_batches += 1 | |||
return num_batches | |||
def __iter__(self): | |||
self.init_iter() | |||
for indices, batch_x, batch_y in self.dataiter: | |||
self.cur_batch_indices = indices | |||
yield batch_x, batch_y | |||
@@ -98,163 +120,62 @@ class Batch: | |||
def __len__(self): | |||
return self.num_batches | |||
@property | |||
def dataset(self): | |||
return self.dataiter.dataset | |||
class Batch1(object): | |||
""" | |||
别名::class:`fastNLP.Batch` :class:`fastNLP.core.batch.Batch` | |||
Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||
组成 `x` 和 `y`:: | |||
batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) | |||
num_batch = len(batch) | |||
for batch_x, batch_y in batch: | |||
# do stuff ... | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.RandomSampler`. | |||
Default: ``None`` | |||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
Default: ``False`` | |||
:param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. | |||
Default: ``False`` | |||
""" | |||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | |||
self.dataset = dataset | |||
class DataSetIter(BatchIter): | |||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None): | |||
super().__init__() | |||
assert isinstance(dataset, DataSet) | |||
dataset = DataSetGetter(dataset, as_numpy) | |||
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | |||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
self.dataiter = torch.utils.data.DataLoader( | |||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
pin_memory=pin_memory, drop_last=drop_last, | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) | |||
self.batch_size = batch_size | |||
if sampler is None: | |||
sampler = RandomSampler() | |||
self.sampler = sampler | |||
self.as_numpy = as_numpy | |||
self.idx_list = None | |||
self.curidx = 0 | |||
self.num_batches = len(dataset) // batch_size + int(len(dataset) % batch_size != 0) | |||
self.cur_batch_indices = None | |||
self.prefetch = prefetch | |||
self.lengths = 0 | |||
def fetch_one(self): | |||
if self.curidx >= len(self.idx_list): | |||
return None | |||
else: | |||
endidx = min(self.curidx + self.batch_size, len(self.idx_list)) | |||
batch_x, batch_y = {}, {} | |||
indices = self.idx_list[self.curidx:endidx] | |||
self.cur_batch_indices = indices | |||
for field_name, field in self.dataset.get_all_fields().items(): | |||
if field.is_target or field.is_input: | |||
batch = field.get(indices) | |||
if not self.as_numpy and \ | |||
field.dtype is not None and \ | |||
issubclass(field.dtype, Number) and not isinstance(batch, torch.Tensor): | |||
batch = _to_tensor(batch) | |||
if field.is_target: | |||
batch_y[field_name] = batch | |||
if field.is_input: | |||
batch_x[field_name] = batch | |||
self.curidx = endidx | |||
return batch_x, batch_y | |||
def __iter__(self): | |||
""" | |||
Iterate on dataset, fetch batch data. Fetch process don't block the iterate process | |||
:return: | |||
""" | |||
if self.prefetch: | |||
return self._run_batch_iter(self) | |||
def batch_iter(): | |||
self.init_iter() | |||
while 1: | |||
res = self.fetch_one() | |||
if res is None: | |||
break | |||
yield res | |||
return batch_iter() | |||
def init_iter(self): | |||
self.idx_list = self.sampler(self.dataset) | |||
self.curidx = 0 | |||
self.lengths = self.dataset.get_length() | |||
def __len__(self): | |||
return self.num_batches | |||
def get_batch_indices(self): | |||
""" | |||
取得当前batch在DataSet中所在的index下标序列 | |||
:return list(int) indexes: 下标序列 | |||
""" | |||
return self.cur_batch_indices | |||
@staticmethod | |||
def _run_fetch(batch, q): | |||
try: | |||
global _python_is_exit | |||
batch.init_iter() | |||
# print('start fetch') | |||
while 1: | |||
res = batch.fetch_one() | |||
# print('fetch one') | |||
while 1: | |||
try: | |||
q.put(res, timeout=3) | |||
break | |||
except Full: | |||
if _python_is_exit: | |||
return | |||
if res is None: | |||
# print('fetch done, waiting processing') | |||
break | |||
# print('fetch exit') | |||
except Exception as e: | |||
q.put(e) | |||
finally: | |||
q.join() | |||
@staticmethod | |||
def _run_batch_iter(batch): | |||
q = mp.JoinableQueue(maxsize=10) | |||
fetch_p = mp.Process(target=Batch._run_fetch, args=(batch, q)) | |||
fetch_p.daemon = True | |||
fetch_p.start() | |||
# print('fork fetch process') | |||
while 1: | |||
try: | |||
res = q.get(timeout=1) | |||
q.task_done() | |||
# print('get fetched') | |||
if res is None: | |||
break | |||
elif isinstance(res, Exception): | |||
raise res | |||
yield res | |||
except Empty as e: | |||
if fetch_p.is_alive(): | |||
continue | |||
else: | |||
break | |||
fetch_p.terminate() | |||
fetch_p.join() | |||
# print('iter done') | |||
class TorchLoaderIter(BatchIter): | |||
def __init__(self, dataset): | |||
super().__init__() | |||
assert isinstance(dataset, torch.utils.data.DataLoader) | |||
self.dataiter = dataset | |||
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) | |||
self.batch_size = dataset.batch_size | |||
class OnlineDataGettter: | |||
# TODO | |||
pass | |||
def _to_tensor(batch): | |||
class OnlineDataIter(BatchIter): | |||
# TODO | |||
def __init__(self, dataset, batch_size=1, buffer_size=10000, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None, **kwargs): | |||
super().__init__() | |||
def _to_tensor(batch, field_dtype): | |||
try: | |||
if issubclass(batch.dtype.type, np.floating): | |||
batch = torch.as_tensor(batch).float() # 默认使用float32 | |||
if field_dtype is not None \ | |||
and issubclass(field_dtype, Number) \ | |||
and not isinstance(batch, torch.Tensor): | |||
if issubclass(batch.dtype.type, np.floating): | |||
new_batch = torch.as_tensor(batch).float() # 默认使用float32 | |||
else: | |||
new_batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||
return new_batch, True | |||
else: | |||
batch = torch.as_tensor(batch) # 复用内存地址,避免复制 | |||
return batch, False | |||
except: | |||
pass | |||
return batch | |||
return batch, False |
@@ -176,7 +176,12 @@ class FieldArray: | |||
if self.padder is None or pad is False: | |||
return np.array(contents) | |||
else: | |||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
return self.pad(contents) | |||
def pad(self, contents): | |||
if self.padder is None: | |||
raise RuntimeError | |||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
def set_padder(self, padder): | |||
""" | |||
@@ -6,7 +6,7 @@ from collections import defaultdict | |||
import torch | |||
from . import Batch | |||
from . import DataSetIter | |||
from . import DataSet | |||
from . import SequentialSampler | |||
from .utils import _build_args | |||
@@ -44,8 +44,7 @@ class Predictor(object): | |||
self.network.eval() | |||
batch_output = defaultdict(list) | |||
data_iterator = Batch(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False, | |||
prefetch=False) | |||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
if hasattr(self.network, "predict"): | |||
predict_func = self.network.predict | |||
@@ -37,7 +37,7 @@ import warnings | |||
import torch | |||
import torch.nn as nn | |||
from .batch import Batch | |||
from .batch import BatchIter, DataSetIter | |||
from .dataset import DataSet | |||
from .metrics import _prepare_metrics | |||
from .sampler import SequentialSampler | |||
@@ -82,7 +82,7 @@ class Tester(object): | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
""" | |||
def __init__(self, data, model, metrics, batch_size=16, device=None, verbose=1): | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): | |||
super(Tester, self).__init__() | |||
if not isinstance(data, DataSet): | |||
@@ -96,6 +96,14 @@ class Tester(object): | |||
self._model = _move_model_to_device(model, device=device) | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
if isinstance(data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
dataset=data, batch_size=batch_size, num_workers=num_workers) | |||
elif isinstance(data, BatchIter): | |||
self.data_iterator = data | |||
else: | |||
raise TypeError("data type {} not support".format(type(data))) | |||
# 如果是DataParallel将没有办法使用predict方法 | |||
if isinstance(self._model, nn.DataParallel): | |||
@@ -124,7 +132,7 @@ class Tester(object): | |||
self._model_device = _get_model_device(self._model) | |||
network = self._model | |||
self._mode(network, is_test=True) | |||
data_iterator = Batch(self.data, self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
data_iterator = self.data_iterator | |||
eval_results = {} | |||
try: | |||
with torch.no_grad(): | |||
@@ -312,7 +312,7 @@ try: | |||
except: | |||
from .utils import _pseudo_tqdm as tqdm | |||
from .batch import Batch | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import CallbackManager, CallbackException | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
@@ -394,7 +394,7 @@ class Trainer(object): | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
batch_size=32, sampler=None, update_every=1, | |||
batch_size=32, sampler=None, update_every=1, num_workers=0, | |||
n_epochs=10, print_every=5, | |||
dev_data=None, metrics=None, metric_key=None, | |||
validate_every=-1, save_path=None, | |||
@@ -439,9 +439,19 @@ class Trainer(object): | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, Sampler): | |||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||
if isinstance(train_data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
dataset=train_data, batch_size=batch_size, num_workers=num_workers) | |||
elif isinstance(train_data, BatchIter): | |||
self.data_iterator = train_data | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||
if check_code_level > -1: | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
# TODO 考虑不同的dataset类型怎么check | |||
_check_code(data_iterator=self.data_iterator, | |||
model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
@@ -493,7 +503,7 @@ class Trainer(object): | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
def train(self, load_best_model=True, on_exception='auto'): | |||
""" | |||
使用该函数使Trainer开始训练。 | |||
@@ -572,8 +582,7 @@ class Trainer(object): | |||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
self.pbar = pbar | |||
avg_loss = 0 | |||
data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, | |||
prefetch=self.prefetch) | |||
data_iterator = self.data_iterator | |||
self.batch_per_epoch = data_iterator.num_batches | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.epoch = epoch | |||
@@ -786,13 +795,14 @@ def _get_value_info(_dict): | |||
return strs | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
def _check_code(data_iterator, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, | |||
check_level=0): | |||
# check get_loss 方法 | |||
model_devcie = model.parameters().__next__().device | |||
batch = Batch(dataset=dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
batch = data_iterator | |||
dataset = data_iterator.dataset | |||
for batch_count, (batch_x, batch_y) in enumerate(batch): | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
# forward check | |||
@@ -15,7 +15,7 @@ from ...io.file_utils import cached_path, _get_base_url | |||
from ._bert import _WordBertModel | |||
from typing import List | |||
from ... import DataSet, Batch, SequentialSampler | |||
from ... import DataSet, DataSetIter, SequentialSampler | |||
from ...core.utils import _move_model_to_device, _get_model_device | |||
@@ -226,7 +226,7 @@ class ContextualEmbedding(TokenEmbedding): | |||
with torch.no_grad(): | |||
for index, dataset in enumerate(datasets): | |||
try: | |||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), prefetch=False) | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_x, batch_y in batch: | |||
words = batch_x['words'].to(device) | |||
words_list = words.tolist() | |||
@@ -3,7 +3,7 @@ import unittest | |||
import numpy as np | |||
import torch | |||
from fastNLP import Batch | |||
from fastNLP import DataSetIter | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import SequentialSampler | |||
@@ -57,7 +57,7 @@ class TestCase1(unittest.TestCase): | |||
dataset = construct_dataset( | |||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
dataset.set_target() | |||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
batch = DataSetIter(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
cnt = 0 | |||
for _, _ in batch: | |||
@@ -68,7 +68,7 @@ class TestCase1(unittest.TestCase): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray)) | |||
self.assertEqual(len(x["x"]), 4) | |||
@@ -81,7 +81,7 @@ class TestCase1(unittest.TestCase): | |||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
for x, y in iter: | |||
self.assertEqual(x["x"].shape, (4, 4)) | |||
self.assertEqual(y["y"].shape, (4, 4)) | |||
@@ -91,7 +91,7 @@ class TestCase1(unittest.TestCase): | |||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
for x, y in iter: | |||
self.assertEqual(x["x"].shape, (4, 4)) | |||
self.assertEqual(y["y"].shape, (4, 4)) | |||
@@ -101,7 +101,7 @@ class TestCase1(unittest.TestCase): | |||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
@@ -113,7 +113,7 @@ class TestCase1(unittest.TestCase): | |||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
@@ -125,7 +125,7 @@ class TestCase1(unittest.TestCase): | |||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
self.assertTrue(isinstance(x["x"], torch.Tensor)) | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
@@ -137,7 +137,7 @@ class TestCase1(unittest.TestCase): | |||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||
ds.set_input("x") | |||
ds.set_target("y") | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
iter = DataSetIter(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
print(x, y) | |||
@@ -146,7 +146,7 @@ class TestCase1(unittest.TestCase): | |||
num_samples = 1000 | |||
dataset = generate_fake_dataset(num_samples) | |||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_x, batch_y in batch: | |||
pass | |||