Browse Source

core部分的测试和一些小修改

tags/v0.4.10
ChenXin 5 years ago
parent
commit
4926b33df0
21 changed files with 324 additions and 305 deletions
  1. +2
    -1
      fastNLP/__init__.py
  2. +1
    -1
      fastNLP/core/__init__.py
  3. +11
    -4
      fastNLP/core/callback.py
  4. +1
    -0
      fastNLP/core/field.py
  5. +5
    -0
      fastNLP/core/metrics.py
  6. +3
    -3
      fastNLP/core/trainer.py
  7. +1
    -1
      fastNLP/core/utils.py
  8. +37
    -32
      fastNLP/core/vocabulary.py
  9. +21
    -20
      test/core/test_batch.py
  10. +21
    -33
      test/core/test_callbacks.py
  11. +7
    -8
      test/core/test_dataset.py
  12. +5
    -5
      test/core/test_field.py
  13. +6
    -6
      test/core/test_instance.py
  14. +12
    -12
      test/core/test_loss.py
  15. +63
    -61
      test/core/test_metrics.py
  16. +9
    -9
      test/core/test_optimizer.py
  17. +3
    -3
      test/core/test_sampler.py
  18. +13
    -17
      test/core/test_tester.py
  19. +38
    -33
      test/core/test_trainer.py
  20. +28
    -20
      test/core/test_utils.py
  21. +37
    -36
      test/core/test_vocabulary.py

+ 2
- 1
fastNLP/__init__.py View File

@@ -13,7 +13,8 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的
__all__ = ["Instance", "FieldArray", "Batch", "Vocabulary", "DataSet", __all__ = ["Instance", "FieldArray", "Batch", "Vocabulary", "DataSet",
"Trainer", "Tester", "Callback", "Trainer", "Tester", "Callback",
"Padder", "AutoPadder", "EngChar2DPadder", "Padder", "AutoPadder", "EngChar2DPadder",
"AccuracyMetric", "Optimizer", "SGD", "Adam",
"AccuracyMetric", "BMESF1PreRecMetric", "SpanFPreRecMetric", "SQuADMetric",
"Optimizer", "SGD", "Adam",
"Sampler", "SequentialSampler", "BucketSampler", "RandomSampler", "Sampler", "SequentialSampler", "BucketSampler", "RandomSampler",
"LossFunc", "CrossEntropyLoss", "L1Loss", "BCELoss", "NLLLoss", "LossInForward", "LossFunc", "CrossEntropyLoss", "L1Loss", "BCELoss", "NLLLoss", "LossInForward",
"cache_results"] "cache_results"]


+ 1
- 1
fastNLP/core/__init__.py View File

@@ -17,7 +17,7 @@ from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
from .instance import Instance from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
from .metrics import AccuracyMetric
from .metrics import AccuracyMetric, BMESF1PreRecMetric, SpanFPreRecMetric, SQuADMetric
from .optimizer import Optimizer, SGD, Adam from .optimizer import Optimizer, SGD, Adam
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester from .tester import Tester


+ 11
- 4
fastNLP/core/callback.py View File

@@ -236,6 +236,7 @@ class CallbackManager(Callback):
for env_name, env_val in env.items(): for env_name, env_val in env.items():
for callback in self.callbacks: for callback in self.callbacks:
print(callback, env_name, env_val )
setattr(callback, '_' + env_name, env_val) # Callback.trainer setattr(callback, '_' + env_name, env_val) # Callback.trainer
@_transfer @_transfer
@@ -425,19 +426,25 @@ class LRFinder(Callback):
super(LRFinder, self).__init__() super(LRFinder, self).__init__()
self.start_lr, self.end_lr = start_lr, end_lr self.start_lr, self.end_lr = start_lr, end_lr
self.num_it = self.batch_per_epoch
self.stop = False self.stop = False
self.best_loss = 0. self.best_loss = 0.
self.best_lr = None self.best_lr = None
self.loss_history = [] self.loss_history = []
self.smooth_value = SmoothValue(0.8) self.smooth_value = SmoothValue(0.8)
self.opt = None self.opt = None
scale = (self.end_lr - self.start_lr) / self.num_it
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it))
self.find = None self.find = None
self.loader = ModelLoader() self.loader = ModelLoader()
@property
def lr_gen(self):
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch
return (self.start_lr + scale * (step + 1) for step in range(self.batch_per_epoch))
@property
def num_it(self):
return self.batch_per_epoch
def on_epoch_begin(self): def on_epoch_begin(self):
if self.epoch == 1: # first epoch if self.epoch == 1: # first epoch
self.opt = self.trainer.optimizer # pytorch optimizer self.opt = self.trainer.optimizer # pytorch optimizer


+ 1
- 0
fastNLP/core/field.py View File

@@ -418,6 +418,7 @@ class AutoPadder(Padder):
return False return False


def __call__(self, contents, field_name, field_ele_dtype): def __call__(self, contents, field_name, field_ele_dtype):
if not _is_iterable(contents[0]): if not _is_iterable(contents[0]):
array = np.array([content for content in contents], dtype=field_ele_dtype) array = np.array([content for content in contents], dtype=field_ele_dtype)
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents):


+ 5
- 0
fastNLP/core/metrics.py View File

@@ -430,6 +430,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None):


class SpanFPreRecMetric(MetricBase): class SpanFPreRecMetric(MetricBase):
""" """
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric`


在序列标注问题中,以span的方式计算F, pre, rec. 在序列标注问题中,以span的方式计算F, pre, rec.
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例)
@@ -619,6 +620,8 @@ class SpanFPreRecMetric(MetricBase):


class BMESF1PreRecMetric(MetricBase): class BMESF1PreRecMetric(MetricBase):
""" """
别名::class:`fastNLP.BMESF1PreRecMetric` :class:`fastNLP.core.metrics.BMESF1PreRecMetric`

按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B,
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B
@@ -826,6 +829,8 @@ def _pred_topk(y_prob, k=1):


class SQuADMetric(MetricBase): class SQuADMetric(MetricBase):
""" """
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric`

SQuAD数据集metric SQuAD数据集metric
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` :param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1`


+ 3
- 3
fastNLP/core/trainer.py View File

@@ -350,7 +350,7 @@ class Trainer(object):
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。
:param nn.modules model: 待训练的模型 :param nn.modules model: 待训练的模型
:param torch.optim.Optimizer optimizer: 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器
:param int batch_size: 训练和验证的时候的batch大小。 :param int batch_size: 训练和验证的时候的batch大小。
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward`
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler`
@@ -403,7 +403,6 @@ class Trainer(object):
callbacks=None, callbacks=None,
check_code_level=0): check_code_level=0):
super(Trainer, self).__init__() super(Trainer, self).__init__()
if not isinstance(train_data, DataSet): if not isinstance(train_data, DataSet):
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.")
if not isinstance(model, nn.Module): if not isinstance(model, nn.Module):
@@ -468,7 +467,7 @@ class Trainer(object):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs len(self.train_data) % self.batch_size != 0)) * self.n_epochs
self.model = _move_model_to_device(self.model, device=device) self.model = _move_model_to_device(self.model, device=device)
if isinstance(optimizer, torch.optim.Optimizer): if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer self.optimizer = optimizer
elif isinstance(optimizer, Optimizer): elif isinstance(optimizer, Optimizer):
@@ -493,6 +492,7 @@ class Trainer(object):
self.step = 0 self.step = 0
self.start_time = None # start timestamp self.start_time = None # start timestamp
print("callback_manager")
self.callback_manager = CallbackManager(env={"trainer": self}, self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks) callbacks=callbacks)


+ 1
- 1
fastNLP/core/utils.py View File

@@ -616,7 +616,7 @@ def seq_lens_to_masks(seq_lens, float=False):
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}."
batch_size = seq_lens.size(0) batch_size = seq_lens.size(0)
max_len = seq_lens.max() max_len = seq_lens.max()
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device)
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long()
masks = indexes.lt(seq_lens.unsqueeze(1)) masks = indexes.lt(seq_lens.unsqueeze(1))


if float: if float:


+ 37
- 32
fastNLP/core/vocabulary.py View File

@@ -2,16 +2,18 @@ from functools import wraps
from collections import Counter from collections import Counter
from .dataset import DataSet from .dataset import DataSet



def _check_build_vocab(func): def _check_build_vocab(func):
"""A decorator to make sure the indexing is built before used. """A decorator to make sure the indexing is built before used.


""" """
@wraps(func) # to solve missing docstring
@wraps(func) # to solve missing docstring
def _wrapper(self, *args, **kwargs): def _wrapper(self, *args, **kwargs):
if self.word2idx is None or self.rebuild is True: if self.word2idx is None or self.rebuild is True:
self.build_vocab() self.build_vocab()
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return _wrapper return _wrapper




@@ -19,7 +21,8 @@ def _check_build_status(func):
"""A decorator to check whether the vocabulary updates after the last build. """A decorator to check whether the vocabulary updates after the last build.


""" """
@wraps(func) # to solve missing docstring
@wraps(func) # to solve missing docstring
def _wrapper(self, *args, **kwargs): def _wrapper(self, *args, **kwargs):
if self.rebuild is False: if self.rebuild is False:
self.rebuild = True self.rebuild = True
@@ -28,7 +31,7 @@ def _check_build_status(func):
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( "Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__)) self.max_size, func.__name__))
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
return _wrapper return _wrapper




@@ -50,15 +53,15 @@ class Vocabulary(object):
若为 ``None`` , 则不限制大小. Default: ``None`` 若为 ``None`` , 则不限制大小. Default: ``None``
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. :param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1.
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` 若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None``
:param str padding: padding的字符. 如果设置为 ``None`` ,
:param str optional padding: padding的字符. 如果设置为 ``None`` ,
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. 则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况.
Default: '<pad>' Default: '<pad>'
:param str unknow: unknow的字符,所有未被记录的词在转为 `int` 时将被视为unknown.
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown.
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. 如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小.
为 ``None`` 的情况多在为label建立Vocabulary的情况. 为 ``None`` 的情况多在为label建立Vocabulary的情况.
Default: '<unk>' Default: '<unk>'
""" """
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'):
self.max_size = max_size self.max_size = max_size
self.min_freq = min_freq self.min_freq = min_freq
@@ -68,7 +71,7 @@ class Vocabulary(object):
self.word2idx = None self.word2idx = None
self.idx2word = None self.idx2word = None
self.rebuild = True self.rebuild = True
@_check_build_status @_check_build_status
def update(self, word_lst): def update(self, word_lst):
"""依次增加序列中词在词典中的出现频率 """依次增加序列中词在词典中的出现频率
@@ -76,7 +79,7 @@ class Vocabulary(object):
:param list word_lst: a list of strings :param list word_lst: a list of strings
""" """
self.word_count.update(word_lst) self.word_count.update(word_lst)
@_check_build_status @_check_build_status
def add(self, word): def add(self, word):
""" """
@@ -85,7 +88,7 @@ class Vocabulary(object):
:param str word: 新词 :param str word: 新词
""" """
self.word_count[word] += 1 self.word_count[word] += 1
@_check_build_status @_check_build_status
def add_word(self, word): def add_word(self, word):
""" """
@@ -94,7 +97,7 @@ class Vocabulary(object):
:param str word: 新词 :param str word: 新词
""" """
self.add(word) self.add(word)
@_check_build_status @_check_build_status
def add_word_lst(self, word_lst): def add_word_lst(self, word_lst):
""" """
@@ -103,7 +106,7 @@ class Vocabulary(object):
:param list[str] word_lst: 词的序列 :param list[str] word_lst: 词的序列
""" """
self.update(word_lst) self.update(word_lst)
def build_vocab(self): def build_vocab(self):
""" """
根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, 根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小,
@@ -116,7 +119,7 @@ class Vocabulary(object):
self.word2idx[self.padding] = len(self.word2idx) self.word2idx[self.padding] = len(self.word2idx)
if self.unknown is not None: if self.unknown is not None:
self.word2idx[self.unknown] = len(self.word2idx) self.word2idx[self.unknown] = len(self.word2idx)
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None max_size = min(self.max_size, len(self.word_count)) if self.max_size else None
words = self.word_count.most_common(max_size) words = self.word_count.most_common(max_size)
if self.min_freq is not None: if self.min_freq is not None:
@@ -127,18 +130,18 @@ class Vocabulary(object):
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab() self.build_reverse_vocab()
self.rebuild = False self.rebuild = False
def build_reverse_vocab(self): def build_reverse_vocab(self):
""" """
基于 "word to index" dict, 构建 "index to word" dict. 基于 "word to index" dict, 构建 "index to word" dict.


""" """
self.idx2word = {i: w for w, i in self.word2idx.items()} self.idx2word = {i: w for w, i in self.word2idx.items()}
@_check_build_vocab @_check_build_vocab
def __len__(self): def __len__(self):
return len(self.word2idx) return len(self.word2idx)
@_check_build_vocab @_check_build_vocab
def __contains__(self, item): def __contains__(self, item):
""" """
@@ -148,7 +151,7 @@ class Vocabulary(object):
:return: True or False :return: True or False
""" """
return item in self.word2idx return item in self.word2idx
def has_word(self, w): def has_word(self, w):
""" """
检查词是否被记录 检查词是否被记录
@@ -163,7 +166,7 @@ class Vocabulary(object):
:return: ``True`` or ``False`` :return: ``True`` or ``False``
""" """
return self.__contains__(w) return self.__contains__(w)
@_check_build_vocab @_check_build_vocab
def __getitem__(self, w): def __getitem__(self, w):
""" """
@@ -177,7 +180,7 @@ class Vocabulary(object):
return self.word2idx[self.unknown] return self.word2idx[self.unknown]
else: else:
raise ValueError("word {} not in vocabulary".format(w)) raise ValueError("word {} not in vocabulary".format(w))
@_check_build_vocab @_check_build_vocab
def index_dataset(self, *datasets, field_name, new_field_name=None): def index_dataset(self, *datasets, field_name, new_field_name=None):
""" """
@@ -194,6 +197,7 @@ class Vocabulary(object):
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. :param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field.
Default: ``None`` Default: ``None``
""" """
def index_instance(ins): def index_instance(ins):
""" """
有几种情况, str, 1d-list, 2d-list 有几种情况, str, 1d-list, 2d-list
@@ -209,8 +213,8 @@ class Vocabulary(object):
else: else:
if isinstance(field[0][0], list): if isinstance(field[0][0], list):
raise RuntimeError("Only support field with 2 dimensions.") raise RuntimeError("Only support field with 2 dimensions.")
return[[self.to_index(c) for c in w] for w in field]
return [[self.to_index(c) for c in w] for w in field]
if new_field_name is None: if new_field_name is None:
new_field_name = field_name new_field_name = field_name
for idx, dataset in enumerate(datasets): for idx, dataset in enumerate(datasets):
@@ -222,7 +226,7 @@ class Vocabulary(object):
raise e raise e
else: else:
raise RuntimeError("Only DataSet type is allowed.") raise RuntimeError("Only DataSet type is allowed.")
def from_dataset(self, *datasets, field_name): def from_dataset(self, *datasets, field_name):
""" """
使用dataset的对应field中词构建词典 使用dataset的对应field中词构建词典
@@ -243,7 +247,7 @@ class Vocabulary(object):
field_name = [field_name] field_name = [field_name]
elif not isinstance(field_name, list): elif not isinstance(field_name, list):
raise TypeError('invalid argument field_name: {}'.format(field_name)) raise TypeError('invalid argument field_name: {}'.format(field_name))
def construct_vocab(ins): def construct_vocab(ins):
for fn in field_name: for fn in field_name:
field = ins[fn] field = ins[fn]
@@ -256,6 +260,7 @@ class Vocabulary(object):
if isinstance(field[0][0], list): if isinstance(field[0][0], list):
raise RuntimeError("Only support field with 2 dimensions.") raise RuntimeError("Only support field with 2 dimensions.")
[self.add_word_lst(w) for w in field] [self.add_word_lst(w) for w in field]
for idx, dataset in enumerate(datasets): for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
try: try:
@@ -266,7 +271,7 @@ class Vocabulary(object):
else: else:
raise RuntimeError("Only DataSet type is allowed.") raise RuntimeError("Only DataSet type is allowed.")
return self return self
def to_index(self, w): def to_index(self, w):
""" """
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出
@@ -282,7 +287,7 @@ class Vocabulary(object):
:return int index: the number :return int index: the number
""" """
return self.__getitem__(w) return self.__getitem__(w)
@property @property
@_check_build_vocab @_check_build_vocab
def unknown_idx(self): def unknown_idx(self):
@@ -292,7 +297,7 @@ class Vocabulary(object):
if self.unknown is None: if self.unknown is None:
return None return None
return self.word2idx[self.unknown] return self.word2idx[self.unknown]
@property @property
@_check_build_vocab @_check_build_vocab
def padding_idx(self): def padding_idx(self):
@@ -302,7 +307,7 @@ class Vocabulary(object):
if self.padding is None: if self.padding is None:
return None return None
return self.word2idx[self.padding] return self.word2idx[self.padding]
@_check_build_vocab @_check_build_vocab
def to_word(self, idx): def to_word(self, idx):
""" """
@@ -312,26 +317,26 @@ class Vocabulary(object):
:return str word: the word :return str word: the word
""" """
return self.idx2word[idx] return self.idx2word[idx]
def __getstate__(self): def __getstate__(self):
"""Use to prepare data for pickle. """Use to prepare data for pickle.


""" """
len(self) # make sure vocab has been built
len(self) # make sure vocab has been built
state = self.__dict__.copy() state = self.__dict__.copy()
# no need to pickle idx2word as it can be constructed from word2idx # no need to pickle idx2word as it can be constructed from word2idx
del state['idx2word'] del state['idx2word']
return state return state
def __setstate__(self, state): def __setstate__(self, state):
"""Use to restore state from pickle. """Use to restore state from pickle.


""" """
self.__dict__.update(state) self.__dict__.update(state)
self.build_reverse_vocab() self.build_reverse_vocab()
def __repr__(self): def __repr__(self):
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) return "Vocabulary({}...)".format(list(self.word_count.keys())[:5])
def __iter__(self): def __iter__(self):
return iter(list(self.word_count.keys())) return iter(list(self.word_count.keys()))

+ 21
- 20
test/core/test_batch.py View File

@@ -1,13 +1,12 @@
import time
import unittest import unittest


import numpy as np import numpy as np
import torch import torch


from fastNLP.core.batch import Batch
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.sampler import SequentialSampler
from fastNLP import Batch
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import SequentialSampler




def generate_fake_dataset(num_samples=1000): def generate_fake_dataset(num_samples=1000):
@@ -16,11 +15,11 @@ def generate_fake_dataset(num_samples=1000):
:param num_samples: sample的数量 :param num_samples: sample的数量
:return: :return:
""" """
max_len = 50 max_len = 50
min_len = 10 min_len = 10
num_features = 4 num_features = 4
data_dict = {} data_dict = {}
for i in range(num_features): for i in range(num_features):
data = [] data = []
@@ -28,9 +27,9 @@ def generate_fake_dataset(num_samples=1000):
for length in lengths: for length in lengths:
data.append(np.random.randint(100, size=length)) data.append(np.random.randint(100, size=length))
data_dict[str(i)] = data data_dict[str(i)] = data
dataset = DataSet(data_dict) dataset = DataSet(data_dict)
for i in range(num_features): for i in range(num_features):
if np.random.randint(2) == 0: if np.random.randint(2) == 0:
dataset.set_input(str(i)) dataset.set_input(str(i))
@@ -38,6 +37,7 @@ def generate_fake_dataset(num_samples=1000):
dataset.set_target(str(i)) dataset.set_target(str(i))
return dataset return dataset



def construct_dataset(sentences): def construct_dataset(sentences):
"""Construct a data set from a list of sentences. """Construct a data set from a list of sentences.


@@ -51,18 +51,19 @@ def construct_dataset(sentences):
dataset.append(instance) dataset.append(instance)
return dataset return dataset



class TestCase1(unittest.TestCase): class TestCase1(unittest.TestCase):
def test_simple(self): def test_simple(self):
dataset = construct_dataset( dataset = construct_dataset(
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
dataset.set_target() dataset.set_target()
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
cnt = 0 cnt = 0
for _, _ in batch: for _, _ in batch:
cnt += 1 cnt += 1
self.assertEqual(cnt, 10) self.assertEqual(cnt, 10)
def test_dataset_batching(self): def test_dataset_batching(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.set_input("x") ds.set_input("x")
@@ -74,7 +75,7 @@ class TestCase1(unittest.TestCase):
self.assertEqual(len(y["y"]), 4) self.assertEqual(len(y["y"]), 4)
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4])
self.assertListEqual(list(y["y"][-1]), [5, 6]) self.assertListEqual(list(y["y"][-1]), [5, 6])
def test_list_padding(self): def test_list_padding(self):
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
@@ -84,7 +85,7 @@ class TestCase1(unittest.TestCase):
for x, y in iter: for x, y in iter:
self.assertEqual(x["x"].shape, (4, 4)) self.assertEqual(x["x"].shape, (4, 4))
self.assertEqual(y["y"].shape, (4, 4)) self.assertEqual(y["y"].shape, (4, 4))
def test_numpy_padding(self): def test_numpy_padding(self):
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
@@ -94,7 +95,7 @@ class TestCase1(unittest.TestCase):
for x, y in iter: for x, y in iter:
self.assertEqual(x["x"].shape, (4, 4)) self.assertEqual(x["x"].shape, (4, 4))
self.assertEqual(y["y"].shape, (4, 4)) self.assertEqual(y["y"].shape, (4, 4))
def test_list_to_tensor(self): def test_list_to_tensor(self):
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
@@ -106,7 +107,7 @@ class TestCase1(unittest.TestCase):
self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertEqual(tuple(x["x"].shape), (4, 4))
self.assertTrue(isinstance(y["y"], torch.Tensor)) self.assertTrue(isinstance(y["y"], torch.Tensor))
self.assertEqual(tuple(y["y"].shape), (4, 4)) self.assertEqual(tuple(y["y"].shape), (4, 4))
def test_numpy_to_tensor(self): def test_numpy_to_tensor(self):
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
@@ -118,7 +119,7 @@ class TestCase1(unittest.TestCase):
self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertEqual(tuple(x["x"].shape), (4, 4))
self.assertTrue(isinstance(y["y"], torch.Tensor)) self.assertTrue(isinstance(y["y"], torch.Tensor))
self.assertEqual(tuple(y["y"].shape), (4, 4)) self.assertEqual(tuple(y["y"].shape), (4, 4))
def test_list_of_list_to_tensor(self): def test_list_of_list_to_tensor(self):
ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] + ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] +
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)])
@@ -130,7 +131,7 @@ class TestCase1(unittest.TestCase):
self.assertEqual(tuple(x["x"].shape), (4, 4)) self.assertEqual(tuple(x["x"].shape), (4, 4))
self.assertTrue(isinstance(y["y"], torch.Tensor)) self.assertTrue(isinstance(y["y"], torch.Tensor))
self.assertEqual(tuple(y["y"].shape), (4, 4)) self.assertEqual(tuple(y["y"].shape), (4, 4))
def test_list_of_numpy_to_tensor(self): def test_list_of_numpy_to_tensor(self):
ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] +
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)])
@@ -139,16 +140,16 @@ class TestCase1(unittest.TestCase):
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
for x, y in iter: for x, y in iter:
print(x, y) print(x, y)
def test_sequential_batch(self): def test_sequential_batch(self):
batch_size = 32 batch_size = 32
num_samples = 1000 num_samples = 1000
dataset = generate_fake_dataset(num_samples) dataset = generate_fake_dataset(num_samples)
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
for batch_x, batch_y in batch: for batch_x, batch_y in batch:
pass pass
""" """
def test_multi_workers_batch(self): def test_multi_workers_batch(self):
batch_size = 32 batch_size = 32


+ 21
- 33
test/core/test_callbacks.py View File

@@ -4,14 +4,13 @@ import numpy as np
import torch import torch


from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \
LRFinder, \
TensorboardCallback
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer
LRFinder, TensorboardCallback
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import BCELoss
from fastNLP import AccuracyMetric
from fastNLP import SGD
from fastNLP import Trainer
from fastNLP.models.base_model import NaiveClassifier from fastNLP.models.base_model import NaiveClassifier




@@ -20,15 +19,15 @@ def prepare_env():
mean = np.array([-3, -3]) mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]]) cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
mean = np.array([3, 3]) mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]]) cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set return data_set
data_set = prepare_fake_dataset() data_set = prepare_fake_dataset()
data_set.set_input("x") data_set.set_input("x")
data_set.set_target("y") data_set.set_target("y")
@@ -37,19 +36,7 @@ def prepare_env():




class TestCallback(unittest.TestCase): class TestCallback(unittest.TestCase):
def test_echo_callback(self):
data_set, model = prepare_env()
trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"),
n_epochs=2,
batch_size=32,
print_every=50,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=False,
callbacks=[EchoCallback()])
trainer.train()

def test_gradient_clip(self): def test_gradient_clip(self):
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = Trainer(data_set, model, trainer = Trainer(data_set, model,
@@ -64,7 +51,7 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) callbacks=[GradientClipCallback(model.parameters(), clip_value=2)])
trainer.train() trainer.train()
def test_early_stop(self): def test_early_stop(self):
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = Trainer(data_set, model, trainer = Trainer(data_set, model,
@@ -79,7 +66,7 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[EarlyStopCallback(5)]) callbacks=[EarlyStopCallback(5)])
trainer.train() trainer.train()
def test_lr_scheduler(self): def test_lr_scheduler(self):
data_set, model = prepare_env() data_set, model = prepare_env()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
@@ -95,7 +82,7 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))])
trainer.train() trainer.train()
def test_KeyBoardInterrupt(self): def test_KeyBoardInterrupt(self):
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = Trainer(data_set, model, trainer = Trainer(data_set, model,
@@ -108,7 +95,7 @@ class TestCallback(unittest.TestCase):
use_tqdm=False, use_tqdm=False,
callbacks=[ControlC(False)]) callbacks=[ControlC(False)])
trainer.train() trainer.train()
def test_LRFinder(self): def test_LRFinder(self):
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = Trainer(data_set, model, trainer = Trainer(data_set, model,
@@ -121,7 +108,7 @@ class TestCallback(unittest.TestCase):
use_tqdm=False, use_tqdm=False,
callbacks=[LRFinder(len(data_set) // 32)]) callbacks=[LRFinder(len(data_set) // 32)])
trainer.train() trainer.train()
def test_TensorboardCallback(self): def test_TensorboardCallback(self):
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = Trainer(data_set, model, trainer = Trainer(data_set, model,
@@ -136,21 +123,22 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[TensorboardCallback("loss", "metric")]) callbacks=[TensorboardCallback("loss", "metric")])
trainer.train() trainer.train()
def test_readonly_property(self): def test_readonly_property(self):
from fastNLP.core.callback import Callback from fastNLP.core.callback import Callback
passed_epochs = [] passed_epochs = []
total_epochs = 5 total_epochs = 5
class MyCallback(Callback): class MyCallback(Callback):
def __init__(self): def __init__(self):
super(MyCallback, self).__init__() super(MyCallback, self).__init__()
def on_epoch_begin(self): def on_epoch_begin(self):
passed_epochs.append(self.epoch) passed_epochs.append(self.epoch)
print(self.n_epochs, self.n_steps, self.batch_size) print(self.n_epochs, self.n_steps, self.batch_size)
print(self.model) print(self.model)
print(self.optimizer) print(self.optimizer)
data_set, model = prepare_env() data_set, model = prepare_env()
trainer = Trainer(data_set, model, trainer = Trainer(data_set, model,
loss=BCELoss(pred="predict", target="y"), loss=BCELoss(pred="predict", target="y"),
@@ -164,4 +152,4 @@ class TestCallback(unittest.TestCase):
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
callbacks=[MyCallback()]) callbacks=[MyCallback()])
trainer.train() trainer.train()
assert passed_epochs == list(range(1, total_epochs+1))
assert passed_epochs == list(range(1, total_epochs + 1))

+ 7
- 8
test/core/test_dataset.py View File

@@ -1,9 +1,10 @@
import os import os
import unittest import unittest


from fastNLP.core.dataset import DataSet
from fastNLP.core.fieldarray import FieldArray
from fastNLP.core.instance import Instance
from fastNLP import DataSet
from fastNLP import FieldArray
from fastNLP import Instance
from fastNLP.io import CSVLoader




class TestDataSetInit(unittest.TestCase): class TestDataSetInit(unittest.TestCase):
@@ -167,13 +168,11 @@ class TestDataSetMethods(unittest.TestCase):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
d1, d2 = ds.split(0.1) d1, d2 = ds.split(0.1)



def test_apply2(self): def test_apply2(self):
def split_sent(ins): def split_sent(ins):
return ins['raw_sentence'].split() return ins['raw_sentence'].split()

dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'),
sep='\t')
csv_loader = CSVLoader(headers=['raw_sentence', 'label'],sep='\t')
dataset = csv_loader.load('../data_for_tests/tutorial_sample_dataset.csv')
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True)
dataset.apply(split_sent, new_field_name='words', is_input=True) dataset.apply(split_sent, new_field_name='words', is_input=True)
# print(dataset) # print(dataset)
@@ -208,7 +207,7 @@ class TestDataSetMethods(unittest.TestCase):
self.assertEqual(ans.content, [[5, 6]] * 10) self.assertEqual(ans.content, [[5, 6]] * 10)


def test_add_null(self): def test_add_null(self):
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError'
# TODO test failed because 'fastNLP\core\field.py:143: RuntimeError'
ds = DataSet() ds = DataSet()
with self.assertRaises(RuntimeError) as RE: with self.assertRaises(RuntimeError) as RE:
ds.add_field('test', []) ds.add_field('test', [])


test/core/test_fieldarray.py → test/core/test_field.py View File

@@ -2,7 +2,7 @@ import unittest


import numpy as np import numpy as np


from fastNLP.core.fieldarray import FieldArray
from fastNLP import FieldArray




class TestFieldArrayInit(unittest.TestCase): class TestFieldArrayInit(unittest.TestCase):
@@ -170,7 +170,7 @@ class TestPadder(unittest.TestCase):
测试AutoPadder能否正常工作 测试AutoPadder能否正常工作
:return: :return:
""" """
from fastNLP.core.fieldarray import AutoPadder
from fastNLP import AutoPadder
padder = AutoPadder() padder = AutoPadder()
content = ['This is a str', 'this is another str'] content = ['This is a str', 'this is another str']
self.assertListEqual(content, padder(content, None, np.str).tolist()) self.assertListEqual(content, padder(content, None, np.str).tolist())
@@ -194,7 +194,7 @@ class TestPadder(unittest.TestCase):
测试EngChar2DPadder能不能正确使用 测试EngChar2DPadder能不能正确使用
:return: :return:
""" """
from fastNLP.core.fieldarray import EngChar2DPadder
from fastNLP import EngChar2DPadder
padder = EngChar2DPadder(pad_length=0) padder = EngChar2DPadder(pad_length=0)


contents = [1, 2] contents = [1, 2]
@@ -225,11 +225,11 @@ class TestPadder(unittest.TestCase):
) )


def test_None_dtype(self): def test_None_dtype(self):
from fastNLP.core.fieldarray import AutoPadder
from fastNLP import AutoPadder
padder = AutoPadder() padder = AutoPadder()
content = [ content = [
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], [[1, 2, 3], [4, 5], [7, 8, 9, 10]],
[[1]] [[1]]
] ]
ans = padder(content, None, None)
ans = padder(content, None, None).tolist()
self.assertListEqual(content, ans) self.assertListEqual(content, ans)

+ 6
- 6
test/core/test_instance.py View File

@@ -1,33 +1,33 @@
import unittest import unittest


from fastNLP.core.instance import Instance
from fastNLP import Instance




class TestCase(unittest.TestCase): class TestCase(unittest.TestCase):
def test_init(self): def test_init(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) ins = Instance(x=[1, 2, 3], y=[4, 5, 6])
self.assertTrue(isinstance(ins.fields, dict)) self.assertTrue(isinstance(ins.fields, dict))
self.assertEqual(ins.fields, fields) self.assertEqual(ins.fields, fields)
ins = Instance(**fields) ins = Instance(**fields)
self.assertEqual(ins.fields, fields) self.assertEqual(ins.fields, fields)
def test_add_field(self): def test_add_field(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} fields = {"x": [1, 2, 3], "y": [4, 5, 6]}
ins = Instance(**fields) ins = Instance(**fields)
ins.add_field("z", [1, 1, 1]) ins.add_field("z", [1, 1, 1])
fields.update({"z": [1, 1, 1]}) fields.update({"z": [1, 1, 1]})
self.assertEqual(ins.fields, fields) self.assertEqual(ins.fields, fields)
def test_get_item(self): def test_get_item(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields) ins = Instance(**fields)
self.assertEqual(ins["x"], [1, 2, 3]) self.assertEqual(ins["x"], [1, 2, 3])
self.assertEqual(ins["y"], [4, 5, 6]) self.assertEqual(ins["y"], [4, 5, 6])
self.assertEqual(ins["z"], [1, 1, 1]) self.assertEqual(ins["z"], [1, 1, 1])
def test_repr(self): def test_repr(self):
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]}
ins = Instance(**fields) ins = Instance(**fields)


+ 12
- 12
test/core/test_loss.py View File

@@ -3,7 +3,7 @@ import unittest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F


import fastNLP.core.losses as loss
import fastNLP as loss
from fastNLP.core.losses import squash, unpad from fastNLP.core.losses import squash, unpad




@@ -14,21 +14,21 @@ class TestLoss(unittest.TestCase):
b = torch.empty(3, dtype=torch.long).random_(5) b = torch.empty(3, dtype=torch.long).random_(5)
ans = ce({"my_predict": a}, {"my_truth": b}) ans = ce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b))
def test_BCELoss(self): def test_BCELoss(self):
bce = loss.BCELoss(pred="my_predict", target="my_truth") bce = loss.BCELoss(pred="my_predict", target="my_truth")
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) a = torch.sigmoid(torch.randn((3, 5), requires_grad=False))
b = torch.randn((3, 5), requires_grad=False) b = torch.randn((3, 5), requires_grad=False)
ans = bce({"my_predict": a}, {"my_truth": b}) ans = bce({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b))
def test_L1Loss(self): def test_L1Loss(self):
l1 = loss.L1Loss(pred="my_predict", target="my_truth") l1 = loss.L1Loss(pred="my_predict", target="my_truth")
a = torch.randn(3, 5, requires_grad=False) a = torch.randn(3, 5, requires_grad=False)
b = torch.randn(3, 5) b = torch.randn(3, 5)
ans = l1({"my_predict": a}, {"my_truth": b}) ans = l1({"my_predict": a}, {"my_truth": b})
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) self.assertEqual(ans, torch.nn.functional.l1_loss(a, b))
def test_NLLLoss(self): def test_NLLLoss(self):
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
@@ -43,34 +43,34 @@ class TestLosserError(unittest.TestCase):
pred_dict = {"pred": torch.zeros(4, 3)} pred_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4).long()} target_dict = {'target': torch.zeros(4).long()}
los = loss.CrossEntropyLoss() los = loss.CrossEntropyLoss()
print(los(pred_dict=pred_dict, target_dict=target_dict)) print(los(pred_dict=pred_dict, target_dict=target_dict))
# #
def test_losser2(self): def test_losser2(self):
# (2) with corrupted size # (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3)} pred_dict = {"pred": torch.zeros(16, 3)}
target_dict = {'target': torch.zeros(16, 3).long()} target_dict = {'target': torch.zeros(16, 3).long()}
los = loss.CrossEntropyLoss() los = loss.CrossEntropyLoss()
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
print(los(pred_dict=pred_dict, target_dict=target_dict)) print(los(pred_dict=pred_dict, target_dict=target_dict))
def test_losser3(self): def test_losser3(self):
# (2) with corrupted size # (2) with corrupted size
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0}
target_dict = {'target': torch.zeros(16).long()} target_dict = {'target': torch.zeros(16).long()}
los = loss.CrossEntropyLoss() los = loss.CrossEntropyLoss()
print(los(pred_dict=pred_dict, target_dict=target_dict)) print(los(pred_dict=pred_dict, target_dict=target_dict))
def test_check_error(self): def test_check_error(self):
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") l1 = loss.NLLLoss(pred="my_predict", target="my_truth")
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0)
b = torch.tensor([1, 0, 4]) b = torch.tensor([1, 0, 4])
with self.assertRaises(Exception): with self.assertRaises(Exception):
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b})
with self.assertRaises(Exception): with self.assertRaises(Exception):
ans = l1({"my_predict": a}, {"truth": b, "my": a}) ans = l1({"my_predict": a}, {"truth": b, "my": a})


@@ -80,7 +80,7 @@ class TestLossUtils(unittest.TestCase):
a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) a, b = squash(torch.randn(3, 5), torch.randn(3, 5))
self.assertEqual(tuple(a.size()), (3, 5)) self.assertEqual(tuple(a.size()), (3, 5))
self.assertEqual(tuple(b.size()), (15,)) self.assertEqual(tuple(b.size()), (15,))
def test_unpad(self): def test_unpad(self):
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8))
self.assertEqual(tuple(a.size()), (5, 8, 3)) self.assertEqual(tuple(a.size()), (5, 8, 3))


+ 63
- 61
test/core/test_metrics.py View File

@@ -3,8 +3,8 @@ import unittest
import numpy as np import numpy as np
import torch import torch


from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.metrics import BMESF1PreRecMetric
from fastNLP import AccuracyMetric
from fastNLP import BMESF1PreRecMetric
from fastNLP.core.metrics import _pred_topk, _accuracy_topk from fastNLP.core.metrics import _pred_topk, _accuracy_topk




@@ -14,24 +14,24 @@ class TestAccuracyMetric(unittest.TestCase):
pred_dict = {"pred": torch.zeros(4, 3)} pred_dict = {"pred": torch.zeros(4, 3)}
target_dict = {'target': torch.zeros(4)} target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric() metric = AccuracyMetric()
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric()) print(metric.get_metric())
def test_AccuracyMetric2(self): def test_AccuracyMetric2(self):
# (2) with corrupted size # (2) with corrupted size
try: try:
pred_dict = {"pred": torch.zeros(4, 3, 2)} pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)} target_dict = {'target': torch.zeros(4)}
metric = AccuracyMetric() metric = AccuracyMetric()
metric(pred_dict=pred_dict, target_dict=target_dict, ) metric(pred_dict=pred_dict, target_dict=target_dict, )
print(metric.get_metric()) print(metric.get_metric())
except Exception as e: except Exception as e:
print(e) print(e)
return return
print("No exception catches.") print("No exception catches.")
def test_AccuracyMetric3(self): def test_AccuracyMetric3(self):
# (3) the second batch is corrupted size # (3) the second batch is corrupted size
try: try:
@@ -39,17 +39,17 @@ class TestAccuracyMetric(unittest.TestCase):
pred_dict = {"pred": torch.zeros(4, 3, 2)} pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4, 3)} target_dict = {'target': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
pred_dict = {"pred": torch.zeros(4, 3, 2)} pred_dict = {"pred": torch.zeros(4, 3, 2)}
target_dict = {'target': torch.zeros(4)} target_dict = {'target': torch.zeros(4)}
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
print(metric.get_metric()) print(metric.get_metric())
except Exception as e: except Exception as e:
print(e) print(e)
return return
self.assertTrue(True, False), "No exception catches." self.assertTrue(True, False), "No exception catches."
def test_AccuaryMetric4(self): def test_AccuaryMetric4(self):
# (5) check reset # (5) check reset
metric = AccuracyMetric() metric = AccuracyMetric()
@@ -61,7 +61,7 @@ class TestAccuracyMetric(unittest.TestCase):
self.assertTrue(isinstance(res, dict)) self.assertTrue(isinstance(res, dict))
self.assertTrue("acc" in res) self.assertTrue("acc" in res)
self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3)
def test_AccuaryMetric5(self): def test_AccuaryMetric5(self):
# (5) check reset # (5) check reset
metric = AccuracyMetric() metric = AccuracyMetric()
@@ -71,7 +71,7 @@ class TestAccuracyMetric(unittest.TestCase):
res = metric.get_metric(reset=False) res = metric.get_metric(reset=False)
ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean()
self.assertAlmostEqual(res["acc"], float(ans), places=4) self.assertAlmostEqual(res["acc"], float(ans), places=4)
def test_AccuaryMetric6(self): def test_AccuaryMetric6(self):
# (6) check numpy array is not acceptable # (6) check numpy array is not acceptable
try: try:
@@ -83,7 +83,7 @@ class TestAccuracyMetric(unittest.TestCase):
print(e) print(e)
return return
self.assertTrue(True, False), "No exception catches." self.assertTrue(True, False), "No exception catches."
def test_AccuaryMetric7(self): def test_AccuaryMetric7(self):
# (7) check map, match # (7) check map, match
metric = AccuracyMetric(pred='predictions', target='targets') metric = AccuracyMetric(pred='predictions', target='targets')
@@ -93,7 +93,7 @@ class TestAccuracyMetric(unittest.TestCase):
res = metric.get_metric() res = metric.get_metric()
ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean()
self.assertAlmostEqual(res["acc"], float(ans), places=4) self.assertAlmostEqual(res["acc"], float(ans), places=4)
def test_AccuaryMetric8(self): def test_AccuaryMetric8(self):
try: try:
metric = AccuracyMetric(pred='predictions', target='targets') metric = AccuracyMetric(pred='predictions', target='targets')
@@ -105,7 +105,7 @@ class TestAccuracyMetric(unittest.TestCase):
print(e) print(e)
return return
self.assertTrue(True, False), "No exception catches." self.assertTrue(True, False), "No exception catches."
def test_AccuaryMetric9(self): def test_AccuaryMetric9(self):
# (9) check map, include unused # (9) check map, include unused
try: try:
@@ -118,12 +118,12 @@ class TestAccuracyMetric(unittest.TestCase):
print(e) print(e)
return return
self.assertTrue(True, False), "No exception catches." self.assertTrue(True, False), "No exception catches."
def test_AccuaryMetric10(self): def test_AccuaryMetric10(self):
# (10) check _fast_metric # (10) check _fast_metric
try: try:
metric = AccuracyMetric() metric = AccuracyMetric()
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3)*3}
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3}
target_dict = {'targets': torch.zeros(4, 3)} target_dict = {'targets': torch.zeros(4, 3)}
metric(pred_dict=pred_dict, target_dict=target_dict) metric(pred_dict=pred_dict, target_dict=target_dict)
self.assertDictEqual(metric.get_metric(), {'acc': 1}) self.assertDictEqual(metric.get_metric(), {'acc': 1})
@@ -131,7 +131,7 @@ class TestAccuracyMetric(unittest.TestCase):
print(e) print(e)
return return
self.assertTrue(True, False), "No exception catches." self.assertTrue(True, False), "No exception catches."
def test_seq_len(self): def test_seq_len(self):
N = 256 N = 256
seq_len = torch.zeros(N).long() seq_len = torch.zeros(N).long()
@@ -145,20 +145,21 @@ class TestAccuracyMetric(unittest.TestCase):
metric(pred_dict=pred, target_dict=target) metric(pred_dict=pred, target_dict=target)
self.assertDictEqual(metric.get_metric(), {'acc': 1.}) self.assertDictEqual(metric.get_metric(), {'acc': 1.})



class SpanF1PreRecMetric(unittest.TestCase): class SpanF1PreRecMetric(unittest.TestCase):
def test_case1(self): def test_case1(self):
from fastNLP.core.metrics import _bmes_tag_to_spans from fastNLP.core.metrics import _bmes_tag_to_spans
from fastNLP.core.metrics import _bio_tag_to_spans from fastNLP.core.metrics import _bio_tag_to_spans
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8']
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8']
expect_bmes_res = set() expect_bmes_res = set()
expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)), expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)),
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))])
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))])
expect_bio_res = set() expect_bio_res = set()
expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)),
('6', (4, 5)), ('7', (6, 7))])
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst)))
('6', (4, 5)), ('7', (6, 7))])
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
@@ -171,19 +172,19 @@ class SpanF1PreRecMetric(unittest.TestCase):
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] # bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))]
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))
def test_case2(self): def test_case2(self):
# 测试不带label的 # 测试不带label的
from fastNLP.core.metrics import _bmes_tag_to_spans from fastNLP.core.metrics import _bmes_tag_to_spans
from fastNLP.core.metrics import _bio_tag_to_spans from fastNLP.core.metrics import _bio_tag_to_spans
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E']
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O']
expect_bmes_res = set() expect_bmes_res = set()
expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))])
expect_bio_res = set() expect_bio_res = set()
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))])
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst)))
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst)))
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans
@@ -195,7 +196,7 @@ class SpanF1PreRecMetric(unittest.TestCase):
# bio_strs = np.random.choice(bio, size=100) # bio_strs = np.random.choice(bio, size=100)
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs)))
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs)))
def tese_case3(self): def tese_case3(self):
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from collections import Counter from collections import Counter
@@ -213,7 +214,7 @@ class SpanF1PreRecMetric(unittest.TestCase):
continue continue
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count
return vocab return vocab
number_labels = 4 number_labels = 4
# bio tag # bio tag
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None)
@@ -221,26 +222,26 @@ class SpanF1PreRecMetric(unittest.TestCase):
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False)
bio_sequence = torch.FloatTensor( bio_sequence = torch.FloatTensor(
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, [[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011,
0.0470, 0.0971],
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523,
0.7987, -0.3970],
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898,
0.6880, 1.4348],
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793,
-1.6876, -0.8917],
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824,
1.4217, 0.2622]],
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136,
1.3592, -0.8973],
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887,
-0.4025, -0.3417],
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698,
0.2861, -0.3966],
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275,
0.0213, 1.4777],
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566,
1.3024, 0.2001]]]
0.0470, 0.0971],
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523,
0.7987, -0.3970],
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898,
0.6880, 1.4348],
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793,
-1.6876, -0.8917],
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824,
1.4217, 0.2622]],
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136,
1.3592, -0.8973],
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887,
-0.4025, -0.3417],
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698,
0.2861, -0.3966],
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275,
0.0213, 1.4777],
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566,
1.3024, 0.2001]]]
) )
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], bio_target = torch.LongTensor([[5., 0., 3., 3., 3.],
[5., 6., 8., 6., 0.]]) [5., 6., 8., 6., 0.]])
@@ -250,8 +251,8 @@ class SpanF1PreRecMetric(unittest.TestCase):
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845,
'f': 0.12499999999994846} 'f': 0.12499999999994846}
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric())
#bmes tag
# bmes tag
bmes_sequence = torch.FloatTensor( bmes_sequence = torch.FloatTensor(
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, [[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352,
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332,
@@ -268,7 +269,7 @@ class SpanF1PreRecMetric(unittest.TestCase):
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, [0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393,
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809,
-0.3779, -0.3195]], -0.3779, -0.3195]],
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, [[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753,
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957,
-0.1103, 0.4417], -0.1103, 0.4417],
@@ -285,22 +286,22 @@ class SpanF1PreRecMetric(unittest.TestCase):
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142,
-0.7344, -1.2046]]] -0.7344, -1.2046]]]
) )
bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.],
[ 6., 15., 6., 15., 5.]])
bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.],
[6., 15., 6., 15., 5.]])
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None)
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels))
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes')
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target})
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001,
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775,
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314,
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504,
'pre': 0.499999999999995, 'rec': 0.499999999999995} 'pre': 0.499999999999995, 'rec': 0.499999999999995}
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res)
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 # 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary # from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary
# from allennlp.training.metrics import SpanBasedF1Measure # from allennlp.training.metrics import SpanBasedF1Measure
@@ -349,6 +350,7 @@ class SpanF1PreRecMetric(unittest.TestCase):
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()),
# fastnlp_bmes_metric.get_metric()) # fastnlp_bmes_metric.get_metric())



class TestBMESF1PreRecMetric(unittest.TestCase): class TestBMESF1PreRecMetric(unittest.TestCase):
def test_case1(self): def test_case1(self):
seq_lens = torch.LongTensor([4, 2]) seq_lens = torch.LongTensor([4, 2])
@@ -356,20 +358,20 @@ class TestBMESF1PreRecMetric(unittest.TestCase):
target = torch.LongTensor([[0, 1, 2, 3], target = torch.LongTensor([[0, 1, 2, 3],
[3, 3, 0, 0]]) [3, 3, 0, 0]])
pred_dict = {'pred': pred} pred_dict = {'pred': pred}
target_dict = {'target': target, 'seq_lens': seq_lens}
target_dict = {'target': target, 'seq_len': seq_lens}
metric = BMESF1PreRecMetric() metric = BMESF1PreRecMetric()
metric(pred_dict, target_dict) metric(pred_dict, target_dict)
metric.get_metric() metric.get_metric()
def test_case2(self): def test_case2(self):
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} # 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1}
seq_lens = torch.LongTensor([4, 2]) seq_lens = torch.LongTensor([4, 2])
target = torch.LongTensor([[0, 1, 2, 3], target = torch.LongTensor([[0, 1, 2, 3],
[3, 3, 0, 0]]) [3, 3, 0, 0]])
pred_dict = {'pred': target} pred_dict = {'pred': target}
target_dict = {'target': target, 'seq_lens': seq_lens}
target_dict = {'target': target, 'seq_len': seq_lens}
metric = BMESF1PreRecMetric() metric = BMESF1PreRecMetric()
metric(pred_dict, target_dict) metric(pred_dict, target_dict)
self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0})
@@ -381,5 +383,5 @@ class TestUsefulFunctions(unittest.TestCase):
# multi-class # multi-class
_ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) _ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3)
_ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) _ = _pred_topk(np.random.randint(0, 3, size=(10, 1)))
# 跑通即可 # 跑通即可

+ 9
- 9
test/core/test_optimizer.py View File

@@ -2,7 +2,7 @@ import unittest


import torch import torch


from fastNLP.core.optimizer import SGD, Adam
from fastNLP import SGD, Adam




class TestOptim(unittest.TestCase): class TestOptim(unittest.TestCase):
@@ -12,42 +12,42 @@ class TestOptim(unittest.TestCase):
self.assertTrue("momentum" in optim.__dict__["settings"]) self.assertTrue("momentum" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD)) self.assertTrue(isinstance(res, torch.optim.SGD))
optim = SGD(lr=0.001) optim = SGD(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD)) self.assertTrue(isinstance(res, torch.optim.SGD))
optim = SGD(lr=0.002, momentum=0.989) optim = SGD(lr=0.002, momentum=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989)
optim = SGD(0.001) optim = SGD(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.SGD)) self.assertTrue(isinstance(res, torch.optim.SGD))
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
_ = SGD("???") _ = SGD("???")
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
_ = SGD(0.001, lr=0.002) _ = SGD(0.001, lr=0.002)
def test_Adam(self): def test_Adam(self):
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) optim = Adam(model_params=torch.nn.Linear(10, 3).parameters())
self.assertTrue("lr" in optim.__dict__["settings"]) self.assertTrue("lr" in optim.__dict__["settings"])
self.assertTrue("weight_decay" in optim.__dict__["settings"]) self.assertTrue("weight_decay" in optim.__dict__["settings"])
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam)) self.assertTrue(isinstance(res, torch.optim.Adam))
optim = Adam(lr=0.001) optim = Adam(lr=0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())
self.assertTrue(isinstance(res, torch.optim.Adam)) self.assertTrue(isinstance(res, torch.optim.Adam))
optim = Adam(lr=0.002, weight_decay=0.989) optim = Adam(lr=0.002, weight_decay=0.989)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) self.assertEqual(optim.__dict__["settings"]["lr"], 0.002)
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989)
optim = Adam(0.001) optim = Adam(0.001)
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) self.assertEqual(optim.__dict__["settings"]["lr"], 0.001)
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters())


+ 3
- 3
test/core/test_sampler.py View File

@@ -3,9 +3,9 @@ import unittest


import torch import torch


from fastNLP.core.dataset import DataSet
from fastNLP.core.sampler import SequentialSampler, RandomSampler, \
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler
from fastNLP import DataSet
from fastNLP import SequentialSampler, RandomSampler, BucketSampler
from fastNLP.core.sampler import k_means_1d, k_means_bucketing, simple_sort_bucketing




class TestSampler(unittest.TestCase): class TestSampler(unittest.TestCase):


+ 13
- 17
test/core/test_tester.py View File

@@ -1,32 +1,25 @@
import unittest import unittest
import numpy as np
from torch import nn
import time
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import AccuracyMetric
from fastNLP import Tester


data_name = "pku_training.utf8" data_name = "pku_training.utf8"
pickle_path = "data_for_tests" pickle_path = "data_for_tests"




import numpy as np
import torch.nn.functional as F
from torch import nn
import time
from fastNLP.core.utils import _CheckError
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.tester import Tester
from fastNLP.models.base_model import NaiveClassifier

def prepare_fake_dataset(): def prepare_fake_dataset():
mean = np.array([-3, -3]) mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]]) cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
mean = np.array([3, 3]) mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]]) cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set return data_set
@@ -39,6 +32,7 @@ def prepare_fake_dataset2(*args, size=100):
data[arg] = np.random.randn(size, 5) data[arg] = np.random.randn(size, 5)
return DataSet(data=data) return DataSet(data=data)



class TestTester(unittest.TestCase): class TestTester(unittest.TestCase):
def test_case_1(self): def test_case_1(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
@@ -46,10 +40,12 @@ class TestTester(unittest.TestCase):
dataset.rename_field('x_unused', 'x2') dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2') dataset.set_input('x1', 'x2')
dataset.set_target('y', 'x1') dataset.set_target('y', 'x1')
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(5, 4) self.fc = nn.Linear(5, 4)
def forward(self, x1, x2): def forward(self, x1, x2):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
@@ -57,7 +53,7 @@ class TestTester(unittest.TestCase):
time.sleep(0.1) time.sleep(0.1)
# loss = F.cross_entropy(x, y) # loss = F.cross_entropy(x, y)
return {'preds': x} return {'preds': x}
model = Model() model = Model()
with self.assertRaises(NameError): with self.assertRaises(NameError):
tester = Tester( tester = Tester(


+ 38
- 33
test/core/test_trainer.py View File

@@ -5,25 +5,24 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn


from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
from fastNLP.core.optimizer import SGD
from fastNLP.core.trainer import Trainer
from fastNLP import DataSet
from fastNLP import Instance
from fastNLP import BCELoss
from fastNLP import CrossEntropyLoss
from fastNLP import AccuracyMetric
from fastNLP import SGD
from fastNLP import Trainer
from fastNLP.models.base_model import NaiveClassifier from fastNLP.models.base_model import NaiveClassifier



def prepare_fake_dataset(): def prepare_fake_dataset():
mean = np.array([-3, -3]) mean = np.array([-3, -3])
cov = np.array([[1, 0], [0, 1]]) cov = np.array([[1, 0], [0, 1]])
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
mean = np.array([3, 3]) mean = np.array([3, 3])
cov = np.array([[1, 0], [0, 1]]) cov = np.array([[1, 0], [0, 1]])
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
return data_set return data_set
@@ -42,11 +41,11 @@ class TrainerTestGround(unittest.TestCase):
data_set = prepare_fake_dataset() data_set = prepare_fake_dataset()
data_set.set_input("x", flag=True) data_set.set_input("x", flag=True)
data_set.set_target("y", flag=True) data_set.set_target("y", flag=True)
train_set, dev_set = data_set.split(0.3) train_set, dev_set = data_set.split(0.3)
model = NaiveClassifier(2, 1) model = NaiveClassifier(2, 1)
trainer = Trainer(train_set, model, trainer = Trainer(train_set, model,
loss=BCELoss(pred="predict", target="y"), loss=BCELoss(pred="predict", target="y"),
metrics=AccuracyMetric(pred="predict", target="y"), metrics=AccuracyMetric(pred="predict", target="y"),
@@ -63,26 +62,26 @@ class TrainerTestGround(unittest.TestCase):
""" """
# 应该正确运行 # 应该正确运行
""" """
def test_trainer_suggestion1(self): def test_trainer_suggestion1(self):
# 检查报错提示能否正确提醒用户。 # 检查报错提示能否正确提醒用户。
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
dataset = prepare_fake_dataset2('x') dataset = prepare_fake_dataset2('x')
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(5, 4) self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y): def forward(self, x1, x2, y):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
x = x1 + x2 x = x1 + x2
loss = F.cross_entropy(x, y) loss = F.cross_entropy(x, y)
return {'loss': loss} return {'loss': loss}
model = Model() model = Model()
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
trainer = Trainer( trainer = Trainer(
train_data=dataset, train_data=dataset,
@@ -97,25 +96,25 @@ class TrainerTestGround(unittest.TestCase):
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input. (2). You need to provide ['x1', 'x2'] in DataSet and set it as input.


""" """
def test_trainer_suggestion2(self): def test_trainer_suggestion2(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,看是否可以运行 # 这里传入forward需要的数据,看是否可以运行
dataset = prepare_fake_dataset2('x1', 'x2') dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True) dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(5, 4) self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y): def forward(self, x1, x2, y):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
x = x1 + x2 x = x1 + x2
loss = F.cross_entropy(x, y) loss = F.cross_entropy(x, y)
return {'loss': loss} return {'loss': loss}
model = Model() model = Model()
trainer = Trainer( trainer = Trainer(
train_data=dataset, train_data=dataset,
@@ -127,25 +126,25 @@ class TrainerTestGround(unittest.TestCase):
""" """
# 应该正确运行 # 应该正确运行
""" """
def test_trainer_suggestion3(self): def test_trainer_suggestion3(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,但是forward没有返回loss这个key # 这里传入forward需要的数据,但是forward没有返回loss这个key
dataset = prepare_fake_dataset2('x1', 'x2') dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True) dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(5, 4) self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y): def forward(self, x1, x2, y):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
x = x1 + x2 x = x1 + x2
loss = F.cross_entropy(x, y) loss = F.cross_entropy(x, y)
return {'wrong_loss_key': loss} return {'wrong_loss_key': loss}
model = Model() model = Model()
with self.assertRaises(NameError): with self.assertRaises(NameError):
trainer = Trainer( trainer = Trainer(
@@ -155,23 +154,25 @@ class TrainerTestGround(unittest.TestCase):
print_every=2 print_every=2
) )
trainer.train() trainer.train()
def test_trainer_suggestion4(self): def test_trainer_suggestion4(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
# 这里传入forward需要的数据,是否可以正确提示unused # 这里传入forward需要的数据,是否可以正确提示unused
dataset = prepare_fake_dataset2('x1', 'x2') dataset = prepare_fake_dataset2('x1', 'x2')
dataset.set_input('x1', 'x2', 'y', flag=True) dataset.set_input('x1', 'x2', 'y', flag=True)
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(5, 4) self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y): def forward(self, x1, x2, y):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
x = x1 + x2 x = x1 + x2
loss = F.cross_entropy(x, y) loss = F.cross_entropy(x, y)
return {'losses': loss} return {'losses': loss}
model = Model() model = Model()
with self.assertRaises(NameError): with self.assertRaises(NameError):
trainer = Trainer( trainer = Trainer(
@@ -180,7 +181,7 @@ class TrainerTestGround(unittest.TestCase):
use_tqdm=False, use_tqdm=False,
print_every=2 print_every=2
) )
def test_trainer_suggestion5(self): def test_trainer_suggestion5(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 # 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错
@@ -188,17 +189,19 @@ class TrainerTestGround(unittest.TestCase):
dataset.rename_field('x_unused', 'x2') dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2', 'y') dataset.set_input('x1', 'x2', 'y')
dataset.set_target('y') dataset.set_target('y')
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(5, 4) self.fc = nn.Linear(5, 4)
def forward(self, x1, x2, y): def forward(self, x1, x2, y):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
x = x1 + x2 x = x1 + x2
loss = F.cross_entropy(x, y) loss = F.cross_entropy(x, y)
return {'loss': loss} return {'loss': loss}
model = Model() model = Model()
trainer = Trainer( trainer = Trainer(
train_data=dataset, train_data=dataset,
@@ -206,7 +209,7 @@ class TrainerTestGround(unittest.TestCase):
use_tqdm=False, use_tqdm=False,
print_every=2 print_every=2
) )
def test_trainer_suggestion6(self): def test_trainer_suggestion6(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
# 这里传入多余参数,让其duplicate # 这里传入多余参数,让其duplicate
@@ -214,10 +217,12 @@ class TrainerTestGround(unittest.TestCase):
dataset.rename_field('x_unused', 'x2') dataset.rename_field('x_unused', 'x2')
dataset.set_input('x1', 'x2') dataset.set_input('x1', 'x2')
dataset.set_target('y', 'x1') dataset.set_target('y', 'x1')
class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.fc = nn.Linear(5, 4) self.fc = nn.Linear(5, 4)
def forward(self, x1, x2): def forward(self, x1, x2):
x1 = self.fc(x1) x1 = self.fc(x1)
x2 = self.fc(x2) x2 = self.fc(x2)
@@ -225,7 +230,7 @@ class TrainerTestGround(unittest.TestCase):
time.sleep(0.1) time.sleep(0.1)
# loss = F.cross_entropy(x, y) # loss = F.cross_entropy(x, y)
return {'preds': x} return {'preds': x}
model = Model() model = Model()
with self.assertRaises(NameError): with self.assertRaises(NameError):
trainer = Trainer( trainer = Trainer(
@@ -236,7 +241,7 @@ class TrainerTestGround(unittest.TestCase):
metrics=AccuracyMetric(), metrics=AccuracyMetric(),
use_tqdm=False, use_tqdm=False,
print_every=2) print_every=2)
""" """
def test_trainer_multiprocess(self): def test_trainer_multiprocess(self):
dataset = prepare_fake_dataset2('x1', 'x2') dataset = prepare_fake_dataset2('x1', 'x2')


+ 28
- 20
test/core/test_utils.py View File

@@ -1,8 +1,7 @@

import unittest import unittest
import _pickle import _pickle
from fastNLP import cache_results from fastNLP import cache_results
from fastNLP.io.embed_loader import EmbedLoader
from fastNLP.io import EmbedLoader
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance
import time import time
@@ -11,11 +10,13 @@ import torch
from torch import nn from torch import nn
from fastNLP.core.utils import _move_model_to_device, _get_model_device from fastNLP.core.utils import _move_model_to_device, _get_model_device



class Model(nn.Module): class Model(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.param = nn.Parameter(torch.zeros(0)) self.param = nn.Parameter(torch.zeros(0))



class TestMoveModelDeivce(unittest.TestCase): class TestMoveModelDeivce(unittest.TestCase):
def test_case1(self): def test_case1(self):
# 测试str # 测试str
@@ -35,36 +36,36 @@ class TestMoveModelDeivce(unittest.TestCase):
_move_model_to_device(model, 'cuda:1000') _move_model_to_device(model, 'cuda:1000')
# 测试None # 测试None
model = _move_model_to_device(model, None) model = _move_model_to_device(model, None)
def test_case2(self): def test_case2(self):
# 测试使用int初始化 # 测试使用int初始化
model = Model() model = Model()
if torch.cuda.is_available(): if torch.cuda.is_available():
model = _move_model_to_device(model, 0) model = _move_model_to_device(model, 0)
assert model.param.device == torch.device('cuda:0') assert model.param.device == torch.device('cuda:0')
assert model.param.device==torch.device('cuda:0'), "The model should be in "
assert model.param.device == torch.device('cuda:0'), "The model should be in "
with self.assertRaises(Exception): with self.assertRaises(Exception):
_move_model_to_device(model, 100) _move_model_to_device(model, 100)
with self.assertRaises(Exception): with self.assertRaises(Exception):
_move_model_to_device(model, -1) _move_model_to_device(model, -1)
def test_case3(self): def test_case3(self):
# 测试None # 测试None
model = Model() model = Model()
device = _get_model_device(model) device = _get_model_device(model)
model = _move_model_to_device(model, None) model = _move_model_to_device(model, None)
assert device==_get_model_device(model), "The device should not change."
assert device == _get_model_device(model), "The device should not change."
if torch.cuda.is_available(): if torch.cuda.is_available():
model.cuda() model.cuda()
device = _get_model_device(model) device = _get_model_device(model)
model = _move_model_to_device(model, None) model = _move_model_to_device(model, None)
assert device==_get_model_device(model), "The device should not change."
assert device == _get_model_device(model), "The device should not change."
model = nn.DataParallel(model, device_ids=[0]) model = nn.DataParallel(model, device_ids=[0])
_move_model_to_device(model, None) _move_model_to_device(model, None)
with self.assertRaises(Exception): with self.assertRaises(Exception):
_move_model_to_device(model, 'cpu') _move_model_to_device(model, 'cpu')
def test_case4(self): def test_case4(self):
# 测试传入list的内容 # 测试传入list的内容
model = Model() model = Model()
@@ -78,15 +79,17 @@ class TestMoveModelDeivce(unittest.TestCase):
device = [torch.device('cuda:0'), torch.device('cuda:0')] device = [torch.device('cuda:0'), torch.device('cuda:0')]
with self.assertRaises(Exception): with self.assertRaises(Exception):
_model = _move_model_to_device(model, device) _model = _move_model_to_device(model, device)
if torch.cuda.device_count()>1:
if torch.cuda.device_count() > 1:
device = [0, 1] device = [0, 1]
_model = _move_model_to_device(model, device) _model = _move_model_to_device(model, device)
assert isinstance(_model, nn.DataParallel) assert isinstance(_model, nn.DataParallel)
device = ['cuda', 'cuda:1'] device = ['cuda', 'cuda:1']
with self.assertRaises(Exception): with self.assertRaises(Exception):
_move_model_to_device(model, device) _move_model_to_device(model, device)
def test_case5(self): def test_case5(self):
if not torch.cuda.is_available():
return
# torch.device() # torch.device()
device = torch.device('cpu') device = torch.device('cpu')
model = Model() model = Model()
@@ -106,10 +109,11 @@ def process_data_1(embed_file, cws_train):
d = DataSet() d = DataSet()
for line in f: for line in f:
line = line.strip() line = line.strip()
if len(line)>0:
if len(line) > 0:
d.append(Instance(raw=line)) d.append(Instance(raw=line))
return embed, vocab, d return embed, vocab, d



class TestCache(unittest.TestCase): class TestCache(unittest.TestCase):
def test_cache_save(self): def test_cache_save(self):
try: try:
@@ -127,10 +131,10 @@ class TestCache(unittest.TestCase):
end_time = time.time() end_time = time.time()
read_time = end_time - start_time read_time = end_time - start_time
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
self.assertGreater(pre_time-0.5, read_time)
self.assertGreater(pre_time - 0.5, read_time)
finally: finally:
os.remove('test/demo1.pkl') os.remove('test/demo1.pkl')
def test_cache_save_overwrite_path(self): def test_cache_save_overwrite_path(self):
try: try:
start_time = time.time() start_time = time.time()
@@ -149,10 +153,10 @@ class TestCache(unittest.TestCase):
end_time = time.time() end_time = time.time()
read_time = end_time - start_time read_time = end_time - start_time
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
self.assertGreater(pre_time-0.5, read_time)
self.assertGreater(pre_time - 0.5, read_time)
finally: finally:
os.remove('test/demo_overwrite.pkl') os.remove('test/demo_overwrite.pkl')
def test_cache_refresh(self): def test_cache_refresh(self):
try: try:
start_time = time.time() start_time = time.time()
@@ -171,34 +175,38 @@ class TestCache(unittest.TestCase):
end_time = time.time() end_time = time.time()
read_time = end_time - start_time read_time = end_time - start_time
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time))
self.assertGreater(0.1, pre_time-read_time)
self.assertGreater(0.1, pre_time - read_time)
finally: finally:
os.remove('test/demo1.pkl') os.remove('test/demo1.pkl')
def test_duplicate_keyword(self): def test_duplicate_keyword(self):
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
@cache_results(None) @cache_results(None)
def func_verbose(a, _verbose): def func_verbose(a, _verbose):
pass pass
func_verbose(0, 1) func_verbose(0, 1)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
@cache_results(None) @cache_results(None)
def func_cache(a, _cache_fp): def func_cache(a, _cache_fp):
pass pass
func_cache(1, 2) func_cache(1, 2)
with self.assertRaises(RuntimeError): with self.assertRaises(RuntimeError):
@cache_results(None) @cache_results(None)
def func_refresh(a, _refresh): def func_refresh(a, _refresh):
pass pass
func_refresh(1, 2) func_refresh(1, 2)
def test_create_cache_dir(self): def test_create_cache_dir(self):
@cache_results('test/demo1/demo.pkl') @cache_results('test/demo1/demo.pkl')
def cache(): def cache():
return 1, 2 return 1, 2
try: try:
results = cache() results = cache()
print(results) print(results)
finally: finally:
os.remove('test/demo1/demo.pkl') os.remove('test/demo1/demo.pkl')
os.rmdir('test/demo1')
os.rmdir('test/demo1')

+ 37
- 36
test/core/test_vocabulary.py View File

@@ -1,9 +1,9 @@
import unittest import unittest
from collections import Counter from collections import Counter


from fastNLP.core.vocabulary import Vocabulary
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP import Vocabulary
from fastNLP import DataSet
from fastNLP import Instance


text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
"works", "well", "in", "most", "cases", "scales", "well"] "works", "well", "in", "most", "cases", "scales", "well"]
@@ -12,92 +12,93 @@ counter = Counter(text)


class TestAdd(unittest.TestCase): class TestAdd(unittest.TestCase):
def test_add(self): def test_add(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab = Vocabulary()
for word in text: for word in text:
vocab.add(word) vocab.add(word)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)
def test_add_word(self): def test_add_word(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab = Vocabulary()
for word in text: for word in text:
vocab.add_word(word) vocab.add_word(word)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)
def test_add_word_lst(self): def test_add_word_lst(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab = Vocabulary()
vocab.add_word_lst(text) vocab.add_word_lst(text)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)
def test_update(self): def test_update(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab = Vocabulary()
vocab.update(text) vocab.update(text)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)
def test_from_dataset(self): def test_from_dataset(self):
start_char = 65 start_char = 65
num_samples = 10 num_samples = 10
# 0 dim # 0 dim
dataset = DataSet() dataset = DataSet()
for i in range(num_samples): for i in range(num_samples):
ins = Instance(char=chr(start_char+i))
ins = Instance(char=chr(start_char + i))
dataset.append(ins) dataset.append(ins)
vocab = Vocabulary() vocab = Vocabulary()
vocab.from_dataset(dataset, field_name='char') vocab.from_dataset(dataset, field_name='char')
for i in range(num_samples): for i in range(num_samples):
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2)
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2)
vocab.index_dataset(dataset, field_name='char') vocab.index_dataset(dataset, field_name='char')
# 1 dim # 1 dim
dataset = DataSet() dataset = DataSet()
for i in range(num_samples): for i in range(num_samples):
ins = Instance(char=[chr(start_char+i)]*6)
ins = Instance(char=[chr(start_char + i)] * 6)
dataset.append(ins) dataset.append(ins)
vocab = Vocabulary() vocab = Vocabulary()
vocab.from_dataset(dataset, field_name='char') vocab.from_dataset(dataset, field_name='char')
for i in range(num_samples): for i in range(num_samples):
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2)
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2)
vocab.index_dataset(dataset, field_name='char') vocab.index_dataset(dataset, field_name='char')
# 2 dim # 2 dim
dataset = DataSet() dataset = DataSet()
for i in range(num_samples): for i in range(num_samples):
ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)])
ins = Instance(char=[[chr(start_char + i) for _ in range(6)] for _ in range(6)])
dataset.append(ins) dataset.append(ins)
vocab = Vocabulary() vocab = Vocabulary()
vocab.from_dataset(dataset, field_name='char') vocab.from_dataset(dataset, field_name='char')
for i in range(num_samples): for i in range(num_samples):
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2)
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2)
vocab.index_dataset(dataset, field_name='char') vocab.index_dataset(dataset, field_name='char')



class TestIndexing(unittest.TestCase): class TestIndexing(unittest.TestCase):
def test_len(self): def test_len(self):
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab = Vocabulary(unknown=None, padding=None)
vocab.update(text) vocab.update(text)
self.assertEqual(len(vocab), len(counter)) self.assertEqual(len(vocab), len(counter))
def test_contains(self): def test_contains(self):
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab = Vocabulary(unknown=None)
vocab.update(text) vocab.update(text)
self.assertTrue(text[-1] in vocab) self.assertTrue(text[-1] in vocab)
self.assertFalse("~!@#" in vocab) self.assertFalse("~!@#" in vocab)
self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1]))
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))
def test_index(self): def test_index(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab = Vocabulary()
vocab.update(text) vocab.update(text)
res = [vocab[w] for w in set(text)] res = [vocab[w] for w in set(text)]
self.assertEqual(len(res), len(set(res))) self.assertEqual(len(res), len(set(res)))
res = [vocab.to_index(w) for w in set(text)] res = [vocab.to_index(w) for w in set(text)]
self.assertEqual(len(res), len(set(res))) self.assertEqual(len(res), len(set(res)))
def test_to_word(self): def test_to_word(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab = Vocabulary()
vocab.update(text) vocab.update(text)
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])
def test_iteration(self): def test_iteration(self):
vocab = Vocabulary() vocab = Vocabulary()
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in",
@@ -110,26 +111,26 @@ class TestIndexing(unittest.TestCase):


class TestOther(unittest.TestCase): class TestOther(unittest.TestCase):
def test_additional_update(self): def test_additional_update(self):
vocab = Vocabulary(max_size=None, min_freq=None)
vocab = Vocabulary()
vocab.update(text) vocab.update(text)
_ = vocab["well"] _ = vocab["well"]
self.assertEqual(vocab.rebuild, False) self.assertEqual(vocab.rebuild, False)
vocab.add("hahaha") vocab.add("hahaha")
self.assertEqual(vocab.rebuild, True) self.assertEqual(vocab.rebuild, True)
_ = vocab["hahaha"] _ = vocab["hahaha"]
self.assertEqual(vocab.rebuild, False) self.assertEqual(vocab.rebuild, False)
self.assertTrue("hahaha" in vocab) self.assertTrue("hahaha" in vocab)
def test_warning(self): def test_warning(self):
vocab = Vocabulary(max_size=len(set(text)), min_freq=None)
vocab = Vocabulary(max_size=len(set(text)))
vocab.update(text) vocab.update(text)
self.assertEqual(vocab.rebuild, True) self.assertEqual(vocab.rebuild, True)
print(len(vocab)) print(len(vocab))
self.assertEqual(vocab.rebuild, False) self.assertEqual(vocab.rebuild, False)
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"])
# this will print a warning # this will print a warning
self.assertEqual(vocab.rebuild, True) self.assertEqual(vocab.rebuild, True)

Loading…
Cancel
Save