From fb82c66b4c8d2521816b7648d9e93eeef31a82fa Mon Sep 17 00:00:00 2001 From: YanqunJiang Date: Fri, 16 Aug 2019 17:51:07 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0char=5Fembedding=E5=8F=AF?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E9=A2=84=E8=AE=AD=E7=BB=83=E7=9A=84character?= =?UTF-8?q?=20embedding=E7=9A=84=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/embeddings/char_embedding.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py index b9e6659e..8243e148 100644 --- a/fastNLP/embeddings/char_embedding.py +++ b/fastNLP/embeddings/char_embedding.py @@ -9,6 +9,7 @@ import torch.nn as nn import torch.nn.functional as F from typing import List +from .static_embedding import StaticEmbedding from ..modules.encoder.lstm import LSTM from ..core.vocabulary import Vocabulary from .embedding import TokenEmbedding @@ -41,10 +42,13 @@ class CNNCharEmbedding(TokenEmbedding): :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. :param min_char_freq: character的最少出现次数。默认值为2. + :param pre_train_char_embed:可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 + 以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 + 如果输入为None则使用embedding_dim的维度随机初始化一个embedding. """ def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), - pool_method: str='max', activation='relu', min_char_freq: int=2): + pool_method: str='max', activation='relu', min_char_freq: int=2, pre_train_char_embed: str=''): super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) for kernel in kernel_sizes: @@ -85,7 +89,11 @@ class CNNCharEmbedding(TokenEmbedding): self.words_to_chars_embedding[index, :len(word)] = \ torch.LongTensor([self.char_vocab.to_index(c) for c in word]) self.word_lengths[index] = len(word) - self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) + # self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) + if len(pre_train_char_embed): + self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) + else: + self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) self.convs = nn.ModuleList([nn.Conv1d( char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) @@ -184,10 +192,13 @@ class LSTMCharEmbedding(TokenEmbedding): :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. :param min_char_freq: character的最小出现次数。默认值为2. :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 + :param pre_train_char_embed:可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 + 以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 + 如果输入为None则使用embedding_dim的维度随机初始化一个embedding. """ def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, - bidirectional=True): + bidirectional=True, pre_train_char_embed: str=''): super(LSTMCharEmbedding, self).__init__(vocab) assert hidden_size % 2 == 0, "Only even kernel is allowed." @@ -227,7 +238,11 @@ class LSTMCharEmbedding(TokenEmbedding): self.words_to_chars_embedding[index, :len(word)] = \ torch.LongTensor([self.char_vocab.to_index(c) for c in word]) self.word_lengths[index] = len(word) - self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) + # self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) + if len(pre_train_char_embed): + self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) + else: + self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) self.fc = nn.Linear(hidden_size, embed_size) hidden_size = hidden_size // 2 if bidirectional else hidden_size