@@ -231,22 +231,29 @@ class Vocabulary(object): | |||
vocab.from_dataset(train_data1, train_data2, field_name='words') | |||
:param DataSet datasets: 需要转index的 DataSet, 支持一个或多个. | |||
:param str field_name: 构建词典所使用的 field. | |||
若有多个 DataSet, 每个DataSet都必须有此 field. | |||
目前仅支持 ``str`` , ``list(str)`` , ``list(list(str))`` | |||
:param field_name: 可为 ``str`` 或 ``list(str)`` . | |||
构建词典所使用的 field(s), 支持一个或多个field | |||
若有多个 DataSet, 每个DataSet都必须有这些field. | |||
目前仅支持的field结构: ``str`` , ``list(str)`` , ``list(list(str))`` | |||
:return self: | |||
""" | |||
if isinstance(field_name, str): | |||
field_name = [field_name] | |||
elif not isinstance(field_name, list): | |||
raise TypeError('invalid argument field_name: {}'.format(field_name)) | |||
def construct_vocab(ins): | |||
field = ins[field_name] | |||
if isinstance(field, str): | |||
self.add_word(field) | |||
elif isinstance(field, list): | |||
if not isinstance(field[0], list): | |||
self.add_word_lst(field) | |||
else: | |||
if isinstance(field[0][0], list): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
[self.add_word_lst(w) for w in field] | |||
for fn in field_name: | |||
field = ins[fn] | |||
if isinstance(field, str): | |||
self.add_word(field) | |||
elif isinstance(field, list): | |||
if not isinstance(field[0], list): | |||
self.add_word_lst(field) | |||
else: | |||
if isinstance(field[0][0], list): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
[self.add_word_lst(w) for w in field] | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
@@ -16,7 +16,7 @@ from fastNLP.modules.encoder.transformer import TransformerEncoder | |||
from fastNLP.modules.encoder.variational_rnn import VarLSTM | |||
from fastNLP.modules.utils import initial_parameter | |||
from fastNLP.modules.utils import seq_mask | |||
from fastNLP.modules.utils import get_embeddings | |||
def _mst(scores): | |||
""" | |||
@@ -230,8 +230,9 @@ class BiaffineParser(GraphParser): | |||
论文参考 ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | |||
<https://arxiv.org/abs/1611.01734>`_ . | |||
:param word_vocab_size: 单词词典大小 | |||
:param word_emb_dim: 单词词嵌入向量的维度 | |||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:param pos_vocab_size: part-of-speech 词典大小 | |||
:param pos_emb_dim: part-of-speech 向量维度 | |||
:param num_label: 边的类别个数 | |||
@@ -245,8 +246,7 @@ class BiaffineParser(GraphParser): | |||
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` | |||
""" | |||
def __init__(self, | |||
word_vocab_size, | |||
word_emb_dim, | |||
init_embed, | |||
pos_vocab_size, | |||
pos_emb_dim, | |||
num_label, | |||
@@ -260,7 +260,8 @@ class BiaffineParser(GraphParser): | |||
super(BiaffineParser, self).__init__() | |||
rnn_out_size = 2 * rnn_hidden_size | |||
word_hid_dim = pos_hid_dim = rnn_hidden_size | |||
self.word_embedding = nn.Embedding(num_embeddings=word_vocab_size, embedding_dim=word_emb_dim) | |||
self.word_embedding = get_embeddings(init_embed) | |||
word_emb_dim = self.word_embedding.embedding_dim | |||
self.pos_embedding = nn.Embedding(num_embeddings=pos_vocab_size, embedding_dim=pos_emb_dim) | |||
self.word_fc = nn.Linear(word_emb_dim, word_hid_dim) | |||
self.pos_fc = nn.Linear(pos_emb_dim, pos_hid_dim) | |||
@@ -2,6 +2,7 @@ | |||
""" | |||
from fastNLP.modules.encoder.star_transformer import StarTransformer | |||
from fastNLP.core.utils import seq_lens_to_masks | |||
from ..modules.utils import get_embeddings | |||
import torch | |||
from torch import nn | |||
@@ -12,8 +13,9 @@ class StarTransEnc(nn.Module): | |||
""" | |||
带word embedding的Star-Transformer Encoder | |||
:param vocab_size: 词嵌入的词典大小 | |||
:param emb_dim: 每个词嵌入的特征维度 | |||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:param num_cls: 输出类别个数 | |||
:param hidden_size: 模型中特征维度. | |||
:param num_layers: 模型层数. | |||
@@ -24,7 +26,7 @@ class StarTransEnc(nn.Module): | |||
:param emb_dropout: 词嵌入的dropout概率. | |||
:param dropout: 模型除词嵌入外的dropout概率. | |||
""" | |||
def __init__(self, vocab_size, emb_dim, | |||
def __init__(self, init_embed, | |||
hidden_size, | |||
num_layers, | |||
num_head, | |||
@@ -33,9 +35,10 @@ class StarTransEnc(nn.Module): | |||
emb_dropout, | |||
dropout): | |||
super(StarTransEnc, self).__init__() | |||
self.embedding = get_embeddings(init_embed) | |||
emb_dim = self.embedding.embedding_dim | |||
self.emb_fc = nn.Linear(emb_dim, hidden_size) | |||
self.emb_drop = nn.Dropout(emb_dropout) | |||
self.embedding = nn.Embedding(vocab_size, emb_dim) | |||
self.encoder = StarTransformer(hidden_size=hidden_size, | |||
num_layers=num_layers, | |||
num_head=num_head, | |||
@@ -1,3 +1,4 @@ | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.init as init | |||
@@ -88,3 +89,25 @@ def seq_mask(seq_len, max_len): | |||
seq_len = seq_len.view(-1, 1).long() # [batch_size, 1] | |||
seq_range = torch.arange(start=0, end=max_len, dtype=torch.long, device=seq_len.device).view(1, -1) # [1, max_len] | |||
return torch.gt(seq_len, seq_range) # [batch_size, max_len] | |||
def get_embeddings(init_embed): | |||
"""得到词嵌入 | |||
:param init_embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即 | |||
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, | |||
此时就以传入的对象作为embedding | |||
:return embeddings: | |||
""" | |||
if isinstance(init_embed, tuple): | |||
res = nn.Embedding(num_embeddings=init_embed[0], embedding_dim=init_embed[1]) | |||
elif isinstance(init_embed, nn.Embedding): | |||
res = init_embed | |||
elif isinstance(init_embed, torch.Tensor): | |||
res = nn.Embedding.from_pretrained(init_embed, freeze=False) | |||
elif isinstance(init_embed, np.ndarray): | |||
init_embed = torch.tensor(init_embed, dtype=torch.float32) | |||
res = nn.Embedding.from_pretrained(init_embed, freeze=False) | |||
else: | |||
raise TypeError('invalid init_embed type: {}'.format((type(init_embed)))) | |||
return res |