From 08c4bc86c79da32b3e8c013398fdca954565a1bb Mon Sep 17 00:00:00 2001 From: yh Date: Sun, 24 Nov 2019 23:02:52 +0800 Subject: [PATCH] =?UTF-8?q?1.=E4=BF=AE=E5=A4=8Dchar=5Fembedding=E4=B8=ADch?= =?UTF-8?q?ar=20encoder=E5=9C=A8=E8=8E=B7=E5=8F=96char=20embedding=20dimen?= =?UTF-8?q?sion=E6=97=B6=E5=8F=AF=E8=83=BD=E5=8F=91=E7=94=9F=E7=9A=84bug;?= =?UTF-8?q?=202.=E5=B0=86staticembedding=E4=B8=ADonly=5Ftrain=5Fmin=5Ffreq?= =?UTF-8?q?=E9=BB=98=E8=AE=A4=E4=B8=BAFalse;=203.=20check=5Floader=5Fpaths?= =?UTF-8?q?()=E4=B8=8D=E4=BC=9A=E5=AF=B9=E9=87=8D=E5=A4=8D=E7=9A=84?= =?UTF-8?q?=E6=96=87=E4=BB=B6=E5=90=8D=E6=8A=A5=E9=94=99=EF=BC=8C=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=BA=86=E8=AF=A5bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/char_embedding.py | 7 ++++--- fastNLP/embeddings/static_embedding.py | 2 +- fastNLP/io/utils.py | 7 +++++-- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py index 93d3ce00..114280f9 100644 --- a/fastNLP/embeddings/char_embedding.py +++ b/fastNLP/embeddings/char_embedding.py @@ -114,7 +114,8 @@ class CNNCharEmbedding(TokenEmbedding): self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) self.convs = nn.ModuleList([nn.Conv1d( - char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) + self.char_embedding.embedding_dim, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, + padding=kernel_sizes[i] // 2) for i in range(len(kernel_sizes))]) self._embed_size = embed_size self.fc = nn.Linear(sum(filter_nums), embed_size) @@ -238,12 +239,12 @@ class LSTMCharEmbedding(TokenEmbedding): if pre_train_char_embed: self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) else: - self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) + self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) self.fc = nn.Linear(hidden_size, embed_size) hidden_size = hidden_size // 2 if bidirectional else hidden_size - self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True) + self.lstm = LSTM(self.char_embedding.embedding_dim, hidden_size, bidirectional=bidirectional, batch_first=True) self._embed_size = embed_size self.bidirectional = bidirectional self.requires_grad = requires_grad diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py index a50ce25d..b907a278 100644 --- a/fastNLP/embeddings/static_embedding.py +++ b/fastNLP/embeddings/static_embedding.py @@ -68,7 +68,7 @@ class StaticEmbedding(TokenEmbedding): :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 - :param dict kwarngs: + :param dict kwargs: bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py index 496aee77..4b5230c0 100644 --- a/fastNLP/io/utils.py +++ b/fastNLP/io/utils.py @@ -44,14 +44,17 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: if 'dev' in filename: if path_pair: raise Exception( - "File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) + "Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0])) path_pair = ('dev', filename) if 'test' in filename: if path_pair: raise Exception( - "File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) + "Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0])) path_pair = ('test', filename) if path_pair: + if path_pair[0] in files: + raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the " + f"filepath for `{path_pair[0]}`.") files[path_pair[0]] = os.path.join(paths, path_pair[1]) if 'train' not in files: raise KeyError(f"There is no train file in {paths}.")