@@ -13,7 +13,8 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||
__all__ = ["Instance", "FieldArray", "Batch", "Vocabulary", "DataSet", | |||
"Trainer", "Tester", "Callback", | |||
"Padder", "AutoPadder", "EngChar2DPadder", | |||
"AccuracyMetric", "Optimizer", "SGD", "Adam", | |||
"AccuracyMetric", "BMESF1PreRecMetric", "SpanFPreRecMetric", "SQuADMetric", | |||
"Optimizer", "SGD", "Adam", | |||
"Sampler", "SequentialSampler", "BucketSampler", "RandomSampler", | |||
"LossFunc", "CrossEntropyLoss", "L1Loss", "BCELoss", "NLLLoss", "LossInForward", | |||
"cache_results"] | |||
@@ -17,7 +17,7 @@ from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
from .instance import Instance | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
from .metrics import AccuracyMetric | |||
from .metrics import AccuracyMetric, BMESF1PreRecMetric, SpanFPreRecMetric, SQuADMetric | |||
from .optimizer import Optimizer, SGD, Adam | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
from .tester import Tester | |||
@@ -236,6 +236,7 @@ class CallbackManager(Callback): | |||
for env_name, env_val in env.items(): | |||
for callback in self.callbacks: | |||
print(callback, env_name, env_val ) | |||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||
@_transfer | |||
@@ -425,19 +426,25 @@ class LRFinder(Callback): | |||
super(LRFinder, self).__init__() | |||
self.start_lr, self.end_lr = start_lr, end_lr | |||
self.num_it = self.batch_per_epoch | |||
self.stop = False | |||
self.best_loss = 0. | |||
self.best_lr = None | |||
self.loss_history = [] | |||
self.smooth_value = SmoothValue(0.8) | |||
self.opt = None | |||
scale = (self.end_lr - self.start_lr) / self.num_it | |||
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) | |||
self.find = None | |||
self.loader = ModelLoader() | |||
@property | |||
def lr_gen(self): | |||
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch | |||
return (self.start_lr + scale * (step + 1) for step in range(self.batch_per_epoch)) | |||
@property | |||
def num_it(self): | |||
return self.batch_per_epoch | |||
def on_epoch_begin(self): | |||
if self.epoch == 1: # first epoch | |||
self.opt = self.trainer.optimizer # pytorch optimizer | |||
@@ -418,6 +418,7 @@ class AutoPadder(Padder): | |||
return False | |||
def __call__(self, contents, field_name, field_ele_dtype): | |||
if not _is_iterable(contents[0]): | |||
array = np.array([content for content in contents], dtype=field_ele_dtype) | |||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | |||
@@ -430,6 +430,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
class SpanFPreRecMetric(MetricBase): | |||
""" | |||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | |||
@@ -619,6 +620,8 @@ class SpanFPreRecMetric(MetricBase): | |||
class BMESF1PreRecMetric(MetricBase): | |||
""" | |||
别名::class:`fastNLP.BMESF1PreRecMetric` :class:`fastNLP.core.metrics.BMESF1PreRecMetric` | |||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | |||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | |||
@@ -826,6 +829,8 @@ def _pred_topk(y_prob, k=1): | |||
class SQuADMetric(MetricBase): | |||
""" | |||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||
SQuAD数据集metric | |||
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | |||
@@ -350,7 +350,7 @@ class Trainer(object): | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param nn.modules model: 待训练的模型 | |||
:param torch.optim.Optimizer optimizer: 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
@@ -403,7 +403,6 @@ class Trainer(object): | |||
callbacks=None, | |||
check_code_level=0): | |||
super(Trainer, self).__init__() | |||
if not isinstance(train_data, DataSet): | |||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | |||
if not isinstance(model, nn.Module): | |||
@@ -468,7 +467,7 @@ class Trainer(object): | |||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | |||
self.model = _move_model_to_device(self.model, device=device) | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
@@ -493,6 +492,7 @@ class Trainer(object): | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
print("callback_manager") | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
@@ -616,7 +616,7 @@ def seq_lens_to_masks(seq_lens, float=False): | |||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | |||
batch_size = seq_lens.size(0) | |||
max_len = seq_lens.max() | |||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long() | |||
masks = indexes.lt(seq_lens.unsqueeze(1)) | |||
if float: | |||
@@ -2,16 +2,18 @@ from functools import wraps | |||
from collections import Counter | |||
from .dataset import DataSet | |||
def _check_build_vocab(func): | |||
"""A decorator to make sure the indexing is built before used. | |||
""" | |||
@wraps(func) # to solve missing docstring | |||
@wraps(func) # to solve missing docstring | |||
def _wrapper(self, *args, **kwargs): | |||
if self.word2idx is None or self.rebuild is True: | |||
self.build_vocab() | |||
return func(self, *args, **kwargs) | |||
return _wrapper | |||
@@ -19,7 +21,8 @@ def _check_build_status(func): | |||
"""A decorator to check whether the vocabulary updates after the last build. | |||
""" | |||
@wraps(func) # to solve missing docstring | |||
@wraps(func) # to solve missing docstring | |||
def _wrapper(self, *args, **kwargs): | |||
if self.rebuild is False: | |||
self.rebuild = True | |||
@@ -28,7 +31,7 @@ def _check_build_status(func): | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
self.max_size, func.__name__)) | |||
return func(self, *args, **kwargs) | |||
return _wrapper | |||
@@ -50,15 +53,15 @@ class Vocabulary(object): | |||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||
:param str padding: padding的字符. 如果设置为 ``None`` , | |||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<pad>' | |||
:param str unknow: unknow的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<unk>' | |||
""" | |||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
self.max_size = max_size | |||
self.min_freq = min_freq | |||
@@ -68,7 +71,7 @@ class Vocabulary(object): | |||
self.word2idx = None | |||
self.idx2word = None | |||
self.rebuild = True | |||
@_check_build_status | |||
def update(self, word_lst): | |||
"""依次增加序列中词在词典中的出现频率 | |||
@@ -76,7 +79,7 @@ class Vocabulary(object): | |||
:param list word_lst: a list of strings | |||
""" | |||
self.word_count.update(word_lst) | |||
@_check_build_status | |||
def add(self, word): | |||
""" | |||
@@ -85,7 +88,7 @@ class Vocabulary(object): | |||
:param str word: 新词 | |||
""" | |||
self.word_count[word] += 1 | |||
@_check_build_status | |||
def add_word(self, word): | |||
""" | |||
@@ -94,7 +97,7 @@ class Vocabulary(object): | |||
:param str word: 新词 | |||
""" | |||
self.add(word) | |||
@_check_build_status | |||
def add_word_lst(self, word_lst): | |||
""" | |||
@@ -103,7 +106,7 @@ class Vocabulary(object): | |||
:param list[str] word_lst: 词的序列 | |||
""" | |||
self.update(word_lst) | |||
def build_vocab(self): | |||
""" | |||
根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | |||
@@ -116,7 +119,7 @@ class Vocabulary(object): | |||
self.word2idx[self.padding] = len(self.word2idx) | |||
if self.unknown is not None: | |||
self.word2idx[self.unknown] = len(self.word2idx) | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
if self.min_freq is not None: | |||
@@ -127,18 +130,18 @@ class Vocabulary(object): | |||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
self.rebuild = False | |||
def build_reverse_vocab(self): | |||
""" | |||
基于 "word to index" dict, 构建 "index to word" dict. | |||
""" | |||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
@_check_build_vocab | |||
def __len__(self): | |||
return len(self.word2idx) | |||
@_check_build_vocab | |||
def __contains__(self, item): | |||
""" | |||
@@ -148,7 +151,7 @@ class Vocabulary(object): | |||
:return: True or False | |||
""" | |||
return item in self.word2idx | |||
def has_word(self, w): | |||
""" | |||
检查词是否被记录 | |||
@@ -163,7 +166,7 @@ class Vocabulary(object): | |||
:return: ``True`` or ``False`` | |||
""" | |||
return self.__contains__(w) | |||
@_check_build_vocab | |||
def __getitem__(self, w): | |||
""" | |||
@@ -177,7 +180,7 @@ class Vocabulary(object): | |||
return self.word2idx[self.unknown] | |||
else: | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
@_check_build_vocab | |||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
""" | |||
@@ -194,6 +197,7 @@ class Vocabulary(object): | |||
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||
Default: ``None`` | |||
""" | |||
def index_instance(ins): | |||
""" | |||
有几种情况, str, 1d-list, 2d-list | |||
@@ -209,8 +213,8 @@ class Vocabulary(object): | |||
else: | |||
if isinstance(field[0][0], list): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
return[[self.to_index(c) for c in w] for w in field] | |||
return [[self.to_index(c) for c in w] for w in field] | |||
if new_field_name is None: | |||
new_field_name = field_name | |||
for idx, dataset in enumerate(datasets): | |||
@@ -222,7 +226,7 @@ class Vocabulary(object): | |||
raise e | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
def from_dataset(self, *datasets, field_name): | |||
""" | |||
使用dataset的对应field中词构建词典 | |||
@@ -243,7 +247,7 @@ class Vocabulary(object): | |||
field_name = [field_name] | |||
elif not isinstance(field_name, list): | |||
raise TypeError('invalid argument field_name: {}'.format(field_name)) | |||
def construct_vocab(ins): | |||
for fn in field_name: | |||
field = ins[fn] | |||
@@ -256,6 +260,7 @@ class Vocabulary(object): | |||
if isinstance(field[0][0], list): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
[self.add_word_lst(w) for w in field] | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
@@ -266,7 +271,7 @@ class Vocabulary(object): | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
return self | |||
def to_index(self, w): | |||
""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 | |||
@@ -282,7 +287,7 @@ class Vocabulary(object): | |||
:return int index: the number | |||
""" | |||
return self.__getitem__(w) | |||
@property | |||
@_check_build_vocab | |||
def unknown_idx(self): | |||
@@ -292,7 +297,7 @@ class Vocabulary(object): | |||
if self.unknown is None: | |||
return None | |||
return self.word2idx[self.unknown] | |||
@property | |||
@_check_build_vocab | |||
def padding_idx(self): | |||
@@ -302,7 +307,7 @@ class Vocabulary(object): | |||
if self.padding is None: | |||
return None | |||
return self.word2idx[self.padding] | |||
@_check_build_vocab | |||
def to_word(self, idx): | |||
""" | |||
@@ -312,26 +317,26 @@ class Vocabulary(object): | |||
:return str word: the word | |||
""" | |||
return self.idx2word[idx] | |||
def __getstate__(self): | |||
"""Use to prepare data for pickle. | |||
""" | |||
len(self) # make sure vocab has been built | |||
len(self) # make sure vocab has been built | |||
state = self.__dict__.copy() | |||
# no need to pickle idx2word as it can be constructed from word2idx | |||
del state['idx2word'] | |||
return state | |||
def __setstate__(self, state): | |||
"""Use to restore state from pickle. | |||
""" | |||
self.__dict__.update(state) | |||
self.build_reverse_vocab() | |||
def __repr__(self): | |||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | |||
def __iter__(self): | |||
return iter(list(self.word_count.keys())) |
@@ -1,13 +1,12 @@ | |||
import time | |||
import unittest | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.batch import Batch | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.sampler import SequentialSampler | |||
from fastNLP import Batch | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import SequentialSampler | |||
def generate_fake_dataset(num_samples=1000): | |||
@@ -16,11 +15,11 @@ def generate_fake_dataset(num_samples=1000): | |||
:param num_samples: sample的数量 | |||
:return: | |||
""" | |||
max_len = 50 | |||
min_len = 10 | |||
num_features = 4 | |||
data_dict = {} | |||
for i in range(num_features): | |||
data = [] | |||
@@ -28,9 +27,9 @@ def generate_fake_dataset(num_samples=1000): | |||
for length in lengths: | |||
data.append(np.random.randint(100, size=length)) | |||
data_dict[str(i)] = data | |||
dataset = DataSet(data_dict) | |||
for i in range(num_features): | |||
if np.random.randint(2) == 0: | |||
dataset.set_input(str(i)) | |||
@@ -38,6 +37,7 @@ def generate_fake_dataset(num_samples=1000): | |||
dataset.set_target(str(i)) | |||
return dataset | |||
def construct_dataset(sentences): | |||
"""Construct a data set from a list of sentences. | |||
@@ -51,18 +51,19 @@ def construct_dataset(sentences): | |||
dataset.append(instance) | |||
return dataset | |||
class TestCase1(unittest.TestCase): | |||
def test_simple(self): | |||
dataset = construct_dataset( | |||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | |||
dataset.set_target() | |||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | |||
cnt = 0 | |||
for _, _ in batch: | |||
cnt += 1 | |||
self.assertEqual(cnt, 10) | |||
def test_dataset_batching(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ds.set_input("x") | |||
@@ -74,7 +75,7 @@ class TestCase1(unittest.TestCase): | |||
self.assertEqual(len(y["y"]), 4) | |||
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | |||
self.assertListEqual(list(y["y"][-1]), [5, 6]) | |||
def test_list_padding(self): | |||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
@@ -84,7 +85,7 @@ class TestCase1(unittest.TestCase): | |||
for x, y in iter: | |||
self.assertEqual(x["x"].shape, (4, 4)) | |||
self.assertEqual(y["y"].shape, (4, 4)) | |||
def test_numpy_padding(self): | |||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
@@ -94,7 +95,7 @@ class TestCase1(unittest.TestCase): | |||
for x, y in iter: | |||
self.assertEqual(x["x"].shape, (4, 4)) | |||
self.assertEqual(y["y"].shape, (4, 4)) | |||
def test_list_to_tensor(self): | |||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | |||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | |||
@@ -106,7 +107,7 @@ class TestCase1(unittest.TestCase): | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
def test_numpy_to_tensor(self): | |||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | |||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | |||
@@ -118,7 +119,7 @@ class TestCase1(unittest.TestCase): | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
def test_list_of_list_to_tensor(self): | |||
ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] + | |||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | |||
@@ -130,7 +131,7 @@ class TestCase1(unittest.TestCase): | |||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | |||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | |||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | |||
def test_list_of_numpy_to_tensor(self): | |||
ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + | |||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | |||
@@ -139,16 +140,16 @@ class TestCase1(unittest.TestCase): | |||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | |||
for x, y in iter: | |||
print(x, y) | |||
def test_sequential_batch(self): | |||
batch_size = 32 | |||
num_samples = 1000 | |||
dataset = generate_fake_dataset(num_samples) | |||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | |||
for batch_x, batch_y in batch: | |||
pass | |||
""" | |||
def test_multi_workers_batch(self): | |||
batch_size = 32 | |||
@@ -4,14 +4,13 @@ import numpy as np | |||
import torch | |||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | |||
LRFinder, \ | |||
TensorboardCallback | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.optimizer import SGD | |||
from fastNLP.core.trainer import Trainer | |||
LRFinder, TensorboardCallback | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import BCELoss | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import SGD | |||
from fastNLP import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier | |||
@@ -20,15 +19,15 @@ def prepare_env(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x") | |||
data_set.set_target("y") | |||
@@ -37,19 +36,7 @@ def prepare_env(): | |||
class TestCallback(unittest.TestCase): | |||
def test_echo_callback(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
n_epochs=2, | |||
batch_size=32, | |||
print_every=50, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=False, | |||
callbacks=[EchoCallback()]) | |||
trainer.train() | |||
def test_gradient_clip(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
@@ -64,7 +51,7 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | |||
trainer.train() | |||
def test_early_stop(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
@@ -79,7 +66,7 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[EarlyStopCallback(5)]) | |||
trainer.train() | |||
def test_lr_scheduler(self): | |||
data_set, model = prepare_env() | |||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | |||
@@ -95,7 +82,7 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | |||
trainer.train() | |||
def test_KeyBoardInterrupt(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
@@ -108,7 +95,7 @@ class TestCallback(unittest.TestCase): | |||
use_tqdm=False, | |||
callbacks=[ControlC(False)]) | |||
trainer.train() | |||
def test_LRFinder(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
@@ -121,7 +108,7 @@ class TestCallback(unittest.TestCase): | |||
use_tqdm=False, | |||
callbacks=[LRFinder(len(data_set) // 32)]) | |||
trainer.train() | |||
def test_TensorboardCallback(self): | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
@@ -136,21 +123,22 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[TensorboardCallback("loss", "metric")]) | |||
trainer.train() | |||
def test_readonly_property(self): | |||
from fastNLP.core.callback import Callback | |||
passed_epochs = [] | |||
total_epochs = 5 | |||
class MyCallback(Callback): | |||
def __init__(self): | |||
super(MyCallback, self).__init__() | |||
def on_epoch_begin(self): | |||
passed_epochs.append(self.epoch) | |||
print(self.n_epochs, self.n_steps, self.batch_size) | |||
print(self.model) | |||
print(self.optimizer) | |||
data_set, model = prepare_env() | |||
trainer = Trainer(data_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
@@ -164,4 +152,4 @@ class TestCallback(unittest.TestCase): | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
callbacks=[MyCallback()]) | |||
trainer.train() | |||
assert passed_epochs == list(range(1, total_epochs+1)) | |||
assert passed_epochs == list(range(1, total_epochs + 1)) |
@@ -1,9 +1,10 @@ | |||
import os | |||
import unittest | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP.core.instance import Instance | |||
from fastNLP import DataSet | |||
from fastNLP import FieldArray | |||
from fastNLP import Instance | |||
from fastNLP.io import CSVLoader | |||
class TestDataSetInit(unittest.TestCase): | |||
@@ -167,13 +168,11 @@ class TestDataSetMethods(unittest.TestCase): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||
d1, d2 = ds.split(0.1) | |||
def test_apply2(self): | |||
def split_sent(ins): | |||
return ins['raw_sentence'].split() | |||
dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | |||
sep='\t') | |||
csv_loader = CSVLoader(headers=['raw_sentence', 'label'],sep='\t') | |||
dataset = csv_loader.load('../data_for_tests/tutorial_sample_dataset.csv') | |||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | |||
dataset.apply(split_sent, new_field_name='words', is_input=True) | |||
# print(dataset) | |||
@@ -208,7 +207,7 @@ class TestDataSetMethods(unittest.TestCase): | |||
self.assertEqual(ans.content, [[5, 6]] * 10) | |||
def test_add_null(self): | |||
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||
# TODO test failed because 'fastNLP\core\field.py:143: RuntimeError' | |||
ds = DataSet() | |||
with self.assertRaises(RuntimeError) as RE: | |||
ds.add_field('test', []) | |||
@@ -2,7 +2,7 @@ import unittest | |||
import numpy as np | |||
from fastNLP.core.fieldarray import FieldArray | |||
from fastNLP import FieldArray | |||
class TestFieldArrayInit(unittest.TestCase): | |||
@@ -170,7 +170,7 @@ class TestPadder(unittest.TestCase): | |||
测试AutoPadder能否正常工作 | |||
:return: | |||
""" | |||
from fastNLP.core.fieldarray import AutoPadder | |||
from fastNLP import AutoPadder | |||
padder = AutoPadder() | |||
content = ['This is a str', 'this is another str'] | |||
self.assertListEqual(content, padder(content, None, np.str).tolist()) | |||
@@ -194,7 +194,7 @@ class TestPadder(unittest.TestCase): | |||
测试EngChar2DPadder能不能正确使用 | |||
:return: | |||
""" | |||
from fastNLP.core.fieldarray import EngChar2DPadder | |||
from fastNLP import EngChar2DPadder | |||
padder = EngChar2DPadder(pad_length=0) | |||
contents = [1, 2] | |||
@@ -225,11 +225,11 @@ class TestPadder(unittest.TestCase): | |||
) | |||
def test_None_dtype(self): | |||
from fastNLP.core.fieldarray import AutoPadder | |||
from fastNLP import AutoPadder | |||
padder = AutoPadder() | |||
content = [ | |||
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], | |||
[[1]] | |||
] | |||
ans = padder(content, None, None) | |||
ans = padder(content, None, None).tolist() | |||
self.assertListEqual(content, ans) |
@@ -1,33 +1,33 @@ | |||
import unittest | |||
from fastNLP.core.instance import Instance | |||
from fastNLP import Instance | |||
class TestCase(unittest.TestCase): | |||
def test_init(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | |||
self.assertTrue(isinstance(ins.fields, dict)) | |||
self.assertEqual(ins.fields, fields) | |||
ins = Instance(**fields) | |||
self.assertEqual(ins.fields, fields) | |||
def test_add_field(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | |||
ins = Instance(**fields) | |||
ins.add_field("z", [1, 1, 1]) | |||
fields.update({"z": [1, 1, 1]}) | |||
self.assertEqual(ins.fields, fields) | |||
def test_get_item(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
ins = Instance(**fields) | |||
self.assertEqual(ins["x"], [1, 2, 3]) | |||
self.assertEqual(ins["y"], [4, 5, 6]) | |||
self.assertEqual(ins["z"], [1, 1, 1]) | |||
def test_repr(self): | |||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | |||
ins = Instance(**fields) | |||
@@ -3,7 +3,7 @@ import unittest | |||
import torch | |||
import torch.nn.functional as F | |||
import fastNLP.core.losses as loss | |||
import fastNLP as loss | |||
from fastNLP.core.losses import squash, unpad | |||
@@ -14,21 +14,21 @@ class TestLoss(unittest.TestCase): | |||
b = torch.empty(3, dtype=torch.long).random_(5) | |||
ans = ce({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | |||
def test_BCELoss(self): | |||
bce = loss.BCELoss(pred="my_predict", target="my_truth") | |||
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | |||
b = torch.randn((3, 5), requires_grad=False) | |||
ans = bce({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | |||
def test_L1Loss(self): | |||
l1 = loss.L1Loss(pred="my_predict", target="my_truth") | |||
a = torch.randn(3, 5, requires_grad=False) | |||
b = torch.randn(3, 5) | |||
ans = l1({"my_predict": a}, {"my_truth": b}) | |||
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | |||
def test_NLLLoss(self): | |||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||
@@ -43,34 +43,34 @@ class TestLosserError(unittest.TestCase): | |||
pred_dict = {"pred": torch.zeros(4, 3)} | |||
target_dict = {'target': torch.zeros(4).long()} | |||
los = loss.CrossEntropyLoss() | |||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
# | |||
def test_losser2(self): | |||
# (2) with corrupted size | |||
pred_dict = {"pred": torch.zeros(16, 3)} | |||
target_dict = {'target': torch.zeros(16, 3).long()} | |||
los = loss.CrossEntropyLoss() | |||
with self.assertRaises(RuntimeError): | |||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
def test_losser3(self): | |||
# (2) with corrupted size | |||
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} | |||
target_dict = {'target': torch.zeros(16).long()} | |||
los = loss.CrossEntropyLoss() | |||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | |||
def test_check_error(self): | |||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | |||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | |||
b = torch.tensor([1, 0, 4]) | |||
with self.assertRaises(Exception): | |||
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | |||
with self.assertRaises(Exception): | |||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) | |||
@@ -80,7 +80,7 @@ class TestLossUtils(unittest.TestCase): | |||
a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) | |||
self.assertEqual(tuple(a.size()), (3, 5)) | |||
self.assertEqual(tuple(b.size()), (15,)) | |||
def test_unpad(self): | |||
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) | |||
self.assertEqual(tuple(a.size()), (5, 8, 3)) | |||
@@ -3,8 +3,8 @@ import unittest | |||
import numpy as np | |||
import torch | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.metrics import BMESF1PreRecMetric | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import BMESF1PreRecMetric | |||
from fastNLP.core.metrics import _pred_topk, _accuracy_topk | |||
@@ -14,24 +14,24 @@ class TestAccuracyMetric(unittest.TestCase): | |||
pred_dict = {"pred": torch.zeros(4, 3)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = AccuracyMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
def test_AccuracyMetric2(self): | |||
# (2) with corrupted size | |||
try: | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric = AccuracyMetric() | |||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
print("No exception catches.") | |||
def test_AccuracyMetric3(self): | |||
# (3) the second batch is corrupted size | |||
try: | |||
@@ -39,17 +39,17 @@ class TestAccuracyMetric(unittest.TestCase): | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | |||
target_dict = {'target': torch.zeros(4)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
print(metric.get_metric()) | |||
except Exception as e: | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric4(self): | |||
# (5) check reset | |||
metric = AccuracyMetric() | |||
@@ -61,7 +61,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
self.assertTrue(isinstance(res, dict)) | |||
self.assertTrue("acc" in res) | |||
self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) | |||
def test_AccuaryMetric5(self): | |||
# (5) check reset | |||
metric = AccuracyMetric() | |||
@@ -71,7 +71,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
res = metric.get_metric(reset=False) | |||
ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() | |||
self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||
def test_AccuaryMetric6(self): | |||
# (6) check numpy array is not acceptable | |||
try: | |||
@@ -83,7 +83,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric7(self): | |||
# (7) check map, match | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
@@ -93,7 +93,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
res = metric.get_metric() | |||
ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() | |||
self.assertAlmostEqual(res["acc"], float(ans), places=4) | |||
def test_AccuaryMetric8(self): | |||
try: | |||
metric = AccuracyMetric(pred='predictions', target='targets') | |||
@@ -105,7 +105,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric9(self): | |||
# (9) check map, include unused | |||
try: | |||
@@ -118,12 +118,12 @@ class TestAccuracyMetric(unittest.TestCase): | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_AccuaryMetric10(self): | |||
# (10) check _fast_metric | |||
try: | |||
metric = AccuracyMetric() | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3)*3} | |||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | |||
target_dict = {'targets': torch.zeros(4, 3)} | |||
metric(pred_dict=pred_dict, target_dict=target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | |||
@@ -131,7 +131,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||
print(e) | |||
return | |||
self.assertTrue(True, False), "No exception catches." | |||
def test_seq_len(self): | |||
N = 256 | |||
seq_len = torch.zeros(N).long() | |||
@@ -145,20 +145,21 @@ class TestAccuracyMetric(unittest.TestCase): | |||
metric(pred_dict=pred, target_dict=target) | |||
self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | |||
class SpanF1PreRecMetric(unittest.TestCase): | |||
def test_case1(self): | |||
from fastNLP.core.metrics import _bmes_tag_to_spans | |||
from fastNLP.core.metrics import _bio_tag_to_spans | |||
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | |||
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | |||
expect_bmes_res = set() | |||
expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)), | |||
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) | |||
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) | |||
expect_bio_res = set() | |||
expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), | |||
('6', (4, 5)), ('7', (6, 7))]) | |||
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) | |||
('6', (4, 5)), ('7', (6, 7))]) | |||
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) | |||
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | |||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
@@ -171,19 +172,19 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | |||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
def test_case2(self): | |||
# 测试不带label的 | |||
from fastNLP.core.metrics import _bmes_tag_to_spans | |||
from fastNLP.core.metrics import _bio_tag_to_spans | |||
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | |||
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | |||
expect_bmes_res = set() | |||
expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) | |||
expect_bio_res = set() | |||
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) | |||
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) | |||
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) | |||
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | |||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | |||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | |||
@@ -195,7 +196,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
# bio_strs = np.random.choice(bio, size=100) | |||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | |||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | |||
def tese_case3(self): | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from collections import Counter | |||
@@ -213,7 +214,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
continue | |||
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | |||
return vocab | |||
number_labels = 4 | |||
# bio tag | |||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | |||
@@ -221,26 +222,26 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | |||
bio_sequence = torch.FloatTensor( | |||
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | |||
0.0470, 0.0971], | |||
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||
0.7987, -0.3970], | |||
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||
0.6880, 1.4348], | |||
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||
-1.6876, -0.8917], | |||
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||
1.4217, 0.2622]], | |||
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||
1.3592, -0.8973], | |||
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||
-0.4025, -0.3417], | |||
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||
0.2861, -0.3966], | |||
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||
0.0213, 1.4777], | |||
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||
1.3024, 0.2001]]] | |||
0.0470, 0.0971], | |||
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||
0.7987, -0.3970], | |||
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||
0.6880, 1.4348], | |||
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||
-1.6876, -0.8917], | |||
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||
1.4217, 0.2622]], | |||
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||
1.3592, -0.8973], | |||
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||
-0.4025, -0.3417], | |||
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||
0.2861, -0.3966], | |||
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||
0.0213, 1.4777], | |||
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||
1.3024, 0.2001]]] | |||
) | |||
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | |||
[5., 6., 8., 6., 0.]]) | |||
@@ -250,8 +251,8 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | |||
'f': 0.12499999999994846} | |||
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | |||
#bmes tag | |||
# bmes tag | |||
bmes_sequence = torch.FloatTensor( | |||
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | |||
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | |||
@@ -268,7 +269,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | |||
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | |||
-0.3779, -0.3195]], | |||
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | |||
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | |||
-0.1103, 0.4417], | |||
@@ -285,22 +286,22 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | |||
-0.7344, -1.2046]]] | |||
) | |||
bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], | |||
[ 6., 15., 6., 15., 5.]]) | |||
bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.], | |||
[6., 15., 6., 15., 5.]]) | |||
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | |||
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | |||
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | |||
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | |||
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | |||
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | |||
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | |||
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | |||
'pre': 0.499999999999995, 'rec': 0.499999999999995} | |||
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | |||
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | |||
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | |||
# from allennlp.training.metrics import SpanBasedF1Measure | |||
@@ -349,6 +350,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | |||
# fastnlp_bmes_metric.get_metric()) | |||
class TestBMESF1PreRecMetric(unittest.TestCase): | |||
def test_case1(self): | |||
seq_lens = torch.LongTensor([4, 2]) | |||
@@ -356,20 +358,20 @@ class TestBMESF1PreRecMetric(unittest.TestCase): | |||
target = torch.LongTensor([[0, 1, 2, 3], | |||
[3, 3, 0, 0]]) | |||
pred_dict = {'pred': pred} | |||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||
target_dict = {'target': target, 'seq_len': seq_lens} | |||
metric = BMESF1PreRecMetric() | |||
metric(pred_dict, target_dict) | |||
metric.get_metric() | |||
def test_case2(self): | |||
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | |||
seq_lens = torch.LongTensor([4, 2]) | |||
target = torch.LongTensor([[0, 1, 2, 3], | |||
[3, 3, 0, 0]]) | |||
pred_dict = {'pred': target} | |||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||
target_dict = {'target': target, 'seq_len': seq_lens} | |||
metric = BMESF1PreRecMetric() | |||
metric(pred_dict, target_dict) | |||
self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) | |||
@@ -381,5 +383,5 @@ class TestUsefulFunctions(unittest.TestCase): | |||
# multi-class | |||
_ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) | |||
_ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) | |||
# 跑通即可 |
@@ -2,7 +2,7 @@ import unittest | |||
import torch | |||
from fastNLP.core.optimizer import SGD, Adam | |||
from fastNLP import SGD, Adam | |||
class TestOptim(unittest.TestCase): | |||
@@ -12,42 +12,42 @@ class TestOptim(unittest.TestCase): | |||
self.assertTrue("momentum" in optim.__dict__["settings"]) | |||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||
optim = SGD(lr=0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||
optim = SGD(lr=0.002, momentum=0.989) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | |||
optim = SGD(0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue(isinstance(res, torch.optim.SGD)) | |||
with self.assertRaises(TypeError): | |||
_ = SGD("???") | |||
with self.assertRaises(TypeError): | |||
_ = SGD(0.001, lr=0.002) | |||
def test_Adam(self): | |||
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue("lr" in optim.__dict__["settings"]) | |||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | |||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue(isinstance(res, torch.optim.Adam)) | |||
optim = Adam(lr=0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
self.assertTrue(isinstance(res, torch.optim.Adam)) | |||
optim = Adam(lr=0.002, weight_decay=0.989) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | |||
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | |||
optim = Adam(0.001) | |||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | |||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | |||
@@ -3,9 +3,9 @@ import unittest | |||
import torch | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.sampler import SequentialSampler, RandomSampler, \ | |||
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler | |||
from fastNLP import DataSet | |||
from fastNLP import SequentialSampler, RandomSampler, BucketSampler | |||
from fastNLP.core.sampler import k_means_1d, k_means_bucketing, simple_sort_bucketing | |||
class TestSampler(unittest.TestCase): | |||
@@ -1,32 +1,25 @@ | |||
import unittest | |||
import numpy as np | |||
from torch import nn | |||
import time | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import Tester | |||
data_name = "pku_training.utf8" | |||
pickle_path = "data_for_tests" | |||
import numpy as np | |||
import torch.nn.functional as F | |||
from torch import nn | |||
import time | |||
from fastNLP.core.utils import _CheckError | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
from fastNLP.core.losses import CrossEntropyLoss | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.optimizer import SGD | |||
from fastNLP.core.tester import Tester | |||
from fastNLP.models.base_model import NaiveClassifier | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
@@ -39,6 +32,7 @@ def prepare_fake_dataset2(*args, size=100): | |||
data[arg] = np.random.randn(size, 5) | |||
return DataSet(data=data) | |||
class TestTester(unittest.TestCase): | |||
def test_case_1(self): | |||
# 检查报错提示能否正确提醒用户 | |||
@@ -46,10 +40,12 @@ class TestTester(unittest.TestCase): | |||
dataset.rename_field('x_unused', 'x2') | |||
dataset.set_input('x1', 'x2') | |||
dataset.set_target('y', 'x1') | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
@@ -57,7 +53,7 @@ class TestTester(unittest.TestCase): | |||
time.sleep(0.1) | |||
# loss = F.cross_entropy(x, y) | |||
return {'preds': x} | |||
model = Model() | |||
with self.assertRaises(NameError): | |||
tester = Tester( | |||
@@ -5,25 +5,24 @@ import numpy as np | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
from fastNLP.core.losses import CrossEntropyLoss | |||
from fastNLP.core.metrics import AccuracyMetric | |||
from fastNLP.core.optimizer import SGD | |||
from fastNLP.core.trainer import Trainer | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
from fastNLP import BCELoss | |||
from fastNLP import CrossEntropyLoss | |||
from fastNLP import AccuracyMetric | |||
from fastNLP import SGD | |||
from fastNLP import Trainer | |||
from fastNLP.models.base_model import NaiveClassifier | |||
def prepare_fake_dataset(): | |||
mean = np.array([-3, -3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
mean = np.array([3, 3]) | |||
cov = np.array([[1, 0], [0, 1]]) | |||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | |||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | |||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | |||
return data_set | |||
@@ -42,11 +41,11 @@ class TrainerTestGround(unittest.TestCase): | |||
data_set = prepare_fake_dataset() | |||
data_set.set_input("x", flag=True) | |||
data_set.set_target("y", flag=True) | |||
train_set, dev_set = data_set.split(0.3) | |||
model = NaiveClassifier(2, 1) | |||
trainer = Trainer(train_set, model, | |||
loss=BCELoss(pred="predict", target="y"), | |||
metrics=AccuracyMetric(pred="predict", target="y"), | |||
@@ -63,26 +62,26 @@ class TrainerTestGround(unittest.TestCase): | |||
""" | |||
# 应该正确运行 | |||
""" | |||
def test_trainer_suggestion1(self): | |||
# 检查报错提示能否正确提醒用户。 | |||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | |||
dataset = prepare_fake_dataset2('x') | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'loss': loss} | |||
model = Model() | |||
with self.assertRaises(RuntimeError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
@@ -97,25 +96,25 @@ class TrainerTestGround(unittest.TestCase): | |||
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | |||
""" | |||
def test_trainer_suggestion2(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入forward需要的数据,看是否可以运行 | |||
dataset = prepare_fake_dataset2('x1', 'x2') | |||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
@@ -127,25 +126,25 @@ class TrainerTestGround(unittest.TestCase): | |||
""" | |||
# 应该正确运行 | |||
""" | |||
def test_trainer_suggestion3(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入forward需要的数据,但是forward没有返回loss这个key | |||
dataset = prepare_fake_dataset2('x1', 'x2') | |||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'wrong_loss_key': loss} | |||
model = Model() | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
@@ -155,23 +154,25 @@ class TrainerTestGround(unittest.TestCase): | |||
print_every=2 | |||
) | |||
trainer.train() | |||
def test_trainer_suggestion4(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入forward需要的数据,是否可以正确提示unused | |||
dataset = prepare_fake_dataset2('x1', 'x2') | |||
dataset.set_input('x1', 'x2', 'y', flag=True) | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'losses': loss} | |||
model = Model() | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
@@ -180,7 +181,7 @@ class TrainerTestGround(unittest.TestCase): | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
def test_trainer_suggestion5(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | |||
@@ -188,17 +189,19 @@ class TrainerTestGround(unittest.TestCase): | |||
dataset.rename_field('x_unused', 'x2') | |||
dataset.set_input('x1', 'x2', 'y') | |||
dataset.set_target('y') | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2, y): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
x = x1 + x2 | |||
loss = F.cross_entropy(x, y) | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
@@ -206,7 +209,7 @@ class TrainerTestGround(unittest.TestCase): | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
def test_trainer_suggestion6(self): | |||
# 检查报错提示能否正确提醒用户 | |||
# 这里传入多余参数,让其duplicate | |||
@@ -214,10 +217,12 @@ class TrainerTestGround(unittest.TestCase): | |||
dataset.rename_field('x_unused', 'x2') | |||
dataset.set_input('x1', 'x2') | |||
dataset.set_target('y', 'x1') | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.fc = nn.Linear(5, 4) | |||
def forward(self, x1, x2): | |||
x1 = self.fc(x1) | |||
x2 = self.fc(x2) | |||
@@ -225,7 +230,7 @@ class TrainerTestGround(unittest.TestCase): | |||
time.sleep(0.1) | |||
# loss = F.cross_entropy(x, y) | |||
return {'preds': x} | |||
model = Model() | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
@@ -236,7 +241,7 @@ class TrainerTestGround(unittest.TestCase): | |||
metrics=AccuracyMetric(), | |||
use_tqdm=False, | |||
print_every=2) | |||
""" | |||
def test_trainer_multiprocess(self): | |||
dataset = prepare_fake_dataset2('x1', 'x2') | |||
@@ -1,8 +1,7 @@ | |||
import unittest | |||
import _pickle | |||
from fastNLP import cache_results | |||
from fastNLP.io.embed_loader import EmbedLoader | |||
from fastNLP.io import EmbedLoader | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
import time | |||
@@ -11,11 +10,13 @@ import torch | |||
from torch import nn | |||
from fastNLP.core.utils import _move_model_to_device, _get_model_device | |||
class Model(nn.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.param = nn.Parameter(torch.zeros(0)) | |||
class TestMoveModelDeivce(unittest.TestCase): | |||
def test_case1(self): | |||
# 测试str | |||
@@ -35,36 +36,36 @@ class TestMoveModelDeivce(unittest.TestCase): | |||
_move_model_to_device(model, 'cuda:1000') | |||
# 测试None | |||
model = _move_model_to_device(model, None) | |||
def test_case2(self): | |||
# 测试使用int初始化 | |||
model = Model() | |||
if torch.cuda.is_available(): | |||
model = _move_model_to_device(model, 0) | |||
assert model.param.device == torch.device('cuda:0') | |||
assert model.param.device==torch.device('cuda:0'), "The model should be in " | |||
assert model.param.device == torch.device('cuda:0'), "The model should be in " | |||
with self.assertRaises(Exception): | |||
_move_model_to_device(model, 100) | |||
with self.assertRaises(Exception): | |||
_move_model_to_device(model, -1) | |||
def test_case3(self): | |||
# 测试None | |||
model = Model() | |||
device = _get_model_device(model) | |||
model = _move_model_to_device(model, None) | |||
assert device==_get_model_device(model), "The device should not change." | |||
assert device == _get_model_device(model), "The device should not change." | |||
if torch.cuda.is_available(): | |||
model.cuda() | |||
device = _get_model_device(model) | |||
model = _move_model_to_device(model, None) | |||
assert device==_get_model_device(model), "The device should not change." | |||
assert device == _get_model_device(model), "The device should not change." | |||
model = nn.DataParallel(model, device_ids=[0]) | |||
_move_model_to_device(model, None) | |||
with self.assertRaises(Exception): | |||
_move_model_to_device(model, 'cpu') | |||
def test_case4(self): | |||
# 测试传入list的内容 | |||
model = Model() | |||
@@ -78,15 +79,17 @@ class TestMoveModelDeivce(unittest.TestCase): | |||
device = [torch.device('cuda:0'), torch.device('cuda:0')] | |||
with self.assertRaises(Exception): | |||
_model = _move_model_to_device(model, device) | |||
if torch.cuda.device_count()>1: | |||
if torch.cuda.device_count() > 1: | |||
device = [0, 1] | |||
_model = _move_model_to_device(model, device) | |||
assert isinstance(_model, nn.DataParallel) | |||
device = ['cuda', 'cuda:1'] | |||
with self.assertRaises(Exception): | |||
_move_model_to_device(model, device) | |||
def test_case5(self): | |||
if not torch.cuda.is_available(): | |||
return | |||
# torch.device() | |||
device = torch.device('cpu') | |||
model = Model() | |||
@@ -106,10 +109,11 @@ def process_data_1(embed_file, cws_train): | |||
d = DataSet() | |||
for line in f: | |||
line = line.strip() | |||
if len(line)>0: | |||
if len(line) > 0: | |||
d.append(Instance(raw=line)) | |||
return embed, vocab, d | |||
class TestCache(unittest.TestCase): | |||
def test_cache_save(self): | |||
try: | |||
@@ -127,10 +131,10 @@ class TestCache(unittest.TestCase): | |||
end_time = time.time() | |||
read_time = end_time - start_time | |||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
self.assertGreater(pre_time-0.5, read_time) | |||
self.assertGreater(pre_time - 0.5, read_time) | |||
finally: | |||
os.remove('test/demo1.pkl') | |||
def test_cache_save_overwrite_path(self): | |||
try: | |||
start_time = time.time() | |||
@@ -149,10 +153,10 @@ class TestCache(unittest.TestCase): | |||
end_time = time.time() | |||
read_time = end_time - start_time | |||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
self.assertGreater(pre_time-0.5, read_time) | |||
self.assertGreater(pre_time - 0.5, read_time) | |||
finally: | |||
os.remove('test/demo_overwrite.pkl') | |||
def test_cache_refresh(self): | |||
try: | |||
start_time = time.time() | |||
@@ -171,34 +175,38 @@ class TestCache(unittest.TestCase): | |||
end_time = time.time() | |||
read_time = end_time - start_time | |||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | |||
self.assertGreater(0.1, pre_time-read_time) | |||
self.assertGreater(0.1, pre_time - read_time) | |||
finally: | |||
os.remove('test/demo1.pkl') | |||
def test_duplicate_keyword(self): | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_verbose(a, _verbose): | |||
pass | |||
func_verbose(0, 1) | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_cache(a, _cache_fp): | |||
pass | |||
func_cache(1, 2) | |||
with self.assertRaises(RuntimeError): | |||
@cache_results(None) | |||
def func_refresh(a, _refresh): | |||
pass | |||
func_refresh(1, 2) | |||
def test_create_cache_dir(self): | |||
@cache_results('test/demo1/demo.pkl') | |||
def cache(): | |||
return 1, 2 | |||
try: | |||
results = cache() | |||
print(results) | |||
finally: | |||
os.remove('test/demo1/demo.pkl') | |||
os.rmdir('test/demo1') | |||
os.rmdir('test/demo1') |
@@ -1,9 +1,9 @@ | |||
import unittest | |||
from collections import Counter | |||
from fastNLP.core.vocabulary import Vocabulary | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP import Vocabulary | |||
from fastNLP import DataSet | |||
from fastNLP import Instance | |||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
"works", "well", "in", "most", "cases", "scales", "well"] | |||
@@ -12,92 +12,93 @@ counter = Counter(text) | |||
class TestAdd(unittest.TestCase): | |||
def test_add(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab = Vocabulary() | |||
for word in text: | |||
vocab.add(word) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_add_word(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab = Vocabulary() | |||
for word in text: | |||
vocab.add_word(word) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_add_word_lst(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab = Vocabulary() | |||
vocab.add_word_lst(text) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_update(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab = Vocabulary() | |||
vocab.update(text) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_from_dataset(self): | |||
start_char = 65 | |||
num_samples = 10 | |||
# 0 dim | |||
dataset = DataSet() | |||
for i in range(num_samples): | |||
ins = Instance(char=chr(start_char+i)) | |||
ins = Instance(char=chr(start_char + i)) | |||
dataset.append(ins) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='char') | |||
for i in range(num_samples): | |||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||
vocab.index_dataset(dataset, field_name='char') | |||
# 1 dim | |||
dataset = DataSet() | |||
for i in range(num_samples): | |||
ins = Instance(char=[chr(start_char+i)]*6) | |||
ins = Instance(char=[chr(start_char + i)] * 6) | |||
dataset.append(ins) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='char') | |||
for i in range(num_samples): | |||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||
vocab.index_dataset(dataset, field_name='char') | |||
# 2 dim | |||
dataset = DataSet() | |||
for i in range(num_samples): | |||
ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)]) | |||
ins = Instance(char=[[chr(start_char + i) for _ in range(6)] for _ in range(6)]) | |||
dataset.append(ins) | |||
vocab = Vocabulary() | |||
vocab.from_dataset(dataset, field_name='char') | |||
for i in range(num_samples): | |||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||
vocab.index_dataset(dataset, field_name='char') | |||
class TestIndexing(unittest.TestCase): | |||
def test_len(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
vocab = Vocabulary(unknown=None, padding=None) | |||
vocab.update(text) | |||
self.assertEqual(len(vocab), len(counter)) | |||
def test_contains(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
vocab = Vocabulary(unknown=None) | |||
vocab.update(text) | |||
self.assertTrue(text[-1] in vocab) | |||
self.assertFalse("~!@#" in vocab) | |||
self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | |||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||
def test_index(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab = Vocabulary() | |||
vocab.update(text) | |||
res = [vocab[w] for w in set(text)] | |||
self.assertEqual(len(res), len(set(res))) | |||
res = [vocab.to_index(w) for w in set(text)] | |||
self.assertEqual(len(res), len(set(res))) | |||
def test_to_word(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab = Vocabulary() | |||
vocab.update(text) | |||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||
def test_iteration(self): | |||
vocab = Vocabulary() | |||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | |||
@@ -110,26 +111,26 @@ class TestIndexing(unittest.TestCase): | |||
class TestOther(unittest.TestCase): | |||
def test_additional_update(self): | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab = Vocabulary() | |||
vocab.update(text) | |||
_ = vocab["well"] | |||
self.assertEqual(vocab.rebuild, False) | |||
vocab.add("hahaha") | |||
self.assertEqual(vocab.rebuild, True) | |||
_ = vocab["hahaha"] | |||
self.assertEqual(vocab.rebuild, False) | |||
self.assertTrue("hahaha" in vocab) | |||
def test_warning(self): | |||
vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||
vocab = Vocabulary(max_size=len(set(text))) | |||
vocab.update(text) | |||
self.assertEqual(vocab.rebuild, True) | |||
print(len(vocab)) | |||
self.assertEqual(vocab.rebuild, False) | |||
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | |||
# this will print a warning | |||
self.assertEqual(vocab.rebuild, True) |