Browse Source

Merge branch 'trainer' of github.com:FengZiYjun/fastNLP into trainer

tags/v0.2.0^2
yh 6 years ago
parent
commit
7c261faf19
3 changed files with 53 additions and 68 deletions
  1. +16
    -33
      fastNLP/core/vocabulary.py
  2. +27
    -25
      test/core/test_trainer.py
  3. +10
    -10
      test/core/test_vocabulary.py

+ 16
- 33
fastNLP/core/vocabulary.py View File

@@ -1,11 +1,4 @@
from collections import Counter from collections import Counter
from copy import deepcopy

DEFAULT_PADDING_LABEL = '<pad>' # dict index = 0
DEFAULT_UNKNOWN_LABEL = '<unk>' # dict index = 1

DEFAULT_WORD_TO_INDEX = {DEFAULT_PADDING_LABEL: 0, DEFAULT_UNKNOWN_LABEL: 1}



def isiterable(p_object): def isiterable(p_object):
try: try:
@@ -57,22 +50,16 @@ class Vocabulary(object):
vocab.to_word(5) vocab.to_word(5)
""" """


def __init__(self, need_default=True, max_size=None, min_freq=None):
def __init__(self, max_size=None, min_freq=None, unknown='<unk>', padding='<pad>'):
""" """
:param bool need_default: set if the Vocabulary has default labels reserved for sequences. Default: True.
:param int max_size: set the max number of words in Vocabulary. Default: None :param int max_size: set the max number of words in Vocabulary. Default: None
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None
""" """
self.max_size = max_size self.max_size = max_size
self.min_freq = min_freq self.min_freq = min_freq
self.word_count = Counter() self.word_count = Counter()
self.has_default = need_default
if self.has_default:
self.padding_label = DEFAULT_PADDING_LABEL
self.unknown_label = DEFAULT_UNKNOWN_LABEL
else:
self.padding_label = None
self.unknown_label = None
self.unknown = unknown
self.padding = padding
self.word2idx = None self.word2idx = None
self.idx2word = None self.idx2word = None
self.rebuild = True self.rebuild = True
@@ -113,17 +100,18 @@ class Vocabulary(object):
"""Build 'word to index' dict, and filter the word using `max_size` and `min_freq`. """Build 'word to index' dict, and filter the word using `max_size` and `min_freq`.


""" """
if self.has_default:
self.word2idx = deepcopy(DEFAULT_WORD_TO_INDEX)
self.word2idx[self.unknown_label] = self.word2idx.pop(DEFAULT_UNKNOWN_LABEL)
self.word2idx[self.padding_label] = self.word2idx.pop(DEFAULT_PADDING_LABEL)
else:
self.word2idx = {}
self.word2idx = {}
if self.padding is not None:
self.word2idx[self.padding] = 0
if self.unknown is not None:
self.word2idx[self.unknown] = 1


max_size = min(self.max_size, len(self.word_count)) if self.max_size else None max_size = min(self.max_size, len(self.word_count)) if self.max_size else None
words = self.word_count.most_common(max_size) words = self.word_count.most_common(max_size)
if self.min_freq is not None: if self.min_freq is not None:
words = filter(lambda kv: kv[1] >= self.min_freq, words) words = filter(lambda kv: kv[1] >= self.min_freq, words)
if self.word2idx is not None:
words = filter(lambda kv: kv[0] not in self.word2idx, words)
start_idx = len(self.word2idx) start_idx = len(self.word2idx)
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab() self.build_reverse_vocab()
@@ -159,8 +147,8 @@ class Vocabulary(object):
""" """
if w in self.word2idx: if w in self.word2idx:
return self.word2idx[w] return self.word2idx[w]
elif self.has_default:
return self.word2idx[self.unknown_label]
if self.unknown is not None:
return self.word2idx[self.unknown]
else: else:
raise ValueError("word {} not in vocabulary".format(w)) raise ValueError("word {} not in vocabulary".format(w))


@@ -175,21 +163,16 @@ class Vocabulary(object):
@property @property
@check_build_vocab @check_build_vocab
def unknown_idx(self): def unknown_idx(self):
if self.unknown_label is None:
if self.unknown is None:
return None return None
return self.word2idx[self.unknown_label]

def __setattr__(self, name, val):
self.__dict__[name] = val
if name in ["unknown_label", "padding_label"]:
self.word2idx = None
return self.word2idx[self.unknown]


@property @property
@check_build_vocab @check_build_vocab
def padding_idx(self): def padding_idx(self):
if self.padding_label is None:
if self.padding is None:
return None return None
return self.word2idx[self.padding_label]
return self.word2idx[self.padding]


@check_build_vocab @check_build_vocab
def to_word(self, idx): def to_word(self, idx):


+ 27
- 25
test/core/test_trainer.py View File

@@ -4,6 +4,7 @@ import numpy as np
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn


from fastNLP.core.utils import CheckError
from fastNLP.core.dataset import DataSet from fastNLP.core.dataset import DataSet
from fastNLP.core.instance import Instance from fastNLP.core.instance import Instance
from fastNLP.core.losses import BCELoss from fastNLP.core.losses import BCELoss
@@ -56,7 +57,8 @@ class TrainerTestGround(unittest.TestCase):
dev_data=dev_set, dev_data=dev_set,
optimizer=SGD(lr=0.1), optimizer=SGD(lr=0.1),
check_code_level=2, check_code_level=2,
use_tqdm=True)
use_tqdm=True,
save_path=None)
trainer.train() trainer.train()
""" """
# 应该正确运行 # 应该正确运行
@@ -145,16 +147,14 @@ class TrainerTestGround(unittest.TestCase):
return {'wrong_loss_key': loss} return {'wrong_loss_key': loss}


model = Model() model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()
"""
# 应该正确运行
"""
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
trainer.train()


def test_trainer_suggestion4(self): def test_trainer_suggestion4(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
@@ -173,12 +173,13 @@ class TrainerTestGround(unittest.TestCase):
return {'loss': loss} return {'loss': loss}


model = Model() model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
use_tqdm=False,
print_every=2
)


def test_trainer_suggestion5(self): def test_trainer_suggestion5(self):
# 检查报错提示能否正确提醒用户 # 检查报错提示能否正确提醒用户
@@ -225,14 +226,15 @@ class TrainerTestGround(unittest.TestCase):
return {'pred': x} return {'pred': x}


model = Model() model = Model()
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2
)
with self.assertRaises(NameError):
trainer = Trainer(
train_data=dataset,
model=model,
dev_data=dataset,
metrics=AccuracyMetric(),
use_tqdm=False,
print_every=2
)


def test_case2(self): def test_case2(self):
# check metrics Wrong # check metrics Wrong


+ 10
- 10
test/core/test_vocabulary.py View File

@@ -10,36 +10,36 @@ counter = Counter(text)


class TestAdd(unittest.TestCase): class TestAdd(unittest.TestCase):
def test_add(self): def test_add(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
for word in text: for word in text:
vocab.add(word) vocab.add(word)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)


def test_add_word(self): def test_add_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
for word in text: for word in text:
vocab.add_word(word) vocab.add_word(word)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)


def test_add_word_lst(self): def test_add_word_lst(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.add_word_lst(text) vocab.add_word_lst(text)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)


def test_update(self): def test_update(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text) vocab.update(text)
self.assertEqual(vocab.word_count, counter) self.assertEqual(vocab.word_count, counter)




class TestIndexing(unittest.TestCase): class TestIndexing(unittest.TestCase):
def test_len(self): def test_len(self):
vocab = Vocabulary(need_default=False, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab.update(text) vocab.update(text)
self.assertEqual(len(vocab), len(counter)) self.assertEqual(len(vocab), len(counter))


def test_contains(self): def test_contains(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None, unknown=None, padding=None)
vocab.update(text) vocab.update(text)
self.assertTrue(text[-1] in vocab) self.assertTrue(text[-1] in vocab)
self.assertFalse("~!@#" in vocab) self.assertFalse("~!@#" in vocab)
@@ -47,7 +47,7 @@ class TestIndexing(unittest.TestCase):
self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#")) self.assertEqual("~!@#" in vocab, vocab.has_word("~!@#"))


def test_index(self): def test_index(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text) vocab.update(text)
res = [vocab[w] for w in set(text)] res = [vocab[w] for w in set(text)]
self.assertEqual(len(res), len(set(res))) self.assertEqual(len(res), len(set(res)))
@@ -56,14 +56,14 @@ class TestIndexing(unittest.TestCase):
self.assertEqual(len(res), len(set(res))) self.assertEqual(len(res), len(set(res)))


def test_to_word(self): def test_to_word(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text) vocab.update(text)
self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]]) self.assertEqual(text, [vocab.to_word(idx) for idx in [vocab[w] for w in text]])




class TestOther(unittest.TestCase): class TestOther(unittest.TestCase):
def test_additional_update(self): def test_additional_update(self):
vocab = Vocabulary(need_default=True, max_size=None, min_freq=None)
vocab = Vocabulary(max_size=None, min_freq=None)
vocab.update(text) vocab.update(text)


_ = vocab["well"] _ = vocab["well"]
@@ -77,7 +77,7 @@ class TestOther(unittest.TestCase):
self.assertTrue("hahaha" in vocab) self.assertTrue("hahaha" in vocab)


def test_warning(self): def test_warning(self):
vocab = Vocabulary(need_default=True, max_size=len(set(text)), min_freq=None)
vocab = Vocabulary(max_size=len(set(text)), min_freq=None)
vocab.update(text) vocab.update(text)
self.assertEqual(vocab.rebuild, True) self.assertEqual(vocab.rebuild, True)
print(len(vocab)) print(len(vocab))


Loading…
Cancel
Save