Browse Source

Merge branch 'dev' of github.com:choosewhatulike/fastNLP-private into dev

tags/v0.4.10
yh_cc 5 years ago
parent
commit
f65c0935f6
4 changed files with 44 additions and 39 deletions
  1. +1
    -1
      fastNLP/models/bert.py
  2. +25
    -25
      fastNLP/modules/encoder/char_encoder.py
  3. +15
    -10
      fastNLP/modules/encoder/embedding.py
  4. +3
    -3
      test/modules/test_char_encoder.py

+ 1
- 1
fastNLP/models/bert.py View File

@@ -280,7 +280,7 @@ class BertForQuestionAnswering(BaseModel):
start_loss = loss_fct(start_logits, start_positions)
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return {"loss": total_loss}
return {"pred1": start_logits, "pred2": end_logits, "loss": total_loss}
else:
return {"pred1": start_logits, "pred2": end_logits}



fastNLP/modules/encoder/char_embedding.py → fastNLP/modules/encoder/char_encoder.py View File

@@ -5,19 +5,18 @@ from fastNLP.modules.utils import initial_parameter


# from torch.nn.init import xavier_uniform
class ConvCharEmbedding(nn.Module):
"""Character-level Embedding with CNN.

:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param tuple feature_maps: tuple of int. The length of the tuple is the number of convolution operations
over characters. The i-th integer is the number of filters (dim of out channels) for the i-th
convolution.
:param tuple kernels: tuple of int. The width of each kernel.
"""
class ConvolutionCharEncoder(nn.Module):
"""char级别的卷积编码器."""

def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(3, 4, 5), initial_method=None):
super(ConvCharEmbedding, self).__init__()
"""
:param int char_emb_size: char级别embedding的维度. Default: 50
例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50.
:param tuple feature_maps: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的filter.
:param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核.
:param initial_method: 初始化参数的方式, 默认为`xavier normal`
"""
super(ConvolutionCharEncoder, self).__init__()
self.convs = nn.ModuleList([
nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, 4))
for i in range(len(kernels))])
@@ -26,16 +25,16 @@ class ConvCharEmbedding(nn.Module):

def forward(self, x):
"""
:param x: ``[batch_size * sent_length, word_length, char_emb_size]``
:return: feature map of shape [batch_size * sent_length, sum(feature_maps), 1]
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding
:return: torch.Tensor : 卷积计算的结果, 维度为[batch_size * sent_length, sum(feature_maps), 1]
"""
x = x.contiguous().view(x.size(0), 1, x.size(1), x.size(2))
# [batch_size*sent_length, channel, width, height]
x = x.transpose(2, 3)
# [batch_size*sent_length, channel, height, width]
return self.convolute(x).unsqueeze(2)
return self._convolute(x).unsqueeze(2)

def convolute(self, x):
def _convolute(self, x):
feats = []
for conv in self.convs:
y = conv(x)
@@ -49,15 +48,16 @@ class ConvCharEmbedding(nn.Module):
return torch.cat(feats, 1) # [batch_size*sent_length, sum(feature_maps)]


class LSTMCharEmbedding(nn.Module):
"""Character-level Embedding with LSTM.

:param int char_emb_size: the size of character level embedding. Default: 50
say 26 characters, each embedded to 50 dim vector, then the input_size is 50.
:param int hidden_size: the number of hidden units. Default: equal to char_emb_size.
"""
class LSTMCharEncoder(nn.Module):
"""char级别基于LSTM的encoder."""
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
super(LSTMCharEmbedding, self).__init__()
"""
:param int char_emb_size: char级别embedding的维度. Default: 50
例: 有26个字符, 每一个的embedding是一个50维的向量, 所以输入的向量维度为50.
:param int hidden_size: LSTM隐层的大小, 默认为char的embedding维度
:param initial_method: 初始化参数的方式, 默认为`xavier normal`
"""
super(LSTMCharEncoder, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size

self.lstm = nn.LSTM(input_size=char_emb_size,
@@ -69,8 +69,8 @@ class LSTMCharEmbedding(nn.Module):

def forward(self, x):
"""
:param x: ``[ n_batch*n_word, word_length, char_emb_size]``
:return: [ n_batch*n_word, char_emb_size]
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding
:return: torch.Tensor : [ n_batch*n_word, char_emb_size]经过LSTM编码的结果
"""
batch_size = x.shape[0]
h0 = torch.empty(1, batch_size, self.hidden_size)

+ 15
- 10
fastNLP/modules/encoder/embedding.py View File

@@ -2,20 +2,25 @@ import torch.nn as nn


class Embedding(nn.Module):
"""A simple lookup table.
"""Embedding组件."""

:param int nums: the size of the lookup table
:param int dims: the size of each vector
:param int padding_idx: pads the tensor with zeros whenever it encounters this index
:param bool sparse: If True, gradient matrix will be a sparse tensor. In this case, only optim.SGD(cuda and cpu) and optim.Adagrad(cpu) can be used
"""
def __init__(self, nums, dims, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
def __init__(self, vocab_size, embed_dim, padding_idx=0, sparse=False, init_emb=None, dropout=0.0):
"""
:param int vocab_size: 词表大小.
:param int embed_dim: embedding维度.
:param int padding_idx: 如果碰到padding_idx则自动补0.
:param bool sparse: 如果为`True`则权重矩阵是一个sparse的矩阵.
:param torch.Tensor init_emb: 初始的embedding矩阵.
:param float dropout: dropout概率.
"""
super(Embedding, self).__init__()
self.embed = nn.Embedding(nums, dims, padding_idx, sparse=sparse)
if init_emb is not None:
self.embed.weight = nn.Parameter(init_emb)
self.embed = nn.Embedding(vocab_size, embed_dim, padding_idx, sparse=sparse, _weight=init_emb)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
"""
:param torch.LongTensor x: [batch, seq_len]
:return: torch.Tensor : [batch, seq_len, embed_dim]
"""
x = self.embed(x)
return self.dropout(x)

test/modules/test_char_embedding.py → test/modules/test_char_encoder.py View File

@@ -2,7 +2,7 @@ import unittest

import torch

from fastNLP.modules.encoder.char_embedding import ConvCharEmbedding, LSTMCharEmbedding
from fastNLP.modules.encoder.char_encoder import ConvolutionCharEncoder, LSTMCharEncoder


class TestCharEmbed(unittest.TestCase):
@@ -13,14 +13,14 @@ class TestCharEmbed(unittest.TestCase):
x = torch.Tensor(batch_size, char_emb, word_length)
x = x.transpose(1, 2)

cce = ConvCharEmbedding(char_emb)
cce = ConvolutionCharEncoder(char_emb)
y = cce(x)
self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb))
print("CNN Char Emb input: ", x.shape)
self.assertEqual(tuple(y.shape), (batch_size, char_emb, 1))
print("CNN Char Emb output: ", y.shape) # [128, 100]

lce = LSTMCharEmbedding(char_emb)
lce = LSTMCharEncoder(char_emb)
o = lce(x)
self.assertEqual(tuple(x.shape), (batch_size, word_length, char_emb))
print("LSTM Char Emb input: ", x.shape)

Loading…
Cancel
Save