diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 79dd56c0..74f67125 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -48,8 +48,6 @@ def simple_sort_bucketing(lengths): """ :param lengths: list of int, the lengths of all examples. - :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length - threshold for each bucket (This is usually None.). :return data: 2-level list :: @@ -75,6 +73,7 @@ def k_means_1d(x, k, max_iter=100): assignment: numpy array, 1-D, the bucket id assigned to each example. """ sorted_x = sorted(list(set(x))) + x = np.array(x) if len(sorted_x) < k: raise ValueError("too few buckets") gap = len(sorted_x) / k @@ -119,34 +118,3 @@ def k_means_bucketing(lengths, buckets): bucket_data[bucket_id].append(idx) return bucket_data - -class BucketSampler(BaseSampler): - """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. - In sampling, first random choose a bucket. Then sample data from it. - The number of buckets is decided dynamically by the variance of sentence lengths. - - """ - - def __call__(self, data_set, batch_size, num_buckets): - return self._process(data_set, batch_size, num_buckets) - - def _process(self, data_set, batch_size, num_buckets, use_kmeans=False): - """ - - :param data_set: a DataSet object - :param batch_size: int - :param num_buckets: int, number of buckets for grouping these sequences. - :param use_kmeans: bool, whether to use k-means to create buckets. - - """ - buckets = ([None] * num_buckets) - if use_kmeans is True: - buckets = k_means_bucketing(data_set, buckets) - else: - buckets = simple_sort_bucketing(data_set) - index_list = [] - for _ in range(len(data_set) // batch_size): - chosen_bucket = buckets[np.random.randint(0, len(buckets))] - np.random.shuffle(chosen_bucket) - index_list += [idx for idx in chosen_bucket[:batch_size]] - return index_list diff --git a/test/core/test_field.py b/test/core/test_field.py index ccc36f49..7f1dc8c1 100644 --- a/test/core/test_field.py +++ b/test/core/test_field.py @@ -1,10 +1,10 @@ import unittest -from fastNLP.core.field import CharTextField +from fastNLP.core.field import CharTextField, LabelField, SeqLabelField class TestField(unittest.TestCase): - def test_case(self): + def test_char_field(self): text = "PhD applicants must submit a Research Plan and a resume " \ "specify your class ranking written in English and a list of research" \ " publications if any".split() @@ -21,3 +21,22 @@ class TestField(unittest.TestCase): self.assertEqual(field.contents(), text) tensor = field.to_tensor(50) self.assertEqual(tuple(tensor.shape), (50, max_word_len)) + + def test_label_field(self): + label = LabelField("A", is_target=True) + self.assertEqual(label.get_length(), 1) + self.assertEqual(label.index({"A": 10}), 10) + + label = LabelField(30, is_target=True) + self.assertEqual(label.get_length(), 1) + tensor = label.to_tensor(0) + self.assertEqual(tensor.shape, ()) + self.assertEqual(int(tensor), 30) + + def test_seq_label_field(self): + seq = ["a", "b", "c", "d", "a", "c", "a", "b"] + field = SeqLabelField(seq) + vocab = {"a": 10, "b": 20, "c": 30, "d": 40} + self.assertEqual(field.index(vocab), [vocab[x] for x in seq]) + tensor = field.to_tensor(10) + self.assertEqual(tuple(tensor.shape), (10,)) diff --git a/test/core/test_sampler.py b/test/core/test_sampler.py index 179d20d7..cf72fe18 100644 --- a/test/core/test_sampler.py +++ b/test/core/test_sampler.py @@ -1,6 +1,7 @@ import torch -from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler +from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ + k_means_1d, k_means_bucketing, simple_sort_bucketing def test_convert_to_torch_tensor(): @@ -26,5 +27,18 @@ def test_random_sampler(): assert d in data -if __name__ == "__main__": - test_sequential_sampler() +def test_k_means(): + centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) + centroids, assign = list(centroids), list(assign) + assert len(centroids) == 2 + assert len(assign) == 10 + + +def test_k_means_bucketing(): + res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) + assert len(res) == 2 + + +def test_simple_sort_bucketing(): + _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) + assert len(_) == 10