From fd37ed60a715e1154f8a642f701f9da042cc90f3 Mon Sep 17 00:00:00 2001 From: yh Date: Fri, 16 Aug 2019 02:19:23 +0800 Subject: [PATCH] =?UTF-8?q?1.=20Trainer=E5=A2=9E=E5=8A=A0=E4=B8=80?= =?UTF-8?q?=E4=B8=AAdev=5Fbatch=5Fsize=E5=8F=82=E6=95=B0;2.StaticEmbedding?= =?UTF-8?q?=E4=B8=AD=E5=A2=9E=E5=8A=A0min=5Ffreq;?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/trainer.py | 6 +- fastNLP/embeddings/static_embedding.py | 82 +++++++++++++++--------- fastNLP/io/pipe/conll.py | 3 +- test/embeddings/test_static_embedding.py | 36 ++++++++--- 4 files changed, 85 insertions(+), 42 deletions(-) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 6d18fd48..a6f4f823 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -422,7 +422,7 @@ class Trainer(object): num_workers=0, n_epochs=10, print_every=5, dev_data=None, metrics=None, metric_key=None, validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, - callbacks=None, check_code_level=0): + callbacks=None, check_code_level=0, **kwargs): if prefetch and num_workers==0: num_workers = 1 if prefetch: @@ -550,12 +550,12 @@ class Trainer(object): self.use_tqdm = use_tqdm self.pbar = None self.print_every = abs(self.print_every) - + self.kwargs = kwargs if self.dev_data is not None: self.tester = Tester(model=self.model, data=self.dev_data, metrics=self.metrics, - batch_size=self.batch_size, + batch_size=kwargs.get("dev_batch_size", self.batch_size), device=None, # 由上面的部分处理device verbose=0, use_tqdm=self.use_tqdm) diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index 78f615f6..12011128 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -10,6 +10,8 @@ from ..core.vocabulary import Vocabulary from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path from .embedding import TokenEmbedding from ..modules.utils import _get_file_name_base_on_postfix +from copy import deepcopy +from collections import defaultdict class StaticEmbedding(TokenEmbedding): """ @@ -46,12 +48,13 @@ class StaticEmbedding(TokenEmbedding): :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对 :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 为大写的词语开辟一个vector表示,则将lower设置为False。 - :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 + :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 + :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 """ def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, - init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False): + init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1): super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) # 得到cache_path @@ -70,6 +73,28 @@ class StaticEmbedding(TokenEmbedding): else: raise ValueError(f"Cannot recognize {model_dir_or_name}.") + # 缩小vocab + truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq=min_freq and word_count0: - if vocab.unknown is None: # 创建一个专门的unknown - unknown_idx = len(matrix) - vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() - else: - unknown_idx = vocab.unknown_idx - words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), - requires_grad=False) - for word, index in vocab: - vec = matrix.get(index, None) - if vec is not None: - vectors[index] = vec - words_to_words[index] = index - else: - vectors[index] = vectors[unknown_idx] - self.words_to_words = words_to_words + if vocab.unknown is None: # 创建一个专门的unknown + unknown_idx = len(matrix) + vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() else: - for index, vec in matrix.items(): - if vec is not None: - vectors[index] = vec + unknown_idx = vocab.unknown_idx + self.words_to_words = nn.Parameter(torch.full((len(vocab), ), fill_value=unknown_idx).long(), + requires_grad=False) + + for index, (index_in_vocab, vec) in enumerate(matrix.items()): + if vec is not None: + vectors[index] = vec + self.words_to_words[index_in_vocab] = index return vectors diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index a49e68b1..0379a45b 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -138,9 +138,8 @@ class OntoNotesNERPipe(_NERPipe): "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 6 "[...]", "[...]", "[...]", . - + :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 - :param bool delete_unused_fields: 是否删除NER任务中用不到的field。 :param int target_pad_val: target的padding值,target这一列pad的位置值为target_pad_val。默认为-100。 """ diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py index 6fd33072..ca97dd75 100644 --- a/test/embeddings/test_static_embedding.py +++ b/test/embeddings/test_static_embedding.py @@ -34,6 +34,7 @@ class TestRandomSameEntry(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_same_vector3(self): + # 验证lower word_lst = ["The", "the"] no_create_word_lst = ['of', 'Of', 'With', 'with'] vocab = Vocabulary().add_word_lst(word_lst) @@ -60,13 +61,7 @@ class TestRandomSameEntry(unittest.TestCase): @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") def test_same_vector4(self): - # words = [] - # create_word_lst = [] # 需要创建 - # no_create_word_lst = [] - # ignore_word_lst = [] - # with open('/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', 'r', encoding='utf-8') as f: - # for line in f: - # words + # 验证在有min_freq下的lower word_lst = ["The", "the", "the", "The", "a", "A"] no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with'] all_words = word_lst[:-2] + no_create_word_lst[:-2] @@ -89,4 +84,29 @@ class TestRandomSameEntry(unittest.TestCase): for idx in range(len(all_words)): word_i, word_j = words[0, idx], lowered_words[0, idx] with self.subTest(idx=idx, word=all_words[idx]): - assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) \ No newline at end of file + assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size) + + @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis") + def test_same_vector5(self): + # 检查通过使用min_freq后的word是否内容一致 + word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"] + no_create_word_lst = ['of', "of", "she", "she", 'With', 'with'] + all_words = word_lst[:-2] + no_create_word_lst[:-2] + vocab = Vocabulary().add_word_lst(word_lst) + vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + embed = StaticEmbedding(vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False, min_freq=2) + words = torch.LongTensor([[vocab.to_index(word) for word in all_words]]) + words = embed(words) + + min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst) + min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True) + min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='/remote-home/source/fastnlp_caches/glove.6B.100d/glove.demo.txt', + lower=False) + min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]]) + min_freq_words = min_freq_embed(min_freq_words) + + for idx in range(len(all_words)): + word_i, word_j = words[0, idx], min_freq_words[0, idx] + with self.subTest(idx=idx, word=all_words[idx]): + assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size) \ No newline at end of file