Browse Source

* add unit tests for instance, vocabulary

* remove and fix other unit tests
* add more code comments
tags/v0.2.0
FengZiYjun yunfan 6 years ago
parent
commit
837bef47dc
13 changed files with 242 additions and 195 deletions
  1. +4
    -12
      fastNLP/core/batch.py
  2. +7
    -9
      fastNLP/core/dataset.py
  3. +24
    -7
      fastNLP/core/fieldarray.py
  4. +13
    -14
      fastNLP/core/instance.py
  5. +39
    -20
      fastNLP/core/vocabulary.py
  6. +9
    -8
      test/core/test_batch.py
  7. +14
    -14
      test/core/test_dataset.py
  8. +0
    -42
      test/core/test_field.py
  9. +6
    -0
      test/core/test_fieldarray.py
  10. +29
    -0
      test/core/test_instance.py
  11. +36
    -38
      test/core/test_sampler.py
  12. +0
    -31
      test/core/test_vocab.py
  13. +61
    -0
      test/core/test_vocabulary.py

+ 4
- 12
fastNLP/core/batch.py View File

@@ -5,7 +5,8 @@ class Batch(object):
"""Batch is an iterable object which iterates over mini-batches.

::
for batch_x, batch_y in Batch(data_set):
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()):


"""

@@ -15,6 +16,8 @@ class Batch(object):
:param dataset: a DataSet object
:param batch_size: int, the size of the batch
:param sampler: a Sampler object
:param as_numpy: bool. If True, return Numpy array. Otherwise, return torch tensors.

"""
self.dataset = dataset
self.batch_size = batch_size
@@ -30,17 +33,6 @@ class Batch(object):
return self

def __next__(self):
"""

:return batch_x: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
E.g.
::
{'text': tensor([[ 0, 1, 2, 3, 0, 0, 0], 4, 5, 2, 6, 7, 8, 9]]), 'text_origin_len': [4, 7]})

batch_y: dict of (str: torch.LongTensor), which means (field name: tensor of shape [batch_size, padding_length])
All tensors in both batch_x and batch_y will be cuda tensors if use_cuda is True.

"""
if self.curidx >= len(self.idx_list):
raise StopIteration
else:


+ 7
- 9
fastNLP/core/dataset.py View File

@@ -117,22 +117,20 @@ class DataSet(object):
assert name in self.field_arrays
self.field_arrays[name].append(field)

def add_field(self, name, fields, padding_val=0, need_tensor=False, is_target=False):
def add_field(self, name, fields, padding_val=0, is_input=False, is_target=False):
"""
:param name:
:param str name:
:param fields:
:param padding_val:
:param need_tensor:
:param is_target:
:param int padding_val:
:param bool is_input:
:param bool is_target:
:return:
"""
if len(self.field_arrays) != 0:
assert len(self) == len(fields)
self.field_arrays[name] = FieldArray(name, fields,
padding_val=padding_val,
need_tensor=need_tensor,
is_target=is_target)
self.field_arrays[name] = FieldArray(name, fields, padding_val=padding_val, is_target=is_target,
is_input=is_input)

def delete_field(self, name):
self.field_arrays.pop(name)


+ 24
- 7
fastNLP/core/fieldarray.py View File

@@ -2,7 +2,19 @@ import numpy as np


class FieldArray(object):
"""FieldArray is the collection of Instances of the same Field.
It is the basic element of DataSet class.

"""
def __init__(self, name, content, padding_val=0, is_target=False, is_input=False):
"""

:param str name: the name of the FieldArray
:param list content: a list of int, float, or other objects.
:param int padding_val: the integer for padding. Default: 0.
:param bool is_target: If True, this FieldArray is used to compute loss.
:param bool is_input: If True, this FieldArray is used to the model input.
"""
self.name = name
self.content = content
self.padding_val = padding_val
@@ -24,23 +36,28 @@ class FieldArray(object):
assert isinstance(name, int)
self.content[name] = val

def get(self, idxes):
if isinstance(idxes, int):
return self.content[idxes]
def get(self, indices):
"""Fetch instances based on indices.

:param indices: an int, or a list of int.
:return:
"""
if isinstance(indices, int):
return self.content[indices]
assert self.is_input is True or self.is_target is True
batch_size = len(idxes)
batch_size = len(indices)
# TODO 当这个fieldArray是seq_length这种只有一位的内容时,不需要padding,需要再讨论一下
if isinstance(self.content[0], int) or isinstance(self.content[0], float):
if self.dtype is None:
self.dtype = np.int64 if isinstance(self.content[0], int) else np.double
array = np.array([self.content[i] for i in idxes], dtype=self.dtype)
array = np.array([self.content[i] for i in indices], dtype=self.dtype)
else:
if self.dtype is None:
self.dtype = np.int64
max_len = max([len(self.content[i]) for i in idxes])
max_len = max([len(self.content[i]) for i in indices])
array = np.full((batch_size, max_len), self.padding_val, dtype=self.dtype)

for i, idx in enumerate(idxes):
for i, idx in enumerate(indices):
array[i][:len(self.content[idx])] = self.content[idx]
return array



+ 13
- 14
fastNLP/core/instance.py View File

@@ -1,16 +1,27 @@


class Instance(object):
"""An instance which consists of Fields is an example in the DataSet.
"""An Instance is an example of data. It is the collection of Fields.

::
Instance(field_1=[1, 1, 1], field_2=[2, 2, 2])

"""

def __init__(self, **fields):
"""

:param fields: a dict of (field name: field)
"""
self.fields = fields

def add_field(self, field_name, field):
"""Add a new field to the instance.

:param field_name: str, the name of the field.
:param field:
"""
self.fields[field_name] = field
return self

def __getitem__(self, name):
if name in self.fields:
@@ -21,17 +32,5 @@ class Instance(object):
def __setitem__(self, name, field):
return self.add_field(name, field)

def __getattr__(self, item):
if hasattr(self, 'fields') and item in self.fields:
return self.fields[item]
else:
raise AttributeError('{} does not exist.'.format(item))

def __setattr__(self, key, value):
if hasattr(self, 'fields'):
self.__setitem__(key, value)
else:
super().__setattr__(key, value)

def __repr__(self):
return self.fields.__repr__()

+ 39
- 20
fastNLP/core/vocabulary.py View File

@@ -1,5 +1,5 @@
from copy import deepcopy
from collections import Counter
from copy import deepcopy

DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1
@@ -20,6 +20,7 @@ def check_build_vocab(func):
if self.word2idx is None:
self.build_vocab()
return func(self, *args, **kwargs)

return _wrapper


@@ -34,6 +35,7 @@ class Vocabulary(object):
vocab["word"]
vocab.to_word(5)
"""

def __init__(self, need_default=True, max_size=None, min_freq=None):
"""
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True.
@@ -54,24 +56,36 @@ class Vocabulary(object):
self.idx2word = None

def update(self, word_lst):
"""add word or list of words into Vocabulary
"""Add a list of words into the vocabulary.

:param word: a list of string or a single string
:param list word_lst: a list of strings
"""
self.word_count.update(word_lst)

def add(self, word):
"""Add a single word into the vocabulary.

:param str word: a word or token.
"""
self.word_count[word] += 1

def add_word(self, word):
"""Add a single word into the vocabulary.

:param str word: a word or token.
"""
self.add(word)

def add_word_lst(self, word_lst):
self.update(word_lst)
"""Add a list of words into the vocabulary.

:param list word_lst: a list of strings
"""
self.update(word_lst)

def build_vocab(self):
"""build 'word to index' dict, and filter the word using `max_size` and `min_freq`
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`.

"""
if self.has_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
@@ -85,11 +99,12 @@ class Vocabulary(object):
if self.min_freq is not None:
words = filter(lambda kv: kv[1] >= self.min_freq, words)
start_idx = len(self.word2idx)
self.word2idx.update({w:i+start_idx for i, (w,_) in enumerate(words)})
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab()

def build_reverse_vocab(self):
"""build 'index to word' dict based on 'word to index' dict
"""Build 'index to word' dict based on 'word to index' dict.

"""
self.idx2word = {i: w for w, i in self.word2idx.items()}

@@ -97,6 +112,15 @@ class Vocabulary(object):
def __len__(self):
return len(self.word2idx)

@check_build_vocab
def __contains__(self, item):
"""Check if a word in vocabulary.

:param item: the word
:return: True or False
"""
return item in self.word2idx

def has_word(self, w):
return self.__contains__(w)

@@ -114,8 +138,8 @@ class Vocabulary(object):
raise ValueError("word {} not in vocabulary".format(w))

def to_index(self, w):
""" like to_index(w) function, turn a word to the index
if w is not in Vocabulary, return the unknown label
""" Turn a word to an index.
If w is not in Vocabulary, return the unknown label.

:param str w:
"""
@@ -144,12 +168,14 @@ class Vocabulary(object):
def to_word(self, idx):
"""given a word's index, return the word itself

:param int idx:
:param int idx: the index
:return str word: the indexed word
"""
return self.idx2word[idx]

def __getstate__(self):
"""use to prepare data for pickle
"""Use to prepare data for pickle.

"""
state = self.__dict__.copy()
# no need to pickle idx2word as it can be constructed from word2idx
@@ -157,16 +183,9 @@ class Vocabulary(object):
return state

def __setstate__(self, state):
"""use to restore state from pickle
"""Use to restore state from pickle.

"""
self.__dict__.update(state)
self.build_reverse_vocab()

@check_build_vocab
def __contains__(self, item):
"""Check if a word in vocabulary.

:param item: the word
:return: True or False
"""
return item in self.word2idx

+ 9
- 8
test/core/test_batch.py View File

@@ -1,17 +1,18 @@
import unittest

from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.dataset import construct_dataset
from fastNLP.core.sampler import SequentialSampler


class TestCase1(unittest.TestCase):
def test(self):
dataset = DataSet([Instance(x=["I", "am", "here"])] * 40)
def test_simple(self):
dataset = construct_dataset(
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
dataset.set_target()
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), use_cuda=False)

for batch_x, batch_y in batch:
print(batch_x, batch_y)
# TODO: weird due to change in dataset.py
cnt = 0
for _, _ in batch:
cnt += 1
self.assertEqual(cnt, 10)

+ 14
- 14
test/core/test_dataset.py View File

@@ -1,20 +1,20 @@
import unittest

from fastNLP.core.dataset import DataSet


class TestDataSet(unittest.TestCase):
labeled_data_list = [
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
[["a", "b", "e", "d"], ["1", "2", "3", "4"]],
]
unlabeled_data_list = [
["a", "b", "e", "d"],
["a", "b", "e", "d"],
["a", "b", "e", "d"]
]
word_vocab = {"a": 0, "b": 1, "e": 2, "d": 3}
label_vocab = {"1": 1, "2": 2, "3": 3, "4": 4}

def test_case_1(self):
# TODO:
pass
ds = DataSet()
ds.add_field(name="xx", fields=["a", "b", "e", "d"])

self.assertTrue("xx" in ds.field_arrays)
self.assertEqual(len(ds.field_arrays["xx"]), 4)
self.assertEqual(ds.get_length(), 4)
self.assertEqual(ds.get_fields(), ds.field_arrays)

try:
ds.add_field(name="yy", fields=["x", "y", "z", "w", "f"])
except BaseException as e:
self.assertTrue(isinstance(e, AssertionError))

+ 0
- 42
test/core/test_field.py View File

@@ -1,42 +0,0 @@
import unittest

from fastNLP.core.field import CharTextField, LabelField, SeqLabelField


class TestField(unittest.TestCase):
def test_char_field(self):
text = "PhD applicants must submit a Research Plan and a resume " \
"specify your class ranking written in English and a list of research" \
" publications if any".split()
max_word_len = max([len(w) for w in text])
field = CharTextField(text, max_word_len, is_target=False)
all_char = set()
for word in text:
all_char.update([ch for ch in word])
char_vocab = {ch: idx + 1 for idx, ch in enumerate(all_char)}

self.assertEqual(field.index(char_vocab),
[[char_vocab[ch] for ch in word] + [0] * (max_word_len - len(word)) for word in text])
self.assertEqual(field.get_length(), len(text))
self.assertEqual(field.contents(), text)
tensor = field.to_tensor(50)
self.assertEqual(tuple(tensor.shape), (50, max_word_len))

def test_label_field(self):
label = LabelField("A", is_target=True)
self.assertEqual(label.get_length(), 1)
self.assertEqual(label.index({"A": 10}), 10)

label = LabelField(30, is_target=True)
self.assertEqual(label.get_length(), 1)
tensor = label.to_tensor(0)
self.assertEqual(tensor.shape, ())
self.assertEqual(int(tensor), 30)

def test_seq_label_field(self):
seq = ["a", "b", "c", "d", "a", "c", "a", "b"]
field = SeqLabelField(seq)
vocab = {"a": 10, "b": 20, "c": 30, "d": 40}
self.assertEqual(field.index(vocab), [vocab[x] for x in seq])
tensor = field.to_tensor(10)
self.assertEqual(tuple(tensor.shape), (10,))

+ 6
- 0
test/core/test_fieldarray.py View File

@@ -0,0 +1,6 @@
import unittest


class TestFieldArray(unittest.TestCase):
def test(self):
pass

+ 29
- 0
test/core/test_instance.py View File

@@ -0,0 +1,29 @@
import unittest

from fastNLP.core.instance import Instance


class TestCase(unittest.TestCase):

def test_init(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
self.assertTrue(isinstance(ins.fields, dict))
self.assertEqual(ins.fields, fields)

ins = Instance(**fields)
self.assertEqual(ins.fields, fields)

def test_add_field(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(**fields)
ins.add_field("z", [1, 1, 1])
fields.update({"z": [1, 1, 1]})
self.assertEqual(ins.fields, fields)

def test_get_item(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields)
self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1])

+ 36
- 38
test/core/test_sampler.py View File

@@ -1,44 +1,42 @@
import unittest

import torch

from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \
k_means_1d, k_means_bucketing, simple_sort_bucketing


def test_convert_to_torch_tensor():
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]]
ans = convert_to_torch_tensor(data, False)
assert isinstance(ans, torch.Tensor)
assert tuple(ans.shape) == (3, 5)


def test_sequential_sampler():
sampler = SequentialSampler()
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10]
for idx, i in enumerate(sampler(data)):
assert idx == i


def test_random_sampler():
sampler = RandomSampler()
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10]
ans = [data[i] for i in sampler(data)]
assert len(ans) == len(data)
for d in ans:
assert d in data


def test_k_means():
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5)
centroids, assign = list(centroids), list(assign)
assert len(centroids) == 2
assert len(assign) == 10


def test_k_means_bucketing():
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None])
assert len(res) == 2


def test_simple_sort_bucketing():
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10])
assert len(_) == 10
class TestSampler(unittest.TestCase):
def test_convert_to_torch_tensor(self):
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]]
ans = convert_to_torch_tensor(data, False)
assert isinstance(ans, torch.Tensor)
assert tuple(ans.shape) == (3, 5)

def test_sequential_sampler(self):
sampler = SequentialSampler()
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10]
for idx, i in enumerate(sampler(data)):
assert idx == i

def test_random_sampler(self):
sampler = RandomSampler()
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10]
ans = [data[i] for i in sampler(data)]
assert len(ans) == len(data)
for d in ans:
assert d in data

def test_k_means(self):
centroids, assign = k_means_1d([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], 2, max_iter=5)
centroids, assign = list(centroids), list(assign)
assert len(centroids) == 2
assert len(assign) == 10

def test_k_means_bucketing(self):
res = k_means_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10], [None, None])
assert len(res) == 2

def test_simple_sort_bucketing(self):
_ = simple_sort_bucketing([21, 3, 25, 7, 9, 22, 4, 6, 28, 10])
assert len(_) == 10

+ 0
- 31
test/core/test_vocab.py View File

@@ -1,31 +0,0 @@
import unittest
from fastNLP.core.vocabulary import Vocabulary, DEFAULT_WORD_TO_INDEX

class TestVocabulary(unittest.TestCase):
def test_vocab(self):
import _pickle as pickle
import os
vocab = Vocabulary()
filename = 'vocab'
vocab.update(filename)
vocab.update([filename, ['a'], [['b']], ['c']])
idx = vocab[filename]
before_pic = (vocab.to_word(idx), vocab[filename])

with open(filename, 'wb') as f:
pickle.dump(vocab, f)
with open(filename, 'rb') as f:
vocab = pickle.load(f)
os.remove(filename)
vocab.build_reverse_vocab()
after_pic = (vocab.to_word(idx), vocab[filename])
TRUE_DICT = {'vocab': 5, 'a': 6, 'b': 7, 'c': 8}
TRUE_DICT.update(DEFAULT_WORD_TO_INDEX)
TRUE_IDXDICT = {0: '<pad>', 1: '<unk>', 2: '<reserved-2>', 3: '<reserved-3>', 4: '<reserved-4>', 5: 'vocab', 6: 'a', 7: 'b', 8: 'c'}
self.assertEqual(before_pic, after_pic)
self.assertDictEqual(TRUE_DICT, vocab.word2idx)
self.assertDictEqual(TRUE_IDXDICT, vocab.idx2word)
if __name__ == '__main__':
unittest.main()

+ 61
- 0
test/core/test_vocabulary.py View File

@@ -0,0 +1,61 @@
import unittest
from collections import Counter

from fastNLP.core.vocabulary import Vocabulary

text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
"works", "well", "in", "most", "cases", "scales", "well"]
counter = Counter(text)


class TestAdd(unittest.TestCase):
def test_add(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
for word in text:
vocab.add(word)
self.assertEqual(vocab.word_count, counter)

def test_add_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
for word in text:
vocab.add_word(word)
self.assertEqual(vocab.word_count, counter)

def test_add_word_lst(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab.add_word_lst(text)
self.assertEqual(vocab.word_count, counter)

def test_update(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab.update(text)
self.assertEqual(vocab.word_count, counter)


class TestIndexing(unittest.TestCase):
def test_len(self):
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None)
vocab.update(text)
self.assertEqual(len(vocab), len(counter))

def test_contains(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab.update(text)
self.assertTrue(text[-1] in vocab)
self.assertFalse("~!@#" in vocab)
self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1]))
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))

def test_index(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab.update(text)
res = [vocab[w] for w in set(text)]
self.assertEqual(len(res), len(set(res)))

res = [vocab.to_index(w) for w in set(text)]
self.assertEqual(len(res), len(set(res)))

def test_to_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab.update(text)
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])

Loading…
Cancel
Save