| @@ -48,8 +48,6 @@ def simple_sort_bucketing(lengths): | |||||
| """ | """ | ||||
| :param lengths: list of int, the lengths of all examples. | :param lengths: list of int, the lengths of all examples. | ||||
| :param buckets: list of int. The length of the list is the number of buckets. Each integer is the maximum length | |||||
| threshold for each bucket (This is usually None.). | |||||
| :return data: 2-level list | :return data: 2-level list | ||||
| :: | :: | ||||
| @@ -75,6 +73,7 @@ def k_means_1d(x, k, max_iter=100): | |||||
| assignment: numpy array, 1-D, the bucket id assigned to each example. | assignment: numpy array, 1-D, the bucket id assigned to each example. | ||||
| """ | """ | ||||
| sorted_x = sorted(list(set(x))) | sorted_x = sorted(list(set(x))) | ||||
| x = np.array(x) | |||||
| if len(sorted_x) < k: | if len(sorted_x) < k: | ||||
| raise ValueError("too few buckets") | raise ValueError("too few buckets") | ||||
| gap = len(sorted_x) / k | gap = len(sorted_x) / k | ||||
| @@ -119,34 +118,3 @@ def k_means_bucketing(lengths, buckets): | |||||
| bucket_data[bucket_id].append(idx) | bucket_data[bucket_id].append(idx) | ||||
| return bucket_data | return bucket_data | ||||
| class BucketSampler(BaseSampler): | |||||
| """Partition all samples into multiple buckets, each of which contains sentences of approximately the same length. | |||||
| In sampling, first random choose a bucket. Then sample data from it. | |||||
| The number of buckets is decided dynamically by the variance of sentence lengths. | |||||
| """ | |||||
| def __call__(self, data_set, batch_size, num_buckets): | |||||
| return self._process(data_set, batch_size, num_buckets) | |||||
| def _process(self, data_set, batch_size, num_buckets, use_kmeans=False): | |||||
| """ | |||||
| :param data_set: a DataSet object | |||||
| :param batch_size: int | |||||
| :param num_buckets: int, number of buckets for grouping these sequences. | |||||
| :param use_kmeans: bool, whether to use k-means to create buckets. | |||||
| """ | |||||
| buckets = ([None] * num_buckets) | |||||
| if use_kmeans is True: | |||||
| buckets = k_means_bucketing(data_set, buckets) | |||||
| else: | |||||
| buckets = simple_sort_bucketing(data_set) | |||||
| index_list = [] | |||||
| for _ in range(len(data_set) // batch_size): | |||||
| chosen_bucket = buckets[np.random.randint(0, len(buckets))] | |||||
| np.random.shuffle(chosen_bucket) | |||||
| index_list += [idx for idx in chosen_bucket[:batch_size]] | |||||
| return index_list | |||||
| @@ -1,10 +1,10 @@ | |||||
| import unittest | import unittest | ||||
| from fastNLP.core.field import CharTextField | |||||
| from fastNLP.core.field import CharTextField, LabelField, SeqLabelField | |||||
| class TestField(unittest.TestCase): | class TestField(unittest.TestCase): | ||||
| def test_case(self): | |||||
| def test_char_field(self): | |||||
| text = "PhD applicants must submit a Research Plan and a resume " \ | text = "PhD applicants must submit a Research Plan and a resume " \ | ||||
| "specify your class ranking written in English and a list of research" \ | "specify your class ranking written in English and a list of research" \ | ||||
| " publications if any".split() | " publications if any".split() | ||||
| @@ -21,3 +21,22 @@ class TestField(unittest.TestCase): | |||||
| self.assertEqual(field.contents(), text) | self.assertEqual(field.contents(), text) | ||||
| tensor = field.to_tensor(50) | tensor = field.to_tensor(50) | ||||
| self.assertEqual(tuple(tensor.shape), (50, max_word_len)) | self.assertEqual(tuple(tensor.shape), (50, max_word_len)) | ||||
| def test_label_field(self): | |||||
| label = LabelField("A", is_target=True) | |||||
| self.assertEqual(label.get_length(), 1) | |||||
| self.assertEqual(label.index({"A": 10}), 10) | |||||
| label = LabelField(30, is_target=True) | |||||
| self.assertEqual(label.get_length(), 1) | |||||
| tensor = label.to_tensor(0) | |||||
| self.assertEqual(tensor.shape, ()) | |||||
| self.assertEqual(int(tensor), 30) | |||||
| def test_seq_label_field(self): | |||||
| seq = ["a", "b", "c", "d", "a", "c", "a", "b"] | |||||
| field = SeqLabelField(seq) | |||||
| vocab = {"a": 10, "b": 20, "c": 30, "d": 40} | |||||
| self.assertEqual(field.index(vocab), [vocab[x] for x in seq]) | |||||
| tensor = field.to_tensor(10) | |||||
| self.assertEqual(tuple(tensor.shape), (10,)) | |||||
| @@ -1,6 +1,7 @@ | |||||
| import torch | import torch | ||||
| from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler | |||||
| from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | |||||
| k_means_1d, k_means_bucketing, simple_sort_bucketing | |||||
| def test_convert_to_torch_tensor(): | def test_convert_to_torch_tensor(): | ||||
| @@ -26,5 +27,18 @@ def test_random_sampler(): | |||||
| assert d in data | assert d in data | ||||
| if __name__ == "__main__": | |||||
| test_sequential_sampler() | |||||
| def test_k_means(): | |||||
| centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||||
| centroids, assign = list(centroids), list(assign) | |||||
| assert len(centroids) == 2 | |||||
| assert len(assign) == 10 | |||||
| def test_k_means_bucketing(): | |||||
| res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||||
| assert len(res) == 2 | |||||
| def test_simple_sort_bucketing(): | |||||
| _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||||
| assert len(_) == 10 | |||||