diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py index e772703a..520e85e6 100644 --- a/fastNLP/embeddings/char_embedding.py +++ b/fastNLP/embeddings/char_embedding.py @@ -14,7 +14,7 @@ from ..modules.encoder.lstm import LSTM from ..core.vocabulary import Vocabulary from .embedding import TokenEmbedding from .utils import _construct_char_vocab_from_vocab - +from .utils import get_embeddings class CNNCharEmbedding(TokenEmbedding): """ @@ -50,7 +50,7 @@ class CNNCharEmbedding(TokenEmbedding): 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. """ def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, - dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), + dropout:float=0, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), pool_method: str='max', activation='relu', min_char_freq: int=2, pre_train_char_embed: str=None): super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) @@ -58,7 +58,6 @@ class CNNCharEmbedding(TokenEmbedding): assert kernel % 2 == 1, "Only odd kernel is allowed." assert pool_method in ('max', 'avg') - self.dropout = nn.Dropout(dropout) self.pool_method = pool_method # activation function if isinstance(activation, str): @@ -96,7 +95,7 @@ class CNNCharEmbedding(TokenEmbedding): if pre_train_char_embed: self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed) else: - self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) + self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) self.convs = nn.ModuleList([nn.Conv1d( char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) @@ -164,6 +163,8 @@ class CNNCharEmbedding(TokenEmbedding): for name, param in self.named_parameters(): if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset continue + if 'char_embedding' in name: + continue if param.data.dim()>1: nn.init.xavier_uniform_(param, 1) else: @@ -203,15 +204,14 @@ class LSTMCharEmbedding(TokenEmbedding): 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. """ def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, - dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, + dropout:float=0, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, bidirectional=True, pre_train_char_embed: str=None): - super(LSTMCharEmbedding, self).__init__(vocab) + super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) assert hidden_size % 2 == 0, "Only even kernel is allowed." assert pool_method in ('max', 'avg') self.pool_method = pool_method - self.dropout = nn.Dropout(dropout) # activation function if isinstance(activation, str): if activation.lower() == 'relu':