Browse Source

fix bug in tests

tags/v0.4.10
xuyige 5 years ago
parent
commit
a39dafac6b
6 changed files with 28 additions and 36 deletions
  1. +2
    -7
      fastNLP/modules/__init__.py
  2. +11
    -1
      fastNLP/modules/encoder/__init__.py
  3. +1
    -2
      fastNLP/modules/encoder/attention.py
  4. +5
    -18
      fastNLP/modules/encoder/bert.py
  5. +6
    -5
      fastNLP/modules/encoder/pooling.py
  6. +3
    -3
      reproduction/README.md

+ 2
- 7
fastNLP/modules/__init__.py View File

@@ -1,11 +1,11 @@
""" """
大部分用于的 NLP 任务神经网络都可以看做由编码 :mod:`~fastNLP.modules.encoder` 、 大部分用于的 NLP 任务神经网络都可以看做由编码 :mod:`~fastNLP.modules.encoder` 、
聚合 :mod:`~fastNLP.modules.aggregator` 、解码 :mod:`~fastNLP.modules.decoder` 三种模块组成。
解码 :mod:`~fastNLP.modules.decoder` 两种模块组成。


.. image:: figures/text_classification.png .. image:: figures/text_classification.png


:mod:`~fastNLP.modules` 中实现了 fastNLP 提供的诸多模块组件,可以帮助用户快速搭建自己所需的网络。 :mod:`~fastNLP.modules` 中实现了 fastNLP 提供的诸多模块组件,可以帮助用户快速搭建自己所需的网络。
种模块的功能和常见组件如下:
种模块的功能和常见组件如下:


+-----------------------+-----------------------+-----------------------+ +-----------------------+-----------------------+-----------------------+
| module type | functionality | example | | module type | functionality | example |
@@ -13,9 +13,6 @@
| encoder | 将输入编码为具有具 | embedding, RNN, CNN, | | encoder | 将输入编码为具有具 | embedding, RNN, CNN, |
| | 有表示能力的向量 | transformer | | | 有表示能力的向量 | transformer |
+-----------------------+-----------------------+-----------------------+ +-----------------------+-----------------------+-----------------------+
| aggregator | 从多个向量中聚合信息 | self-attention, |
| | | max-pooling |
+-----------------------+-----------------------+-----------------------+
| decoder | 将具有某种表示意义的 | MLP, CRF | | decoder | 将具有某种表示意义的 | MLP, CRF |
| | 向量解码为需要的输出 | | | | 向量解码为需要的输出 | |
| | 形式 | | | | 形式 | |
@@ -46,10 +43,8 @@ __all__ = [
"allowed_transitions", "allowed_transitions",
] ]


from . import aggregator
from . import decoder from . import decoder
from . import encoder from . import encoder
from .aggregator import *
from .decoder import * from .decoder import *
from .dropout import TimestepDropout from .dropout import TimestepDropout
from .encoder import * from .encoder import *


+ 11
- 1
fastNLP/modules/encoder/__init__.py View File

@@ -22,7 +22,14 @@ __all__ = [
"VarRNN", "VarRNN",
"VarLSTM", "VarLSTM",
"VarGRU"
"VarGRU",

"MaxPool",
"MaxPoolWithMask",
"AvgPool",
"AvgPoolWithMask",

"MultiHeadAttention",
] ]
from ._bert import BertModel from ._bert import BertModel
from .bert import BertWordPieceEncoder from .bert import BertWordPieceEncoder
@@ -34,3 +41,6 @@ from .lstm import LSTM
from .star_transformer import StarTransformer from .star_transformer import StarTransformer
from .transformer import TransformerEncoder from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU from .variational_rnn import VarRNN, VarLSTM, VarGRU

from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask
from .attention import MultiHeadAttention

+ 1
- 2
fastNLP/modules/encoder/attention.py View File

@@ -45,8 +45,7 @@ class DotAttention(nn.Module):


class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
""" """
别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.aggregator.attention.MultiHeadAttention`

别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.attention.MultiHeadAttention`


:param input_size: int, 输入维度的大小。同时也是输出维度的大小。 :param input_size: int, 输入维度的大小。同时也是输出维度的大小。
:param key_size: int, 每个head的维度大小。 :param key_size: int, 每个head的维度大小。


+ 5
- 18
fastNLP/modules/encoder/bert.py View File

@@ -2,35 +2,22 @@
import os import os
from torch import nn from torch import nn
import torch import torch
from ...io.file_utils import _get_base_url, cached_path
from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ._bert import _WordPieceBertModel, BertModel from ._bert import _WordPieceBertModel, BertModel



class BertWordPieceEncoder(nn.Module): class BertWordPieceEncoder(nn.Module):
""" """
读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。


:param fastNLP.Vocabulary vocab: 词表
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased`` :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased``
:param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
:param bool requires_grad: 是否需要gradient。 :param bool requires_grad: 是否需要gradient。
""" """
def __init__(self, model_dir_or_name:str='en-base-uncased', layers:str='-1',
requires_grad:bool=False):
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1',
requires_grad: bool=False):
super().__init__() super().__init__()
PRETRAIN_URL = _get_base_url('bert') PRETRAIN_URL = _get_base_url('bert')
PRETRAINED_BERT_MODEL_DIR = {'en': 'bert-base-cased-f89bfe08.zip',
'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
'en-base-cased': 'bert-base-cased-f89bfe08.zip',
'en-large-uncased': 'bert-large-uncased-20939f45.zip',
'en-large-cased': 'bert-large-cased-e0cf90fc.zip',

'cn': 'bert-base-chinese-29d0a84a.zip',
'cn-base': 'bert-base-chinese-29d0a84a.zip',

'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
}


if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR:
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
@@ -89,4 +76,4 @@ class BertWordPieceEncoder(nn.Module):
outputs = self.model(word_pieces, token_type_ids) outputs = self.model(word_pieces, token_type_ids)
outputs = torch.cat([*outputs], dim=-1) outputs = torch.cat([*outputs], dim=-1)


return outputs
return outputs

+ 6
- 5
fastNLP/modules/encoder/pooling.py View File

@@ -1,7 +1,8 @@
__all__ = [ __all__ = [
"MaxPool", "MaxPool",
"MaxPoolWithMask", "MaxPoolWithMask",
"AvgPool"
"AvgPool",
"AvgPoolWithMask"
] ]
import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -9,7 +10,7 @@ import torch.nn as nn


class MaxPool(nn.Module): class MaxPool(nn.Module):
""" """
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.aggregator.pooling.MaxPool`
别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.pooling.MaxPool`


Max-pooling模块。 Max-pooling模块。
@@ -58,7 +59,7 @@ class MaxPool(nn.Module):


class MaxPoolWithMask(nn.Module): class MaxPoolWithMask(nn.Module):
""" """
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.MaxPoolWithMask`
别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.pooling.MaxPoolWithMask`


带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。 带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。
""" """
@@ -98,7 +99,7 @@ class KMaxPool(nn.Module):


class AvgPool(nn.Module): class AvgPool(nn.Module):
""" """
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.aggregator.pooling.AvgPool`
别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.pooling.AvgPool`


给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size] 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size]
""" """
@@ -125,7 +126,7 @@ class AvgPool(nn.Module):


class AvgPoolWithMask(nn.Module): class AvgPoolWithMask(nn.Module):
""" """
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.aggregator.pooling.AvgPoolWithMask`
别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.pooling.AvgPoolWithMask`


给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling
的时候只会考虑mask为1的位置 的时候只会考虑mask为1的位置


+ 3
- 3
reproduction/README.md View File

@@ -9,7 +9,7 @@


# 任务复现 # 任务复现
## Text Classification (文本分类) ## Text Classification (文本分类)
- still in progress
- [Text Classification 文本分类任务复现](text_classification)




## Matching (自然语言推理/句子匹配) ## Matching (自然语言推理/句子匹配)
@@ -21,11 +21,11 @@




## Coreference resolution (指代消解) ## Coreference resolution (指代消解)
- still in progress
- [Coreference resolution 指代消解任务复现](coreference_resolution)




## Summarization (摘要) ## Summarization (摘要)
- still in progress
- [BertSum](Summmarization)




## ... ## ...

Loading…
Cancel
Save