Browse Source

1. Trainer增加一个dev_batch_size参数;2.StaticEmbedding中增加min_freq;

tags/v0.4.10
yh 6 years ago
parent
commit
fd37ed60a7
4 changed files with 85 additions and 42 deletions
  1. +3
    -3
      fastNLP/core/trainer.py
  2. +53
    -29
      fastNLP/embeddings/static_embedding.py
  3. +1
    -2
      fastNLP/io/pipe/conll.py
  4. +28
    -8
      test/embeddings/test_static_embedding.py

+ 3
- 3
fastNLP/core/trainer.py View File

@@ -422,7 +422,7 @@ class Trainer(object):
num_workers=0, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None,
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False,
callbacks=None, check_code_level=0):
callbacks=None, check_code_level=0, **kwargs):
if prefetch and num_workers==0:
num_workers = 1
if prefetch:
@@ -550,12 +550,12 @@ class Trainer(object):
self.use_tqdm = use_tqdm
self.pbar = None
self.print_every = abs(self.print_every)
self.kwargs = kwargs
if self.dev_data is not None:
self.tester = Tester(model=self.model,
data=self.dev_data,
metrics=self.metrics,
batch_size=self.batch_size,
batch_size=kwargs.get("dev_batch_size", self.batch_size),
device=None, # 由上面的部分处理device
verbose=0,
use_tqdm=self.use_tqdm)


+ 53
- 29
fastNLP/embeddings/static_embedding.py View File

@@ -10,6 +10,8 @@ from ..core.vocabulary import Vocabulary
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path
from .embedding import TokenEmbedding
from ..modules.utils import _get_file_name_base_on_postfix
from copy import deepcopy
from collections import defaultdict

class StaticEmbedding(TokenEmbedding):
"""
@@ -46,12 +48,13 @@ class StaticEmbedding(TokenEmbedding):
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
为大写的词语开辟一个vector表示,则将lower设置为False。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
"""
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True,
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False):
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1):
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

# 得到cache_path
@@ -70,6 +73,28 @@ class StaticEmbedding(TokenEmbedding):
else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.")

# 缩小vocab
truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq)
if truncate_vocab:
truncated_vocab = deepcopy(vocab)
truncated_vocab.min_freq = min_freq
truncated_vocab.word2idx = None
if lower: # 如果有lower,将大小写的的freq需要同时考虑到
lowered_word_count = defaultdict(int)
for word, count in truncated_vocab.word_count.items():
lowered_word_count[word.lower()] += count
for word in truncated_vocab.word_count.keys():
word_count = truncated_vocab.word_count[word]
if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq:
truncated_vocab.add_word_lst([word]*(min_freq-word_count),
no_create_entry=truncated_vocab._is_word_no_create_entry(word))
truncated_vocab.build_vocab()
truncated_words_to_words = torch.arange(len(vocab)).long()
for word, index in vocab:
truncated_words_to_words[index] = truncated_vocab.to_index(word)
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
vocab = truncated_vocab

# 读取embedding
if lower:
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
@@ -84,9 +109,6 @@ class StaticEmbedding(TokenEmbedding):
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
else:
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
# 需要适配一下
if not hasattr(self, 'words_to_words'):
self.words_to_words = torch.arange(len(lowered_vocab)).long()
if lowered_vocab.unknown:
unknown_idx = lowered_vocab.unknown_idx
else:
@@ -108,6 +130,14 @@ class StaticEmbedding(TokenEmbedding):
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
if normalize:
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)

if truncate_vocab:
for i in range(len(truncated_words_to_words)):
index_in_truncated_vocab = truncated_words_to_words[i]
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab]
del self.words_to_words
self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False)

self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx,
max_norm=None, norm_type=2, scale_grad_by_freq=False,
@@ -184,6 +214,10 @@ class StaticEmbedding(TokenEmbedding):
dim = len(parts) - 1
f.seek(0)
matrix = {}
if vocab.padding:
matrix[vocab.padding_idx] = torch.zeros(dim)
if vocab.unknown:
matrix[vocab.unknown_idx] = torch.zeros(dim)
found_count = 0
for idx, line in enumerate(f, start_idx):
try:
@@ -208,35 +242,25 @@ class StaticEmbedding(TokenEmbedding):
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
for word, index in vocab:
if index not in matrix and not vocab._is_word_no_create_entry(word):
if vocab.padding_idx == index:
matrix[index] = torch.zeros(dim)
elif vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
matrix[index] = matrix[vocab.unknown_idx]
else:
matrix[index] = None
# matrix中代表是需要建立entry的词
vectors = self._randomly_init_embed(len(matrix), dim, init_method)

vectors = self._randomly_init_embed(len(vocab), dim, init_method)

if vocab._no_create_word_length>0:
if vocab.unknown is None: # 创建一个专门的unknown
unknown_idx = len(matrix)
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
else:
unknown_idx = vocab.unknown_idx
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
requires_grad=False)
for word, index in vocab:
vec = matrix.get(index, None)
if vec is not None:
vectors[index] = vec
words_to_words[index] = index
else:
vectors[index] = vectors[unknown_idx]
self.words_to_words = words_to_words
if vocab.unknown is None: # 创建一个专门的unknown
unknown_idx = len(matrix)
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
else:
for index, vec in matrix.items():
if vec is not None:
vectors[index] = vec
unknown_idx = vocab.unknown_idx
self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(),
requires_grad=False)

for index, (index_in_vocab, vec) in enumerate(matrix.items()):
if vec is not None:
vectors[index] = vec
self.words_to_words[index_in_vocab] = index

return vectors



+ 1
- 2
fastNLP/io/pipe/conll.py View File

@@ -138,9 +138,8 @@ class OntoNotesNERPipe(_NERPipe):
"[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6
"[...]", "[...]", "[...]", .

:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
:param bool delete_unused_fields: 是否删除NER任务中用不到的field。
:param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。
"""



+ 28
- 8
test/embeddings/test_static_embedding.py View File

@@ -34,6 +34,7 @@ class TestRandomSameEntry(unittest.TestCase):

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_same_vector3(self):
# 验证lower
word_lst = ["The", "the"]
no_create_word_lst = ['of', 'Of', 'With', 'with']
vocab = Vocabulary().add_word_lst(word_lst)
@@ -60,13 +61,7 @@ class TestRandomSameEntry(unittest.TestCase):

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_same_vector4(self):
# words = []
# create_word_lst = [] # 需要创建
# no_create_word_lst = []
# ignore_word_lst = []
# with open('/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', 'r', encoding='utf-8') as f:
# for line in f:
# words
# 验证在有min_freq下的lower
word_lst = ["The", "the", "the", "The", "a", "A"]
no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
all_words = word_lst[:-2] + no_create_word_lst[:-2]
@@ -89,4 +84,29 @@ class TestRandomSameEntry(unittest.TestCase):
for idx in range(len(all_words)):
word_i, word_j = words[0, idx], lowered_words[0, idx]
with self.subTest(idx=idx, word=all_words[idx]):
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)

@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
def test_same_vector5(self):
# 检查通过使用min_freq后的word是否内容一致
word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"]
no_create_word_lst = ['of', "of", "she", "she", 'With', 'with']
all_words = word_lst[:-2] + no_create_word_lst[:-2]
vocab = Vocabulary().add_word_lst(word_lst)
vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lower=False, min_freq=2)
words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
words = embed(words)

min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt',
lower=False)
min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
min_freq_words = min_freq_embed(min_freq_words)

for idx in range(len(all_words)):
word_i, word_j = words[0, idx], min_freq_words[0, idx]
with self.subTest(idx=idx, word=all_words[idx]):
assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size)

Loading…
Cancel
Save