@@ -1,11 +1,4 @@ | |||
from collections import Counter | |||
from copy import deepcopy | |||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1} | |||
def isiterable(p_object): | |||
try: | |||
@@ -57,22 +50,16 @@ class Vocabulary(object): | |||
vocab.to_word(5) | |||
""" | |||
def __init__(self, need_default=True, max_size=None, min_freq=None): | |||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||
""" | |||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||
:param int max_size: set the max number of words in Vocabulary. Default: None | |||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | |||
""" | |||
self.max_size = max_size | |||
self.min_freq = min_freq | |||
self.word_count = Counter() | |||
self.has_default = need_default | |||
if self.has_default: | |||
self.padding_label = DEFAULT_PADDING_LABEL | |||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||
else: | |||
self.padding_label = None | |||
self.unknown_label = None | |||
self.unknown = unknown | |||
self.padding = padding | |||
self.word2idx = None | |||
self.idx2word = None | |||
self.rebuild = True | |||
@@ -113,17 +100,18 @@ class Vocabulary(object): | |||
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | |||
""" | |||
if self.has_default: | |||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||
else: | |||
self.word2idx = {} | |||
self.word2idx = {} | |||
if self.padding is not None: | |||
self.word2idx[self.padding] = 0 | |||
if self.unknown is not None: | |||
self.word2idx[self.unknown] = 1 | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
if self.min_freq is not None: | |||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
if self.word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||
start_idx = len(self.word2idx) | |||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
@@ -159,8 +147,8 @@ class Vocabulary(object): | |||
""" | |||
if w in self.word2idx: | |||
return self.word2idx[w] | |||
elif self.has_default: | |||
return self.word2idx[self.unknown_label] | |||
if self.unknown is not None: | |||
return self.word2idx[self.unknown] | |||
else: | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
@@ -175,21 +163,16 @@ class Vocabulary(object): | |||
@property | |||
@check_build_vocab | |||
def unknown_idx(self): | |||
if self.unknown_label is None: | |||
if self.unknown is None: | |||
return None | |||
return self.word2idx[self.unknown_label] | |||
def __setattr__(self, name, val): | |||
self.__dict__[name] = val | |||
if name in ["unknown_label", "padding_label"]: | |||
self.word2idx = None | |||
return self.word2idx[self.unknown] | |||
@property | |||
@check_build_vocab | |||
def padding_idx(self): | |||
if self.padding_label is None: | |||
if self.padding is None: | |||
return None | |||
return self.word2idx[self.padding_label] | |||
return self.word2idx[self.padding] | |||
@check_build_vocab | |||
def to_word(self, idx): | |||
@@ -4,6 +4,7 @@ import numpy as np | |||
import torch.nn.functional as F | |||
from torch import nn | |||
from fastNLP.core.utils import CheckError | |||
from fastNLP.core.dataset import DataSet | |||
from fastNLP.core.instance import Instance | |||
from fastNLP.core.losses import BCELoss | |||
@@ -56,7 +57,8 @@ class TrainerTestGround(unittest.TestCase): | |||
dev_data=dev_set, | |||
optimizer=SGD(lr=0.1), | |||
check_code_level=2, | |||
use_tqdm=True) | |||
use_tqdm=True, | |||
save_path=None) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
@@ -145,16 +147,14 @@ class TrainerTestGround(unittest.TestCase): | |||
return {'wrong_loss_key': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer.train() | |||
""" | |||
# 应该正确运行 | |||
""" | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
trainer.train() | |||
def test_trainer_suggestion4(self): | |||
# 检查报错提示能否正确提醒用户 | |||
@@ -173,12 +173,13 @@ class TrainerTestGround(unittest.TestCase): | |||
return {'loss': loss} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
def test_trainer_suggestion5(self): | |||
# 检查报错提示能否正确提醒用户 | |||
@@ -225,14 +226,15 @@ class TrainerTestGround(unittest.TestCase): | |||
return {'pred': x} | |||
model = Model() | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
dev_data=dataset, | |||
metrics=AccuracyMetric(), | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
with self.assertRaises(NameError): | |||
trainer = Trainer( | |||
train_data=dataset, | |||
model=model, | |||
dev_data=dataset, | |||
metrics=AccuracyMetric(), | |||
use_tqdm=False, | |||
print_every=2 | |||
) | |||
def test_case2(self): | |||
# check metrics Wrong | |||
@@ -10,36 +10,36 @@ counter = Counter(text) | |||
class TestAdd(unittest.TestCase): | |||
def test_add(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
for word in text: | |||
vocab.add(word) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_add_word(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
for word in text: | |||
vocab.add_word(word) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_add_word_lst(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab.add_word_lst(text) | |||
self.assertEqual(vocab.word_count, counter) | |||
def test_update(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab.update(text) | |||
self.assertEqual(vocab.word_count, counter) | |||
class TestIndexing(unittest.TestCase): | |||
def test_len(self): | |||
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
vocab.update(text) | |||
self.assertEqual(len(vocab), len(counter)) | |||
def test_contains(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||
vocab.update(text) | |||
self.assertTrue(text[-1] in vocab) | |||
self.assertFalse("~!@#" in vocab) | |||
@@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase): | |||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | |||
def test_index(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab.update(text) | |||
res = [vocab[w] for w in set(text)] | |||
self.assertEqual(len(res), len(set(res))) | |||
@@ -56,14 +56,14 @@ class TestIndexing(unittest.TestCase): | |||
self.assertEqual(len(res), len(set(res))) | |||
def test_to_word(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab.update(text) | |||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | |||
class TestOther(unittest.TestCase): | |||
def test_additional_update(self): | |||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||
vocab = Vocabulary(max_size=None, min_freq=None) | |||
vocab.update(text) | |||
_ = vocab["well"] | |||
@@ -77,7 +77,7 @@ class TestOther(unittest.TestCase): | |||
self.assertTrue("hahaha" in vocab) | |||
def test_warning(self): | |||
vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) | |||
vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||
vocab.update(text) | |||
self.assertEqual(vocab.rebuild, True) | |||
print(len(vocab)) | |||