From 46d72c7439524a07672168dce735407501651811 Mon Sep 17 00:00:00 2001 From: yunfan Date: Thu, 25 Apr 2019 22:14:26 +0800 Subject: [PATCH] - update doc - add get_embeddings --- fastNLP/core/vocabulary.py | 33 ++++++++++++++++++------------ fastNLP/models/biaffine_parser.py | 13 ++++++------ fastNLP/models/star_transformer.py | 11 ++++++---- fastNLP/modules/utils.py | 23 +++++++++++++++++++++ 4 files changed, 57 insertions(+), 23 deletions(-) diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index 6a1830ad..6779a282 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -231,22 +231,29 @@ class Vocabulary(object): vocab.from_dataset(train_data1, train_data2, field_name='words') :param DataSet datasets: 需要转index的 DataSet, 支持一个或多个. - :param str field_name: 构建词典所使用的 field. - 若有多个 DataSet, 每个DataSet都必须有此 field. - 目前仅支持 ``str`` , ``list(str)`` , ``list(list(str))`` + :param field_name: 可为 ``str`` 或 ``list(str)`` . + 构建词典所使用的 field(s), 支持一个或多个field + 若有多个 DataSet, 每个DataSet都必须有这些field. + 目前仅支持的field结构: ``str`` , ``list(str)`` , ``list(list(str))`` :return self: """ + if isinstance(field_name, str): + field_name = [field_name] + elif not isinstance(field_name, list): + raise TypeError('invalid argument field_name: {}'.format(field_name)) + def construct_vocab(ins): - field = ins[field_name] - if isinstance(field, str): - self.add_word(field) - elif isinstance(field, list): - if not isinstance(field[0], list): - self.add_word_lst(field) - else: - if isinstance(field[0][0], list): - raise RuntimeError("Only support field with 2 dimensions.") - [self.add_word_lst(w) for w in field] + for fn in field_name: + field = ins[fn] + if isinstance(field, str): + self.add_word(field) + elif isinstance(field, list): + if not isinstance(field[0], list): + self.add_word_lst(field) + else: + if isinstance(field[0][0], list): + raise RuntimeError("Only support field with 2 dimensions.") + [self.add_word_lst(w) for w in field] for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): try: diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index 59d95558..f2329dca 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -16,7 +16,7 @@ from fastNLP.modules.encoder.transformer import TransformerEncoder from fastNLP.modules.encoder.variational_rnn import VarLSTM from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import seq_mask - +from fastNLP.modules.utils import get_embeddings def _mst(scores): """ @@ -230,8 +230,9 @@ class BiaffineParser(GraphParser): 论文参考 ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) `_ . - :param word_vocab_size: 单词词典大小 - :param word_emb_dim: 单词词嵌入向量的维度 + :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 + embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, + 此时就以传入的对象作为embedding :param pos_vocab_size: part-of-speech 词典大小 :param pos_emb_dim: part-of-speech 向量维度 :param num_label: 边的类别个数 @@ -245,8 +246,7 @@ class BiaffineParser(GraphParser): 若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` """ def __init__(self, - word_vocab_size, - word_emb_dim, + init_embed, pos_vocab_size, pos_emb_dim, num_label, @@ -260,7 +260,8 @@ class BiaffineParser(GraphParser): super(BiaffineParser, self).__init__() rnn_out_size = 2 * rnn_hidden_size word_hid_dim = pos_hid_dim = rnn_hidden_size - self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) + self.word_embedding = get_embeddings(init_embed) + word_emb_dim = self.word_embedding.embedding_dim self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index f68aca42..e4fbeb28 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -2,6 +2,7 @@ """ from fastNLP.modules.encoder.star_transformer import StarTransformer from fastNLP.core.utils import seq_lens_to_masks +from ..modules.utils import get_embeddings import torch from torch import nn @@ -12,8 +13,9 @@ class StarTransEnc(nn.Module): """ 带word embedding的Star-Transformer Encoder - :param vocab_size: 词嵌入的词典大小 - :param emb_dim: 每个词嵌入的特征维度 + :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 + embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, + 此时就以传入的对象作为embedding :param num_cls: 输出类别个数 :param hidden_size: 模型中特征维度. :param num_layers: 模型层数. @@ -24,7 +26,7 @@ class StarTransEnc(nn.Module): :param emb_dropout: 词嵌入的dropout概率. :param dropout: 模型除词嵌入外的dropout概率. """ - def __init__(self, vocab_size, emb_dim, + def __init__(self, init_embed, hidden_size, num_layers, num_head, @@ -33,9 +35,10 @@ class StarTransEnc(nn.Module): emb_dropout, dropout): super(StarTransEnc, self).__init__() + self.embedding = get_embeddings(init_embed) + emb_dim = self.embedding.embedding_dim self.emb_fc = nn.Linear(emb_dim, hidden_size) self.emb_drop = nn.Dropout(emb_dropout) - self.embedding = nn.Embedding(vocab_size, emb_dim) self.encoder = StarTransformer(hidden_size=hidden_size, num_layers=num_layers, num_head=num_head, diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py index 4ae15b18..56dbb894 100644 --- a/fastNLP/modules/utils.py +++ b/fastNLP/modules/utils.py @@ -1,3 +1,4 @@ +import numpy as np import torch import torch.nn as nn import torch.nn.init as init @@ -88,3 +89,25 @@ def seq_mask(seq_len, max_len): seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] return torch.gt(seq_len, seq_range) # [batch_size, max_len] + + +def get_embeddings(init_embed): + """得到词嵌入 + + :param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 + embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, + 此时就以传入的对象作为embedding + :return embeddings: + """ + if isinstance(init_embed, tuple): + res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1]) + elif isinstance(init_embed, nn.Embedding): + res = init_embed + elif isinstance(init_embed, torch.Tensor): + res = nn.Embedding.from_pretrained(init_embed, freeze=False) + elif isinstance(init_embed, np.ndarray): + init_embed = torch.tensor(init_embed, dtype=torch.float32) + res = nn.Embedding.from_pretrained(init_embed, freeze=False) + else: + raise TypeError('invalid init_embed type: {}'.format((type(init_embed)))) + return res