Browse Source

split the docs in models

tags/v0.4.10
ChenXin 5 years ago
parent
commit
32e99e3696
6 changed files with 185 additions and 135 deletions
  1. +30
    -15
      fastNLP/models/bert.py
  2. +54
    -40
      fastNLP/models/biaffine_parser.py
  3. +9
    -6
      fastNLP/models/cnn_text_classification.py
  4. +28
    -21
      fastNLP/models/sequence_labeling.py
  5. +8
    -5
      fastNLP/models/snli.py
  6. +56
    -48
      fastNLP/models/star_transformer.py

+ 30
- 15
fastNLP/models/bert.py View File

@@ -37,8 +37,8 @@ import torch
from torch import nn

from .base_model import BaseModel
from ..core.const import Const
from ..core._logger import logger
from ..core.const import Const
from ..embeddings import BertEmbedding


@@ -46,11 +46,14 @@ class BertForSequenceClassification(BaseModel):
"""
BERT model for classification.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 文本分类类别数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1):
"""
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 文本分类类别数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
super(BertForSequenceClassification, self).__init__()

self.num_labels = num_labels
@@ -89,11 +92,14 @@ class BertForSentenceMatching(BaseModel):
"""
BERT model for sentence matching.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: Matching任务类别数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, embed: BertEmbedding, num_labels: int=2, dropout=0.1):
"""
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: Matching任务类别数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
super(BertForSentenceMatching, self).__init__()
self.num_labels = num_labels
self.bert = embed
@@ -131,11 +137,14 @@ class BertForMultipleChoice(BaseModel):
"""
BERT model for multiple choice.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_choices: 多选任务选项数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, embed: BertEmbedding, num_choices=2, dropout=0.1):
"""
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_choices: 多选任务选项数目,默认值为2.
:param float dropout: dropout的大小,默认值为0.1.
"""
super(BertForMultipleChoice, self).__init__()

self.num_choices = num_choices
@@ -178,11 +187,14 @@ class BertForTokenClassification(BaseModel):
"""
BERT model for token classification.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 序列标注标签数目,无默认值.
:param float dropout: dropout的大小,默认值为0.1.
"""
def __init__(self, embed: BertEmbedding, num_labels, dropout=0.1):
"""
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 序列标注标签数目,无默认值.
:param float dropout: dropout的大小,默认值为0.1.
"""
super(BertForTokenClassification, self).__init__()

self.num_labels = num_labels
@@ -221,10 +233,13 @@ class BertForQuestionAnswering(BaseModel):
"""
BERT model for classification.

:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 抽取式QA列数,默认值为2(即第一列为start_span, 第二列为end_span).
"""
def __init__(self, embed: BertEmbedding, num_labels=2):
"""
:param fastNLP.embeddings.BertEmbedding embed: 下游模型的编码器(encoder).
:param int num_labels: 抽取式QA列数,默认值为2(即第一列为start_span, 第二列为end_span).
"""
super(BertForQuestionAnswering, self).__init__()

self.bert = embed


+ 54
- 40
fastNLP/models/biaffine_parser.py View File

@@ -6,23 +6,23 @@ __all__ = [
"GraphParser"
]

from collections import defaultdict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import defaultdict

from .base_model import BaseModel
from ..core.const import Const as C
from ..core.losses import LossFunc
from ..core.metrics import MetricBase
from ..core.utils import seq_len_to_mask
from ..embeddings.utils import get_embeddings
from ..modules.dropout import TimestepDropout
from ..modules.encoder.transformer import TransformerEncoder
from ..modules.encoder.variational_rnn import VarLSTM
from ..modules.utils import initial_parameter
from ..embeddings.utils import get_embeddings
from .base_model import BaseModel
from ..core.utils import seq_len_to_mask


def _mst(scores):
@@ -181,11 +181,14 @@ class ArcBiaffine(nn.Module):
"""
Biaffine Dependency Parser 的子模块, 用于构建预测边的图

:param hidden_size: 输入的特征维度
:param bias: 是否使用bias. Default: ``True``
"""
def __init__(self, hidden_size, bias=True):
"""
:param hidden_size: 输入的特征维度
:param bias: 是否使用bias. Default: ``True``
"""
super(ArcBiaffine, self).__init__()
self.U = nn.Parameter(torch.Tensor(hidden_size, hidden_size), requires_grad=True)
self.has_bias = bias
@@ -213,13 +216,16 @@ class LabelBilinear(nn.Module):
"""
Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图

:param in1_features: 输入的特征1维度
:param in2_features: 输入的特征2维度
:param num_label: 边类别的个数
:param bias: 是否使用bias. Default: ``True``
"""
def __init__(self, in1_features, in2_features, num_label, bias=True):
"""
:param in1_features: 输入的特征1维度
:param in2_features: 输入的特征2维度
:param num_label: 边类别的个数
:param bias: 是否使用bias. Default: ``True``
"""
super(LabelBilinear, self).__init__()
self.bilinear = nn.Bilinear(in1_features, in2_features, num_label, bias=bias)
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False)
@@ -241,20 +247,6 @@ class BiaffineParser(GraphParser):
Biaffine Dependency Parser 实现.
论文参考 `Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) <https://arxiv.org/abs/1611.01734>`_ .

:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param pos_vocab_size: part-of-speech 词典大小
:param pos_emb_dim: part-of-speech 向量维度
:param num_label: 边的类别个数
:param rnn_layers: rnn encoder的层数
:param rnn_hidden_size: rnn encoder 的隐状态维度
:param arc_mlp_size: 边预测的MLP维度
:param label_mlp_size: 类别预测的MLP维度
:param dropout: dropout概率.
:param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm
:param use_greedy_infer: 是否在inference时使用贪心算法.
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False``
"""
def __init__(self,
@@ -269,6 +261,23 @@ class BiaffineParser(GraphParser):
dropout=0.3,
encoder='lstm',
use_greedy_infer=False):
"""
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param pos_vocab_size: part-of-speech 词典大小
:param pos_emb_dim: part-of-speech 向量维度
:param num_label: 边的类别个数
:param rnn_layers: rnn encoder的层数
:param rnn_hidden_size: rnn encoder 的隐状态维度
:param arc_mlp_size: 边预测的MLP维度
:param label_mlp_size: 类别预测的MLP维度
:param dropout: dropout概率.
:param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm
:param use_greedy_infer: 是否在inference时使用贪心算法.
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False``
"""
super(BiaffineParser, self).__init__()
rnn_out_size = 2 * rnn_hidden_size
word_hid_dim = pos_hid_dim = rnn_hidden_size
@@ -473,17 +482,20 @@ class ParserLoss(LossFunc):
"""
计算parser的loss

:param pred1: [batch_size, seq_len, seq_len] 边预测logits
:param pred2: [batch_size, seq_len, num_label] label预测logits
:param target1: [batch_size, seq_len] 真实边的标注
:param target2: [batch_size, seq_len] 真实类别的标注
:param seq_len: [batch_size, seq_len] 真实目标的长度
:return loss: scalar
"""
def __init__(self, pred1=None, pred2=None,
target1=None, target2=None,
seq_len=None):
"""
:param pred1: [batch_size, seq_len, seq_len] 边预测logits
:param pred2: [batch_size, seq_len, num_label] label预测logits
:param target1: [batch_size, seq_len] 真实边的标注
:param target2: [batch_size, seq_len] 真实类别的标注
:param seq_len: [batch_size, seq_len] 真实目标的长度
:return loss: scalar
"""
super(ParserLoss, self).__init__(BiaffineParser.loss,
pred1=pred1,
pred2=pred2,
@@ -496,20 +508,22 @@ class ParserMetric(MetricBase):
"""
评估parser的性能

:param pred1: 边预测logits
:param pred2: label预测logits
:param target1: 真实边的标注
:param target2: 真实类别的标注
:param seq_len: 序列长度
:return dict: 评估结果::

UAS: 不带label时, 边预测的准确率
LAS: 同时预测边和label的准确率
"""
def __init__(self, pred1=None, pred2=None,
target1=None, target2=None, seq_len=None):
"""
:param pred1: 边预测logits
:param pred2: label预测logits
:param target1: 真实边的标注
:param target2: 真实类别的标注
:param seq_len: 序列长度
:return dict: 评估结果::
UAS: 不带label时, 边预测的准确率
LAS: 同时预测边和label的准确率
"""
super().__init__()
self._init_param_map(pred1=pred1, pred2=pred2,
target1=target1, target2=target2,


+ 9
- 6
fastNLP/models/cnn_text_classification.py View File

@@ -21,12 +21,6 @@ class CNNText(torch.nn.Module):
使用CNN进行文本分类的模型
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.'
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int num_classes: 一共有多少类
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param float dropout: Dropout的大小
"""

def __init__(self, embed,
@@ -34,6 +28,15 @@ class CNNText(torch.nn.Module):
kernel_nums=(30, 40, 50),
kernel_sizes=(1, 3, 5),
dropout=0.5):
"""
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int num_classes: 一共有多少类
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param float dropout: Dropout的大小
"""
super(CNNText, self).__init__()

# no support for pre-trained embedding currently


+ 28
- 21
fastNLP/models/sequence_labeling.py View File

@@ -25,16 +25,19 @@ class BiLSTMCRF(BaseModel):
"""
结构为embedding + BiLSTM + FC + Dropout + CRF.

:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100)
:param num_classes: 一共多少个类
:param num_layers: BiLSTM的层数
:param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向)
:param dropout: dropout的概率,0为不dropout
:param target_vocab: Vocabulary对象,target与index的对应关系
:param encoding_type: encoding的类型,支持'bioes', 'bmes', 'bio', 'bmeso'等
"""
def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5,
target_vocab=None, encoding_type=None):
"""
:param embed: 支持(1)fastNLP的各种Embedding, (2) tuple, 指明num_embedding, dimension, 如(1000, 100)
:param num_classes: 一共多少个类
:param num_layers: BiLSTM的层数
:param hidden_size: BiLSTM的hidden_size,实际hidden size为该值的两倍(前向、后向)
:param dropout: dropout的概率,0为不dropout
:param target_vocab: Vocabulary对象,target与index的对应关系
:param encoding_type: encoding的类型,支持'bioes', 'bmes', 'bio', 'bmeso'等
"""
super().__init__()
self.embed = get_embeddings(embed)

@@ -80,13 +83,16 @@ class SeqLabeling(BaseModel):
一个基础的Sequence labeling的模型。
用于做sequence labeling的基础类。结构包含一层Embedding,一层LSTM(单向,一层),一层FC,以及一层CRF。
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding
:param int hidden_size: LSTM隐藏层的大小
:param int num_classes: 一共有多少类
"""
def __init__(self, embed, hidden_size, num_classes):
"""
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, embedding, ndarray等则直接使用该值初始化Embedding
:param int hidden_size: LSTM隐藏层的大小
:param int num_classes: 一共有多少类
"""
super(SeqLabeling, self).__init__()
self.embedding = get_embeddings(embed)
@@ -155,20 +161,21 @@ class SeqLabeling(BaseModel):
class AdvSeqLabel(nn.Module):
"""
更复杂的Sequence Labelling模型。结构为Embedding, LayerNorm, 双向LSTM(两层),FC,LayerNorm,DropOut,FC,CRF。
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int hidden_size: LSTM的隐层大小
:param int num_classes: 有多少个类
:param float dropout: LSTM中以及DropOut层的drop概率
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S'
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。)
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。
"""
def __init__(self, embed, hidden_size, num_classes, dropout=0.3, id2words=None, encoding_type='bmes'):
"""
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray embed: Embedding的大小(传入tuple(int, int),
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding
:param int hidden_size: LSTM的隐层大小
:param int num_classes: 有多少个类
:param float dropout: LSTM中以及DropOut层的drop概率
:param dict id2words: tag id转为其tag word的表。用于在CRF解码时防止解出非法的顺序,比如'BMES'这个标签规范中,'S'
不能出现在'B'之后。这里也支持类似与'B-NN',即'-'前为标签类型的指示,后面为具体的tag的情况。这里不但会保证
'B-NN'后面不为'S-NN'还会保证'B-NN'后面不会出现'M-xx'(任何非'M-NN'和'E-NN'的情况。)
:param str encoding_type: 支持"BIO", "BMES", "BEMSO", 只有在id2words不为None的情况有用。
"""
super().__init__()
self.Embedding = get_embeddings(embed)


+ 8
- 5
fastNLP/models/snli.py View File

@@ -22,15 +22,18 @@ class ESIM(BaseModel):
ESIM model的一个PyTorch实现
论文参见: https://arxiv.org/pdf/1609.06038.pdf

:param embed: 初始化的Embedding
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度
:param int num_labels: 目标标签种类数量,默认值为3
:param float dropout_rate: dropout的比率,默认值为0.3
:param float dropout_embed: 对Embedding的dropout比率,默认值为0.1
"""

def __init__(self, embed, hidden_size=None, num_labels=3, dropout_rate=0.3,
dropout_embed=0.1):
"""
:param embed: 初始化的Embedding
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度
:param int num_labels: 目标标签种类数量,默认值为3
:param float dropout_rate: dropout的比率,默认值为0.3
:param float dropout_embed: 对Embedding的dropout比率,默认值为0.1
"""
super(ESIM, self).__init__()

if isinstance(embed, TokenEmbedding) or isinstance(embed, Embedding):


+ 56
- 48
fastNLP/models/star_transformer.py View File

@@ -11,26 +11,16 @@ __all__ = [
import torch
from torch import nn

from ..modules.encoder.star_transformer import StarTransformer
from ..core.const import Const
from ..core.utils import seq_len_to_mask
from ..embeddings.utils import get_embeddings
from ..core.const import Const
from ..modules.encoder.star_transformer import StarTransformer


class StarTransEnc(nn.Module):
"""
带word embedding的Star-Transformer Encoder

:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param hidden_size: 模型中特征维度.
:param num_layers: 模型层数.
:param num_head: 模型中multi-head的head个数.
:param head_dim: 模型中multi-head中每个head特征维度.
:param max_len: 模型能接受的最大输入长度.
:param emb_dropout: 词嵌入的dropout概率.
:param dropout: 模型除词嵌入外的dropout概率.
"""

def __init__(self, embed,
@@ -41,6 +31,18 @@ class StarTransEnc(nn.Module):
max_len,
emb_dropout,
dropout):
"""
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,此时就以传入的对象作为embedding
:param hidden_size: 模型中特征维度.
:param num_layers: 模型层数.
:param num_head: 模型中multi-head的head个数.
:param head_dim: 模型中multi-head中每个head特征维度.
:param max_len: 模型能接受的最大输入长度.
:param emb_dropout: 词嵌入的dropout概率.
:param dropout: 模型除词嵌入外的dropout概率.
"""
super(StarTransEnc, self).__init__()
self.embedding = get_embeddings(embed)
emb_dim = self.embedding.embedding_dim
@@ -104,18 +106,6 @@ class STSeqLabel(nn.Module):
"""
用于序列标注的Star-Transformer模型

:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
:param hidden_size: 模型中特征维度. Default: 300
:param num_layers: 模型层数. Default: 4
:param num_head: 模型中multi-head的head个数. Default: 8
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32
:param max_len: 模型能接受的最大输入长度. Default: 512
:param cls_hidden_size: 分类器隐层维度. Default: 600
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, embed, num_cls,
@@ -127,6 +117,20 @@ class STSeqLabel(nn.Module):
cls_hidden_size=600,
emb_dropout=0.1,
dropout=0.1, ):
"""
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
:param hidden_size: 模型中特征维度. Default: 300
:param num_layers: 模型层数. Default: 4
:param num_head: 模型中multi-head的head个数. Default: 8
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32
:param max_len: 模型能接受的最大输入长度. Default: 512
:param cls_hidden_size: 分类器隐层维度. Default: 600
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""
super(STSeqLabel, self).__init__()
self.enc = StarTransEnc(embed=embed,
hidden_size=hidden_size,
@@ -167,18 +171,6 @@ class STSeqCls(nn.Module):
"""
用于分类任务的Star-Transformer

:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
:param hidden_size: 模型中特征维度. Default: 300
:param num_layers: 模型层数. Default: 4
:param num_head: 模型中multi-head的head个数. Default: 8
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32
:param max_len: 模型能接受的最大输入长度. Default: 512
:param cls_hidden_size: 分类器隐层维度. Default: 600
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, embed, num_cls,
@@ -190,6 +182,20 @@ class STSeqCls(nn.Module):
cls_hidden_size=600,
emb_dropout=0.1,
dropout=0.1, ):
"""
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
:param hidden_size: 模型中特征维度. Default: 300
:param num_layers: 模型层数. Default: 4
:param num_head: 模型中multi-head的head个数. Default: 8
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32
:param max_len: 模型能接受的最大输入长度. Default: 512
:param cls_hidden_size: 分类器隐层维度. Default: 600
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""
super(STSeqCls, self).__init__()
self.enc = StarTransEnc(embed=embed,
hidden_size=hidden_size,
@@ -230,18 +236,6 @@ class STNLICls(nn.Module):
"""
用于自然语言推断(NLI)的Star-Transformer

:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象,
此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
:param hidden_size: 模型中特征维度. Default: 300
:param num_layers: 模型层数. Default: 4
:param num_head: 模型中multi-head的head个数. Default: 8
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32
:param max_len: 模型能接受的最大输入长度. Default: 512
:param cls_hidden_size: 分类器隐层维度. Default: 600
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""

def __init__(self, embed, num_cls,
@@ -253,6 +247,20 @@ class STNLICls(nn.Module):
cls_hidden_size=600,
emb_dropout=0.1,
dropout=0.1, ):
"""
:param embed: 单词词典, 可以是 tuple, 包括(num_embedings, embedding_dim), 即
embedding的大小和每个词的维度. 也可以传入 nn.Embedding 对象, 此时就以传入的对象作为embedding
:param num_cls: 输出类别个数
:param hidden_size: 模型中特征维度. Default: 300
:param num_layers: 模型层数. Default: 4
:param num_head: 模型中multi-head的head个数. Default: 8
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32
:param max_len: 模型能接受的最大输入长度. Default: 512
:param cls_hidden_size: 分类器隐层维度. Default: 600
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""
super(STNLICls, self).__init__()
self.enc = StarTransEnc(embed=embed,
hidden_size=hidden_size,


Loading…
Cancel
Save