diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index d8c61047..5e0be4c3 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -5,7 +5,8 @@ class Batch(object): """Batch is an iterable object which iterates over mini-batches. :: - for batch_x, batch_y in Batch(data_set): + for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): + """ @@ -15,6 +16,8 @@ class Batch(object): :param dataset: a DataSet object :param batch_size: int, the size of the batch :param sampler: a Sampler object + :param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. + """ self.dataset = dataset self.batch_size = batch_size @@ -30,17 +33,6 @@ class Batch(object): return self def __next__(self): - """ - - :return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) - E.g. - :: - {'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) - - batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) - All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. - - """ if self.curidx >= len(self.idx_list): raise StopIteration else: diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 550ef7d9..668bb93e 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -117,22 +117,20 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) - def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): + def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): """ - :param name: + :param str name: :param fields: - :param padding_val: - :param need_tensor: - :param is_target: + :param int padding_val: + :param bool is_input: + :param bool is_target: :return: """ if len(self.field_arrays) != 0: assert len(self) == len(fields) - self.field_arrays[name] = FieldArray(name, fields, - padding_val=padding_val, - need_tensor=need_tensor, - is_target=is_target) + self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, + is_input=is_input) def delete_field(self, name): self.field_arrays.pop(name) diff --git a/fastNLP/core/fieldarray.py b/fastNLP/core/fieldarray.py index 473738b0..58e6c09d 100644 --- a/fastNLP/core/fieldarray.py +++ b/fastNLP/core/fieldarray.py @@ -2,7 +2,19 @@ import numpy as np class FieldArray(object): + """FieldArray is the collection of Instances of the same Field. + It is the basic element of DataSet class. + + """ def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): + """ + + :param str name: the name of the FieldArray + :param list content: a list of int, float, or other objects. + :param int padding_val: the integer for padding. Default: 0. + :param bool is_target: If True, this FieldArray is used to compute loss. + :param bool is_input: If True, this FieldArray is used to the model input. + """ self.name = name self.content = content self.padding_val = padding_val @@ -24,23 +36,28 @@ class FieldArray(object): assert isinstance(name, int) self.content[name] = val - def get(self, idxes): - if isinstance(idxes, int): - return self.content[idxes] + def get(self, indices): + """Fetch instances based on indices. + + :param indices: an int, or a list of int. + :return: + """ + if isinstance(indices, int): + return self.content[indices] assert self.is_input is True or self.is_target is True - batch_size = len(idxes) + batch_size = len(indices) # TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 if isinstance(self.content[0], int) or isinstance(self.content[0], float): if self.dtype is None: self.dtype = np.int64 if isinstance(self.content[0], int) else np.double - array = np.array([self.content[i] for i in idxes], dtype=self.dtype) + array = np.array([self.content[i] for i in indices], dtype=self.dtype) else: if self.dtype is None: self.dtype = np.int64 - max_len = max([len(self.content[i]) for i in idxes]) + max_len = max([len(self.content[i]) for i in indices]) array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) - for i, idx in enumerate(idxes): + for i, idx in enumerate(indices): array[i][:len(self.content[idx])] = self.content[idx] return array diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py index d6029ab1..26140e59 100644 --- a/fastNLP/core/instance.py +++ b/fastNLP/core/instance.py @@ -1,16 +1,27 @@ class Instance(object): - """An instance which consists of Fields is an example in the DataSet. + """An Instance is an example of data. It is the collection of Fields. + + :: + Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) """ def __init__(self, **fields): + """ + + :param fields: a dict of (field name: field) + """ self.fields = fields def add_field(self, field_name, field): + """Add a new field to the instance. + + :param field_name: str, the name of the field. + :param field: + """ self.fields[field_name] = field - return self def __getitem__(self, name): if name in self.fields: @@ -21,17 +32,5 @@ class Instance(object): def __setitem__(self, name, field): return self.add_field(name, field) - def __getattr__(self, item): - if hasattr(self, 'fields') and item in self.fields: - return self.fields[item] - else: - raise AttributeError('{} does not exist.'.format(item)) - - def __setattr__(self, key, value): - if hasattr(self, 'fields'): - self.__setitem__(key, value) - else: - super().__setattr__(key, value) - def __repr__(self): return self.fields.__repr__() diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index a9370be5..7b0ab614 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,5 +1,5 @@ -from copy import deepcopy from collections import Counter +from copy import deepcopy DEFAULT_PADDING_LABEL = '' # dict index = 0 DEFAULT_UNKNOWN_LABEL = '' # dict index = 1 @@ -20,6 +20,7 @@ def check_build_vocab(func): if self.word2idx is None: self.build_vocab() return func(self, *args, **kwargs) + return _wrapper @@ -34,6 +35,7 @@ class Vocabulary(object): vocab["word"] vocab.to_word(5) """ + def __init__(self, need_default=True, max_size=None, min_freq=None): """ :param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. @@ -54,24 +56,36 @@ class Vocabulary(object): self.idx2word = None def update(self, word_lst): - """add word or list of words into Vocabulary + """Add a list of words into the vocabulary. - :param word: a list of string or a single string + :param list word_lst: a list of strings """ self.word_count.update(word_lst) def add(self, word): + """Add a single word into the vocabulary. + + :param str word: a word or token. + """ self.word_count[word] += 1 def add_word(self, word): + """Add a single word into the vocabulary. + + :param str word: a word or token. + """ self.add(word) def add_word_lst(self, word_lst): - self.update(word_lst) + """Add a list of words into the vocabulary. + :param list word_lst: a list of strings + """ + self.update(word_lst) def build_vocab(self): - """build 'word to index' dict, and filter the word using `max_size` and `min_freq` + """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. + """ if self.has_default: self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) @@ -85,11 +99,12 @@ class Vocabulary(object): if self.min_freq is not None: words = filter(lambda kv: kv[1] >= self.min_freq, words) start_idx = len(self.word2idx) - self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)}) + self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.build_reverse_vocab() def build_reverse_vocab(self): - """build 'index to word' dict based on 'word to index' dict + """Build 'index to word' dict based on 'word to index' dict. + """ self.idx2word = {i: w for w, i in self.word2idx.items()} @@ -97,6 +112,15 @@ class Vocabulary(object): def __len__(self): return len(self.word2idx) + @check_build_vocab + def __contains__(self, item): + """Check if a word in vocabulary. + + :param item: the word + :return: True or False + """ + return item in self.word2idx + def has_word(self, w): return self.__contains__(w) @@ -114,8 +138,8 @@ class Vocabulary(object): raise ValueError("word {} not in vocabulary".format(w)) def to_index(self, w): - """ like to_index(w) function, turn a word to the index - if w is not in Vocabulary, return the unknown label + """ Turn a word to an index. + If w is not in Vocabulary, return the unknown label. :param str w: """ @@ -144,12 +168,14 @@ class Vocabulary(object): def to_word(self, idx): """given a word's index, return the word itself - :param int idx: + :param int idx: the index + :return str word: the indexed word """ return self.idx2word[idx] def __getstate__(self): - """use to prepare data for pickle + """Use to prepare data for pickle. + """ state = self.__dict__.copy() # no need to pickle idx2word as it can be constructed from word2idx @@ -157,16 +183,9 @@ class Vocabulary(object): return state def __setstate__(self, state): - """use to restore state from pickle + """Use to restore state from pickle. + """ self.__dict__.update(state) self.build_reverse_vocab() - @check_build_vocab - def __contains__(self, item): - """Check if a word in vocabulary. - - :param item: the word - :return: True or False - """ - return item in self.word2idx diff --git a/test/core/test_batch.py b/test/core/test_batch.py index b6d0460d..c820af57 100644 --- a/test/core/test_batch.py +++ b/test/core/test_batch.py @@ -1,17 +1,18 @@ import unittest from fastNLP.core.batch import Batch -from fastNLP.core.dataset import DataSet -from fastNLP.core.instance import Instance +from fastNLP.core.dataset import construct_dataset from fastNLP.core.sampler import SequentialSampler class TestCase1(unittest.TestCase): - def test(self): - dataset = DataSet([Instance(x=["I", "am", "here"])] * 40) + def test_simple(self): + dataset = construct_dataset( + [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) + dataset.set_target() batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) - for batch_x, batch_y in batch: - print(batch_x, batch_y) - - # TODO: weird due to change in dataset.py + cnt = 0 + for _, _ in batch: + cnt += 1 + self.assertEqual(cnt, 10) diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index c6af4c43..3082db25 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -1,20 +1,20 @@ import unittest +from fastNLP.core.dataset import DataSet + class TestDataSet(unittest.TestCase): - labeled_data_list = [ - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - [["a", "b", "e", "d"], ["1", "2", "3", "4"]], - ] - unlabeled_data_list = [ - ["a", "b", "e", "d"], - ["a", "b", "e", "d"], - ["a", "b", "e", "d"] - ] - word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} - label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} def test_case_1(self): - # TODO: - pass + ds = DataSet() + ds.add_field(name="xx", fields=["a", "b", "e", "d"]) + + self.assertTrue("xx" in ds.field_arrays) + self.assertEqual(len(ds.field_arrays["xx"]), 4) + self.assertEqual(ds.get_length(), 4) + self.assertEqual(ds.get_fields(), ds.field_arrays) + + try: + ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"]) + except BaseException as e: + self.assertTrue(isinstance(e, AssertionError)) diff --git a/test/core/test_field.py b/test/core/test_field.py deleted file mode 100644 index 7f1dc8c1..00000000 --- a/test/core/test_field.py +++ /dev/null @@ -1,42 +0,0 @@ -import unittest - -from fastNLP.core.field import CharTextField, LabelField, SeqLabelField - - -class TestField(unittest.TestCase): - def test_char_field(self): - text = "PhD applicants must submit a Research Plan and a resume " \ - "specify your class ranking written in English and a list of research" \ - " publications if any".split() - max_word_len = max([len(w) for w in text]) - field = CharTextField(text, max_word_len, is_target=False) - all_char = set() - for word in text: - all_char.update([ch for ch in word]) - char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)} - - self.assertEqual(field.index(char_vocab), - [[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text]) - self.assertEqual(field.get_length(), len(text)) - self.assertEqual(field.contents(), text) - tensor = field.to_tensor(50) - self.assertEqual(tuple(tensor.shape), (50, max_word_len)) - - def test_label_field(self): - label = LabelField("A", is_target=True) - self.assertEqual(label.get_length(), 1) - self.assertEqual(label.index({"A": 10}), 10) - - label = LabelField(30, is_target=True) - self.assertEqual(label.get_length(), 1) - tensor = label.to_tensor(0) - self.assertEqual(tensor.shape, ()) - self.assertEqual(int(tensor), 30) - - def test_seq_label_field(self): - seq = ["a", "b", "c", "d", "a", "c", "a", "b"] - field = SeqLabelField(seq) - vocab = {"a": 10, "b": 20, "c": 30, "d": 40} - self.assertEqual(field.index(vocab), [vocab[x] for x in seq]) - tensor = field.to_tensor(10) - self.assertEqual(tuple(tensor.shape), (10,)) diff --git a/test/core/test_fieldarray.py b/test/core/test_fieldarray.py new file mode 100644 index 00000000..b5fd60ac --- /dev/null +++ b/test/core/test_fieldarray.py @@ -0,0 +1,6 @@ +import unittest + + +class TestFieldArray(unittest.TestCase): + def test(self): + pass diff --git a/test/core/test_instance.py b/test/core/test_instance.py new file mode 100644 index 00000000..abe6b7f7 --- /dev/null +++ b/test/core/test_instance.py @@ -0,0 +1,29 @@ +import unittest + +from fastNLP.core.instance import Instance + + +class TestCase(unittest.TestCase): + + def test_init(self): + fields = {"x": [1, 2, 3], "y": [4, 5, 6]} + ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) + self.assertTrue(isinstance(ins.fields, dict)) + self.assertEqual(ins.fields, fields) + + ins = Instance(**fields) + self.assertEqual(ins.fields, fields) + + def test_add_field(self): + fields = {"x": [1, 2, 3], "y": [4, 5, 6]} + ins = Instance(**fields) + ins.add_field("z", [1, 1, 1]) + fields.update({"z": [1, 1, 1]}) + self.assertEqual(ins.fields, fields) + + def test_get_item(self): + fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} + ins = Instance(**fields) + self.assertEqual(ins["x"], [1, 2, 3]) + self.assertEqual(ins["y"], [4, 5, 6]) + self.assertEqual(ins["z"], [1, 1, 1]) diff --git a/test/core/test_sampler.py b/test/core/test_sampler.py index cf72fe18..5da0e6db 100644 --- a/test/core/test_sampler.py +++ b/test/core/test_sampler.py @@ -1,44 +1,42 @@ +import unittest + import torch from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ k_means_1d, k_means_bucketing, simple_sort_bucketing -def test_convert_to_torch_tensor(): - data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] - ans = convert_to_torch_tensor(data, False) - assert isinstance(ans, torch.Tensor) - assert tuple(ans.shape) == (3, 5) - - -def test_sequential_sampler(): - sampler = SequentialSampler() - data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] - for idx, i in enumerate(sampler(data)): - assert idx == i - - -def test_random_sampler(): - sampler = RandomSampler() - data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] - ans = [data[i] for i in sampler(data)] - assert len(ans) == len(data) - for d in ans: - assert d in data - - -def test_k_means(): - centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) - centroids, assign = list(centroids), list(assign) - assert len(centroids) == 2 - assert len(assign) == 10 - - -def test_k_means_bucketing(): - res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) - assert len(res) == 2 - - -def test_simple_sort_bucketing(): - _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) - assert len(_) == 10 +class TestSampler(unittest.TestCase): + def test_convert_to_torch_tensor(self): + data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] + ans = convert_to_torch_tensor(data, False) + assert isinstance(ans, torch.Tensor) + assert tuple(ans.shape) == (3, 5) + + def test_sequential_sampler(self): + sampler = SequentialSampler() + data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] + for idx, i in enumerate(sampler(data)): + assert idx == i + + def test_random_sampler(self): + sampler = RandomSampler() + data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] + ans = [data[i] for i in sampler(data)] + assert len(ans) == len(data) + for d in ans: + assert d in data + + def test_k_means(self): + centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) + centroids, assign = list(centroids), list(assign) + assert len(centroids) == 2 + assert len(assign) == 10 + + def test_k_means_bucketing(self): + res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) + assert len(res) == 2 + + def test_simple_sort_bucketing(self): + _ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) + assert len(_) == 10 diff --git a/test/core/test_vocab.py b/test/core/test_vocab.py deleted file mode 100644 index 89b0691a..00000000 --- a/test/core/test_vocab.py +++ /dev/null @@ -1,31 +0,0 @@ -import unittest -from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX - -class TestVocabulary(unittest.TestCase): - def test_vocab(self): - import _pickle as pickle - import os - vocab = Vocabulary() - filename = 'vocab' - vocab.update(filename) - vocab.update([filename, ['a'], [['b']], ['c']]) - idx = vocab[filename] - before_pic = (vocab.to_word(idx), vocab[filename]) - - with open(filename, 'wb') as f: - pickle.dump(vocab, f) - with open(filename, 'rb') as f: - vocab = pickle.load(f) - os.remove(filename) - - vocab.build_reverse_vocab() - after_pic = (vocab.to_word(idx), vocab[filename]) - TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} - TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) - TRUE_IDXDICT = {0: '', 1: '', 2: '', 3: '', 4: '', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} - self.assertEqual(before_pic, after_pic) - self.assertDictEqual(TRUE_DICT, vocab.word2idx) - self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file diff --git a/test/core/test_vocabulary.py b/test/core/test_vocabulary.py new file mode 100644 index 00000000..e140b1aa --- /dev/null +++ b/test/core/test_vocabulary.py @@ -0,0 +1,61 @@ +import unittest +from collections import Counter + +from fastNLP.core.vocabulary import Vocabulary + +text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", + "works", "well", "in", "most", "cases", "scales", "well"] +counter = Counter(text) + + +class TestAdd(unittest.TestCase): + def test_add(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + for word in text: + vocab.add(word) + self.assertEqual(vocab.word_count, counter) + + def test_add_word(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + for word in text: + vocab.add_word(word) + self.assertEqual(vocab.word_count, counter) + + def test_add_word_lst(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.add_word_lst(text) + self.assertEqual(vocab.word_count, counter) + + def test_update(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.update(text) + self.assertEqual(vocab.word_count, counter) + + +class TestIndexing(unittest.TestCase): + def test_len(self): + vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) + vocab.update(text) + self.assertEqual(len(vocab), len(counter)) + + def test_contains(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.update(text) + self.assertTrue(text[-1] in vocab) + self.assertFalse("~!@#" in vocab) + self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) + self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) + + def test_index(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.update(text) + res = [vocab[w] for w in set(text)] + self.assertEqual(len(res), len(set(res))) + + res = [vocab.to_index(w) for w in set(text)] + self.assertEqual(len(res), len(set(res))) + + def test_to_word(self): + vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) + vocab.update(text) + self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])