Browse Source

一些符合 PEP8 的微调

tags/v0.4.10
ChenXin 6 years ago
parent
commit
bdec6187a2
43 changed files with 229 additions and 225 deletions
  1. +2
    -2
      fastNLP/__init__.py
  2. +4
    -4
      fastNLP/core/batch.py
  3. +12
    -12
      fastNLP/core/callback.py
  4. +4
    -4
      fastNLP/core/dataset.py
  5. +4
    -4
      fastNLP/core/field.py
  6. +12
    -12
      fastNLP/core/losses.py
  7. +7
    -7
      fastNLP/core/metrics.py
  8. +2
    -2
      fastNLP/core/optimizer.py
  9. +4
    -4
      fastNLP/core/sampler.py
  10. +3
    -0
      fastNLP/core/trainer.py
  11. +5
    -4
      fastNLP/core/utils.py
  12. +6
    -6
      fastNLP/core/vocabulary.py
  13. +5
    -5
      fastNLP/io/__init__.py
  14. +3
    -3
      fastNLP/io/base_loader.py
  15. +6
    -6
      fastNLP/io/config_io.py
  16. +6
    -6
      fastNLP/io/dataset_loader.py
  17. +4
    -4
      fastNLP/io/embed_loader.py
  18. +4
    -4
      fastNLP/io/model_io.py
  19. +9
    -9
      fastNLP/models/__init__.py
  20. +5
    -5
      fastNLP/models/biaffine_parser.py
  21. +4
    -4
      fastNLP/models/cnn_text_classification.py
  22. +0
    -1
      fastNLP/models/enas_utils.py
  23. +5
    -5
      fastNLP/models/sequence_labeling.py
  24. +4
    -4
      fastNLP/models/snli.py
  25. +7
    -7
      fastNLP/models/star_transformer.py
  26. +9
    -9
      fastNLP/modules/__init__.py
  27. +7
    -7
      fastNLP/modules/aggregator/__init__.py
  28. +4
    -4
      fastNLP/modules/aggregator/attention.py
  29. +7
    -2
      fastNLP/modules/aggregator/pooling.py
  30. +5
    -5
      fastNLP/modules/decoder/__init__.py
  31. +5
    -5
      fastNLP/modules/decoder/crf.py
  32. +4
    -4
      fastNLP/modules/decoder/mlp.py
  33. +1
    -2
      fastNLP/modules/decoder/utils.py
  34. +4
    -2
      fastNLP/modules/dropout.py
  35. +8
    -9
      fastNLP/modules/encoder/__init__.py
  36. +4
    -5
      fastNLP/modules/encoder/char_encoder.py
  37. +3
    -4
      fastNLP/modules/encoder/conv_maxpool.py
  38. +2
    -3
      fastNLP/modules/encoder/embedding.py
  39. +4
    -4
      fastNLP/modules/encoder/lstm.py
  40. +4
    -4
      fastNLP/modules/encoder/star_transformer.py
  41. +3
    -4
      fastNLP/modules/encoder/transformer.py
  42. +27
    -27
      fastNLP/modules/encoder/variational_rnn.py
  43. +1
    -1
      fastNLP/modules/utils.py

+ 2
- 2
fastNLP/__init__.py View File

@@ -52,8 +52,8 @@ __all__ = [
"cache_results"
]
__version__ = '0.4.0'

from .core import *
from . import models
from . import modules

__version__ = '0.4.0'

+ 4
- 4
fastNLP/core/batch.py View File

@@ -2,6 +2,10 @@
batch 模块实现了 fastNLP 所需的 Batch 类。

"""
__all__ = [
"Batch"
]

import atexit
from queue import Empty, Full

@@ -11,10 +15,6 @@ import torch.multiprocessing as mp

from .sampler import RandomSampler

__all__ = [
"Batch"
]

_python_is_exit = False




+ 12
- 12
fastNLP/core/callback.py View File

@@ -49,6 +49,18 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:
trainer.train()

"""
__all__ = [
"Callback",
"GradientClipCallback",
"EarlyStopCallback",
"TensorboardCallback",
"LRScheduler",
"ControlC",
"CallbackException",
"EarlyStopError"
]

import os

import torch
@@ -62,18 +74,6 @@ except:

from ..io.model_io import ModelSaver, ModelLoader

__all__ = [
"Callback",
"GradientClipCallback",
"EarlyStopCallback",
"TensorboardCallback",
"LRScheduler",
"ControlC",
"CallbackException",
"EarlyStopError"
]


class Callback(object):
"""


+ 4
- 4
fastNLP/core/dataset.py View File

@@ -272,6 +272,10 @@


"""
__all__ = [
"DataSet"
]

import _pickle as pickle
import warnings

@@ -282,10 +286,6 @@ from .field import FieldArray
from .instance import Instance
from .utils import _get_func_signature

__all__ = [
"DataSet"
]


class DataSet(object):
"""


+ 4
- 4
fastNLP/core/field.py View File

@@ -3,10 +3,6 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas
原理部分请参考 :doc:`fastNLP.core.dataset`

"""
from copy import deepcopy

import numpy as np

__all__ = [
"FieldArray",
"Padder",
@@ -14,6 +10,10 @@ __all__ = [
"EngChar2DPadder"
]

from copy import deepcopy

import numpy as np


class FieldArray(object):
"""


+ 12
- 12
fastNLP/core/losses.py View File

@@ -2,6 +2,18 @@
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。

"""
__all__ = [
"LossBase",
"LossFunc",
"LossInForward",
"CrossEntropyLoss",
"BCELoss",
"L1Loss",
"NLLLoss"
]

import inspect
from collections import defaultdict

@@ -15,18 +27,6 @@ from .utils import _check_arg_dict_list
from .utils import _check_function_or_method
from .utils import _get_func_signature

__all__ = [
"LossBase",
"LossFunc",
"LossInForward",
"CrossEntropyLoss",
"BCELoss",
"L1Loss",
"NLLLoss"
]


class LossBase(object):
"""


+ 7
- 7
fastNLP/core/metrics.py View File

@@ -2,6 +2,13 @@
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。

"""
__all__ = [
"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",
"SQuADMetric"
]

import inspect
from collections import defaultdict

@@ -16,13 +23,6 @@ from .utils import _get_func_signature
from .utils import seq_len_to_mask
from .vocabulary import Vocabulary

__all__ = [
"MetricBase",
"AccuracyMetric",
"SpanFPreRecMetric",
"SQuADMetric"
]


class MetricBase(object):
"""


+ 2
- 2
fastNLP/core/optimizer.py View File

@@ -2,14 +2,14 @@
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。

"""
import torch

__all__ = [
"Optimizer",
"SGD",
"Adam"
]

import torch


class Optimizer(object):
"""


+ 4
- 4
fastNLP/core/sampler.py View File

@@ -1,10 +1,6 @@
"""
sampler 子类实现了 fastNLP 所需的各种采样器。
"""
from itertools import chain

import numpy as np

__all__ = [
"Sampler",
"BucketSampler",
@@ -12,6 +8,10 @@ __all__ = [
"RandomSampler"
]

from itertools import chain

import numpy as np


class Sampler(object):
"""


+ 3
- 0
fastNLP/core/trainer.py View File

@@ -295,6 +295,9 @@ Example2.3
fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。

"""
__all__ = [
"Trainer"
]

import os
import time


+ 5
- 4
fastNLP/core/utils.py View File

@@ -1,6 +1,11 @@
"""
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。
"""
__all__ = [
"cache_results",
"seq_len_to_mask"
]

import _pickle
import inspect
import os
@@ -11,10 +16,6 @@ import numpy as np
import torch
import torch.nn as nn

__all__ = [
"cache_results",
"seq_len_to_mask"
]

_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed',
'varargs'])


+ 6
- 6
fastNLP/core/vocabulary.py View File

@@ -1,12 +1,12 @@
__all__ = [
"Vocabulary"
]

from functools import wraps
from collections import Counter

from .dataset import DataSet

__all__ = [
"Vocabulary"
]


def _check_build_vocab(func):
"""A decorator to make sure the indexing is built before used.
@@ -322,7 +322,7 @@ class Vocabulary(object):
:return str word: the word
"""
return self.idx2word[idx]
def clear(self):
"""
删除Vocabulary中的词表数据。相当于重新初始化一下。
@@ -333,7 +333,7 @@ class Vocabulary(object):
self.word2idx = None
self.idx2word = None
self.rebuild = True
def __getstate__(self):
"""Use to prepare data for pickle.



+ 5
- 5
fastNLP/io/__init__.py View File

@@ -9,11 +9,6 @@

这些类的使用方法如下:
"""
from .embed_loader import EmbedLoader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver

__all__ = [
'EmbedLoader',
@@ -29,3 +24,8 @@ __all__ = [
'ModelLoader',
'ModelSaver',
]

from .embed_loader import EmbedLoader
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \
PeopleDailyCorpusLoader, Conll2003Loader
from .model_io import ModelLoader, ModelSaver

+ 3
- 3
fastNLP/io/base_loader.py View File

@@ -1,10 +1,10 @@
import _pickle as pickle
import os

__all__ = [
"BaseLoader"
]

import _pickle as pickle
import os


class BaseLoader(object):
"""


+ 6
- 6
fastNLP/io/config_io.py View File

@@ -3,18 +3,18 @@
.. todo::
这个模块中的类可能被抛弃?
"""
import configparser
import json
import os

from .base_loader import BaseLoader

__all__ = [
"ConfigLoader",
"ConfigSection",
"ConfigSaver"
]

import configparser
import json
import os

from .base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""


+ 6
- 6
fastNLP/io/dataset_loader.py View File

@@ -10,12 +10,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的

# ... do stuff
"""
from nltk.tree import Tree

from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll

__all__ = [
'DataSetLoader',
'CSVLoader',
@@ -27,6 +21,12 @@ __all__ = [
'Conll2003Loader',
]

from nltk.tree import Tree

from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll


def _download_from_url(url, path):
try:


+ 4
- 4
fastNLP/io/embed_loader.py View File

@@ -1,3 +1,7 @@
__all__ = [
"EmbedLoader"
]

import os
import warnings

@@ -6,10 +10,6 @@ import numpy as np
from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader

__all__ = [
"EmbedLoader"
]


class EmbedLoader(BaseLoader):
"""


+ 4
- 4
fastNLP/io/model_io.py View File

@@ -1,15 +1,15 @@
"""
用于载入和保存模型
"""
import torch

from .base_loader import BaseLoader

__all__ = [
"ModelLoader",
"ModelSaver"
]

import torch

from .base_loader import BaseLoader


class ModelLoader(BaseLoader):
"""


+ 9
- 9
fastNLP/models/__init__.py View File

@@ -7,15 +7,6 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models


"""
from .base_model import BaseModel
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \
BertForTokenClassification
from .biaffine_parser import BiaffineParser, GraphParser
from .cnn_text_classification import CNNText
from .sequence_labeling import SeqLabeling, AdvSeqLabel
from .snli import ESIM
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel

__all__ = [
"CNNText",
@@ -32,3 +23,12 @@ __all__ = [
"BiaffineParser",
"GraphParser"
]

from .base_model import BaseModel
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \
BertForTokenClassification
from .biaffine_parser import BiaffineParser, GraphParser
from .cnn_text_classification import CNNText
from .sequence_labeling import SeqLabeling, AdvSeqLabel
from .snli import ESIM
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel

+ 5
- 5
fastNLP/models/biaffine_parser.py View File

@@ -1,6 +1,11 @@
"""
Biaffine Dependency Parser 的 Pytorch 实现.
"""
__all__ = [
"BiaffineParser",
"GraphParser"
]

import numpy as np
import torch
import torch.nn as nn
@@ -19,11 +24,6 @@ from ..modules.utils import get_embeddings
from .base_model import BaseModel
from ..core.utils import seq_len_to_mask

__all__ = [
"BiaffineParser",
"GraphParser"
]


def _mst(scores):
"""


+ 4
- 4
fastNLP/models/cnn_text_classification.py View File

@@ -1,13 +1,13 @@
__all__ = [
"CNNText"
]

import torch
import torch.nn as nn

from ..core.const import Const as C
from ..modules import encoder

__all__ = [
"CNNText"
]


class CNNText(torch.nn.Module):
"""


+ 0
- 1
fastNLP/models/enas_utils.py View File

@@ -1,6 +1,5 @@
# Code Modified from https://github.com/carpedm20/ENAS-pytorch

from __future__ import print_function
from collections import defaultdict
import collections



+ 5
- 5
fastNLP/models/sequence_labeling.py View File

@@ -1,6 +1,11 @@
"""
本模块实现了两种序列标注模型
"""
__all__ = [
"SeqLabeling",
"AdvSeqLabel"
]

import torch
import torch.nn as nn

@@ -10,11 +15,6 @@ from ..modules.decoder.crf import allowed_transitions
from ..core.utils import seq_len_to_mask
from ..core.const import Const as C

__all__ = [
"SeqLabeling",
"AdvSeqLabel"
]


class SeqLabeling(BaseModel):
"""


+ 4
- 4
fastNLP/models/snli.py View File

@@ -1,3 +1,7 @@
__all__ = [
"ESIM"
]

import torch
import torch.nn as nn

@@ -8,10 +12,6 @@ from ..modules import encoder as Encoder
from ..modules import aggregator as Aggregator
from ..core.utils import seq_len_to_mask

__all__ = [
"ESIM"
]

my_inf = 10e12




+ 7
- 7
fastNLP/models/star_transformer.py View File

@@ -1,6 +1,13 @@
"""
Star-Transformer 的 Pytorch 实现。
"""
__all__ = [
"StarTransEnc",
"STNLICls",
"STSeqCls",
"STSeqLabel",
]

import torch
from torch import nn

@@ -9,13 +16,6 @@ from ..core.utils import seq_len_to_mask
from ..modules.utils import get_embeddings
from ..core.const import Const

__all__ = [
"StarTransEnc",
"STNLICls",
"STSeqCls",
"STSeqLabel",
]


class StarTransEnc(nn.Module):
"""


+ 9
- 9
fastNLP/modules/__init__.py View File

@@ -22,15 +22,6 @@
+-----------------------+-----------------------+-----------------------+

"""
from . import aggregator
from . import decoder
from . import encoder
from .aggregator import *
from .decoder import *
from .dropout import TimestepDropout
from .encoder import *
from .utils import get_embeddings

__all__ = [
# "BertModel",
"ConvolutionCharEncoder",
@@ -54,3 +45,12 @@ __all__ = [
"viterbi_decode",
"allowed_transitions",
]

from . import aggregator
from . import decoder
from . import encoder
from .aggregator import *
from .decoder import *
from .dropout import TimestepDropout
from .encoder import *
from .utils import get_embeddings

+ 7
- 7
fastNLP/modules/aggregator/__init__.py View File

@@ -1,10 +1,3 @@
from .pooling import MaxPool
from .pooling import MaxPoolWithMask
from .pooling import AvgPool
from .pooling import AvgPoolWithMask

from .attention import MultiHeadAttention

__all__ = [
"MaxPool",
"MaxPoolWithMask",
@@ -12,3 +5,10 @@ __all__ = [
"MultiHeadAttention",
]

from .pooling import MaxPool
from .pooling import MaxPoolWithMask
from .pooling import AvgPool
from .pooling import AvgPoolWithMask

from .attention import MultiHeadAttention

+ 4
- 4
fastNLP/modules/aggregator/attention.py View File

@@ -1,3 +1,7 @@
__all__ = [
"MultiHeadAttention"
]

import math

import torch
@@ -8,10 +12,6 @@ from ..dropout import TimestepDropout

from ..utils import initial_parameter

__all__ = [
"MultiHeadAttention"
]


class DotAttention(nn.Module):
"""


+ 7
- 2
fastNLP/modules/aggregator/pooling.py View File

@@ -1,4 +1,8 @@
__all__ = ["MaxPool", "MaxPoolWithMask", "AvgPool"]
__all__ = [
"MaxPool",
"MaxPoolWithMask",
"AvgPool"
]
import torch
import torch.nn as nn

@@ -16,6 +20,7 @@ class MaxPool(nn.Module):
:param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension
:param ceil_mode:
"""
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False):
super(MaxPool, self).__init__()
@@ -125,7 +130,7 @@ class AvgPoolWithMask(nn.Module):
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling
的时候只会考虑mask为1的位置
"""
def __init__(self):
super(AvgPoolWithMask, self).__init__()
self.inf = 10e12


+ 5
- 5
fastNLP/modules/decoder/__init__.py View File

@@ -1,11 +1,11 @@
from .crf import ConditionalRandomField
from .mlp import MLP
from .utils import viterbi_decode
from .crf import allowed_transitions

__all__ = [
"MLP",
"ConditionalRandomField",
"viterbi_decode",
"allowed_transitions"
]

from .crf import ConditionalRandomField
from .mlp import MLP
from .utils import viterbi_decode
from .crf import allowed_transitions

+ 5
- 5
fastNLP/modules/decoder/crf.py View File

@@ -1,13 +1,13 @@
import torch
from torch import nn

from ..utils import initial_parameter

__all__ = [
"ConditionalRandomField",
"allowed_transitions"
]

import torch
from torch import nn

from ..utils import initial_parameter


def allowed_transitions(id2target, encoding_type='bio', include_start_end=True):
"""


+ 4
- 4
fastNLP/modules/decoder/mlp.py View File

@@ -1,12 +1,12 @@
__all__ = [
"MLP"
]

import torch
import torch.nn as nn

from ..utils import initial_parameter

__all__ = [
"MLP"
]


class MLP(nn.Module):
"""


+ 1
- 2
fastNLP/modules/decoder/utils.py View File

@@ -1,8 +1,7 @@
import torch

__all__ = [
"viterbi_decode"
]
import torch


def viterbi_decode(logits, transitions, mask=None, unpad=False):


+ 4
- 2
fastNLP/modules/dropout.py View File

@@ -1,6 +1,8 @@
import torch
__all__ = []

import torch


class TimestepDropout(torch.nn.Dropout):
"""
别名::class:`fastNLP.modules.TimestepDropout`
@@ -8,7 +10,7 @@ class TimestepDropout(torch.nn.Dropout):
接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``)
在每个timestamp上做dropout。
"""
def forward(self, x):
dropout_mask = x.new_ones(x.shape[0], x.shape[-1])
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True)


+ 8
- 9
fastNLP/modules/encoder/__init__.py View File

@@ -1,12 +1,3 @@
from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .lstm import LSTM
from .star_transformer import StarTransformer
from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU

__all__ = [
# "BertModel",
@@ -27,3 +18,11 @@ __all__ = [
"VarLSTM",
"VarGRU"
]
from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
from .embedding import Embedding
from .lstm import LSTM
from .star_transformer import StarTransformer
from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU

+ 4
- 5
fastNLP/modules/encoder/char_encoder.py View File

@@ -1,12 +1,11 @@
import torch
import torch.nn as nn

from ..utils import initial_parameter

__all__ = [
"ConvolutionCharEncoder",
"LSTMCharEncoder"
]
import torch
import torch.nn as nn

from ..utils import initial_parameter


# from torch.nn.init import xavier_uniform


+ 3
- 4
fastNLP/modules/encoder/conv_maxpool.py View File

@@ -1,13 +1,12 @@
__all__ = [
"ConvMaxpool"
]
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..utils import initial_parameter

__all__ = [
"ConvMaxpool"
]


class ConvMaxpool(nn.Module):
"""


+ 2
- 3
fastNLP/modules/encoder/embedding.py View File

@@ -1,9 +1,8 @@
import torch.nn as nn
from ..utils import get_embeddings

__all__ = [
"Embedding"
]
import torch.nn as nn
from ..utils import get_embeddings


class Embedding(nn.Embedding):


+ 4
- 4
fastNLP/modules/encoder/lstm.py View File

@@ -2,16 +2,16 @@
轻量封装的 Pytorch LSTM 模块.
可在 forward 时传入序列的长度, 自动对padding做合适的处理.
"""
__all__ = [
"LSTM"
]

import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn

from ..utils import initial_parameter

__all__ = [
"LSTM"
]


class LSTM(nn.Module):
"""


+ 4
- 4
fastNLP/modules/encoder/star_transformer.py View File

@@ -1,15 +1,15 @@
"""
Star-Transformer 的encoder部分的 Pytorch 实现
"""
__all__ = [
"StarTransformer"
]

import numpy as NP
import torch
from torch import nn
from torch.nn import functional as F

__all__ = [
"StarTransformer"
]


class StarTransformer(nn.Module):
"""


+ 3
- 4
fastNLP/modules/encoder/transformer.py View File

@@ -1,12 +1,11 @@
__all__ = [
"TransformerEncoder"
]
from torch import nn

from ..aggregator.attention import MultiHeadAttention
from ..dropout import TimestepDropout

__all__ = [
"TransformerEncoder"
]


class TransformerEncoder(nn.Module):
"""


+ 27
- 27
fastNLP/modules/encoder/variational_rnn.py View File

@@ -1,6 +1,12 @@
"""
Variational RNN 的 Pytorch 实现
"""
__all__ = [
"VarRNN",
"VarLSTM",
"VarGRU"
]

import torch
import torch.nn as nn
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
@@ -17,25 +23,19 @@ except ImportError:

from ..utils import initial_parameter

__all__ = [
"VarRNN",
"VarLSTM",
"VarGRU"
]


class VarRnnCellWrapper(nn.Module):
"""
Wrapper for normal RNN Cells, make it support variational dropout
"""
def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
self.hidden_size = hidden_size
self.input_p = input_p
self.hidden_p = hidden_p
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False):
"""
:param PackedSequence input_x: [seq_len, batch_size, input_size]
@@ -47,13 +47,13 @@ class VarRnnCellWrapper(nn.Module):
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
"""
def get_hi(hi, h0, size):
h0_size = size - hi.size(0)
if h0_size > 0:
return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size]
is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x.data, input_x.batch_sizes
output = []
@@ -64,7 +64,7 @@ class VarRnnCellWrapper(nn.Module):
else:
batch_iter = batch_sizes
idx = 0
if is_lstm:
hn = (hidden[0].clone(), hidden[1].clone())
else:
@@ -91,7 +91,7 @@ class VarRnnCellWrapper(nn.Module):
hi = cell(input_i, hi)
hn[:size] = hi
output.append(hi)
if is_reversed:
output = list(reversed(output))
output = torch.cat(output, dim=0)
@@ -117,7 +117,7 @@ class VarRNNBase(nn.Module):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -141,7 +141,7 @@ class VarRNNBase(nn.Module):
cell, self.hidden_size, input_dropout, hidden_dropout))
initial_parameter(self)
self.is_lstm = (self.mode == "LSTM")
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h):
is_lstm = self.is_lstm
idx = self.num_directions * n_layer + n_direction
@@ -150,7 +150,7 @@ class VarRNNBase(nn.Module):
output_x, hidden_x = cell(
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1))
return output_x, hidden_x
def forward(self, x, hx=None):
"""

@@ -170,13 +170,13 @@ class VarRNNBase(nn.Module):
else:
max_batch_size = int(x.batch_sizes[0])
x, batch_sizes = x.data, x.batch_sizes
if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions,
max_batch_size, self.hidden_size, requires_grad=True)
if is_lstm:
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True))
mask_x = x.new_ones((max_batch_size, self.input_size))
mask_out = x.new_ones(
(max_batch_size, self.hidden_size * self.num_directions))
@@ -185,7 +185,7 @@ class VarRNNBase(nn.Module):
training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout,
training=self.training, inplace=True)
hidden = x.new_zeros(
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
if is_lstm:
@@ -207,16 +207,16 @@ class VarRNNBase(nn.Module):
else:
hidden[idx] = hidden_x
x = torch.cat(output_list, dim=-1)
if is_lstm:
hidden = (hidden, cellstate)
if is_packed:
output = PackedSequence(x, batch_sizes)
else:
x = PackedSequence(x, batch_sizes)
output, _ = pad_packed_sequence(x, batch_first=self.batch_first)
return output, hidden


@@ -236,11 +236,11 @@ class VarLSTM(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False``
"""
def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)
def forward(self, x, hx=None):
return super(VarLSTM, self).forward(x, hx)

@@ -261,11 +261,11 @@ class VarRNN(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(
mode="RNN", Cell=nn.RNNCell, *args, **kwargs)
def forward(self, x, hx=None):
return super(VarRNN, self).forward(x, hx)

@@ -286,10 +286,10 @@ class VarGRU(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False``
"""
def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(
mode="GRU", Cell=nn.GRUCell, *args, **kwargs)
def forward(self, x, hx=None):
return super(VarGRU, self).forward(x, hx)

+ 1
- 1
fastNLP/modules/utils.py View File

@@ -1,5 +1,5 @@
from functools import reduce
from collections import OrderedDict
import numpy as np
import torch
import torch.nn as nn


Loading…
Cancel
Save