* remove and fix other unit tests * add more code commentstags/v0.2.0
@@ -5,7 +5,8 @@ class Batch(object): | |||
"""Batch is an iterable object which iterates over mini-batches. | |||
:: | |||
for batch_x, batch_y in Batch(data_set): | |||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||
""" | |||
@@ -15,6 +16,8 @@ class Batch(object): | |||
:param dataset: a DataSet object | |||
:param batch_size: int, the size of the batch | |||
:param sampler: a Sampler object | |||
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors. | |||
""" | |||
self.dataset = dataset | |||
self.batch_size = batch_size | |||
@@ -30,17 +33,6 @@ class Batch(object): | |||
return self | |||
def __next__(self): | |||
""" | |||
:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||
E.g. | |||
:: | |||
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]}) | |||
batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length]) | |||
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True. | |||
""" | |||
if self.curidx >= len(self.idx_list): | |||
raise StopIteration | |||
else: | |||
@@ -117,22 +117,20 @@ class DataSet(object): | |||
assert name in self.field_arrays | |||
self.field_arrays[name].append(field) | |||
def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False): | |||
def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False): | |||
""" | |||
:param name: | |||
:param str name: | |||
:param fields: | |||
:param padding_val: | |||
:param need_tensor: | |||
:param is_target: | |||
:param int padding_val: | |||
:param bool is_input: | |||
:param bool is_target: | |||
:return: | |||
""" | |||
if len(self.field_arrays) != 0: | |||
assert len(self) == len(fields) | |||
self.field_arrays[name] = FieldArray(name, fields, | |||
padding_val=padding_val, | |||
need_tensor=need_tensor, | |||
is_target=is_target) | |||
self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target, | |||
is_input=is_input) | |||
def delete_field(self, name): | |||
self.field_arrays.pop(name) | |||
@@ -2,7 +2,19 @@ import numpy as np | |||
class FieldArray(object): | |||
"""FieldArray is the collection of Instances of the same Field. | |||
It is the basic element of DataSet class. | |||
""" | |||
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False): | |||
""" | |||
:param str name: the name of the FieldArray | |||
:param list content: a list of int, float, or other objects. | |||
:param int padding_val: the integer for padding. Default: 0. | |||
:param bool is_target: If True, this FieldArray is used to compute loss. | |||
:param bool is_input: If True, this FieldArray is used to the model input. | |||
""" | |||
self.name = name | |||
self.content = content | |||
self.padding_val = padding_val | |||
@@ -24,23 +36,28 @@ class FieldArray(object): | |||
assert isinstance(name, int) | |||
self.content[name] = val | |||
def get(self, idxes): | |||
if isinstance(idxes, int): | |||
return self.content[idxes] | |||
def get(self, indices): | |||
"""Fetch instances based on indices. | |||
:param indices: an int, or a list of int. | |||
:return: | |||
""" | |||
if isinstance(indices, int): | |||
return self.content[indices] | |||
assert self.is_input is True or self.is_target is True | |||
batch_size = len(idxes) | |||
batch_size = len(indices) | |||
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下 | |||
if isinstance(self.content[0], int) or isinstance(self.content[0], float): | |||
if self.dtype is None: | |||
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double | |||
array = np.array([self.content[i] for i in idxes], dtype=self.dtype) | |||
array = np.array([self.content[i] for i in indices], dtype=self.dtype) | |||
else: | |||
if self.dtype is None: | |||
self.dtype = np.int64 | |||
max_len = max([len(self.content[i]) for i in idxes]) | |||
max_len = max([len(self.content[i]) for i in indices]) | |||
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype) | |||
for i, idx in enumerate(idxes): | |||
for i, idx in enumerate(indices): | |||
array[i][:len(self.content[idx])] = self.content[idx] | |||
return array | |||
@@ -1,16 +1,27 @@ | |||
class Instance(object): | |||
"""An instance which consists of Fields is an example in the DataSet. | |||
"""An Instance is an example of data. It is the collection of Fields. | |||
:: | |||
Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) | |||
""" | |||
def __init__(self, **fields): | |||
""" | |||
:param fields: a dict of (field name: field) | |||
""" | |||
self.fields = fields | |||
def add_field(self, field_name, field): | |||
"""Add a new field to the instance. | |||
:param field_name: str, the name of the field. | |||
:param field: | |||
""" | |||
self.fields[field_name] = field | |||
return self | |||
def __getitem__(self, name): | |||
if name in self.fields: | |||
@@ -21,17 +32,5 @@ class Instance(object): | |||
def __setitem__(self, name, field): | |||
return self.add_field(name, field) | |||
def __getattr__(self, item): | |||
if hasattr(self, 'fields') and item in self.fields: | |||
return self.fields[item] | |||
else: | |||
raise AttributeError('{} does not exist.'.format(item)) | |||
def __setattr__(self, key, value): | |||
if hasattr(self, 'fields'): | |||
self.__setitem__(key, value) | |||
else: | |||
super().__setattr__(key, value) | |||
def __repr__(self): | |||
return self.fields.__repr__() |
@@ -1,5 +1,5 @@ | |||
from copy import deepcopy | |||
from collections import Counter | |||
from copy import deepcopy | |||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
@@ -20,6 +20,7 @@ def check_build_vocab(func): | |||
if self.word2idx is None: | |||
self.build_vocab() | |||
return func(self, *args, **kwargs) | |||
return _wrapper | |||
@@ -34,6 +35,7 @@ class Vocabulary(object): | |||
vocab["word"] | |||
vocab.to_word(5) | |||
""" | |||
def __init__(self, need_default=True, max_size=None, min_freq=None): | |||
""" | |||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||
@@ -54,24 +56,36 @@ class Vocabulary(object): | |||
self.idx2word = None | |||
def update(self, word_lst): | |||
"""add word or list of words into Vocabulary | |||
"""Add a list of words into the vocabulary. | |||
:param word: a list of string or a single string | |||
:param list word_lst: a list of strings | |||
""" | |||
self.word_count.update(word_lst) | |||
def add(self, word): | |||
"""Add a single word into the vocabulary. | |||
:param str word: a word or token. | |||
""" | |||
self.word_count[word] += 1 | |||
def add_word(self, word): | |||
"""Add a single word into the vocabulary. | |||
:param str word: a word or token. | |||
""" | |||
self.add(word) | |||
def add_word_lst(self, word_lst): | |||
self.update(word_lst) | |||
"""Add a list of words into the vocabulary. | |||
:param list word_lst: a list of strings | |||
""" | |||
self.update(word_lst) | |||
def build_vocab(self): | |||
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq` | |||
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | |||
""" | |||
if self.has_default: | |||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||
@@ -85,11 +99,12 @@ class Vocabulary(object): | |||
if self.min_freq is not None: | |||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
start_idx = len(self.word2idx) | |||
self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)}) | |||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
def build_reverse_vocab(self): | |||
"""build 'index to word' dict based on 'word to index' dict | |||
"""Build 'index to word' dict based on 'word to index' dict. | |||
""" | |||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
@@ -97,6 +112,15 @@ class Vocabulary(object): | |||
def __len__(self): | |||
return len(self.word2idx) | |||
@check_build_vocab | |||
def __contains__(self, item): | |||
"""Check if a word in vocabulary. | |||
:param item: the word | |||
:return: True or False | |||
""" | |||
return item in self.word2idx | |||
def has_word(self, w): | |||
return self.__contains__(w) | |||
@@ -114,8 +138,8 @@ class Vocabulary(object): | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
def to_index(self, w): | |||
""" like to_index(w) function, turn a word to the index | |||
if w is not in Vocabulary, return the unknown label | |||
""" Turn a word to an index. | |||
If w is not in Vocabulary, return the unknown label. | |||
:param str w: | |||
""" | |||
@@ -144,12 +168,14 @@ class Vocabulary(object): | |||
def to_word(self, idx): | |||
"""given a word's index, return the word itself | |||
:param int idx: | |||
:param int idx: the index | |||
:return str word: the indexed word | |||
""" | |||
return self.idx2word[idx] | |||
def __getstate__(self): | |||
"""use to prepare data for pickle | |||
"""Use to prepare data for pickle. | |||
""" | |||
state = self.__dict__.copy() | |||
# no need to pickle idx2word as it can be constructed from word2idx | |||
@@ -157,16 +183,9 @@ class Vocabulary(object): | |||
return state | |||
def __setstate__(self, state): | |||
"""use to restore state from pickle | |||
"""Use to restore state from pickle. | |||
""" | |||
self.__dict__.update(state) | |||
self.build_reverse_vocab() | |||
@check_build_vocab | |||
def __contains__(self, item): | |||
"""Check if a word in vocabulary. | |||
:param item: the word | |||
:return: True or False | |||
""" | |||
return item in self.word2idx |
@@ -1,17 +1,18 @@ | |||
import unittest | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.dataset import construct_dataset | |||
from fastNLP.core.sampler import SequentialSampler | |||
class TestCase1(unittest.TestCase): | |||
def test(self): | |||
dataset = DataSet([Instance(x=["I", "am", "here"])] * 40) | |||
def test_simple(self): | |||
dataset = construct_dataset( | |||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
dataset.set_target() | |||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False) | |||
for batch_x, batch_y in batch: | |||
print(batch_x, batch_y) | |||
# TODO: weird due to change in dataset.py | |||
cnt = 0 | |||
for _, _ in batch: | |||
cnt += 1 | |||
self.assertEqual(cnt, 10) |
@@ -1,20 +1,20 @@ | |||
import unittest | |||
from fastNLP.core.dataset import DataSet | |||
class TestDataSet(unittest.TestCase): | |||
labeled_data_list = [ | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
[["a", "b", "e", "d"], ["1", "2", "3", "4"]], | |||
] | |||
unlabeled_data_list = [ | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"], | |||
["a", "b", "e", "d"] | |||
] | |||
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3} | |||
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4} | |||
def test_case_1(self): | |||
# TODO: | |||
pass | |||
ds = DataSet() | |||
ds.add_field(name="xx", fields=["a", "b", "e", "d"]) | |||
self.assertTrue("xx" in ds.field_arrays) | |||
self.assertEqual(len(ds.field_arrays["xx"]), 4) | |||
self.assertEqual(ds.get_length(), 4) | |||
self.assertEqual(ds.get_fields(), ds.field_arrays) | |||
try: | |||
ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"]) | |||
except BaseException as e: | |||
self.assertTrue(isinstance(e, AssertionError)) |
@@ -1,42 +0,0 @@ | |||
import unittest | |||
from fastNLP.core.field import CharTextField, LabelField, SeqLabelField | |||
class TestField(unittest.TestCase): | |||
def test_char_field(self): | |||
text = "PhD applicants must submit a Research Plan and a resume " \ | |||
"specify your class ranking written in English and a list of research" \ | |||
" publications if any".split() | |||
max_word_len = max([len(w) for w in text]) | |||
field = CharTextField(text, max_word_len, is_target=False) | |||
all_char = set() | |||
for word in text: | |||
all_char.update([ch for ch in word]) | |||
char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)} | |||
self.assertEqual(field.index(char_vocab), | |||
[[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text]) | |||
self.assertEqual(field.get_length(), len(text)) | |||
self.assertEqual(field.contents(), text) | |||
tensor = field.to_tensor(50) | |||
self.assertEqual(tuple(tensor.shape), (50, max_word_len)) | |||
def test_label_field(self): | |||
label = LabelField("A", is_target=True) | |||
self.assertEqual(label.get_length(), 1) | |||
self.assertEqual(label.index({"A": 10}), 10) | |||
label = LabelField(30, is_target=True) | |||
self.assertEqual(label.get_length(), 1) | |||
tensor = label.to_tensor(0) | |||
self.assertEqual(tensor.shape, ()) | |||
self.assertEqual(int(tensor), 30) | |||
def test_seq_label_field(self): | |||
seq = ["a", "b", "c", "d", "a", "c", "a", "b"] | |||
field = SeqLabelField(seq) | |||
vocab = {"a": 10, "b": 20, "c": 30, "d": 40} | |||
self.assertEqual(field.index(vocab), [vocab[x] for x in seq]) | |||
tensor = field.to_tensor(10) | |||
self.assertEqual(tuple(tensor.shape), (10,)) |
@@ -0,0 +1,6 @@ | |||
import unittest | |||
class TestFieldArray(unittest.TestCase): | |||
def test(self): | |||
pass |
@@ -0,0 +1,29 @@ | |||
import unittest | |||
from fastNLP.core.instance import Instance | |||
class TestCase(unittest.TestCase): | |||
def test_init(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||
self.assertTrue(isinstance(ins.fields, dict)) | |||
self.assertEqual(ins.fields, fields) | |||
ins = Instance(**fields) | |||
self.assertEqual(ins.fields, fields) | |||
def test_add_field(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
ins = Instance(**fields) | |||
ins.add_field("z", [1, 1, 1]) | |||
fields.update({"z": [1, 1, 1]}) | |||
self.assertEqual(ins.fields, fields) | |||
def test_get_item(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
ins = Instance(**fields) | |||
self.assertEqual(ins["x"], [1, 2, 3]) | |||
self.assertEqual(ins["y"], [4, 5, 6]) | |||
self.assertEqual(ins["z"], [1, 1, 1]) |
@@ -1,44 +1,42 @@ | |||
import unittest | |||
import torch | |||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | |||
k_means_1d, k_means_bucketing, simple_sort_bucketing | |||
def test_convert_to_torch_tensor(): | |||
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||
ans = convert_to_torch_tensor(data, False) | |||
assert isinstance(ans, torch.Tensor) | |||
assert tuple(ans.shape) == (3, 5) | |||
def test_sequential_sampler(): | |||
sampler = SequentialSampler() | |||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
for idx, i in enumerate(sampler(data)): | |||
assert idx == i | |||
def test_random_sampler(): | |||
sampler = RandomSampler() | |||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
ans = [data[i] for i in sampler(data)] | |||
assert len(ans) == len(data) | |||
for d in ans: | |||
assert d in data | |||
def test_k_means(): | |||
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||
centroids, assign = list(centroids), list(assign) | |||
assert len(centroids) == 2 | |||
assert len(assign) == 10 | |||
def test_k_means_bucketing(): | |||
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||
assert len(res) == 2 | |||
def test_simple_sort_bucketing(): | |||
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||
assert len(_) == 10 | |||
class TestSampler(unittest.TestCase): | |||
def test_convert_to_torch_tensor(self): | |||
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||
ans = convert_to_torch_tensor(data, False) | |||
assert isinstance(ans, torch.Tensor) | |||
assert tuple(ans.shape) == (3, 5) | |||
def test_sequential_sampler(self): | |||
sampler = SequentialSampler() | |||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
for idx, i in enumerate(sampler(data)): | |||
assert idx == i | |||
def test_random_sampler(self): | |||
sampler = RandomSampler() | |||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | |||
ans = [data[i] for i in sampler(data)] | |||
assert len(ans) == len(data) | |||
for d in ans: | |||
assert d in data | |||
def test_k_means(self): | |||
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5) | |||
centroids, assign = list(centroids), list(assign) | |||
assert len(centroids) == 2 | |||
assert len(assign) == 10 | |||
def test_k_means_bucketing(self): | |||
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None]) | |||
assert len(res) == 2 | |||
def test_simple_sort_bucketing(self): | |||
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10]) | |||
assert len(_) == 10 |
@@ -1,31 +0,0 @@ | |||
import unittest | |||
from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX | |||
class TestVocabulary(unittest.TestCase): | |||
def test_vocab(self): | |||
import _pickle as pickle | |||
import os | |||
vocab = Vocabulary() | |||
filename = 'vocab' | |||
vocab.update(filename) | |||
vocab.update([filename, ['a'], [['b']], ['c']]) | |||
idx = vocab[filename] | |||
before_pic = (vocab.to_word(idx), vocab[filename]) | |||
with open(filename, 'wb') as f: | |||
pickle.dump(vocab, f) | |||
with open(filename, 'rb') as f: | |||
vocab = pickle.load(f) | |||
os.remove(filename) | |||
vocab.build_reverse_vocab() | |||
after_pic = (vocab.to_word(idx), vocab[filename]) | |||
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8} | |||
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX) | |||
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'} | |||
self.assertEqual(before_pic, after_pic) | |||
self.assertDictEqual(TRUE_DICT, vocab.word2idx) | |||
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word) | |||
if __name__ == '__main__': | |||
unittest.main() |
@@ -0,0 +1,61 @@ | |||
import unittest | |||
from collections import Counter | |||
from fastNLP.core.vocabulary import Vocabulary | |||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
"works", "well", "in", "most", "cases", "scales", "well"] | |||
counter = Counter(text) | |||
class TestAdd(unittest.TestCase): | |||
def test_add(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
for word in text: | |||
vocab.add(word) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_add_word(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
for word in text: | |||
vocab.add_word(word) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_add_word_lst(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab.add_word_lst(text) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_update(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab.update(text) | |||
self.assertEqual(vocab.word_count, counter) | |||
class TestIndexing(unittest.TestCase): | |||
def test_len(self): | |||
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||
vocab.update(text) | |||
self.assertEqual(len(vocab), len(counter)) | |||
def test_contains(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab.update(text) | |||
self.assertTrue(text[-1] in vocab) | |||
self.assertFalse("~!@#" in vocab) | |||
self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | |||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||
def test_index(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab.update(text) | |||
res = [vocab[w] for w in set(text)] | |||
self.assertEqual(len(res), len(set(res))) | |||
res = [vocab.to_index(w) for w in set(text)] | |||
self.assertEqual(len(res), len(set(res))) | |||
def test_to_word(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab.update(text) | |||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) |