Browse Source

修复CharacterEmbedding中的bug

tags/v0.4.10
yh 6 years ago
parent
commit
8d7c3ba140
1 changed files with 7 additions and 7 deletions
  1. +7
    -7
      fastNLP/embeddings/char_embedding.py

+ 7
- 7
fastNLP/embeddings/char_embedding.py View File

@@ -14,7 +14,7 @@ from ..modules.encoder.lstm import LSTM
from ..core.vocabulary import Vocabulary
from .embedding import TokenEmbedding
from .utils import _construct_char_vocab_from_vocab
from .utils import get_embeddings

class CNNCharEmbedding(TokenEmbedding):
"""
@@ -50,7 +50,7 @@ class CNNCharEmbedding(TokenEmbedding):
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
dropout:float=0, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
pool_method: str='max', activation='relu', min_char_freq: int=2, pre_train_char_embed: str=None):
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

@@ -58,7 +58,6 @@ class CNNCharEmbedding(TokenEmbedding):
assert kernel % 2 == 1, "Only odd kernel is allowed."

assert pool_method in ('max', 'avg')
self.dropout = nn.Dropout(dropout)
self.pool_method = pool_method
# activation function
if isinstance(activation, str):
@@ -96,7 +95,7 @@ class CNNCharEmbedding(TokenEmbedding):
if pre_train_char_embed:
self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed)
else:
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))

self.convs = nn.ModuleList([nn.Conv1d(
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
@@ -164,6 +163,8 @@ class CNNCharEmbedding(TokenEmbedding):
for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset
continue
if 'char_embedding' in name:
continue
if param.data.dim()>1:
nn.init.xavier_uniform_(param, 1)
else:
@@ -203,15 +204,14 @@ class LSTMCharEmbedding(TokenEmbedding):
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
"""
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
dropout:float=0, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
bidirectional=True, pre_train_char_embed: str=None):
super(LSTMCharEmbedding, self).__init__(vocab)
super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

assert hidden_size % 2 == 0, "Only even kernel is allowed."

assert pool_method in ('max', 'avg')
self.pool_method = pool_method
self.dropout = nn.Dropout(dropout)
# activation function
if isinstance(activation, str):
if activation.lower() == 'relu':


Loading…
Cancel
Save