@@ -13,7 +13,8 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||||
__all__ = ["Instance", "FieldArray", "Batch", "Vocabulary", "DataSet", | __all__ = ["Instance", "FieldArray", "Batch", "Vocabulary", "DataSet", | ||||
"Trainer", "Tester", "Callback", | "Trainer", "Tester", "Callback", | ||||
"Padder", "AutoPadder", "EngChar2DPadder", | "Padder", "AutoPadder", "EngChar2DPadder", | ||||
"AccuracyMetric", "Optimizer", "SGD", "Adam", | |||||
"AccuracyMetric", "BMESF1PreRecMetric", "SpanFPreRecMetric", "SQuADMetric", | |||||
"Optimizer", "SGD", "Adam", | |||||
"Sampler", "SequentialSampler", "BucketSampler", "RandomSampler", | "Sampler", "SequentialSampler", "BucketSampler", "RandomSampler", | ||||
"LossFunc", "CrossEntropyLoss", "L1Loss", "BCELoss", "NLLLoss", "LossInForward", | "LossFunc", "CrossEntropyLoss", "L1Loss", "BCELoss", "NLLLoss", "LossInForward", | ||||
"cache_results"] | "cache_results"] | ||||
@@ -17,7 +17,7 @@ from .dataset import DataSet | |||||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | ||||
from .instance import Instance | from .instance import Instance | ||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
from .metrics import AccuracyMetric | |||||
from .metrics import AccuracyMetric, BMESF1PreRecMetric, SpanFPreRecMetric, SQuADMetric | |||||
from .optimizer import Optimizer, SGD, Adam | from .optimizer import Optimizer, SGD, Adam | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | ||||
from .tester import Tester | from .tester import Tester | ||||
@@ -236,6 +236,7 @@ class CallbackManager(Callback): | |||||
for env_name, env_val in env.items(): | for env_name, env_val in env.items(): | ||||
for callback in self.callbacks: | for callback in self.callbacks: | ||||
print(callback, env_name, env_val ) | |||||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | setattr(callback, '_' + env_name, env_val) # Callback.trainer | ||||
@_transfer | @_transfer | ||||
@@ -425,19 +426,25 @@ class LRFinder(Callback): | |||||
super(LRFinder, self).__init__() | super(LRFinder, self).__init__() | ||||
self.start_lr, self.end_lr = start_lr, end_lr | self.start_lr, self.end_lr = start_lr, end_lr | ||||
self.num_it = self.batch_per_epoch | |||||
self.stop = False | self.stop = False | ||||
self.best_loss = 0. | self.best_loss = 0. | ||||
self.best_lr = None | self.best_lr = None | ||||
self.loss_history = [] | self.loss_history = [] | ||||
self.smooth_value = SmoothValue(0.8) | self.smooth_value = SmoothValue(0.8) | ||||
self.opt = None | self.opt = None | ||||
scale = (self.end_lr - self.start_lr) / self.num_it | |||||
self.lr_gen = (self.start_lr + scale * (step + 1) for step in range(self.num_it)) | |||||
self.find = None | self.find = None | ||||
self.loader = ModelLoader() | self.loader = ModelLoader() | ||||
@property | |||||
def lr_gen(self): | |||||
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch | |||||
return (self.start_lr + scale * (step + 1) for step in range(self.batch_per_epoch)) | |||||
@property | |||||
def num_it(self): | |||||
return self.batch_per_epoch | |||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
if self.epoch == 1: # first epoch | if self.epoch == 1: # first epoch | ||||
self.opt = self.trainer.optimizer # pytorch optimizer | self.opt = self.trainer.optimizer # pytorch optimizer | ||||
@@ -418,6 +418,7 @@ class AutoPadder(Padder): | |||||
return False | return False | ||||
def __call__(self, contents, field_name, field_ele_dtype): | def __call__(self, contents, field_name, field_ele_dtype): | ||||
if not _is_iterable(contents[0]): | if not _is_iterable(contents[0]): | ||||
array = np.array([content for content in contents], dtype=field_ele_dtype) | array = np.array([content for content in contents], dtype=field_ele_dtype) | ||||
elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | elif field_ele_dtype in (np.int64, np.float64) and self._is_two_dimension(contents): | ||||
@@ -430,6 +430,7 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||||
class SpanFPreRecMetric(MetricBase): | class SpanFPreRecMetric(MetricBase): | ||||
""" | """ | ||||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||||
在序列标注问题中,以span的方式计算F, pre, rec. | 在序列标注问题中,以span的方式计算F, pre, rec. | ||||
比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) | ||||
@@ -619,6 +620,8 @@ class SpanFPreRecMetric(MetricBase): | |||||
class BMESF1PreRecMetric(MetricBase): | class BMESF1PreRecMetric(MetricBase): | ||||
""" | """ | ||||
别名::class:`fastNLP.BMESF1PreRecMetric` :class:`fastNLP.core.metrics.BMESF1PreRecMetric` | |||||
按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, | ||||
next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B | ||||
@@ -826,6 +829,8 @@ def _pred_topk(y_prob, k=1): | |||||
class SQuADMetric(MetricBase): | class SQuADMetric(MetricBase): | ||||
""" | """ | ||||
别名::class:`fastNLP.SQuADMetric` :class:`fastNLP.core.metrics.SQuADMetric` | |||||
SQuAD数据集metric | SQuAD数据集metric | ||||
:param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | :param pred1: 参数映射表中`pred1`的映射关系,None表示映射关系为`pred1`->`pred1` | ||||
@@ -350,7 +350,7 @@ class Trainer(object): | |||||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | ||||
:param nn.modules model: 待训练的模型 | :param nn.modules model: 待训练的模型 | ||||
:param torch.optim.Optimizer optimizer: 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||||
:param int batch_size: 训练和验证的时候的batch大小。 | :param int batch_size: 训练和验证的时候的batch大小。 | ||||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | ||||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | :param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | ||||
@@ -403,7 +403,6 @@ class Trainer(object): | |||||
callbacks=None, | callbacks=None, | ||||
check_code_level=0): | check_code_level=0): | ||||
super(Trainer, self).__init__() | super(Trainer, self).__init__() | ||||
if not isinstance(train_data, DataSet): | if not isinstance(train_data, DataSet): | ||||
raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | raise TypeError(f"The type of train_data must be fastNLP.DataSet, got {type(train_data)}.") | ||||
if not isinstance(model, nn.Module): | if not isinstance(model, nn.Module): | ||||
@@ -468,7 +467,7 @@ class Trainer(object): | |||||
len(self.train_data) % self.batch_size != 0)) * self.n_epochs | len(self.train_data) % self.batch_size != 0)) * self.n_epochs | ||||
self.model = _move_model_to_device(self.model, device=device) | self.model = _move_model_to_device(self.model, device=device) | ||||
if isinstance(optimizer, torch.optim.Optimizer): | if isinstance(optimizer, torch.optim.Optimizer): | ||||
self.optimizer = optimizer | self.optimizer = optimizer | ||||
elif isinstance(optimizer, Optimizer): | elif isinstance(optimizer, Optimizer): | ||||
@@ -493,6 +492,7 @@ class Trainer(object): | |||||
self.step = 0 | self.step = 0 | ||||
self.start_time = None # start timestamp | self.start_time = None # start timestamp | ||||
print("callback_manager") | |||||
self.callback_manager = CallbackManager(env={"trainer": self}, | self.callback_manager = CallbackManager(env={"trainer": self}, | ||||
callbacks=callbacks) | callbacks=callbacks) | ||||
@@ -616,7 +616,7 @@ def seq_lens_to_masks(seq_lens, float=False): | |||||
assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | assert len(seq_lens.size()) == 1, f"seq_lens can only have one dimension, got {len(seq_lens.size())==1}." | ||||
batch_size = seq_lens.size(0) | batch_size = seq_lens.size(0) | ||||
max_len = seq_lens.max() | max_len = seq_lens.max() | ||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device) | |||||
indexes = torch.arange(max_len).view(1, -1).repeat(batch_size, 1).to(seq_lens.device).long() | |||||
masks = indexes.lt(seq_lens.unsqueeze(1)) | masks = indexes.lt(seq_lens.unsqueeze(1)) | ||||
if float: | if float: | ||||
@@ -2,16 +2,18 @@ from functools import wraps | |||||
from collections import Counter | from collections import Counter | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
def _check_build_vocab(func): | def _check_build_vocab(func): | ||||
"""A decorator to make sure the indexing is built before used. | """A decorator to make sure the indexing is built before used. | ||||
""" | """ | ||||
@wraps(func) # to solve missing docstring | |||||
@wraps(func) # to solve missing docstring | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.word2idx is None or self.rebuild is True: | if self.word2idx is None or self.rebuild is True: | ||||
self.build_vocab() | self.build_vocab() | ||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
return _wrapper | return _wrapper | ||||
@@ -19,7 +21,8 @@ def _check_build_status(func): | |||||
"""A decorator to check whether the vocabulary updates after the last build. | """A decorator to check whether the vocabulary updates after the last build. | ||||
""" | """ | ||||
@wraps(func) # to solve missing docstring | |||||
@wraps(func) # to solve missing docstring | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.rebuild is False: | if self.rebuild is False: | ||||
self.rebuild = True | self.rebuild = True | ||||
@@ -28,7 +31,7 @@ def _check_build_status(func): | |||||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | "Adding more words may cause unexpected behaviour of Vocabulary. ".format( | ||||
self.max_size, func.__name__)) | self.max_size, func.__name__)) | ||||
return func(self, *args, **kwargs) | return func(self, *args, **kwargs) | ||||
return _wrapper | return _wrapper | ||||
@@ -50,15 +53,15 @@ class Vocabulary(object): | |||||
若为 ``None`` , 则不限制大小. Default: ``None`` | 若为 ``None`` , 则不限制大小. Default: ``None`` | ||||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | :param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | ||||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | 若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | ||||
:param str padding: padding的字符. 如果设置为 ``None`` , | |||||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | 则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | ||||
Default: '<pad>' | Default: '<pad>' | ||||
:param str unknow: unknow的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | 如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | ||||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | 为 ``None`` 的情况多在为label建立Vocabulary的情况. | ||||
Default: '<unk>' | Default: '<unk>' | ||||
""" | """ | ||||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | ||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
@@ -68,7 +71,7 @@ class Vocabulary(object): | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | self.rebuild = True | ||||
@_check_build_status | @_check_build_status | ||||
def update(self, word_lst): | def update(self, word_lst): | ||||
"""依次增加序列中词在词典中的出现频率 | """依次增加序列中词在词典中的出现频率 | ||||
@@ -76,7 +79,7 @@ class Vocabulary(object): | |||||
:param list word_lst: a list of strings | :param list word_lst: a list of strings | ||||
""" | """ | ||||
self.word_count.update(word_lst) | self.word_count.update(word_lst) | ||||
@_check_build_status | @_check_build_status | ||||
def add(self, word): | def add(self, word): | ||||
""" | """ | ||||
@@ -85,7 +88,7 @@ class Vocabulary(object): | |||||
:param str word: 新词 | :param str word: 新词 | ||||
""" | """ | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
@_check_build_status | @_check_build_status | ||||
def add_word(self, word): | def add_word(self, word): | ||||
""" | """ | ||||
@@ -94,7 +97,7 @@ class Vocabulary(object): | |||||
:param str word: 新词 | :param str word: 新词 | ||||
""" | """ | ||||
self.add(word) | self.add(word) | ||||
@_check_build_status | @_check_build_status | ||||
def add_word_lst(self, word_lst): | def add_word_lst(self, word_lst): | ||||
""" | """ | ||||
@@ -103,7 +106,7 @@ class Vocabulary(object): | |||||
:param list[str] word_lst: 词的序列 | :param list[str] word_lst: 词的序列 | ||||
""" | """ | ||||
self.update(word_lst) | self.update(word_lst) | ||||
def build_vocab(self): | def build_vocab(self): | ||||
""" | """ | ||||
根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | 根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | ||||
@@ -116,7 +119,7 @@ class Vocabulary(object): | |||||
self.word2idx[self.padding] = len(self.word2idx) | self.word2idx[self.padding] = len(self.word2idx) | ||||
if self.unknown is not None: | if self.unknown is not None: | ||||
self.word2idx[self.unknown] = len(self.word2idx) | self.word2idx[self.unknown] = len(self.word2idx) | ||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
words = self.word_count.most_common(max_size) | words = self.word_count.most_common(max_size) | ||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
@@ -127,18 +130,18 @@ class Vocabulary(object): | |||||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
self.rebuild = False | self.rebuild = False | ||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
""" | """ | ||||
基于 "word to index" dict, 构建 "index to word" dict. | 基于 "word to index" dict, 构建 "index to word" dict. | ||||
""" | """ | ||||
self.idx2word = {i: w for w, i in self.word2idx.items()} | self.idx2word = {i: w for w, i in self.word2idx.items()} | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def __len__(self): | def __len__(self): | ||||
return len(self.word2idx) | return len(self.word2idx) | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def __contains__(self, item): | def __contains__(self, item): | ||||
""" | """ | ||||
@@ -148,7 +151,7 @@ class Vocabulary(object): | |||||
:return: True or False | :return: True or False | ||||
""" | """ | ||||
return item in self.word2idx | return item in self.word2idx | ||||
def has_word(self, w): | def has_word(self, w): | ||||
""" | """ | ||||
检查词是否被记录 | 检查词是否被记录 | ||||
@@ -163,7 +166,7 @@ class Vocabulary(object): | |||||
:return: ``True`` or ``False`` | :return: ``True`` or ``False`` | ||||
""" | """ | ||||
return self.__contains__(w) | return self.__contains__(w) | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
""" | """ | ||||
@@ -177,7 +180,7 @@ class Vocabulary(object): | |||||
return self.word2idx[self.unknown] | return self.word2idx[self.unknown] | ||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def index_dataset(self, *datasets, field_name, new_field_name=None): | def index_dataset(self, *datasets, field_name, new_field_name=None): | ||||
""" | """ | ||||
@@ -194,6 +197,7 @@ class Vocabulary(object): | |||||
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | :param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | ||||
Default: ``None`` | Default: ``None`` | ||||
""" | """ | ||||
def index_instance(ins): | def index_instance(ins): | ||||
""" | """ | ||||
有几种情况, str, 1d-list, 2d-list | 有几种情况, str, 1d-list, 2d-list | ||||
@@ -209,8 +213,8 @@ class Vocabulary(object): | |||||
else: | else: | ||||
if isinstance(field[0][0], list): | if isinstance(field[0][0], list): | ||||
raise RuntimeError("Only support field with 2 dimensions.") | raise RuntimeError("Only support field with 2 dimensions.") | ||||
return[[self.to_index(c) for c in w] for w in field] | |||||
return [[self.to_index(c) for c in w] for w in field] | |||||
if new_field_name is None: | if new_field_name is None: | ||||
new_field_name = field_name | new_field_name = field_name | ||||
for idx, dataset in enumerate(datasets): | for idx, dataset in enumerate(datasets): | ||||
@@ -222,7 +226,7 @@ class Vocabulary(object): | |||||
raise e | raise e | ||||
else: | else: | ||||
raise RuntimeError("Only DataSet type is allowed.") | raise RuntimeError("Only DataSet type is allowed.") | ||||
def from_dataset(self, *datasets, field_name): | def from_dataset(self, *datasets, field_name): | ||||
""" | """ | ||||
使用dataset的对应field中词构建词典 | 使用dataset的对应field中词构建词典 | ||||
@@ -243,7 +247,7 @@ class Vocabulary(object): | |||||
field_name = [field_name] | field_name = [field_name] | ||||
elif not isinstance(field_name, list): | elif not isinstance(field_name, list): | ||||
raise TypeError('invalid argument field_name: {}'.format(field_name)) | raise TypeError('invalid argument field_name: {}'.format(field_name)) | ||||
def construct_vocab(ins): | def construct_vocab(ins): | ||||
for fn in field_name: | for fn in field_name: | ||||
field = ins[fn] | field = ins[fn] | ||||
@@ -256,6 +260,7 @@ class Vocabulary(object): | |||||
if isinstance(field[0][0], list): | if isinstance(field[0][0], list): | ||||
raise RuntimeError("Only support field with 2 dimensions.") | raise RuntimeError("Only support field with 2 dimensions.") | ||||
[self.add_word_lst(w) for w in field] | [self.add_word_lst(w) for w in field] | ||||
for idx, dataset in enumerate(datasets): | for idx, dataset in enumerate(datasets): | ||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
try: | try: | ||||
@@ -266,7 +271,7 @@ class Vocabulary(object): | |||||
else: | else: | ||||
raise RuntimeError("Only DataSet type is allowed.") | raise RuntimeError("Only DataSet type is allowed.") | ||||
return self | return self | ||||
def to_index(self, w): | def to_index(self, w): | ||||
""" | """ | ||||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 | 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 | ||||
@@ -282,7 +287,7 @@ class Vocabulary(object): | |||||
:return int index: the number | :return int index: the number | ||||
""" | """ | ||||
return self.__getitem__(w) | return self.__getitem__(w) | ||||
@property | @property | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def unknown_idx(self): | def unknown_idx(self): | ||||
@@ -292,7 +297,7 @@ class Vocabulary(object): | |||||
if self.unknown is None: | if self.unknown is None: | ||||
return None | return None | ||||
return self.word2idx[self.unknown] | return self.word2idx[self.unknown] | ||||
@property | @property | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
@@ -302,7 +307,7 @@ class Vocabulary(object): | |||||
if self.padding is None: | if self.padding is None: | ||||
return None | return None | ||||
return self.word2idx[self.padding] | return self.word2idx[self.padding] | ||||
@_check_build_vocab | @_check_build_vocab | ||||
def to_word(self, idx): | def to_word(self, idx): | ||||
""" | """ | ||||
@@ -312,26 +317,26 @@ class Vocabulary(object): | |||||
:return str word: the word | :return str word: the word | ||||
""" | """ | ||||
return self.idx2word[idx] | return self.idx2word[idx] | ||||
def __getstate__(self): | def __getstate__(self): | ||||
"""Use to prepare data for pickle. | """Use to prepare data for pickle. | ||||
""" | """ | ||||
len(self) # make sure vocab has been built | |||||
len(self) # make sure vocab has been built | |||||
state = self.__dict__.copy() | state = self.__dict__.copy() | ||||
# no need to pickle idx2word as it can be constructed from word2idx | # no need to pickle idx2word as it can be constructed from word2idx | ||||
del state['idx2word'] | del state['idx2word'] | ||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
"""Use to restore state from pickle. | """Use to restore state from pickle. | ||||
""" | """ | ||||
self.__dict__.update(state) | self.__dict__.update(state) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
def __repr__(self): | def __repr__(self): | ||||
return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | return "Vocabulary({}...)".format(list(self.word_count.keys())[:5]) | ||||
def __iter__(self): | def __iter__(self): | ||||
return iter(list(self.word_count.keys())) | return iter(list(self.word_count.keys())) |
@@ -1,13 +1,12 @@ | |||||
import time | |||||
import unittest | import unittest | ||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.batch import Batch | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.sampler import SequentialSampler | |||||
from fastNLP import Batch | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import SequentialSampler | |||||
def generate_fake_dataset(num_samples=1000): | def generate_fake_dataset(num_samples=1000): | ||||
@@ -16,11 +15,11 @@ def generate_fake_dataset(num_samples=1000): | |||||
:param num_samples: sample的数量 | :param num_samples: sample的数量 | ||||
:return: | :return: | ||||
""" | """ | ||||
max_len = 50 | max_len = 50 | ||||
min_len = 10 | min_len = 10 | ||||
num_features = 4 | num_features = 4 | ||||
data_dict = {} | data_dict = {} | ||||
for i in range(num_features): | for i in range(num_features): | ||||
data = [] | data = [] | ||||
@@ -28,9 +27,9 @@ def generate_fake_dataset(num_samples=1000): | |||||
for length in lengths: | for length in lengths: | ||||
data.append(np.random.randint(100, size=length)) | data.append(np.random.randint(100, size=length)) | ||||
data_dict[str(i)] = data | data_dict[str(i)] = data | ||||
dataset = DataSet(data_dict) | dataset = DataSet(data_dict) | ||||
for i in range(num_features): | for i in range(num_features): | ||||
if np.random.randint(2) == 0: | if np.random.randint(2) == 0: | ||||
dataset.set_input(str(i)) | dataset.set_input(str(i)) | ||||
@@ -38,6 +37,7 @@ def generate_fake_dataset(num_samples=1000): | |||||
dataset.set_target(str(i)) | dataset.set_target(str(i)) | ||||
return dataset | return dataset | ||||
def construct_dataset(sentences): | def construct_dataset(sentences): | ||||
"""Construct a data set from a list of sentences. | """Construct a data set from a list of sentences. | ||||
@@ -51,18 +51,19 @@ def construct_dataset(sentences): | |||||
dataset.append(instance) | dataset.append(instance) | ||||
return dataset | return dataset | ||||
class TestCase1(unittest.TestCase): | class TestCase1(unittest.TestCase): | ||||
def test_simple(self): | def test_simple(self): | ||||
dataset = construct_dataset( | dataset = construct_dataset( | ||||
[["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)]) | ||||
dataset.set_target() | dataset.set_target() | ||||
batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True) | ||||
cnt = 0 | cnt = 0 | ||||
for _, _ in batch: | for _, _ in batch: | ||||
cnt += 1 | cnt += 1 | ||||
self.assertEqual(cnt, 10) | self.assertEqual(cnt, 10) | ||||
def test_dataset_batching(self): | def test_dataset_batching(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
ds.set_input("x") | ds.set_input("x") | ||||
@@ -74,7 +75,7 @@ class TestCase1(unittest.TestCase): | |||||
self.assertEqual(len(y["y"]), 4) | self.assertEqual(len(y["y"]), 4) | ||||
self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4]) | ||||
self.assertListEqual(list(y["y"][-1]), [5, 6]) | self.assertListEqual(list(y["y"][-1]), [5, 6]) | ||||
def test_list_padding(self): | def test_list_padding(self): | ||||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | ||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | ||||
@@ -84,7 +85,7 @@ class TestCase1(unittest.TestCase): | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertEqual(x["x"].shape, (4, 4)) | self.assertEqual(x["x"].shape, (4, 4)) | ||||
self.assertEqual(y["y"].shape, (4, 4)) | self.assertEqual(y["y"].shape, (4, 4)) | ||||
def test_numpy_padding(self): | def test_numpy_padding(self): | ||||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | ||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | ||||
@@ -94,7 +95,7 @@ class TestCase1(unittest.TestCase): | |||||
for x, y in iter: | for x, y in iter: | ||||
self.assertEqual(x["x"].shape, (4, 4)) | self.assertEqual(x["x"].shape, (4, 4)) | ||||
self.assertEqual(y["y"].shape, (4, 4)) | self.assertEqual(y["y"].shape, (4, 4)) | ||||
def test_list_to_tensor(self): | def test_list_to_tensor(self): | ||||
ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10, | ||||
"y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10}) | ||||
@@ -106,7 +107,7 @@ class TestCase1(unittest.TestCase): | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | self.assertTrue(isinstance(y["y"], torch.Tensor)) | ||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | self.assertEqual(tuple(y["y"].shape), (4, 4)) | ||||
def test_numpy_to_tensor(self): | def test_numpy_to_tensor(self): | ||||
ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10), | ||||
"y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)}) | ||||
@@ -118,7 +119,7 @@ class TestCase1(unittest.TestCase): | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | self.assertTrue(isinstance(y["y"], torch.Tensor)) | ||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | self.assertEqual(tuple(y["y"].shape), (4, 4)) | ||||
def test_list_of_list_to_tensor(self): | def test_list_of_list_to_tensor(self): | ||||
ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] + | ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] + | ||||
[Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)]) | ||||
@@ -130,7 +131,7 @@ class TestCase1(unittest.TestCase): | |||||
self.assertEqual(tuple(x["x"].shape), (4, 4)) | self.assertEqual(tuple(x["x"].shape), (4, 4)) | ||||
self.assertTrue(isinstance(y["y"], torch.Tensor)) | self.assertTrue(isinstance(y["y"], torch.Tensor)) | ||||
self.assertEqual(tuple(y["y"].shape), (4, 4)) | self.assertEqual(tuple(y["y"].shape), (4, 4)) | ||||
def test_list_of_numpy_to_tensor(self): | def test_list_of_numpy_to_tensor(self): | ||||
ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + | ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] + | ||||
[Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)]) | ||||
@@ -139,16 +140,16 @@ class TestCase1(unittest.TestCase): | |||||
iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False) | ||||
for x, y in iter: | for x, y in iter: | ||||
print(x, y) | print(x, y) | ||||
def test_sequential_batch(self): | def test_sequential_batch(self): | ||||
batch_size = 32 | batch_size = 32 | ||||
num_samples = 1000 | num_samples = 1000 | ||||
dataset = generate_fake_dataset(num_samples) | dataset = generate_fake_dataset(num_samples) | ||||
batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler()) | ||||
for batch_x, batch_y in batch: | for batch_x, batch_y in batch: | ||||
pass | pass | ||||
""" | """ | ||||
def test_multi_workers_batch(self): | def test_multi_workers_batch(self): | ||||
batch_size = 32 | batch_size = 32 | ||||
@@ -4,14 +4,13 @@ import numpy as np | |||||
import torch | import torch | ||||
from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | from fastNLP.core.callback import EarlyStopCallback, GradientClipCallback, LRScheduler, ControlC, \ | ||||
LRFinder, \ | |||||
TensorboardCallback | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.trainer import Trainer | |||||
LRFinder, TensorboardCallback | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import BCELoss | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
@@ -20,15 +19,15 @@ def prepare_env(): | |||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
mean = np.array([3, 3]) | mean = np.array([3, 3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | ||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | ||||
return data_set | return data_set | ||||
data_set = prepare_fake_dataset() | data_set = prepare_fake_dataset() | ||||
data_set.set_input("x") | data_set.set_input("x") | ||||
data_set.set_target("y") | data_set.set_target("y") | ||||
@@ -37,19 +36,7 @@ def prepare_env(): | |||||
class TestCallback(unittest.TestCase): | class TestCallback(unittest.TestCase): | ||||
def test_echo_callback(self): | |||||
data_set, model = prepare_env() | |||||
trainer = Trainer(data_set, model, | |||||
loss=BCELoss(pred="predict", target="y"), | |||||
n_epochs=2, | |||||
batch_size=32, | |||||
print_every=50, | |||||
optimizer=SGD(lr=0.1), | |||||
check_code_level=2, | |||||
use_tqdm=False, | |||||
callbacks=[EchoCallback()]) | |||||
trainer.train() | |||||
def test_gradient_clip(self): | def test_gradient_clip(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
@@ -64,7 +51,7 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | callbacks=[GradientClipCallback(model.parameters(), clip_value=2)]) | ||||
trainer.train() | trainer.train() | ||||
def test_early_stop(self): | def test_early_stop(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
@@ -79,7 +66,7 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[EarlyStopCallback(5)]) | callbacks=[EarlyStopCallback(5)]) | ||||
trainer.train() | trainer.train() | ||||
def test_lr_scheduler(self): | def test_lr_scheduler(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | optimizer = torch.optim.SGD(model.parameters(), lr=0.01) | ||||
@@ -95,7 +82,7 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | callbacks=[LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1))]) | ||||
trainer.train() | trainer.train() | ||||
def test_KeyBoardInterrupt(self): | def test_KeyBoardInterrupt(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
@@ -108,7 +95,7 @@ class TestCallback(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
callbacks=[ControlC(False)]) | callbacks=[ControlC(False)]) | ||||
trainer.train() | trainer.train() | ||||
def test_LRFinder(self): | def test_LRFinder(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
@@ -121,7 +108,7 @@ class TestCallback(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
callbacks=[LRFinder(len(data_set) // 32)]) | callbacks=[LRFinder(len(data_set) // 32)]) | ||||
trainer.train() | trainer.train() | ||||
def test_TensorboardCallback(self): | def test_TensorboardCallback(self): | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
@@ -136,21 +123,22 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[TensorboardCallback("loss", "metric")]) | callbacks=[TensorboardCallback("loss", "metric")]) | ||||
trainer.train() | trainer.train() | ||||
def test_readonly_property(self): | def test_readonly_property(self): | ||||
from fastNLP.core.callback import Callback | from fastNLP.core.callback import Callback | ||||
passed_epochs = [] | passed_epochs = [] | ||||
total_epochs = 5 | total_epochs = 5 | ||||
class MyCallback(Callback): | class MyCallback(Callback): | ||||
def __init__(self): | def __init__(self): | ||||
super(MyCallback, self).__init__() | super(MyCallback, self).__init__() | ||||
def on_epoch_begin(self): | def on_epoch_begin(self): | ||||
passed_epochs.append(self.epoch) | passed_epochs.append(self.epoch) | ||||
print(self.n_epochs, self.n_steps, self.batch_size) | print(self.n_epochs, self.n_steps, self.batch_size) | ||||
print(self.model) | print(self.model) | ||||
print(self.optimizer) | print(self.optimizer) | ||||
data_set, model = prepare_env() | data_set, model = prepare_env() | ||||
trainer = Trainer(data_set, model, | trainer = Trainer(data_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
@@ -164,4 +152,4 @@ class TestCallback(unittest.TestCase): | |||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
callbacks=[MyCallback()]) | callbacks=[MyCallback()]) | ||||
trainer.train() | trainer.train() | ||||
assert passed_epochs == list(range(1, total_epochs+1)) | |||||
assert passed_epochs == list(range(1, total_epochs + 1)) |
@@ -1,9 +1,10 @@ | |||||
import os | import os | ||||
import unittest | import unittest | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.fieldarray import FieldArray | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP import DataSet | |||||
from fastNLP import FieldArray | |||||
from fastNLP import Instance | |||||
from fastNLP.io import CSVLoader | |||||
class TestDataSetInit(unittest.TestCase): | class TestDataSetInit(unittest.TestCase): | ||||
@@ -167,13 +168,11 @@ class TestDataSetMethods(unittest.TestCase): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||
d1, d2 = ds.split(0.1) | d1, d2 = ds.split(0.1) | ||||
def test_apply2(self): | def test_apply2(self): | ||||
def split_sent(ins): | def split_sent(ins): | ||||
return ins['raw_sentence'].split() | return ins['raw_sentence'].split() | ||||
dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'), | |||||
sep='\t') | |||||
csv_loader = CSVLoader(headers=['raw_sentence', 'label'],sep='\t') | |||||
dataset = csv_loader.load('../data_for_tests/tutorial_sample_dataset.csv') | |||||
dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0, inplace=True) | ||||
dataset.apply(split_sent, new_field_name='words', is_input=True) | dataset.apply(split_sent, new_field_name='words', is_input=True) | ||||
# print(dataset) | # print(dataset) | ||||
@@ -208,7 +207,7 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertEqual(ans.content, [[5, 6]] * 10) | self.assertEqual(ans.content, [[5, 6]] * 10) | ||||
def test_add_null(self): | def test_add_null(self): | ||||
# TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' | |||||
# TODO test failed because 'fastNLP\core\field.py:143: RuntimeError' | |||||
ds = DataSet() | ds = DataSet() | ||||
with self.assertRaises(RuntimeError) as RE: | with self.assertRaises(RuntimeError) as RE: | ||||
ds.add_field('test', []) | ds.add_field('test', []) | ||||
@@ -2,7 +2,7 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
from fastNLP.core.fieldarray import FieldArray | |||||
from fastNLP import FieldArray | |||||
class TestFieldArrayInit(unittest.TestCase): | class TestFieldArrayInit(unittest.TestCase): | ||||
@@ -170,7 +170,7 @@ class TestPadder(unittest.TestCase): | |||||
测试AutoPadder能否正常工作 | 测试AutoPadder能否正常工作 | ||||
:return: | :return: | ||||
""" | """ | ||||
from fastNLP.core.fieldarray import AutoPadder | |||||
from fastNLP import AutoPadder | |||||
padder = AutoPadder() | padder = AutoPadder() | ||||
content = ['This is a str', 'this is another str'] | content = ['This is a str', 'this is another str'] | ||||
self.assertListEqual(content, padder(content, None, np.str).tolist()) | self.assertListEqual(content, padder(content, None, np.str).tolist()) | ||||
@@ -194,7 +194,7 @@ class TestPadder(unittest.TestCase): | |||||
测试EngChar2DPadder能不能正确使用 | 测试EngChar2DPadder能不能正确使用 | ||||
:return: | :return: | ||||
""" | """ | ||||
from fastNLP.core.fieldarray import EngChar2DPadder | |||||
from fastNLP import EngChar2DPadder | |||||
padder = EngChar2DPadder(pad_length=0) | padder = EngChar2DPadder(pad_length=0) | ||||
contents = [1, 2] | contents = [1, 2] | ||||
@@ -225,11 +225,11 @@ class TestPadder(unittest.TestCase): | |||||
) | ) | ||||
def test_None_dtype(self): | def test_None_dtype(self): | ||||
from fastNLP.core.fieldarray import AutoPadder | |||||
from fastNLP import AutoPadder | |||||
padder = AutoPadder() | padder = AutoPadder() | ||||
content = [ | content = [ | ||||
[[1, 2, 3], [4, 5], [7, 8, 9, 10]], | [[1, 2, 3], [4, 5], [7, 8, 9, 10]], | ||||
[[1]] | [[1]] | ||||
] | ] | ||||
ans = padder(content, None, None) | |||||
ans = padder(content, None, None).tolist() | |||||
self.assertListEqual(content, ans) | self.assertListEqual(content, ans) |
@@ -1,33 +1,33 @@ | |||||
import unittest | import unittest | ||||
from fastNLP.core.instance import Instance | |||||
from fastNLP import Instance | |||||
class TestCase(unittest.TestCase): | class TestCase(unittest.TestCase): | ||||
def test_init(self): | def test_init(self): | ||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | ||||
ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) | ||||
self.assertTrue(isinstance(ins.fields, dict)) | self.assertTrue(isinstance(ins.fields, dict)) | ||||
self.assertEqual(ins.fields, fields) | self.assertEqual(ins.fields, fields) | ||||
ins = Instance(**fields) | ins = Instance(**fields) | ||||
self.assertEqual(ins.fields, fields) | self.assertEqual(ins.fields, fields) | ||||
def test_add_field(self): | def test_add_field(self): | ||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | fields = {"x": [1, 2, 3], "y": [4, 5, 6]} | ||||
ins = Instance(**fields) | ins = Instance(**fields) | ||||
ins.add_field("z", [1, 1, 1]) | ins.add_field("z", [1, 1, 1]) | ||||
fields.update({"z": [1, 1, 1]}) | fields.update({"z": [1, 1, 1]}) | ||||
self.assertEqual(ins.fields, fields) | self.assertEqual(ins.fields, fields) | ||||
def test_get_item(self): | def test_get_item(self): | ||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | ||||
ins = Instance(**fields) | ins = Instance(**fields) | ||||
self.assertEqual(ins["x"], [1, 2, 3]) | self.assertEqual(ins["x"], [1, 2, 3]) | ||||
self.assertEqual(ins["y"], [4, 5, 6]) | self.assertEqual(ins["y"], [4, 5, 6]) | ||||
self.assertEqual(ins["z"], [1, 1, 1]) | self.assertEqual(ins["z"], [1, 1, 1]) | ||||
def test_repr(self): | def test_repr(self): | ||||
fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} | ||||
ins = Instance(**fields) | ins = Instance(**fields) | ||||
@@ -3,7 +3,7 @@ import unittest | |||||
import torch | import torch | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
import fastNLP.core.losses as loss | |||||
import fastNLP as loss | |||||
from fastNLP.core.losses import squash, unpad | from fastNLP.core.losses import squash, unpad | ||||
@@ -14,21 +14,21 @@ class TestLoss(unittest.TestCase): | |||||
b = torch.empty(3, dtype=torch.long).random_(5) | b = torch.empty(3, dtype=torch.long).random_(5) | ||||
ans = ce({"my_predict": a}, {"my_truth": b}) | ans = ce({"my_predict": a}, {"my_truth": b}) | ||||
self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | self.assertEqual(ans, torch.nn.functional.cross_entropy(a, b)) | ||||
def test_BCELoss(self): | def test_BCELoss(self): | ||||
bce = loss.BCELoss(pred="my_predict", target="my_truth") | bce = loss.BCELoss(pred="my_predict", target="my_truth") | ||||
a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | a = torch.sigmoid(torch.randn((3, 5), requires_grad=False)) | ||||
b = torch.randn((3, 5), requires_grad=False) | b = torch.randn((3, 5), requires_grad=False) | ||||
ans = bce({"my_predict": a}, {"my_truth": b}) | ans = bce({"my_predict": a}, {"my_truth": b}) | ||||
self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | self.assertEqual(ans, torch.nn.functional.binary_cross_entropy(a, b)) | ||||
def test_L1Loss(self): | def test_L1Loss(self): | ||||
l1 = loss.L1Loss(pred="my_predict", target="my_truth") | l1 = loss.L1Loss(pred="my_predict", target="my_truth") | ||||
a = torch.randn(3, 5, requires_grad=False) | a = torch.randn(3, 5, requires_grad=False) | ||||
b = torch.randn(3, 5) | b = torch.randn(3, 5) | ||||
ans = l1({"my_predict": a}, {"my_truth": b}) | ans = l1({"my_predict": a}, {"my_truth": b}) | ||||
self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | self.assertEqual(ans, torch.nn.functional.l1_loss(a, b)) | ||||
def test_NLLLoss(self): | def test_NLLLoss(self): | ||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | ||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | ||||
@@ -43,34 +43,34 @@ class TestLosserError(unittest.TestCase): | |||||
pred_dict = {"pred": torch.zeros(4, 3)} | pred_dict = {"pred": torch.zeros(4, 3)} | ||||
target_dict = {'target': torch.zeros(4).long()} | target_dict = {'target': torch.zeros(4).long()} | ||||
los = loss.CrossEntropyLoss() | los = loss.CrossEntropyLoss() | ||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | print(los(pred_dict=pred_dict, target_dict=target_dict)) | ||||
# | # | ||||
def test_losser2(self): | def test_losser2(self): | ||||
# (2) with corrupted size | # (2) with corrupted size | ||||
pred_dict = {"pred": torch.zeros(16, 3)} | pred_dict = {"pred": torch.zeros(16, 3)} | ||||
target_dict = {'target': torch.zeros(16, 3).long()} | target_dict = {'target': torch.zeros(16, 3).long()} | ||||
los = loss.CrossEntropyLoss() | los = loss.CrossEntropyLoss() | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | print(los(pred_dict=pred_dict, target_dict=target_dict)) | ||||
def test_losser3(self): | def test_losser3(self): | ||||
# (2) with corrupted size | # (2) with corrupted size | ||||
pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} | pred_dict = {"pred": torch.zeros(16, 3), 'stop_fast_param': 0} | ||||
target_dict = {'target': torch.zeros(16).long()} | target_dict = {'target': torch.zeros(16).long()} | ||||
los = loss.CrossEntropyLoss() | los = loss.CrossEntropyLoss() | ||||
print(los(pred_dict=pred_dict, target_dict=target_dict)) | print(los(pred_dict=pred_dict, target_dict=target_dict)) | ||||
def test_check_error(self): | def test_check_error(self): | ||||
l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | l1 = loss.NLLLoss(pred="my_predict", target="my_truth") | ||||
a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | a = F.log_softmax(torch.randn(3, 5, requires_grad=False), dim=0) | ||||
b = torch.tensor([1, 0, 4]) | b = torch.tensor([1, 0, 4]) | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | ans = l1({"wrong_predict": a, "my": b}, {"my_truth": b}) | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
ans = l1({"my_predict": a}, {"truth": b, "my": a}) | ans = l1({"my_predict": a}, {"truth": b, "my": a}) | ||||
@@ -80,7 +80,7 @@ class TestLossUtils(unittest.TestCase): | |||||
a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) | a, b = squash(torch.randn(3, 5), torch.randn(3, 5)) | ||||
self.assertEqual(tuple(a.size()), (3, 5)) | self.assertEqual(tuple(a.size()), (3, 5)) | ||||
self.assertEqual(tuple(b.size()), (15,)) | self.assertEqual(tuple(b.size()), (15,)) | ||||
def test_unpad(self): | def test_unpad(self): | ||||
a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) | a, b = unpad(torch.randn(5, 8, 3), torch.randn(5, 8)) | ||||
self.assertEqual(tuple(a.size()), (5, 8, 3)) | self.assertEqual(tuple(a.size()), (5, 8, 3)) | ||||
@@ -3,8 +3,8 @@ import unittest | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.metrics import BMESF1PreRecMetric | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import BMESF1PreRecMetric | |||||
from fastNLP.core.metrics import _pred_topk, _accuracy_topk | from fastNLP.core.metrics import _pred_topk, _accuracy_topk | ||||
@@ -14,24 +14,24 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
pred_dict = {"pred": torch.zeros(4, 3)} | pred_dict = {"pred": torch.zeros(4, 3)} | ||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
def test_AccuracyMetric2(self): | def test_AccuracyMetric2(self): | ||||
# (2) with corrupted size | # (2) with corrupted size | ||||
try: | try: | ||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
metric(pred_dict=pred_dict, target_dict=target_dict, ) | metric(pred_dict=pred_dict, target_dict=target_dict, ) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
return | return | ||||
print("No exception catches.") | print("No exception catches.") | ||||
def test_AccuracyMetric3(self): | def test_AccuracyMetric3(self): | ||||
# (3) the second batch is corrupted size | # (3) the second batch is corrupted size | ||||
try: | try: | ||||
@@ -39,17 +39,17 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4, 3)} | target_dict = {'target': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
pred_dict = {"pred": torch.zeros(4, 3, 2)} | pred_dict = {"pred": torch.zeros(4, 3, 2)} | ||||
target_dict = {'target': torch.zeros(4)} | target_dict = {'target': torch.zeros(4)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
print(metric.get_metric()) | print(metric.get_metric()) | ||||
except Exception as e: | except Exception as e: | ||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_AccuaryMetric4(self): | def test_AccuaryMetric4(self): | ||||
# (5) check reset | # (5) check reset | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
@@ -61,7 +61,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
self.assertTrue(isinstance(res, dict)) | self.assertTrue(isinstance(res, dict)) | ||||
self.assertTrue("acc" in res) | self.assertTrue("acc" in res) | ||||
self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) | self.assertAlmostEqual(res["acc"], float(ans.float().mean()), places=3) | ||||
def test_AccuaryMetric5(self): | def test_AccuaryMetric5(self): | ||||
# (5) check reset | # (5) check reset | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
@@ -71,7 +71,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
res = metric.get_metric(reset=False) | res = metric.get_metric(reset=False) | ||||
ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() | ans = (torch.argmax(pred_dict["pred"], dim=2).float() == target_dict["target"]).float().mean() | ||||
self.assertAlmostEqual(res["acc"], float(ans), places=4) | self.assertAlmostEqual(res["acc"], float(ans), places=4) | ||||
def test_AccuaryMetric6(self): | def test_AccuaryMetric6(self): | ||||
# (6) check numpy array is not acceptable | # (6) check numpy array is not acceptable | ||||
try: | try: | ||||
@@ -83,7 +83,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_AccuaryMetric7(self): | def test_AccuaryMetric7(self): | ||||
# (7) check map, match | # (7) check map, match | ||||
metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
@@ -93,7 +93,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
res = metric.get_metric() | res = metric.get_metric() | ||||
ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() | ans = (torch.argmax(pred_dict["predictions"], dim=2).float() == target_dict["targets"]).float().mean() | ||||
self.assertAlmostEqual(res["acc"], float(ans), places=4) | self.assertAlmostEqual(res["acc"], float(ans), places=4) | ||||
def test_AccuaryMetric8(self): | def test_AccuaryMetric8(self): | ||||
try: | try: | ||||
metric = AccuracyMetric(pred='predictions', target='targets') | metric = AccuracyMetric(pred='predictions', target='targets') | ||||
@@ -105,7 +105,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_AccuaryMetric9(self): | def test_AccuaryMetric9(self): | ||||
# (9) check map, include unused | # (9) check map, include unused | ||||
try: | try: | ||||
@@ -118,12 +118,12 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_AccuaryMetric10(self): | def test_AccuaryMetric10(self): | ||||
# (10) check _fast_metric | # (10) check _fast_metric | ||||
try: | try: | ||||
metric = AccuracyMetric() | metric = AccuracyMetric() | ||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3)*3} | |||||
pred_dict = {"predictions": torch.zeros(4, 3, 2), "seq_len": torch.ones(3) * 3} | |||||
target_dict = {'targets': torch.zeros(4, 3)} | target_dict = {'targets': torch.zeros(4, 3)} | ||||
metric(pred_dict=pred_dict, target_dict=target_dict) | metric(pred_dict=pred_dict, target_dict=target_dict) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1}) | self.assertDictEqual(metric.get_metric(), {'acc': 1}) | ||||
@@ -131,7 +131,7 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
print(e) | print(e) | ||||
return | return | ||||
self.assertTrue(True, False), "No exception catches." | self.assertTrue(True, False), "No exception catches." | ||||
def test_seq_len(self): | def test_seq_len(self): | ||||
N = 256 | N = 256 | ||||
seq_len = torch.zeros(N).long() | seq_len = torch.zeros(N).long() | ||||
@@ -145,20 +145,21 @@ class TestAccuracyMetric(unittest.TestCase): | |||||
metric(pred_dict=pred, target_dict=target) | metric(pred_dict=pred, target_dict=target) | ||||
self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | self.assertDictEqual(metric.get_metric(), {'acc': 1.}) | ||||
class SpanF1PreRecMetric(unittest.TestCase): | class SpanF1PreRecMetric(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
from fastNLP.core.metrics import _bmes_tag_to_spans | from fastNLP.core.metrics import _bmes_tag_to_spans | ||||
from fastNLP.core.metrics import _bio_tag_to_spans | from fastNLP.core.metrics import _bio_tag_to_spans | ||||
bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | bmes_lst = ['M-8', 'S-2', 'S-0', 'B-9', 'B-6', 'E-5', 'B-7', 'S-2', 'E-7', 'S-8'] | ||||
bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | bio_lst = ['O-8', 'O-2', 'B-0', 'O-9', 'I-6', 'I-5', 'I-7', 'I-2', 'I-7', 'O-8'] | ||||
expect_bmes_res = set() | expect_bmes_res = set() | ||||
expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)), | expect_bmes_res.update([('8', (0, 1)), ('2', (1, 2)), ('0', (2, 3)), ('9', (3, 4)), ('6', (4, 5)), | ||||
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) | |||||
('5', (5, 6)), ('7', (6, 7)), ('2', (7, 8)), ('7', (8, 9)), ('8', (9, 10))]) | |||||
expect_bio_res = set() | expect_bio_res = set() | ||||
expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), | expect_bio_res.update([('7', (8, 9)), ('0', (2, 3)), ('2', (7, 8)), ('5', (5, 6)), | ||||
('6', (4, 5)), ('7', (6, 7))]) | |||||
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) | |||||
('6', (4, 5)), ('7', (6, 7))]) | |||||
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) | |||||
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | ||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | ||||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | ||||
@@ -171,19 +172,19 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
# bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | # bio_strs = [str_ + '-' + tag for tag, str_ in zip(strs, np.random.choice(bio, size=len(strs)))] | ||||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | ||||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | ||||
def test_case2(self): | def test_case2(self): | ||||
# 测试不带label的 | # 测试不带label的 | ||||
from fastNLP.core.metrics import _bmes_tag_to_spans | from fastNLP.core.metrics import _bmes_tag_to_spans | ||||
from fastNLP.core.metrics import _bio_tag_to_spans | from fastNLP.core.metrics import _bio_tag_to_spans | ||||
bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | bmes_lst = ['B', 'E', 'B', 'S', 'B', 'M', 'E', 'M', 'B', 'E'] | ||||
bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | bio_lst = ['I', 'B', 'O', 'O', 'I', 'O', 'I', 'B', 'O', 'O'] | ||||
expect_bmes_res = set() | expect_bmes_res = set() | ||||
expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) | expect_bmes_res.update([('', (0, 2)), ('', (2, 3)), ('', (3, 4)), ('', (4, 7)), ('', (7, 8)), ('', (8, 10))]) | ||||
expect_bio_res = set() | expect_bio_res = set() | ||||
expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) | expect_bio_res.update([('', (7, 8)), ('', (6, 7)), ('', (4, 5)), ('', (0, 1)), ('', (1, 2))]) | ||||
self.assertSetEqual(expect_bmes_res,set(_bmes_tag_to_spans(bmes_lst))) | |||||
self.assertSetEqual(expect_bmes_res, set(_bmes_tag_to_spans(bmes_lst))) | |||||
self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | self.assertSetEqual(expect_bio_res, set(_bio_tag_to_spans(bio_lst))) | ||||
# 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | # 已与allennlp对应函数做过验证,但由于测试不能依赖allennlp,所以这里只是截取上面的例子做固定测试 | ||||
# from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | # from allennlp.data.dataset_readers.dataset_utils import bio_tags_to_spans as allen_bio_tags_to_spans | ||||
@@ -195,7 +196,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
# bio_strs = np.random.choice(bio, size=100) | # bio_strs = np.random.choice(bio, size=100) | ||||
# self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | # self.assertSetEqual(set(allen_bmes_tags_to_spans(bmes_strs)),set(bmes_tag_to_spans(bmes_strs))) | ||||
# self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | # self.assertSetEqual(set(allen_bio_tags_to_spans(bio_strs)), set(bio_tag_to_spans(bio_strs))) | ||||
def tese_case3(self): | def tese_case3(self): | ||||
from fastNLP.core.vocabulary import Vocabulary | from fastNLP.core.vocabulary import Vocabulary | ||||
from collections import Counter | from collections import Counter | ||||
@@ -213,7 +214,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
continue | continue | ||||
vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | vocab['{}-{}'.format(tag, label)] = len(vocab) + 1 # 其实表达的是这个的count | ||||
return vocab | return vocab | ||||
number_labels = 4 | number_labels = 4 | ||||
# bio tag | # bio tag | ||||
fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) | ||||
@@ -221,26 +222,26 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) | ||||
bio_sequence = torch.FloatTensor( | bio_sequence = torch.FloatTensor( | ||||
[[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | [[[-0.9543, -1.4357, -0.2365, 0.2438, 1.0312, -1.4302, 0.3011, | ||||
0.0470, 0.0971], | |||||
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||||
0.7987, -0.3970], | |||||
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||||
0.6880, 1.4348], | |||||
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||||
-1.6876, -0.8917], | |||||
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||||
1.4217, 0.2622]], | |||||
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||||
1.3592, -0.8973], | |||||
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||||
-0.4025, -0.3417], | |||||
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||||
0.2861, -0.3966], | |||||
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||||
0.0213, 1.4777], | |||||
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||||
1.3024, 0.2001]]] | |||||
0.0470, 0.0971], | |||||
[-0.6638, -0.7116, -1.9804, 0.2787, -0.2732, -0.9501, -1.4523, | |||||
0.7987, -0.3970], | |||||
[0.2939, 0.8132, -0.0903, -2.8296, 0.2080, -0.9823, -0.1898, | |||||
0.6880, 1.4348], | |||||
[-0.1886, 0.0067, -0.6862, -0.4635, 2.2776, 0.0710, -1.6793, | |||||
-1.6876, -0.8917], | |||||
[-0.7663, 0.6377, 0.8669, 0.1237, 1.7628, 0.0313, -1.0824, | |||||
1.4217, 0.2622]], | |||||
[[0.1529, 0.7474, -0.9037, 1.5287, 0.2771, 0.2223, 0.8136, | |||||
1.3592, -0.8973], | |||||
[0.4515, -0.5235, 0.3265, -1.1947, 0.8308, 1.8754, -0.4887, | |||||
-0.4025, -0.3417], | |||||
[-0.7855, 0.1615, -0.1272, -1.9289, -0.5181, 1.9742, -0.9698, | |||||
0.2861, -0.3966], | |||||
[-0.8291, -0.8823, -1.1496, 0.2164, 1.3390, -0.3964, -0.5275, | |||||
0.0213, 1.4777], | |||||
[-1.1299, 0.0627, -0.1358, -1.5951, 0.4484, -0.6081, -1.9566, | |||||
1.3024, 0.2001]]] | |||||
) | ) | ||||
bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | bio_target = torch.LongTensor([[5., 0., 3., 3., 3.], | ||||
[5., 6., 8., 6., 0.]]) | [5., 6., 8., 6., 0.]]) | ||||
@@ -250,8 +251,8 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | 'rec-0': 0.0, 'f-0': 0.0, 'pre': 0.12499999999999845, 'rec': 0.12499999999999845, | ||||
'f': 0.12499999999994846} | 'f': 0.12499999999994846} | ||||
self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | self.assertDictEqual(expect_bio_res, fastnlp_bio_metric.get_metric()) | ||||
#bmes tag | |||||
# bmes tag | |||||
bmes_sequence = torch.FloatTensor( | bmes_sequence = torch.FloatTensor( | ||||
[[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | [[[0.6536, -0.7179, 0.6579, 1.2503, 0.4176, 0.6696, 0.2352, | ||||
-0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | -0.4085, 0.4084, -0.4185, 1.4172, -0.9162, -0.2679, 0.3332, | ||||
@@ -268,7 +269,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
[0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | [0.9088, -0.4955, -0.5076, 0.3732, 0.0283, -0.0263, -1.0393, | ||||
0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | 0.7734, 1.0968, 0.4132, -1.3647, -0.5762, 0.6678, 0.8809, | ||||
-0.3779, -0.3195]], | -0.3779, -0.3195]], | ||||
[[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | [[-0.4638, -0.5939, -0.1052, -0.5573, 0.4600, -1.3484, 0.1753, | ||||
0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | 0.0685, 0.3663, -0.6789, 0.0097, 1.0327, -0.0212, -0.9957, | ||||
-0.1103, 0.4417], | -0.1103, 0.4417], | ||||
@@ -285,22 +286,22 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | 2.6973, -0.8308, -1.4939, 0.9865, -0.3935, 0.2743, 0.1142, | ||||
-0.7344, -1.2046]]] | -0.7344, -1.2046]]] | ||||
) | ) | ||||
bmes_target = torch.LongTensor([[ 9., 6., 1., 9., 15.], | |||||
[ 6., 15., 6., 15., 5.]]) | |||||
bmes_target = torch.LongTensor([[9., 6., 1., 9., 15.], | |||||
[6., 15., 6., 15., 5.]]) | |||||
fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | fastnlp_bmes_vocab = Vocabulary(unknown=None, padding=None) | ||||
fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | fastnlp_bmes_vocab.word_count = Counter(generate_allen_tags('BMES', number_labels)) | ||||
fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | fastnlp_bmes_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bmes_vocab, only_gross=False, encoding_type='bmes') | ||||
fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | fastnlp_bmes_metric({'pred': bmes_sequence, 'seq_lens': torch.LongTensor([20, 20])}, {'target': bmes_target}) | ||||
expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | expect_bmes_res = {'f-3': 0.6666666666665778, 'pre-3': 0.499999999999975, 'rec-3': 0.9999999999999001, | ||||
'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | 'f-0': 0.0, 'pre-0': 0.0, 'rec-0': 0.0, 'f-1': 0.33333333333327775, | ||||
'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | 'pre-1': 0.24999999999999373, 'rec-1': 0.499999999999975, 'f-2': 0.7499999999999314, | ||||
'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | 'pre-2': 0.7499999999999812, 'rec-2': 0.7499999999999812, 'f': 0.49999999999994504, | ||||
'pre': 0.499999999999995, 'rec': 0.499999999999995} | 'pre': 0.499999999999995, 'rec': 0.499999999999995} | ||||
self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | self.assertDictEqual(fastnlp_bmes_metric.get_metric(), expect_bmes_res) | ||||
# 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | # 已经和allennlp做过验证,但由于不能依赖allennlp,所以注释了以下代码 | ||||
# from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | # from allennlp.data.vocabulary import Vocabulary as allen_Vocabulary | ||||
# from allennlp.training.metrics import SpanBasedF1Measure | # from allennlp.training.metrics import SpanBasedF1Measure | ||||
@@ -349,6 +350,7 @@ class SpanF1PreRecMetric(unittest.TestCase): | |||||
# self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | # self.assertDictEqual(convert_allen_res_to_fastnlp_res(allen_bmes_metric.get_metric()), | ||||
# fastnlp_bmes_metric.get_metric()) | # fastnlp_bmes_metric.get_metric()) | ||||
class TestBMESF1PreRecMetric(unittest.TestCase): | class TestBMESF1PreRecMetric(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
seq_lens = torch.LongTensor([4, 2]) | seq_lens = torch.LongTensor([4, 2]) | ||||
@@ -356,20 +358,20 @@ class TestBMESF1PreRecMetric(unittest.TestCase): | |||||
target = torch.LongTensor([[0, 1, 2, 3], | target = torch.LongTensor([[0, 1, 2, 3], | ||||
[3, 3, 0, 0]]) | [3, 3, 0, 0]]) | ||||
pred_dict = {'pred': pred} | pred_dict = {'pred': pred} | ||||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||||
target_dict = {'target': target, 'seq_len': seq_lens} | |||||
metric = BMESF1PreRecMetric() | metric = BMESF1PreRecMetric() | ||||
metric(pred_dict, target_dict) | metric(pred_dict, target_dict) | ||||
metric.get_metric() | metric.get_metric() | ||||
def test_case2(self): | def test_case2(self): | ||||
# 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | # 测试相同两个seqence,应该给出{f1: 1, precision:1, recall:1} | ||||
seq_lens = torch.LongTensor([4, 2]) | seq_lens = torch.LongTensor([4, 2]) | ||||
target = torch.LongTensor([[0, 1, 2, 3], | target = torch.LongTensor([[0, 1, 2, 3], | ||||
[3, 3, 0, 0]]) | [3, 3, 0, 0]]) | ||||
pred_dict = {'pred': target} | pred_dict = {'pred': target} | ||||
target_dict = {'target': target, 'seq_lens': seq_lens} | |||||
target_dict = {'target': target, 'seq_len': seq_lens} | |||||
metric = BMESF1PreRecMetric() | metric = BMESF1PreRecMetric() | ||||
metric(pred_dict, target_dict) | metric(pred_dict, target_dict) | ||||
self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) | self.assertDictEqual(metric.get_metric(), {'f': 1.0, 'pre': 1.0, 'rec': 1.0}) | ||||
@@ -381,5 +383,5 @@ class TestUsefulFunctions(unittest.TestCase): | |||||
# multi-class | # multi-class | ||||
_ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) | _ = _accuracy_topk(np.random.randint(0, 3, size=(10, 1)), np.random.randint(0, 3, size=(10, 1)), k=3) | ||||
_ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) | _ = _pred_topk(np.random.randint(0, 3, size=(10, 1))) | ||||
# 跑通即可 | # 跑通即可 |
@@ -2,7 +2,7 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP.core.optimizer import SGD, Adam | |||||
from fastNLP import SGD, Adam | |||||
class TestOptim(unittest.TestCase): | class TestOptim(unittest.TestCase): | ||||
@@ -12,42 +12,42 @@ class TestOptim(unittest.TestCase): | |||||
self.assertTrue("momentum" in optim.__dict__["settings"]) | self.assertTrue("momentum" in optim.__dict__["settings"]) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | self.assertTrue(isinstance(res, torch.optim.SGD)) | ||||
optim = SGD(lr=0.001) | optim = SGD(lr=0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | self.assertTrue(isinstance(res, torch.optim.SGD)) | ||||
optim = SGD(lr=0.002, momentum=0.989) | optim = SGD(lr=0.002, momentum=0.989) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | ||||
self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | self.assertEqual(optim.__dict__["settings"]["momentum"], 0.989) | ||||
optim = SGD(0.001) | optim = SGD(0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue(isinstance(res, torch.optim.SGD)) | self.assertTrue(isinstance(res, torch.optim.SGD)) | ||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
_ = SGD("???") | _ = SGD("???") | ||||
with self.assertRaises(TypeError): | with self.assertRaises(TypeError): | ||||
_ = SGD(0.001, lr=0.002) | _ = SGD(0.001, lr=0.002) | ||||
def test_Adam(self): | def test_Adam(self): | ||||
optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) | optim = Adam(model_params=torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue("lr" in optim.__dict__["settings"]) | self.assertTrue("lr" in optim.__dict__["settings"]) | ||||
self.assertTrue("weight_decay" in optim.__dict__["settings"]) | self.assertTrue("weight_decay" in optim.__dict__["settings"]) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue(isinstance(res, torch.optim.Adam)) | self.assertTrue(isinstance(res, torch.optim.Adam)) | ||||
optim = Adam(lr=0.001) | optim = Adam(lr=0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
self.assertTrue(isinstance(res, torch.optim.Adam)) | self.assertTrue(isinstance(res, torch.optim.Adam)) | ||||
optim = Adam(lr=0.002, weight_decay=0.989) | optim = Adam(lr=0.002, weight_decay=0.989) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.002) | ||||
self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | self.assertEqual(optim.__dict__["settings"]["weight_decay"], 0.989) | ||||
optim = Adam(0.001) | optim = Adam(0.001) | ||||
self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | self.assertEqual(optim.__dict__["settings"]["lr"], 0.001) | ||||
res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | res = optim.construct_from_pytorch(torch.nn.Linear(10, 3).parameters()) | ||||
@@ -3,9 +3,9 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.sampler import SequentialSampler, RandomSampler, \ | |||||
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler | |||||
from fastNLP import DataSet | |||||
from fastNLP import SequentialSampler, RandomSampler, BucketSampler | |||||
from fastNLP.core.sampler import k_means_1d, k_means_bucketing, simple_sort_bucketing | |||||
class TestSampler(unittest.TestCase): | class TestSampler(unittest.TestCase): | ||||
@@ -1,32 +1,25 @@ | |||||
import unittest | import unittest | ||||
import numpy as np | |||||
from torch import nn | |||||
import time | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import Tester | |||||
data_name = "pku_training.utf8" | data_name = "pku_training.utf8" | ||||
pickle_path = "data_for_tests" | pickle_path = "data_for_tests" | ||||
import numpy as np | |||||
import torch.nn.functional as F | |||||
from torch import nn | |||||
import time | |||||
from fastNLP.core.utils import _CheckError | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.tester import Tester | |||||
from fastNLP.models.base_model import NaiveClassifier | |||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
mean = np.array([3, 3]) | mean = np.array([3, 3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | ||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | ||||
return data_set | return data_set | ||||
@@ -39,6 +32,7 @@ def prepare_fake_dataset2(*args, size=100): | |||||
data[arg] = np.random.randn(size, 5) | data[arg] = np.random.randn(size, 5) | ||||
return DataSet(data=data) | return DataSet(data=data) | ||||
class TestTester(unittest.TestCase): | class TestTester(unittest.TestCase): | ||||
def test_case_1(self): | def test_case_1(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
@@ -46,10 +40,12 @@ class TestTester(unittest.TestCase): | |||||
dataset.rename_field('x_unused', 'x2') | dataset.rename_field('x_unused', 'x2') | ||||
dataset.set_input('x1', 'x2') | dataset.set_input('x1', 'x2') | ||||
dataset.set_target('y', 'x1') | dataset.set_target('y', 'x1') | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
@@ -57,7 +53,7 @@ class TestTester(unittest.TestCase): | |||||
time.sleep(0.1) | time.sleep(0.1) | ||||
# loss = F.cross_entropy(x, y) | # loss = F.cross_entropy(x, y) | ||||
return {'preds': x} | return {'preds': x} | ||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
tester = Tester( | tester = Tester( | ||||
@@ -5,25 +5,24 @@ import numpy as np | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP.core.losses import BCELoss | |||||
from fastNLP.core.losses import CrossEntropyLoss | |||||
from fastNLP.core.metrics import AccuracyMetric | |||||
from fastNLP.core.optimizer import SGD | |||||
from fastNLP.core.trainer import Trainer | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
from fastNLP import BCELoss | |||||
from fastNLP import CrossEntropyLoss | |||||
from fastNLP import AccuracyMetric | |||||
from fastNLP import SGD | |||||
from fastNLP import Trainer | |||||
from fastNLP.models.base_model import NaiveClassifier | from fastNLP.models.base_model import NaiveClassifier | ||||
def prepare_fake_dataset(): | def prepare_fake_dataset(): | ||||
mean = np.array([-3, -3]) | mean = np.array([-3, -3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_A = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
mean = np.array([3, 3]) | mean = np.array([3, 3]) | ||||
cov = np.array([[1, 0], [0, 1]]) | cov = np.array([[1, 0], [0, 1]]) | ||||
class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | class_B = np.random.multivariate_normal(mean, cov, size=(1000,)) | ||||
data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] + | ||||
[Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B]) | ||||
return data_set | return data_set | ||||
@@ -42,11 +41,11 @@ class TrainerTestGround(unittest.TestCase): | |||||
data_set = prepare_fake_dataset() | data_set = prepare_fake_dataset() | ||||
data_set.set_input("x", flag=True) | data_set.set_input("x", flag=True) | ||||
data_set.set_target("y", flag=True) | data_set.set_target("y", flag=True) | ||||
train_set, dev_set = data_set.split(0.3) | train_set, dev_set = data_set.split(0.3) | ||||
model = NaiveClassifier(2, 1) | model = NaiveClassifier(2, 1) | ||||
trainer = Trainer(train_set, model, | trainer = Trainer(train_set, model, | ||||
loss=BCELoss(pred="predict", target="y"), | loss=BCELoss(pred="predict", target="y"), | ||||
metrics=AccuracyMetric(pred="predict", target="y"), | metrics=AccuracyMetric(pred="predict", target="y"), | ||||
@@ -63,26 +62,26 @@ class TrainerTestGround(unittest.TestCase): | |||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
""" | """ | ||||
def test_trainer_suggestion1(self): | def test_trainer_suggestion1(self): | ||||
# 检查报错提示能否正确提醒用户。 | # 检查报错提示能否正确提醒用户。 | ||||
# 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。 | ||||
dataset = prepare_fake_dataset2('x') | dataset = prepare_fake_dataset2('x') | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
x = x1 + x2 | x = x1 + x2 | ||||
loss = F.cross_entropy(x, y) | loss = F.cross_entropy(x, y) | ||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
train_data=dataset, | train_data=dataset, | ||||
@@ -97,25 +96,25 @@ class TrainerTestGround(unittest.TestCase): | |||||
(2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | (2). You need to provide ['x1', 'x2'] in DataSet and set it as input. | ||||
""" | """ | ||||
def test_trainer_suggestion2(self): | def test_trainer_suggestion2(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
# 这里传入forward需要的数据,看是否可以运行 | # 这里传入forward需要的数据,看是否可以运行 | ||||
dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
dataset.set_input('x1', 'x2', 'y', flag=True) | dataset.set_input('x1', 'x2', 'y', flag=True) | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
x = x1 + x2 | x = x1 + x2 | ||||
loss = F.cross_entropy(x, y) | loss = F.cross_entropy(x, y) | ||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | trainer = Trainer( | ||||
train_data=dataset, | train_data=dataset, | ||||
@@ -127,25 +126,25 @@ class TrainerTestGround(unittest.TestCase): | |||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
""" | """ | ||||
def test_trainer_suggestion3(self): | def test_trainer_suggestion3(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
# 这里传入forward需要的数据,但是forward没有返回loss这个key | # 这里传入forward需要的数据,但是forward没有返回loss这个key | ||||
dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
dataset.set_input('x1', 'x2', 'y', flag=True) | dataset.set_input('x1', 'x2', 'y', flag=True) | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
x = x1 + x2 | x = x1 + x2 | ||||
loss = F.cross_entropy(x, y) | loss = F.cross_entropy(x, y) | ||||
return {'wrong_loss_key': loss} | return {'wrong_loss_key': loss} | ||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -155,23 +154,25 @@ class TrainerTestGround(unittest.TestCase): | |||||
print_every=2 | print_every=2 | ||||
) | ) | ||||
trainer.train() | trainer.train() | ||||
def test_trainer_suggestion4(self): | def test_trainer_suggestion4(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
# 这里传入forward需要的数据,是否可以正确提示unused | # 这里传入forward需要的数据,是否可以正确提示unused | ||||
dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
dataset.set_input('x1', 'x2', 'y', flag=True) | dataset.set_input('x1', 'x2', 'y', flag=True) | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
x = x1 + x2 | x = x1 + x2 | ||||
loss = F.cross_entropy(x, y) | loss = F.cross_entropy(x, y) | ||||
return {'losses': loss} | return {'losses': loss} | ||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -180,7 +181,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
print_every=2 | print_every=2 | ||||
) | ) | ||||
def test_trainer_suggestion5(self): | def test_trainer_suggestion5(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
# 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | # 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错 | ||||
@@ -188,17 +189,19 @@ class TrainerTestGround(unittest.TestCase): | |||||
dataset.rename_field('x_unused', 'x2') | dataset.rename_field('x_unused', 'x2') | ||||
dataset.set_input('x1', 'x2', 'y') | dataset.set_input('x1', 'x2', 'y') | ||||
dataset.set_target('y') | dataset.set_target('y') | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
def forward(self, x1, x2, y): | def forward(self, x1, x2, y): | ||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
x = x1 + x2 | x = x1 + x2 | ||||
loss = F.cross_entropy(x, y) | loss = F.cross_entropy(x, y) | ||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | trainer = Trainer( | ||||
train_data=dataset, | train_data=dataset, | ||||
@@ -206,7 +209,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
use_tqdm=False, | use_tqdm=False, | ||||
print_every=2 | print_every=2 | ||||
) | ) | ||||
def test_trainer_suggestion6(self): | def test_trainer_suggestion6(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
# 这里传入多余参数,让其duplicate | # 这里传入多余参数,让其duplicate | ||||
@@ -214,10 +217,12 @@ class TrainerTestGround(unittest.TestCase): | |||||
dataset.rename_field('x_unused', 'x2') | dataset.rename_field('x_unused', 'x2') | ||||
dataset.set_input('x1', 'x2') | dataset.set_input('x1', 'x2') | ||||
dataset.set_target('y', 'x1') | dataset.set_target('y', 'x1') | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.fc = nn.Linear(5, 4) | self.fc = nn.Linear(5, 4) | ||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
x1 = self.fc(x1) | x1 = self.fc(x1) | ||||
x2 = self.fc(x2) | x2 = self.fc(x2) | ||||
@@ -225,7 +230,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
time.sleep(0.1) | time.sleep(0.1) | ||||
# loss = F.cross_entropy(x, y) | # loss = F.cross_entropy(x, y) | ||||
return {'preds': x} | return {'preds': x} | ||||
model = Model() | model = Model() | ||||
with self.assertRaises(NameError): | with self.assertRaises(NameError): | ||||
trainer = Trainer( | trainer = Trainer( | ||||
@@ -236,7 +241,7 @@ class TrainerTestGround(unittest.TestCase): | |||||
metrics=AccuracyMetric(), | metrics=AccuracyMetric(), | ||||
use_tqdm=False, | use_tqdm=False, | ||||
print_every=2) | print_every=2) | ||||
""" | """ | ||||
def test_trainer_multiprocess(self): | def test_trainer_multiprocess(self): | ||||
dataset = prepare_fake_dataset2('x1', 'x2') | dataset = prepare_fake_dataset2('x1', 'x2') | ||||
@@ -1,8 +1,7 @@ | |||||
import unittest | import unittest | ||||
import _pickle | import _pickle | ||||
from fastNLP import cache_results | from fastNLP import cache_results | ||||
from fastNLP.io.embed_loader import EmbedLoader | |||||
from fastNLP.io import EmbedLoader | |||||
from fastNLP import DataSet | from fastNLP import DataSet | ||||
from fastNLP import Instance | from fastNLP import Instance | ||||
import time | import time | ||||
@@ -11,11 +10,13 @@ import torch | |||||
from torch import nn | from torch import nn | ||||
from fastNLP.core.utils import _move_model_to_device, _get_model_device | from fastNLP.core.utils import _move_model_to_device, _get_model_device | ||||
class Model(nn.Module): | class Model(nn.Module): | ||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.param = nn.Parameter(torch.zeros(0)) | self.param = nn.Parameter(torch.zeros(0)) | ||||
class TestMoveModelDeivce(unittest.TestCase): | class TestMoveModelDeivce(unittest.TestCase): | ||||
def test_case1(self): | def test_case1(self): | ||||
# 测试str | # 测试str | ||||
@@ -35,36 +36,36 @@ class TestMoveModelDeivce(unittest.TestCase): | |||||
_move_model_to_device(model, 'cuda:1000') | _move_model_to_device(model, 'cuda:1000') | ||||
# 测试None | # 测试None | ||||
model = _move_model_to_device(model, None) | model = _move_model_to_device(model, None) | ||||
def test_case2(self): | def test_case2(self): | ||||
# 测试使用int初始化 | # 测试使用int初始化 | ||||
model = Model() | model = Model() | ||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
model = _move_model_to_device(model, 0) | model = _move_model_to_device(model, 0) | ||||
assert model.param.device == torch.device('cuda:0') | assert model.param.device == torch.device('cuda:0') | ||||
assert model.param.device==torch.device('cuda:0'), "The model should be in " | |||||
assert model.param.device == torch.device('cuda:0'), "The model should be in " | |||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
_move_model_to_device(model, 100) | _move_model_to_device(model, 100) | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
_move_model_to_device(model, -1) | _move_model_to_device(model, -1) | ||||
def test_case3(self): | def test_case3(self): | ||||
# 测试None | # 测试None | ||||
model = Model() | model = Model() | ||||
device = _get_model_device(model) | device = _get_model_device(model) | ||||
model = _move_model_to_device(model, None) | model = _move_model_to_device(model, None) | ||||
assert device==_get_model_device(model), "The device should not change." | |||||
assert device == _get_model_device(model), "The device should not change." | |||||
if torch.cuda.is_available(): | if torch.cuda.is_available(): | ||||
model.cuda() | model.cuda() | ||||
device = _get_model_device(model) | device = _get_model_device(model) | ||||
model = _move_model_to_device(model, None) | model = _move_model_to_device(model, None) | ||||
assert device==_get_model_device(model), "The device should not change." | |||||
assert device == _get_model_device(model), "The device should not change." | |||||
model = nn.DataParallel(model, device_ids=[0]) | model = nn.DataParallel(model, device_ids=[0]) | ||||
_move_model_to_device(model, None) | _move_model_to_device(model, None) | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
_move_model_to_device(model, 'cpu') | _move_model_to_device(model, 'cpu') | ||||
def test_case4(self): | def test_case4(self): | ||||
# 测试传入list的内容 | # 测试传入list的内容 | ||||
model = Model() | model = Model() | ||||
@@ -78,15 +79,17 @@ class TestMoveModelDeivce(unittest.TestCase): | |||||
device = [torch.device('cuda:0'), torch.device('cuda:0')] | device = [torch.device('cuda:0'), torch.device('cuda:0')] | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
_model = _move_model_to_device(model, device) | _model = _move_model_to_device(model, device) | ||||
if torch.cuda.device_count()>1: | |||||
if torch.cuda.device_count() > 1: | |||||
device = [0, 1] | device = [0, 1] | ||||
_model = _move_model_to_device(model, device) | _model = _move_model_to_device(model, device) | ||||
assert isinstance(_model, nn.DataParallel) | assert isinstance(_model, nn.DataParallel) | ||||
device = ['cuda', 'cuda:1'] | device = ['cuda', 'cuda:1'] | ||||
with self.assertRaises(Exception): | with self.assertRaises(Exception): | ||||
_move_model_to_device(model, device) | _move_model_to_device(model, device) | ||||
def test_case5(self): | def test_case5(self): | ||||
if not torch.cuda.is_available(): | |||||
return | |||||
# torch.device() | # torch.device() | ||||
device = torch.device('cpu') | device = torch.device('cpu') | ||||
model = Model() | model = Model() | ||||
@@ -106,10 +109,11 @@ def process_data_1(embed_file, cws_train): | |||||
d = DataSet() | d = DataSet() | ||||
for line in f: | for line in f: | ||||
line = line.strip() | line = line.strip() | ||||
if len(line)>0: | |||||
if len(line) > 0: | |||||
d.append(Instance(raw=line)) | d.append(Instance(raw=line)) | ||||
return embed, vocab, d | return embed, vocab, d | ||||
class TestCache(unittest.TestCase): | class TestCache(unittest.TestCase): | ||||
def test_cache_save(self): | def test_cache_save(self): | ||||
try: | try: | ||||
@@ -127,10 +131,10 @@ class TestCache(unittest.TestCase): | |||||
end_time = time.time() | end_time = time.time() | ||||
read_time = end_time - start_time | read_time = end_time - start_time | ||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | ||||
self.assertGreater(pre_time-0.5, read_time) | |||||
self.assertGreater(pre_time - 0.5, read_time) | |||||
finally: | finally: | ||||
os.remove('test/demo1.pkl') | os.remove('test/demo1.pkl') | ||||
def test_cache_save_overwrite_path(self): | def test_cache_save_overwrite_path(self): | ||||
try: | try: | ||||
start_time = time.time() | start_time = time.time() | ||||
@@ -149,10 +153,10 @@ class TestCache(unittest.TestCase): | |||||
end_time = time.time() | end_time = time.time() | ||||
read_time = end_time - start_time | read_time = end_time - start_time | ||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | ||||
self.assertGreater(pre_time-0.5, read_time) | |||||
self.assertGreater(pre_time - 0.5, read_time) | |||||
finally: | finally: | ||||
os.remove('test/demo_overwrite.pkl') | os.remove('test/demo_overwrite.pkl') | ||||
def test_cache_refresh(self): | def test_cache_refresh(self): | ||||
try: | try: | ||||
start_time = time.time() | start_time = time.time() | ||||
@@ -171,34 +175,38 @@ class TestCache(unittest.TestCase): | |||||
end_time = time.time() | end_time = time.time() | ||||
read_time = end_time - start_time | read_time = end_time - start_time | ||||
print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | print("Read using {:.3f}, while prepare using:{:.3f}".format(read_time, pre_time)) | ||||
self.assertGreater(0.1, pre_time-read_time) | |||||
self.assertGreater(0.1, pre_time - read_time) | |||||
finally: | finally: | ||||
os.remove('test/demo1.pkl') | os.remove('test/demo1.pkl') | ||||
def test_duplicate_keyword(self): | def test_duplicate_keyword(self): | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
@cache_results(None) | @cache_results(None) | ||||
def func_verbose(a, _verbose): | def func_verbose(a, _verbose): | ||||
pass | pass | ||||
func_verbose(0, 1) | func_verbose(0, 1) | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
@cache_results(None) | @cache_results(None) | ||||
def func_cache(a, _cache_fp): | def func_cache(a, _cache_fp): | ||||
pass | pass | ||||
func_cache(1, 2) | func_cache(1, 2) | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
@cache_results(None) | @cache_results(None) | ||||
def func_refresh(a, _refresh): | def func_refresh(a, _refresh): | ||||
pass | pass | ||||
func_refresh(1, 2) | func_refresh(1, 2) | ||||
def test_create_cache_dir(self): | def test_create_cache_dir(self): | ||||
@cache_results('test/demo1/demo.pkl') | @cache_results('test/demo1/demo.pkl') | ||||
def cache(): | def cache(): | ||||
return 1, 2 | return 1, 2 | ||||
try: | try: | ||||
results = cache() | results = cache() | ||||
print(results) | print(results) | ||||
finally: | finally: | ||||
os.remove('test/demo1/demo.pkl') | os.remove('test/demo1/demo.pkl') | ||||
os.rmdir('test/demo1') | |||||
os.rmdir('test/demo1') |
@@ -1,9 +1,9 @@ | |||||
import unittest | import unittest | ||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.vocabulary import Vocabulary | |||||
from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.instance import Instance | |||||
from fastNLP import Vocabulary | |||||
from fastNLP import DataSet | |||||
from fastNLP import Instance | |||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | ||||
"works", "well", "in", "most", "cases", "scales", "well"] | "works", "well", "in", "most", "cases", "scales", "well"] | ||||
@@ -12,92 +12,93 @@ counter = Counter(text) | |||||
class TestAdd(unittest.TestCase): | class TestAdd(unittest.TestCase): | ||||
def test_add(self): | def test_add(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab = Vocabulary() | |||||
for word in text: | for word in text: | ||||
vocab.add(word) | vocab.add(word) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_add_word(self): | def test_add_word(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab = Vocabulary() | |||||
for word in text: | for word in text: | ||||
vocab.add_word(word) | vocab.add_word(word) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_add_word_lst(self): | def test_add_word_lst(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab = Vocabulary() | |||||
vocab.add_word_lst(text) | vocab.add_word_lst(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_update(self): | def test_update(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab = Vocabulary() | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_from_dataset(self): | def test_from_dataset(self): | ||||
start_char = 65 | start_char = 65 | ||||
num_samples = 10 | num_samples = 10 | ||||
# 0 dim | # 0 dim | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for i in range(num_samples): | for i in range(num_samples): | ||||
ins = Instance(char=chr(start_char+i)) | |||||
ins = Instance(char=chr(start_char + i)) | |||||
dataset.append(ins) | dataset.append(ins) | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
vocab.from_dataset(dataset, field_name='char') | vocab.from_dataset(dataset, field_name='char') | ||||
for i in range(num_samples): | for i in range(num_samples): | ||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||||
vocab.index_dataset(dataset, field_name='char') | vocab.index_dataset(dataset, field_name='char') | ||||
# 1 dim | # 1 dim | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for i in range(num_samples): | for i in range(num_samples): | ||||
ins = Instance(char=[chr(start_char+i)]*6) | |||||
ins = Instance(char=[chr(start_char + i)] * 6) | |||||
dataset.append(ins) | dataset.append(ins) | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
vocab.from_dataset(dataset, field_name='char') | vocab.from_dataset(dataset, field_name='char') | ||||
for i in range(num_samples): | for i in range(num_samples): | ||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||||
vocab.index_dataset(dataset, field_name='char') | vocab.index_dataset(dataset, field_name='char') | ||||
# 2 dim | # 2 dim | ||||
dataset = DataSet() | dataset = DataSet() | ||||
for i in range(num_samples): | for i in range(num_samples): | ||||
ins = Instance(char=[[chr(start_char+i) for _ in range(6)] for _ in range(6)]) | |||||
ins = Instance(char=[[chr(start_char + i) for _ in range(6)] for _ in range(6)]) | |||||
dataset.append(ins) | dataset.append(ins) | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
vocab.from_dataset(dataset, field_name='char') | vocab.from_dataset(dataset, field_name='char') | ||||
for i in range(num_samples): | for i in range(num_samples): | ||||
self.assertEqual(vocab.to_index(chr(start_char+i)), i+2) | |||||
self.assertEqual(vocab.to_index(chr(start_char + i)), i + 2) | |||||
vocab.index_dataset(dataset, field_name='char') | vocab.index_dataset(dataset, field_name='char') | ||||
class TestIndexing(unittest.TestCase): | class TestIndexing(unittest.TestCase): | ||||
def test_len(self): | def test_len(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab = Vocabulary(unknown=None, padding=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(len(vocab), len(counter)) | self.assertEqual(len(vocab), len(counter)) | ||||
def test_contains(self): | def test_contains(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab = Vocabulary(unknown=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertTrue(text[-1] in vocab) | self.assertTrue(text[-1] in vocab) | ||||
self.assertFalse("~!@#" in vocab) | self.assertFalse("~!@#" in vocab) | ||||
self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | self.assertEqual(text[-1] in vocab, vocab.has_word(text[-1])) | ||||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | ||||
def test_index(self): | def test_index(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab = Vocabulary() | |||||
vocab.update(text) | vocab.update(text) | ||||
res = [vocab[w] for w in set(text)] | res = [vocab[w] for w in set(text)] | ||||
self.assertEqual(len(res), len(set(res))) | self.assertEqual(len(res), len(set(res))) | ||||
res = [vocab.to_index(w) for w in set(text)] | res = [vocab.to_index(w) for w in set(text)] | ||||
self.assertEqual(len(res), len(set(res))) | self.assertEqual(len(res), len(set(res))) | ||||
def test_to_word(self): | def test_to_word(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab = Vocabulary() | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | ||||
def test_iteration(self): | def test_iteration(self): | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | text = ["FastNLP", "works", "well", "in", "most", "cases", "and", "scales", "well", "in", | ||||
@@ -110,26 +111,26 @@ class TestIndexing(unittest.TestCase): | |||||
class TestOther(unittest.TestCase): | class TestOther(unittest.TestCase): | ||||
def test_additional_update(self): | def test_additional_update(self): | ||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab = Vocabulary() | |||||
vocab.update(text) | vocab.update(text) | ||||
_ = vocab["well"] | _ = vocab["well"] | ||||
self.assertEqual(vocab.rebuild, False) | self.assertEqual(vocab.rebuild, False) | ||||
vocab.add("hahaha") | vocab.add("hahaha") | ||||
self.assertEqual(vocab.rebuild, True) | self.assertEqual(vocab.rebuild, True) | ||||
_ = vocab["hahaha"] | _ = vocab["hahaha"] | ||||
self.assertEqual(vocab.rebuild, False) | self.assertEqual(vocab.rebuild, False) | ||||
self.assertTrue("hahaha" in vocab) | self.assertTrue("hahaha" in vocab) | ||||
def test_warning(self): | def test_warning(self): | ||||
vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||||
vocab = Vocabulary(max_size=len(set(text))) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(vocab.rebuild, True) | self.assertEqual(vocab.rebuild, True) | ||||
print(len(vocab)) | print(len(vocab)) | ||||
self.assertEqual(vocab.rebuild, False) | self.assertEqual(vocab.rebuild, False) | ||||
vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | vocab.update(["hahahha", "hhh", "vvvv", "ass", "asss", "jfweiong", "eqgfeg", "feqfw"]) | ||||
# this will print a warning | # this will print a warning | ||||
self.assertEqual(vocab.rebuild, True) | self.assertEqual(vocab.rebuild, True) |