|
|
@@ -12,6 +12,14 @@ class CNNText(torch.nn.Module): |
|
|
|
""" |
|
|
|
使用CNN进行文本分类的模型 |
|
|
|
'Yoon Kim. 2014. Convolution Neural Networks for Sentence Classification.' |
|
|
|
|
|
|
|
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), |
|
|
|
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding |
|
|
|
:param int num_classes: 一共有多少类 |
|
|
|
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 |
|
|
|
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 |
|
|
|
:param int padding: |
|
|
|
:param float dropout: Dropout的大小 |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, init_embed, |
|
|
@@ -20,16 +28,6 @@ class CNNText(torch.nn.Module): |
|
|
|
kernel_sizes=(3, 4, 5), |
|
|
|
padding=0, |
|
|
|
dropout=0.5): |
|
|
|
""" |
|
|
|
|
|
|
|
:param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int), |
|
|
|
第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding |
|
|
|
:param int num_classes: 一共有多少类 |
|
|
|
:param int,tuple(int) out_channels: 输出channel的数量。如果为list,则需要与kernel_sizes的数量保持一致 |
|
|
|
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。 |
|
|
|
:param int padding: |
|
|
|
:param float dropout: Dropout的大小 |
|
|
|
""" |
|
|
|
super(CNNText, self).__init__() |
|
|
|
|
|
|
|
# no support for pre-trained embedding currently |
|
|
|