|
@@ -9,6 +9,7 @@ import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.nn.functional as F |
|
|
from typing import List |
|
|
from typing import List |
|
|
|
|
|
|
|
|
|
|
|
from .static_embedding import StaticEmbedding |
|
|
from ..modules.encoder.lstm import LSTM |
|
|
from ..modules.encoder.lstm import LSTM |
|
|
from ..core.vocabulary import Vocabulary |
|
|
from ..core.vocabulary import Vocabulary |
|
|
from .embedding import TokenEmbedding |
|
|
from .embedding import TokenEmbedding |
|
@@ -44,10 +45,13 @@ class CNNCharEmbedding(TokenEmbedding): |
|
|
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. |
|
|
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. |
|
|
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. |
|
|
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. |
|
|
:param min_char_freq: character的最少出现次数。默认值为2. |
|
|
:param min_char_freq: character的最少出现次数。默认值为2. |
|
|
|
|
|
:param pre_train_char_embed:可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 |
|
|
|
|
|
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 |
|
|
|
|
|
如果输入为None则使用embedding_dim的维度随机初始化一个embedding. |
|
|
""" |
|
|
""" |
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, |
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, |
|
|
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), |
|
|
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), |
|
|
pool_method: str='max', activation='relu', min_char_freq: int=2): |
|
|
|
|
|
|
|
|
pool_method: str='max', activation='relu', min_char_freq: int=2, pre_train_char_embed: str=''): |
|
|
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) |
|
|
|
|
|
|
|
|
for kernel in kernel_sizes: |
|
|
for kernel in kernel_sizes: |
|
@@ -88,7 +92,11 @@ class CNNCharEmbedding(TokenEmbedding): |
|
|
self.words_to_chars_embedding[index, :len(word)] = \ |
|
|
self.words_to_chars_embedding[index, :len(word)] = \ |
|
|
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) |
|
|
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) |
|
|
self.word_lengths[index] = len(word) |
|
|
self.word_lengths[index] = len(word) |
|
|
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) |
|
|
|
|
|
|
|
|
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) |
|
|
|
|
|
if len(pre_train_char_embed): |
|
|
|
|
|
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) |
|
|
|
|
|
else: |
|
|
|
|
|
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) |
|
|
|
|
|
|
|
|
self.convs = nn.ModuleList([nn.Conv1d( |
|
|
self.convs = nn.ModuleList([nn.Conv1d( |
|
|
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) |
|
|
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) |
|
@@ -190,10 +198,13 @@ class LSTMCharEmbedding(TokenEmbedding): |
|
|
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. |
|
|
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. |
|
|
:param min_char_freq: character的最小出现次数。默认值为2. |
|
|
:param min_char_freq: character的最小出现次数。默认值为2. |
|
|
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 |
|
|
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 |
|
|
|
|
|
:param pre_train_char_embed:可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 |
|
|
|
|
|
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 |
|
|
|
|
|
如果输入为None则使用embedding_dim的维度随机初始化一个embedding. |
|
|
""" |
|
|
""" |
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, |
|
|
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, |
|
|
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, |
|
|
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, |
|
|
bidirectional=True): |
|
|
|
|
|
|
|
|
bidirectional=True, pre_train_char_embed: str=''): |
|
|
super(LSTMCharEmbedding, self).__init__(vocab) |
|
|
super(LSTMCharEmbedding, self).__init__(vocab) |
|
|
|
|
|
|
|
|
assert hidden_size % 2 == 0, "Only even kernel is allowed." |
|
|
assert hidden_size % 2 == 0, "Only even kernel is allowed." |
|
@@ -233,7 +244,11 @@ class LSTMCharEmbedding(TokenEmbedding): |
|
|
self.words_to_chars_embedding[index, :len(word)] = \ |
|
|
self.words_to_chars_embedding[index, :len(word)] = \ |
|
|
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) |
|
|
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) |
|
|
self.word_lengths[index] = len(word) |
|
|
self.word_lengths[index] = len(word) |
|
|
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) |
|
|
|
|
|
|
|
|
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) |
|
|
|
|
|
if len(pre_train_char_embed): |
|
|
|
|
|
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) |
|
|
|
|
|
else: |
|
|
|
|
|
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) |
|
|
|
|
|
|
|
|
self.fc = nn.Linear(hidden_size, embed_size) |
|
|
self.fc = nn.Linear(hidden_size, embed_size) |
|
|
hidden_size = hidden_size // 2 if bidirectional else hidden_size |
|
|
hidden_size = hidden_size // 2 if bidirectional else hidden_size |
|
|