Browse Source

1.修复char_embedding中char encoder在获取char embedding dimension时可能发生的bug; 2.将staticembedding中only_train_min_freq默认为False; 3. check_loader_paths()不会对重复的文件名报错,修改了该bug

tags/v0.5.5
yh 5 years ago
parent
commit
08c4bc86c7
3 changed files with 10 additions and 6 deletions
  1. +4
    -3
      fastNLP/embeddings/char_embedding.py
  2. +1
    -1
      fastNLP/embeddings/static_embedding.py
  3. +5
    -2
      fastNLP/io/utils.py

+ 4
- 3
fastNLP/embeddings/char_embedding.py View File

@@ -114,7 +114,8 @@ class CNNCharEmbedding(TokenEmbedding):
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
self.convs = nn.ModuleList([nn.Conv1d(
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
self.char_embedding.embedding_dim, filter_nums[i], kernel_size=kernel_sizes[i], bias=True,
padding=kernel_sizes[i] // 2)
for i in range(len(kernel_sizes))])
self._embed_size = embed_size
self.fc = nn.Linear(sum(filter_nums), embed_size)
@@ -238,12 +239,12 @@ class LSTMCharEmbedding(TokenEmbedding):
if pre_train_char_embed:
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
else:
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
self.fc = nn.Linear(hidden_size, embed_size)
hidden_size = hidden_size // 2 if bidirectional else hidden_size
self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True)
self.lstm = LSTM(self.char_embedding.embedding_dim, hidden_size, bidirectional=bidirectional, batch_first=True)
self._embed_size = embed_size
self.bidirectional = bidirectional
self.requires_grad = requires_grad


+ 1
- 1
fastNLP/embeddings/static_embedding.py View File

@@ -68,7 +68,7 @@ class StaticEmbedding(TokenEmbedding):
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
:param dict kwarngs:
:param dict kwargs:
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选;
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize;
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表


+ 5
- 2
fastNLP/io/utils.py View File

@@ -44,14 +44,17 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
if 'dev' in filename:
if path_pair:
raise Exception(
"File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
"Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename)
if 'test' in filename:
if path_pair:
raise Exception(
"File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
"Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename)
if path_pair:
if path_pair[0] in files:
raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the "
f"filepath for `{path_pair[0]}`.")
files[path_pair[0]] = os.path.join(paths, path_pair[1])
if 'train' not in files:
raise KeyError(f"There is no train file in {paths}.")


Loading…
Cancel
Save