Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
36ed894a77
6 changed files with 439 additions and 160 deletions
  1. +0
    -1
      fastNLP/core/samplers/reproducible_batch_sampler.py
  2. +122
    -1
      tests/core/controllers/test_trainer_event_trigger.py
  3. +2
    -2
      tests/core/controllers/utils/test_utils.py
  4. +122
    -152
      tests/core/samplers/test_reproducible_batch_sampler.py
  5. +141
    -0
      tests/core/samplers/test_reproducible_batch_sampler_torch.py
  6. +52
    -4
      tests/helpers/datasets/normal_data.py

+ 0
- 1
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -8,7 +8,6 @@ import math
from copy import deepcopy from copy import deepcopy
from typing import Dict, Union, List from typing import Dict, Union, List
from itertools import chain from itertools import chain
import os


import numpy as np import numpy as np




+ 122
- 1
tests/core/controllers/test_trainer_event_trigger.py View File

@@ -70,7 +70,7 @@ def model_and_optimizers():
@pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]]) @pytest.mark.parametrize("callbacks", [[RecordTrainerEventTriggerCallback()]])
@pytest.mark.torch @pytest.mark.torch
@magic_argv_env_context @magic_argv_env_context
def test_trainer_event_trigger(
def test_trainer_event_trigger_1(
model_and_optimizers: TrainerParameters, model_and_optimizers: TrainerParameters,
driver, driver,
device, device,
@@ -104,5 +104,126 @@ def test_trainer_event_trigger(
assert member.value in output[0] assert member.value in output[0]




@pytest.mark.parametrize("driver,device", [("torch", "cpu"),("torch", 6), ("torch", [6, 7])]) # , ("torch", 6), ("torch", [6, 7])
@pytest.mark.torch
@magic_argv_env_context
def test_trainer_event_trigger_2(
model_and_optimizers: TrainerParameters,
driver,
device,
n_epochs=2,
):

@Trainer.on(Events.on_after_trainer_initialized)
def on_after_trainer_initialized(trainer, driver):
print("on_after_trainer_initialized")

@Trainer.on(Events.on_sanity_check_begin)
def on_sanity_check_begin(trainer):
print("on_sanity_check_begin")

@Trainer.on(Events.on_sanity_check_end)
def on_sanity_check_end(trainer, sanity_check_res):
print("on_sanity_check_end")

@Trainer.on(Events.on_train_begin)
def on_train_begin(trainer):
print("on_train_begin")

@Trainer.on(Events.on_train_end)
def on_train_end(trainer):
print("on_train_end")

@Trainer.on(Events.on_train_epoch_begin)
def on_train_epoch_begin(trainer):
if trainer.cur_epoch_idx >= 1:
# 触发 on_exception;
raise Exception
print("on_train_epoch_begin")

@Trainer.on(Events.on_train_epoch_end)
def on_train_epoch_end(trainer):
print("on_train_epoch_end")

@Trainer.on(Events.on_fetch_data_begin)
def on_fetch_data_begin(trainer):
print("on_fetch_data_begin")

@Trainer.on(Events.on_fetch_data_end)
def on_fetch_data_end(trainer):
print("on_fetch_data_end")

@Trainer.on(Events.on_train_batch_begin)
def on_train_batch_begin(trainer, batch, indices=None):
print("on_train_batch_begin")

@Trainer.on(Events.on_train_batch_end)
def on_train_batch_end(trainer):
print("on_train_batch_end")

@Trainer.on(Events.on_exception)
def on_exception(trainer, exception):
print("on_exception")

@Trainer.on(Events.on_before_backward)
def on_before_backward(trainer, outputs):
print("on_before_backward")

@Trainer.on(Events.on_after_backward)
def on_after_backward(trainer):
print("on_after_backward")

@Trainer.on(Events.on_before_optimizers_step)
def on_before_optimizers_step(trainer, optimizers):
print("on_before_optimizers_step")

@Trainer.on(Events.on_after_optimizers_step)
def on_after_optimizers_step(trainer, optimizers):
print("on_after_optimizers_step")

@Trainer.on(Events.on_before_zero_grad)
def on_before_zero_grad(trainer, optimizers):
print("on_before_zero_grad")

@Trainer.on(Events.on_after_zero_grad)
def on_after_zero_grad(trainer, optimizers):
print("on_after_zero_grad")

@Trainer.on(Events.on_evaluate_begin)
def on_evaluate_begin(trainer):
print("on_evaluate_begin")

@Trainer.on(Events.on_evaluate_end)
def on_evaluate_end(trainer, results):
print("on_evaluate_end")

with pytest.raises(Exception):
with Capturing() as output:
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,

n_epochs=n_epochs,
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()

for name, member in Events.__members__.items():
assert member.value in output[0]










+ 2
- 2
tests/core/controllers/utils/test_utils.py View File

@@ -1,7 +1,7 @@
from functools import reduce from functools import reduce


from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改; from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
from tests.helpers.datasets.normal_data import NormalIterator
from tests.helpers.datasets.normal_data import NormalSampler




class Test_WrapDataLoader: class Test_WrapDataLoader:
@@ -9,7 +9,7 @@ class Test_WrapDataLoader:
def test_normal_generator(self): def test_normal_generator(self):
all_sanity_batches = [4, 20, 100] all_sanity_batches = [4, 20, 100]
for sanity_batches in all_sanity_batches: for sanity_batches in all_sanity_batches:
data = NormalIterator(num_of_data=1000)
data = NormalSampler(num_of_data=1000)
wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches) wrapper = _TruncatedDataLoader(dataloader=data, num_batches=sanity_batches)
dataloader = iter(wrapper) dataloader = iter(wrapper)
mark = 0 mark = 0


+ 122
- 152
tests/core/samplers/test_reproducible_batch_sampler.py View File

@@ -1,161 +1,131 @@
from array import array

import numpy as np import numpy as np
import pytest import pytest
from itertools import chain from itertools import chain
from copy import deepcopy from copy import deepcopy
from array import array


from tests.helpers.datasets.normal_data import NormalSampler, NormalBatchSampler
from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler from fastNLP.core.samplers import ReproduceBatchSampler, BucketedBatchSampler, RandomBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset

#
# class TestReproducibleBatchSampler:
# # TODO 拆分测试,在这里只测试一个东西
# def test_torch_dataloader_1(self):
# import torch
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# next(iter_dataloader)
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
# assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
# "sampler_type": "ReproduceBatchSampler"}
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 改变 batch_size;
# after_batch_size = 3
# dataloader = DataLoader(dataset, batch_size=after_batch_size)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# real_res = []
# supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
# forward_steps = 2
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# real_res.append(next(iter_dataloader))
#
# for i in range(forward_steps):
# assert all(real_res[i] == supposed_res[i])
#
# # 断点重训的第二轮是否是一个完整的 dataloader;
# # 先把断点重训所在的那一个 epoch 跑完;
# begin_idx = 27
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# # 开始新的一轮;
# begin_idx = 0
# iter_dataloader = iter(dataloader)
# while True:
# try:
# data = next(iter_dataloader)
# _batch_size = len(data)
# assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
# begin_idx += _batch_size
# except StopIteration:
# break
#
# def test_torch_dataloader_2(self):
# # 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
# from torch.utils.data import DataLoader
# # no shuffle
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 将一轮的所有数据保存下来,看是否恢复的是正确的;
# all_supposed_data = []
# forward_steps = 3
# iter_dataloader = iter(dataloader)
# for _ in range(forward_steps):
# all_supposed_data.extend(next(iter_dataloader).tolist())
#
# # 1. 保存状态
# _get_re_batchsampler = dataloader.batch_sampler
# assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
# state = _get_re_batchsampler.state_dict()
#
# # 2. 断点重训,重新生成一个 dataloader;
# # 不改变 batch_size;
# dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
# re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
# re_batchsampler.load_state_dict(state)
# dataloader = replace_batch_sampler(dataloader, re_batchsampler)
#
# # 先把这一轮的数据过完;
# pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
# while True:
# try:
# all_supposed_data.extend(next(iter_dataloader).tolist())
# except StopIteration:
# break
# assert all_supposed_data == list(pre_index_list)
#
# # 重新开启新的一轮;
# for _ in range(3):
# iter_dataloader = iter(dataloader)
# res = []
# while True:
# try:
# res.append(next(iter_dataloader))
# except StopIteration:
# break
#
# def test_3(self):
# import torch
# from torch.utils.data import DataLoader
# before_batch_size = 7
# dataset = TorchNormalDataset(num_of_data=100)
# # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
# dataloader = DataLoader(dataset, batch_size=before_batch_size)
#
# for idx, data in enumerate(dataloader):
# if idx > 3:
# break
#
# iterator = iter(dataloader)
# for each in iterator:
# pass


class TestReproducibleBatchSampler:
def test_1(self):
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;

reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)

forward_steps = 3
iterator = iter(reproduce_batch_sampler)
i = 0
while i < forward_steps:
next(iterator)
i += 1

# 保存状态;
state = reproduce_batch_sampler.state_dict()

assert state == {"index_list": array("I", list(range(100))),
"num_consumed_samples": forward_steps * 4,
"sampler_type": "ReproduceBatchSampler"}

# 重新生成一个 batchsampler 然后加载状态;
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=4, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

real_res = []
supposed_res = (list(range(12, 16)), list(range(16, 20)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert supposed_res[i] == real_res[i]

# 改变 batchsize;
sampler = NormalSampler(num_of_data=100) # 这里是否是 batchsampler 不影响;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, batch_size=7, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

real_res = []
supposed_res = (list(range(12, 19)), list(range(19, 26)))
forward_steps = 2
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert supposed_res[i] == real_res[i]

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 26
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert data == list(range(begin_idx, begin_idx + _batch_size))
begin_idx += _batch_size
except StopIteration:
break

def test_2(self):

# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
before_batch_size = 7
sampler = NormalSampler(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(reproduce_batch_sampler)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader))

# 1. 保存状态
state = reproduce_batch_sampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
sampler = NormalSampler(num_of_data=100, shuffle=True)
reproduce_batch_sampler = ReproduceBatchSampler(sampler, before_batch_size, drop_last=False)
reproduce_batch_sampler.load_state_dict(state)

# 先把这一轮的数据过完;
pre_index_list = reproduce_batch_sampler.state_dict()["index_list"]
iter_dataloader = iter(reproduce_batch_sampler)
while True:
try:
all_supposed_data.extend(next(iter_dataloader))
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(reproduce_batch_sampler)
res = []
while True:
try:
res.extend(next(iter_dataloader))
except StopIteration:
break
assert res != all_supposed_data




class DatasetWithVaryLength: class DatasetWithVaryLength:


+ 141
- 0
tests/core/samplers/test_reproducible_batch_sampler_torch.py View File

@@ -0,0 +1,141 @@
from array import array
import torch
from torch.utils.data import DataLoader

import pytest

from fastNLP.core.samplers import ReproduceBatchSampler
from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
from tests.helpers.datasets.torch_data import TorchNormalDataset


@pytest.mark.torch
class TestReproducibleBatchSamplerTorch:
def test_torch_dataloader_1(self):
# no shuffle
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
next(iter_dataloader)

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()
assert state == {"index_list": array("I", list(range(100))), "num_consumed_samples": forward_steps*before_batch_size,
"sampler_type": "ReproduceBatchSampler"}

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 28))), torch.tensor(list(range(28, 35))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 改变 batch_size;
after_batch_size = 3
dataloader = DataLoader(dataset, batch_size=after_batch_size)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

real_res = []
supposed_res = (torch.tensor(list(range(21, 24))), torch.tensor(list(range(24, 27))))
forward_steps = 2
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
real_res.append(next(iter_dataloader))

for i in range(forward_steps):
assert all(real_res[i] == supposed_res[i])

# 断点重训的第二轮是否是一个完整的 dataloader;
# 先把断点重训所在的那一个 epoch 跑完;
begin_idx = 27
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

# 开始新的一轮;
begin_idx = 0
iter_dataloader = iter(dataloader)
while True:
try:
data = next(iter_dataloader)
_batch_size = len(data)
assert all(data == torch.tensor(list(range(begin_idx, begin_idx + _batch_size))))
begin_idx += _batch_size
except StopIteration:
break

def test_torch_dataloader_2(self):
# 测试新的一轮的 index list 是重新生成的,而不是沿用上一轮的;
from torch.utils.data import DataLoader
before_batch_size = 7
dataset = TorchNormalDataset(num_of_data=100)
# 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

# 将一轮的所有数据保存下来,看是否恢复的是正确的;
all_supposed_data = []
forward_steps = 3
iter_dataloader = iter(dataloader)
for _ in range(forward_steps):
all_supposed_data.extend(next(iter_dataloader).tolist())

# 1. 保存状态
_get_re_batchsampler = dataloader.batch_sampler
assert isinstance(_get_re_batchsampler, ReproduceBatchSampler)
state = _get_re_batchsampler.state_dict()

# 2. 断点重训,重新生成一个 dataloader;
# 不改变 batch_size;
dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True)
re_batchsampler = ReproduceBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False)
re_batchsampler.load_state_dict(state)
dataloader = replace_batch_sampler(dataloader, re_batchsampler)

iter_dataloader = iter(dataloader)
# 先把这一轮的数据过完;
pre_index_list = dataloader.batch_sampler.state_dict()["index_list"]
while True:
try:
all_supposed_data.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert all_supposed_data == list(pre_index_list)

# 重新开启新的一轮;
for _ in range(3):
iter_dataloader = iter(dataloader)
res = []
while True:
try:
res.extend(next(iter_dataloader).tolist())
except StopIteration:
break
assert res != all_supposed_data


+ 52
- 4
tests/helpers/datasets/normal_data.py View File

@@ -1,13 +1,25 @@
import numpy as np import numpy as np
import random




class NormalIterator:
def __init__(self, num_of_data=1000):
class NormalSampler:
def __init__(self, num_of_data=1000, shuffle=False):
self._num_of_data = num_of_data self._num_of_data = num_of_data
self._data = list(range(num_of_data)) self._data = list(range(num_of_data))
if shuffle:
random.shuffle(self._data)
self.shuffle = shuffle
self._index = 0 self._index = 0
self.need_reinitialize = False


def __iter__(self): def __iter__(self):
if self.need_reinitialize:
self._index = 0
if self.shuffle:
random.shuffle(self._data)
else:
self.need_reinitialize = True

return self return self


def __next__(self): def __next__(self):
@@ -15,12 +27,45 @@ class NormalIterator:
raise StopIteration raise StopIteration
_data = self._data[self._index] _data = self._data[self._index]
self._index += 1 self._index += 1
return self._data
return _data


def __len__(self): def __len__(self):
return self._num_of_data return self._num_of_data




class NormalBatchSampler:
def __init__(self, sampler, batch_size: int, drop_last: bool) -> None:
# Since collections.abc.Iterable does not check for `__getitem__`, which
# is one way for an object to be an iterable, we don't do an `isinstance`
# check here.
if not isinstance(batch_size, int) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integer value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

def __len__(self) -> int:
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size


class RandomDataset: class RandomDataset:
def __init__(self, num_data=10): def __init__(self, num_data=10):
self.data = np.random.rand(num_data) self.data = np.random.rand(num_data)
@@ -29,4 +74,7 @@ class RandomDataset:
return len(self.data) return len(self.data)


def __getitem__(self, item): def __getitem__(self, item):
return self.data[item]
return self.data[item]




Loading…
Cancel
Save