Browse Source

fix bugs in vocab

tags/v0.2.0^2
yunfan 5 years ago
parent
commit
52b1b18a76
3 changed files with 53 additions and 68 deletions
  1. +16
    -33
      fastNLP/core/vocabulary.py
  2. +27
    -25
      test/core/test_trainer.py
  3. +10
    -10
      test/core/test_vocabulary.py

+ 16
- 33
fastNLP/core/vocabulary.py View File

@@ -1,11 +1,4 @@
from collections import Counter
from copy import deepcopy

DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1

DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1}


def isiterable(p_object):
try:
@@ -57,22 +50,16 @@ class Vocabulary(object):
vocab.to_word(5)
"""

def __init__(self, need_default=True, max_size=None, min_freq=None):
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'):
"""
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True.
:param int max_size: set the max number of words in Vocabulary. Default: None
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None
"""
self.max_size = max_size
self.min_freq = min_freq
self.word_count = Counter()
self.has_default = need_default
if self.has_default:
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
else:
self.padding_label = None
self.unknown_label = None
self.unknown = unknown
self.padding = padding
self.word2idx = None
self.idx2word = None
self.rebuild = True
@@ -113,17 +100,18 @@ class Vocabulary(object):
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`.

"""
if self.has_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL)
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL)
else:
self.word2idx = {}
self.word2idx = {}
if self.padding is not None:
self.word2idx[self.padding] = 0
if self.unknown is not None:
self.word2idx[self.unknown] = 1

max_size = min(self.max_size, len(self.word_count)) if self.max_size else None
words = self.word_count.most_common(max_size)
if self.min_freq is not None:
words = filter(lambda kv: kv[1] >= self.min_freq, words)
if self.word2idx is not None:
words = filter(lambda kv: kv[0] not in self.word2idx, words)
start_idx = len(self.word2idx)
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab()
@@ -159,8 +147,8 @@ class Vocabulary(object):
"""
if w in self.word2idx:
return self.word2idx[w]
elif self.has_default:
return self.word2idx[self.unknown_label]
if self.unknown is not None:
return self.word2idx[self.unknown]
else:
raise ValueError("word {} not in vocabulary".format(w))

@@ -175,21 +163,16 @@ class Vocabulary(object):
@property
@check_build_vocab
def unknown_idx(self):
if self.unknown_label is None:
if self.unknown is None:
return None
return self.word2idx[self.unknown_label]

def __setattr__(self, name, val):
self.__dict__[name] = val
if name in ["unknown_label", "padding_label"]:
self.word2idx = None
return self.word2idx[self.unknown]

@property
@check_build_vocab
def padding_idx(self):
if self.padding_label is None:
if self.padding is None:
return None
return self.word2idx[self.padding_label]
return self.word2idx[self.padding]

@check_build_vocab
def to_word(self, idx):


+ 27
- 25
test/core/test_trainer.py View File

@@ -4,6 +4,7 @@ import numpy as np
import torch.nn.functional as F
from torch import nn

from fastNLP.core.utils import CheckError
from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss
@@ -56,7 +57,8 @@ class TrainerTestGround(unittest.TestCase):
dev_data=dev_set,
optimizer=SGD(lr=0.1),
check_code_level=2,
use_tqdm=True)
use_tqdm=True,
save_path=None)
trainer.train()
"""
# 应该正确运行
@@ -145,16 +147,14 @@ class TrainerTestGround(unittest.TestCase):
return {'wrong_loss_key': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()
"""
# 应该正确运行
"""
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()

def test_trainer_suggestion4(self):
# 检查报错提示能否正确提醒用户
@@ -173,12 +173,13 @@ class TrainerTestGround(unittest.TestCase):
return {'loss': loss}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)

def test_trainer_suggestion5(self):
# 检查报错提示能否正确提醒用户
@@ -225,14 +226,15 @@ class TrainerTestGround(unittest.TestCase):
return {'pred': x}

model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2
)
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2
)

def test_case2(self):
# check metrics Wrong


+ 10
- 10
test/core/test_vocabulary.py View File

@@ -10,36 +10,36 @@ counter = Counter(text)

class TestAdd(unittest.TestCase):
def test_add(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
for word in text:
vocab.add(word)
self.assertEqual(vocab.word_count, counter)

def test_add_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
for word in text:
vocab.add_word(word)
self.assertEqual(vocab.word_count, counter)

def test_add_word_lst(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.add_word_lst(text)
self.assertEqual(vocab.word_count, counter)

def test_update(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text)
self.assertEqual(vocab.word_count, counter)


class TestIndexing(unittest.TestCase):
def test_len(self):
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab.update(text)
self.assertEqual(len(vocab), len(counter))

def test_contains(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab.update(text)
self.assertTrue(text[-1] in vocab)
self.assertFalse("~!@#" in vocab)
@@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase):
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))

def test_index(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text)
res = [vocab[w] for w in set(text)]
self.assertEqual(len(res), len(set(res)))
@@ -56,14 +56,14 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(len(res), len(set(res)))

def test_to_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text)
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])


class TestOther(unittest.TestCase):
def test_additional_update(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text)

_ = vocab["well"]
@@ -77,7 +77,7 @@ class TestOther(unittest.TestCase):
self.assertTrue("hahaha" in vocab)

def test_warning(self):
vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None)
vocab = Vocabulary(max_size=len(set(text)), min_freq=None)
vocab.update(text)
self.assertEqual(vocab.rebuild, True)
print(len(vocab))


Loading…
Cancel
Save