|
- import unittest
- import random
- from fastNLP.core.samplers import SequentialSampler, RandomSampler, BucketSampler
- from fastNLP.core.dataset import DataSet
- from array import array
- import torch
-
- from fastNLP.core.samplers.sampler import ReproduceBatchSampler
- from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler
- from tests.helpers.datasets.torch_data import TorchNormalDataset
-
-
-
-
-
-
-
-
-
- class SamplerTest(unittest.TestCase):
-
- def test_sequentialsampler(self):
- ds = DataSet({'x': [1, 2, 3, 4] * 10})
- sqspl = SequentialSampler(ds)
- for idx, inst in enumerate(sqspl):
- self.assertEqual(idx, inst)
-
- def test_randomsampler(self):
- ds = DataSet({'x': [1, 2, 3, 4] * 10})
- rdspl = RandomSampler(ds)
- ans = [ds[i] for i in rdspl]
- self.assertEqual(len(ans), len(ds))
-
- def test_bucketsampler(self):
- data_set = DataSet({"x": [[0] * random.randint(1, 10)] * 10, "y": [[5, 6]] * 10})
- sampler = BucketSampler(data_set, num_buckets=3, batch_size=16, seq_len_field_name="seq_len")
-
|