@@ -1,11 +1,4 @@ | |||||
from collections import Counter | from collections import Counter | ||||
from copy import deepcopy | |||||
DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0 | |||||
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1 | |||||
DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1} | |||||
def isiterable(p_object): | def isiterable(p_object): | ||||
try: | try: | ||||
@@ -57,22 +50,16 @@ class Vocabulary(object): | |||||
vocab.to_word(5) | vocab.to_word(5) | ||||
""" | """ | ||||
def __init__(self, need_default=True, max_size=None, min_freq=None): | |||||
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'): | |||||
""" | """ | ||||
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True. | |||||
:param int max_size: set the max number of words in Vocabulary. Default: None | :param int max_size: set the max number of words in Vocabulary. Default: None | ||||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | ||||
""" | """ | ||||
self.max_size = max_size | self.max_size = max_size | ||||
self.min_freq = min_freq | self.min_freq = min_freq | ||||
self.word_count = Counter() | self.word_count = Counter() | ||||
self.has_default = need_default | |||||
if self.has_default: | |||||
self.padding_label = DEFAULT_PADDING_LABEL | |||||
self.unknown_label = DEFAULT_UNKNOWN_LABEL | |||||
else: | |||||
self.padding_label = None | |||||
self.unknown_label = None | |||||
self.unknown = unknown | |||||
self.padding = padding | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | self.rebuild = True | ||||
@@ -113,17 +100,18 @@ class Vocabulary(object): | |||||
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. | ||||
""" | """ | ||||
if self.has_default: | |||||
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX) | |||||
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL) | |||||
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL) | |||||
else: | |||||
self.word2idx = {} | |||||
self.word2idx = {} | |||||
if self.padding is not None: | |||||
self.word2idx[self.padding] = 0 | |||||
if self.unknown is not None: | |||||
self.word2idx[self.unknown] = 1 | |||||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | ||||
words = self.word_count.most_common(max_size) | words = self.word_count.most_common(max_size) | ||||
if self.min_freq is not None: | if self.min_freq is not None: | ||||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | words = filter(lambda kv: kv[1] >= self.min_freq, words) | ||||
if self.word2idx is not None: | |||||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||||
start_idx = len(self.word2idx) | start_idx = len(self.word2idx) | ||||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | ||||
self.build_reverse_vocab() | self.build_reverse_vocab() | ||||
@@ -159,8 +147,8 @@ class Vocabulary(object): | |||||
""" | """ | ||||
if w in self.word2idx: | if w in self.word2idx: | ||||
return self.word2idx[w] | return self.word2idx[w] | ||||
elif self.has_default: | |||||
return self.word2idx[self.unknown_label] | |||||
if self.unknown is not None: | |||||
return self.word2idx[self.unknown] | |||||
else: | else: | ||||
raise ValueError("word {} not in vocabulary".format(w)) | raise ValueError("word {} not in vocabulary".format(w)) | ||||
@@ -175,21 +163,16 @@ class Vocabulary(object): | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def unknown_idx(self): | def unknown_idx(self): | ||||
if self.unknown_label is None: | |||||
if self.unknown is None: | |||||
return None | return None | ||||
return self.word2idx[self.unknown_label] | |||||
def __setattr__(self, name, val): | |||||
self.__dict__[name] = val | |||||
if name in ["unknown_label", "padding_label"]: | |||||
self.word2idx = None | |||||
return self.word2idx[self.unknown] | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
if self.padding_label is None: | |||||
if self.padding is None: | |||||
return None | return None | ||||
return self.word2idx[self.padding_label] | |||||
return self.word2idx[self.padding] | |||||
@check_build_vocab | @check_build_vocab | ||||
def to_word(self, idx): | def to_word(self, idx): | ||||
@@ -4,6 +4,7 @@ import numpy as np | |||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from torch import nn | from torch import nn | ||||
from fastNLP.core.utils import CheckError | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.core.losses import BCELoss | from fastNLP.core.losses import BCELoss | ||||
@@ -56,7 +57,8 @@ class TrainerTestGround(unittest.TestCase): | |||||
dev_data=dev_set, | dev_data=dev_set, | ||||
optimizer=SGD(lr=0.1), | optimizer=SGD(lr=0.1), | ||||
check_code_level=2, | check_code_level=2, | ||||
use_tqdm=True) | |||||
use_tqdm=True, | |||||
save_path=None) | |||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
# 应该正确运行 | # 应该正确运行 | ||||
@@ -145,16 +147,14 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'wrong_loss_key': loss} | return {'wrong_loss_key': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
""" | |||||
# 应该正确运行 | |||||
""" | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
trainer.train() | |||||
def test_trainer_suggestion4(self): | def test_trainer_suggestion4(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
@@ -173,12 +173,13 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'loss': loss} | return {'loss': loss} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_trainer_suggestion5(self): | def test_trainer_suggestion5(self): | ||||
# 检查报错提示能否正确提醒用户 | # 检查报错提示能否正确提醒用户 | ||||
@@ -225,14 +226,15 @@ class TrainerTestGround(unittest.TestCase): | |||||
return {'pred': x} | return {'pred': x} | ||||
model = Model() | model = Model() | ||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
with self.assertRaises(NameError): | |||||
trainer = Trainer( | |||||
train_data=dataset, | |||||
model=model, | |||||
dev_data=dataset, | |||||
metrics=AccuracyMetric(), | |||||
use_tqdm=False, | |||||
print_every=2 | |||||
) | |||||
def test_case2(self): | def test_case2(self): | ||||
# check metrics Wrong | # check metrics Wrong | ||||
@@ -10,36 +10,36 @@ counter = Counter(text) | |||||
class TestAdd(unittest.TestCase): | class TestAdd(unittest.TestCase): | ||||
def test_add(self): | def test_add(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
for word in text: | for word in text: | ||||
vocab.add(word) | vocab.add(word) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_add_word(self): | def test_add_word(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
for word in text: | for word in text: | ||||
vocab.add_word(word) | vocab.add_word(word) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_add_word_lst(self): | def test_add_word_lst(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.add_word_lst(text) | vocab.add_word_lst(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
def test_update(self): | def test_update(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(vocab.word_count, counter) | self.assertEqual(vocab.word_count, counter) | ||||
class TestIndexing(unittest.TestCase): | class TestIndexing(unittest.TestCase): | ||||
def test_len(self): | def test_len(self): | ||||
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(len(vocab), len(counter)) | self.assertEqual(len(vocab), len(counter)) | ||||
def test_contains(self): | def test_contains(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertTrue(text[-1] in vocab) | self.assertTrue(text[-1] in vocab) | ||||
self.assertFalse("~!@#" in vocab) | self.assertFalse("~!@#" in vocab) | ||||
@@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase): | |||||
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) | ||||
def test_index(self): | def test_index(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
res = [vocab[w] for w in set(text)] | res = [vocab[w] for w in set(text)] | ||||
self.assertEqual(len(res), len(set(res))) | self.assertEqual(len(res), len(set(res))) | ||||
@@ -56,14 +56,14 @@ class TestIndexing(unittest.TestCase): | |||||
self.assertEqual(len(res), len(set(res))) | self.assertEqual(len(res), len(set(res))) | ||||
def test_to_word(self): | def test_to_word(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) | ||||
class TestOther(unittest.TestCase): | class TestOther(unittest.TestCase): | ||||
def test_additional_update(self): | def test_additional_update(self): | ||||
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None) | |||||
vocab = Vocabulary(max_size=None, min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
_ = vocab["well"] | _ = vocab["well"] | ||||
@@ -77,7 +77,7 @@ class TestOther(unittest.TestCase): | |||||
self.assertTrue("hahaha" in vocab) | self.assertTrue("hahaha" in vocab) | ||||
def test_warning(self): | def test_warning(self): | ||||
vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None) | |||||
vocab = Vocabulary(max_size=len(set(text)), min_freq=None) | |||||
vocab.update(text) | vocab.update(text) | ||||
self.assertEqual(vocab.rebuild, True) | self.assertEqual(vocab.rebuild, True) | ||||
print(len(vocab)) | print(len(vocab)) | ||||