@@ -114,7 +114,8 @@ class CNNCharEmbedding(TokenEmbedding): | |||||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | ||||
self.convs = nn.ModuleList([nn.Conv1d( | self.convs = nn.ModuleList([nn.Conv1d( | ||||
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) | |||||
self.char_embedding.embedding_dim, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, | |||||
padding=kernel_sizes[i] // 2) | |||||
for i in range(len(kernel_sizes))]) | for i in range(len(kernel_sizes))]) | ||||
self._embed_size = embed_size | self._embed_size = embed_size | ||||
self.fc = nn.Linear(sum(filter_nums), embed_size) | self.fc = nn.Linear(sum(filter_nums), embed_size) | ||||
@@ -238,12 +239,12 @@ class LSTMCharEmbedding(TokenEmbedding): | |||||
if pre_train_char_embed: | if pre_train_char_embed: | ||||
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) | self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) | ||||
else: | else: | ||||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | |||||
self.fc = nn.Linear(hidden_size, embed_size) | self.fc = nn.Linear(hidden_size, embed_size) | ||||
hidden_size = hidden_size // 2 if bidirectional else hidden_size | hidden_size = hidden_size // 2 if bidirectional else hidden_size | ||||
self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True) | |||||
self.lstm = LSTM(self.char_embedding.embedding_dim, hidden_size, bidirectional=bidirectional, batch_first=True) | |||||
self._embed_size = embed_size | self._embed_size = embed_size | ||||
self.bidirectional = bidirectional | self.bidirectional = bidirectional | ||||
self.requires_grad = requires_grad | self.requires_grad = requires_grad | ||||
@@ -68,7 +68,7 @@ class StaticEmbedding(TokenEmbedding): | |||||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | ||||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | ||||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | ||||
:param dict kwarngs: | |||||
:param dict kwargs: | |||||
bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | bool only_train_min_freq: 仅对train中的词语使用min_freq筛选; | ||||
bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; | bool only_norm_found_vector: 是否仅对在预训练中找到的词语使用normalize; | ||||
bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 | bool only_use_pretrain_word: 仅使用出现在pretrain词表中的词语。如果该词没有在预训练的词表中出现则为unk。如果词表 | ||||
@@ -44,14 +44,17 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: | |||||
if 'dev' in filename: | if 'dev' in filename: | ||||
if path_pair: | if path_pair: | ||||
raise Exception( | raise Exception( | ||||
"File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
"Directory:{} in {} contains both `{}` and `dev`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('dev', filename) | path_pair = ('dev', filename) | ||||
if 'test' in filename: | if 'test' in filename: | ||||
if path_pair: | if path_pair: | ||||
raise Exception( | raise Exception( | ||||
"File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
"Directory:{} in {} contains both `{}` and `test`.".format(filename, paths, path_pair[0])) | |||||
path_pair = ('test', filename) | path_pair = ('test', filename) | ||||
if path_pair: | if path_pair: | ||||
if path_pair[0] in files: | |||||
raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the " | |||||
f"filepath for `{path_pair[0]}`.") | |||||
files[path_pair[0]] = os.path.join(paths, path_pair[1]) | files[path_pair[0]] = os.path.join(paths, path_pair[1]) | ||||
if 'train' not in files: | if 'train' not in files: | ||||
raise KeyError(f"There is no train file in {paths}.") | raise KeyError(f"There is no train file in {paths}.") | ||||