@@ -52,8 +52,8 @@ __all__ = [ | |||
"cache_results" | |||
] | |||
__version__ = '0.4.0' | |||
from .core import * | |||
from . import models | |||
from . import modules | |||
__version__ = '0.4.0' |
@@ -2,6 +2,10 @@ | |||
batch 模块实现了 fastNLP 所需的 Batch 类。 | |||
""" | |||
__all__ = [ | |||
"Batch" | |||
] | |||
import atexit | |||
from queue import Empty, Full | |||
@@ -11,10 +15,6 @@ import torch.multiprocessing as mp | |||
from .sampler import RandomSampler | |||
__all__ = [ | |||
"Batch" | |||
] | |||
_python_is_exit = False | |||
@@ -49,6 +49,18 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: | |||
trainer.train() | |||
""" | |||
__all__ = [ | |||
"Callback", | |||
"GradientClipCallback", | |||
"EarlyStopCallback", | |||
"TensorboardCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"CallbackException", | |||
"EarlyStopError" | |||
] | |||
import os | |||
import torch | |||
@@ -62,18 +74,6 @@ except: | |||
from ..io.model_io import ModelSaver, ModelLoader | |||
__all__ = [ | |||
"Callback", | |||
"GradientClipCallback", | |||
"EarlyStopCallback", | |||
"TensorboardCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"CallbackException", | |||
"EarlyStopError" | |||
] | |||
class Callback(object): | |||
""" | |||
@@ -272,6 +272,10 @@ | |||
""" | |||
__all__ = [ | |||
"DataSet" | |||
] | |||
import _pickle as pickle | |||
import warnings | |||
@@ -282,10 +286,6 @@ from .field import FieldArray | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
__all__ = [ | |||
"DataSet" | |||
] | |||
class DataSet(object): | |||
""" | |||
@@ -3,10 +3,6 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas | |||
原理部分请参考 :doc:`fastNLP.core.dataset` | |||
""" | |||
from copy import deepcopy | |||
import numpy as np | |||
__all__ = [ | |||
"FieldArray", | |||
"Padder", | |||
@@ -14,6 +10,10 @@ __all__ = [ | |||
"EngChar2DPadder" | |||
] | |||
from copy import deepcopy | |||
import numpy as np | |||
class FieldArray(object): | |||
""" | |||
@@ -2,6 +2,18 @@ | |||
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
""" | |||
__all__ = [ | |||
"LossBase", | |||
"LossFunc", | |||
"LossInForward", | |||
"CrossEntropyLoss", | |||
"BCELoss", | |||
"L1Loss", | |||
"NLLLoss" | |||
] | |||
import inspect | |||
from collections import defaultdict | |||
@@ -15,18 +27,6 @@ from .utils import _check_arg_dict_list | |||
from .utils import _check_function_or_method | |||
from .utils import _get_func_signature | |||
__all__ = [ | |||
"LossBase", | |||
"LossFunc", | |||
"LossInForward", | |||
"CrossEntropyLoss", | |||
"BCELoss", | |||
"L1Loss", | |||
"NLLLoss" | |||
] | |||
class LossBase(object): | |||
""" | |||
@@ -2,6 +2,13 @@ | |||
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
""" | |||
__all__ = [ | |||
"MetricBase", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"SQuADMetric" | |||
] | |||
import inspect | |||
from collections import defaultdict | |||
@@ -16,13 +23,6 @@ from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
__all__ = [ | |||
"MetricBase", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"SQuADMetric" | |||
] | |||
class MetricBase(object): | |||
""" | |||
@@ -2,14 +2,14 @@ | |||
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | |||
""" | |||
import torch | |||
__all__ = [ | |||
"Optimizer", | |||
"SGD", | |||
"Adam" | |||
] | |||
import torch | |||
class Optimizer(object): | |||
""" | |||
@@ -1,10 +1,6 @@ | |||
""" | |||
sampler 子类实现了 fastNLP 所需的各种采样器。 | |||
""" | |||
from itertools import chain | |||
import numpy as np | |||
__all__ = [ | |||
"Sampler", | |||
"BucketSampler", | |||
@@ -12,6 +8,10 @@ __all__ = [ | |||
"RandomSampler" | |||
] | |||
from itertools import chain | |||
import numpy as np | |||
class Sampler(object): | |||
""" | |||
@@ -295,6 +295,9 @@ Example2.3 | |||
fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。 | |||
""" | |||
__all__ = [ | |||
"Trainer" | |||
] | |||
import os | |||
import time | |||
@@ -1,6 +1,11 @@ | |||
""" | |||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||
""" | |||
__all__ = [ | |||
"cache_results", | |||
"seq_len_to_mask" | |||
] | |||
import _pickle | |||
import inspect | |||
import os | |||
@@ -11,10 +16,6 @@ import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
__all__ = [ | |||
"cache_results", | |||
"seq_len_to_mask" | |||
] | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
@@ -1,12 +1,12 @@ | |||
__all__ = [ | |||
"Vocabulary" | |||
] | |||
from functools import wraps | |||
from collections import Counter | |||
from .dataset import DataSet | |||
__all__ = [ | |||
"Vocabulary" | |||
] | |||
def _check_build_vocab(func): | |||
"""A decorator to make sure the indexing is built before used. | |||
@@ -322,7 +322,7 @@ class Vocabulary(object): | |||
:return str word: the word | |||
""" | |||
return self.idx2word[idx] | |||
def clear(self): | |||
""" | |||
删除Vocabulary中的词表数据。相当于重新初始化一下。 | |||
@@ -333,7 +333,7 @@ class Vocabulary(object): | |||
self.word2idx = None | |||
self.idx2word = None | |||
self.rebuild = True | |||
def __getstate__(self): | |||
"""Use to prepare data for pickle. | |||
@@ -9,11 +9,6 @@ | |||
这些类的使用方法如下: | |||
""" | |||
from .embed_loader import EmbedLoader | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||
PeopleDailyCorpusLoader, Conll2003Loader | |||
from .model_io import ModelLoader, ModelSaver | |||
__all__ = [ | |||
'EmbedLoader', | |||
@@ -29,3 +24,8 @@ __all__ = [ | |||
'ModelLoader', | |||
'ModelSaver', | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||
PeopleDailyCorpusLoader, Conll2003Loader | |||
from .model_io import ModelLoader, ModelSaver |
@@ -1,10 +1,10 @@ | |||
import _pickle as pickle | |||
import os | |||
__all__ = [ | |||
"BaseLoader" | |||
] | |||
import _pickle as pickle | |||
import os | |||
class BaseLoader(object): | |||
""" | |||
@@ -3,18 +3,18 @@ | |||
.. todo:: | |||
这个模块中的类可能被抛弃? | |||
""" | |||
import configparser | |||
import json | |||
import os | |||
from .base_loader import BaseLoader | |||
__all__ = [ | |||
"ConfigLoader", | |||
"ConfigSection", | |||
"ConfigSaver" | |||
] | |||
import configparser | |||
import json | |||
import os | |||
from .base_loader import BaseLoader | |||
class ConfigLoader(BaseLoader): | |||
""" | |||
@@ -10,12 +10,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||
# ... do stuff | |||
""" | |||
from nltk.tree import Tree | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from .file_reader import _read_csv, _read_json, _read_conll | |||
__all__ = [ | |||
'DataSetLoader', | |||
'CSVLoader', | |||
@@ -27,6 +21,12 @@ __all__ = [ | |||
'Conll2003Loader', | |||
] | |||
from nltk.tree import Tree | |||
from ..core.dataset import DataSet | |||
from ..core.instance import Instance | |||
from .file_reader import _read_csv, _read_json, _read_conll | |||
def _download_from_url(url, path): | |||
try: | |||
@@ -1,3 +1,7 @@ | |||
__all__ = [ | |||
"EmbedLoader" | |||
] | |||
import os | |||
import warnings | |||
@@ -6,10 +10,6 @@ import numpy as np | |||
from ..core.vocabulary import Vocabulary | |||
from .base_loader import BaseLoader | |||
__all__ = [ | |||
"EmbedLoader" | |||
] | |||
class EmbedLoader(BaseLoader): | |||
""" | |||
@@ -1,15 +1,15 @@ | |||
""" | |||
用于载入和保存模型 | |||
""" | |||
import torch | |||
from .base_loader import BaseLoader | |||
__all__ = [ | |||
"ModelLoader", | |||
"ModelSaver" | |||
] | |||
import torch | |||
from .base_loader import BaseLoader | |||
class ModelLoader(BaseLoader): | |||
""" | |||
@@ -7,15 +7,6 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||
""" | |||
from .base_model import BaseModel | |||
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | |||
BertForTokenClassification | |||
from .biaffine_parser import BiaffineParser, GraphParser | |||
from .cnn_text_classification import CNNText | |||
from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||
from .snli import ESIM | |||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||
__all__ = [ | |||
"CNNText", | |||
@@ -32,3 +23,12 @@ __all__ = [ | |||
"BiaffineParser", | |||
"GraphParser" | |||
] | |||
from .base_model import BaseModel | |||
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | |||
BertForTokenClassification | |||
from .biaffine_parser import BiaffineParser, GraphParser | |||
from .cnn_text_classification import CNNText | |||
from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||
from .snli import ESIM | |||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel |
@@ -1,6 +1,11 @@ | |||
""" | |||
Biaffine Dependency Parser 的 Pytorch 实现. | |||
""" | |||
__all__ = [ | |||
"BiaffineParser", | |||
"GraphParser" | |||
] | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
@@ -19,11 +24,6 @@ from ..modules.utils import get_embeddings | |||
from .base_model import BaseModel | |||
from ..core.utils import seq_len_to_mask | |||
__all__ = [ | |||
"BiaffineParser", | |||
"GraphParser" | |||
] | |||
def _mst(scores): | |||
""" | |||
@@ -1,13 +1,13 @@ | |||
__all__ = [ | |||
"CNNText" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
from ..core.const import Const as C | |||
from ..modules import encoder | |||
__all__ = [ | |||
"CNNText" | |||
] | |||
class CNNText(torch.nn.Module): | |||
""" | |||
@@ -1,6 +1,5 @@ | |||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | |||
from __future__ import print_function | |||
from collections import defaultdict | |||
import collections | |||
@@ -1,6 +1,11 @@ | |||
""" | |||
本模块实现了两种序列标注模型 | |||
""" | |||
__all__ = [ | |||
"SeqLabeling", | |||
"AdvSeqLabel" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
@@ -10,11 +15,6 @@ from ..modules.decoder.crf import allowed_transitions | |||
from ..core.utils import seq_len_to_mask | |||
from ..core.const import Const as C | |||
__all__ = [ | |||
"SeqLabeling", | |||
"AdvSeqLabel" | |||
] | |||
class SeqLabeling(BaseModel): | |||
""" | |||
@@ -1,3 +1,7 @@ | |||
__all__ = [ | |||
"ESIM" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
@@ -8,10 +12,6 @@ from ..modules import encoder as Encoder | |||
from ..modules import aggregator as Aggregator | |||
from ..core.utils import seq_len_to_mask | |||
__all__ = [ | |||
"ESIM" | |||
] | |||
my_inf = 10e12 | |||
@@ -1,6 +1,13 @@ | |||
""" | |||
Star-Transformer 的 Pytorch 实现。 | |||
""" | |||
__all__ = [ | |||
"StarTransEnc", | |||
"STNLICls", | |||
"STSeqCls", | |||
"STSeqLabel", | |||
] | |||
import torch | |||
from torch import nn | |||
@@ -9,13 +16,6 @@ from ..core.utils import seq_len_to_mask | |||
from ..modules.utils import get_embeddings | |||
from ..core.const import Const | |||
__all__ = [ | |||
"StarTransEnc", | |||
"STNLICls", | |||
"STSeqCls", | |||
"STSeqLabel", | |||
] | |||
class StarTransEnc(nn.Module): | |||
""" | |||
@@ -22,15 +22,6 @@ | |||
+-----------------------+-----------------------+-----------------------+ | |||
""" | |||
from . import aggregator | |||
from . import decoder | |||
from . import encoder | |||
from .aggregator import * | |||
from .decoder import * | |||
from .dropout import TimestepDropout | |||
from .encoder import * | |||
from .utils import get_embeddings | |||
__all__ = [ | |||
# "BertModel", | |||
"ConvolutionCharEncoder", | |||
@@ -54,3 +45,12 @@ __all__ = [ | |||
"viterbi_decode", | |||
"allowed_transitions", | |||
] | |||
from . import aggregator | |||
from . import decoder | |||
from . import encoder | |||
from .aggregator import * | |||
from .decoder import * | |||
from .dropout import TimestepDropout | |||
from .encoder import * | |||
from .utils import get_embeddings |
@@ -1,10 +1,3 @@ | |||
from .pooling import MaxPool | |||
from .pooling import MaxPoolWithMask | |||
from .pooling import AvgPool | |||
from .pooling import AvgPoolWithMask | |||
from .attention import MultiHeadAttention | |||
__all__ = [ | |||
"MaxPool", | |||
"MaxPoolWithMask", | |||
@@ -12,3 +5,10 @@ __all__ = [ | |||
"MultiHeadAttention", | |||
] | |||
from .pooling import MaxPool | |||
from .pooling import MaxPoolWithMask | |||
from .pooling import AvgPool | |||
from .pooling import AvgPoolWithMask | |||
from .attention import MultiHeadAttention |
@@ -1,3 +1,7 @@ | |||
__all__ = [ | |||
"MultiHeadAttention" | |||
] | |||
import math | |||
import torch | |||
@@ -8,10 +12,6 @@ from ..dropout import TimestepDropout | |||
from ..utils import initial_parameter | |||
__all__ = [ | |||
"MultiHeadAttention" | |||
] | |||
class DotAttention(nn.Module): | |||
""" | |||
@@ -1,4 +1,8 @@ | |||
__all__ = ["MaxPool", "MaxPoolWithMask", "AvgPool"] | |||
__all__ = [ | |||
"MaxPool", | |||
"MaxPoolWithMask", | |||
"AvgPool" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
@@ -16,6 +20,7 @@ class MaxPool(nn.Module): | |||
:param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension | |||
:param ceil_mode: | |||
""" | |||
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | |||
super(MaxPool, self).__init__() | |||
@@ -125,7 +130,7 @@ class AvgPoolWithMask(nn.Module): | |||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | |||
的时候只会考虑mask为1的位置 | |||
""" | |||
def __init__(self): | |||
super(AvgPoolWithMask, self).__init__() | |||
self.inf = 10e12 | |||
@@ -1,11 +1,11 @@ | |||
from .crf import ConditionalRandomField | |||
from .mlp import MLP | |||
from .utils import viterbi_decode | |||
from .crf import allowed_transitions | |||
__all__ = [ | |||
"MLP", | |||
"ConditionalRandomField", | |||
"viterbi_decode", | |||
"allowed_transitions" | |||
] | |||
from .crf import ConditionalRandomField | |||
from .mlp import MLP | |||
from .utils import viterbi_decode | |||
from .crf import allowed_transitions |
@@ -1,13 +1,13 @@ | |||
import torch | |||
from torch import nn | |||
from ..utils import initial_parameter | |||
__all__ = [ | |||
"ConditionalRandomField", | |||
"allowed_transitions" | |||
] | |||
import torch | |||
from torch import nn | |||
from ..utils import initial_parameter | |||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | |||
""" | |||
@@ -1,12 +1,12 @@ | |||
__all__ = [ | |||
"MLP" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
from ..utils import initial_parameter | |||
__all__ = [ | |||
"MLP" | |||
] | |||
class MLP(nn.Module): | |||
""" | |||
@@ -1,8 +1,7 @@ | |||
import torch | |||
__all__ = [ | |||
"viterbi_decode" | |||
] | |||
import torch | |||
def viterbi_decode(logits, transitions, mask=None, unpad=False): | |||
@@ -1,6 +1,8 @@ | |||
import torch | |||
__all__ = [] | |||
import torch | |||
class TimestepDropout(torch.nn.Dropout): | |||
""" | |||
别名::class:`fastNLP.modules.TimestepDropout` | |||
@@ -8,7 +10,7 @@ class TimestepDropout(torch.nn.Dropout): | |||
接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``) | |||
在每个timestamp上做dropout。 | |||
""" | |||
def forward(self, x): | |||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | |||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | |||
@@ -1,12 +1,3 @@ | |||
from .bert import BertModel | |||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
from .conv_maxpool import ConvMaxpool | |||
from .embedding import Embedding | |||
from .lstm import LSTM | |||
from .star_transformer import StarTransformer | |||
from .transformer import TransformerEncoder | |||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||
__all__ = [ | |||
# "BertModel", | |||
@@ -27,3 +18,11 @@ __all__ = [ | |||
"VarLSTM", | |||
"VarGRU" | |||
] | |||
from .bert import BertModel | |||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||
from .conv_maxpool import ConvMaxpool | |||
from .embedding import Embedding | |||
from .lstm import LSTM | |||
from .star_transformer import StarTransformer | |||
from .transformer import TransformerEncoder | |||
from .variational_rnn import VarRNN, VarLSTM, VarGRU |
@@ -1,12 +1,11 @@ | |||
import torch | |||
import torch.nn as nn | |||
from ..utils import initial_parameter | |||
__all__ = [ | |||
"ConvolutionCharEncoder", | |||
"LSTMCharEncoder" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
from ..utils import initial_parameter | |||
# from torch.nn.init import xavier_uniform | |||
@@ -1,13 +1,12 @@ | |||
__all__ = [ | |||
"ConvMaxpool" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from ..utils import initial_parameter | |||
__all__ = [ | |||
"ConvMaxpool" | |||
] | |||
class ConvMaxpool(nn.Module): | |||
""" | |||
@@ -1,9 +1,8 @@ | |||
import torch.nn as nn | |||
from ..utils import get_embeddings | |||
__all__ = [ | |||
"Embedding" | |||
] | |||
import torch.nn as nn | |||
from ..utils import get_embeddings | |||
class Embedding(nn.Embedding): | |||
@@ -2,16 +2,16 @@ | |||
轻量封装的 Pytorch LSTM 模块. | |||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | |||
""" | |||
__all__ = [ | |||
"LSTM" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.utils.rnn as rnn | |||
from ..utils import initial_parameter | |||
__all__ = [ | |||
"LSTM" | |||
] | |||
class LSTM(nn.Module): | |||
""" | |||
@@ -1,15 +1,15 @@ | |||
""" | |||
Star-Transformer 的encoder部分的 Pytorch 实现 | |||
""" | |||
__all__ = [ | |||
"StarTransformer" | |||
] | |||
import numpy as NP | |||
import torch | |||
from torch import nn | |||
from torch.nn import functional as F | |||
__all__ = [ | |||
"StarTransformer" | |||
] | |||
class StarTransformer(nn.Module): | |||
""" | |||
@@ -1,12 +1,11 @@ | |||
__all__ = [ | |||
"TransformerEncoder" | |||
] | |||
from torch import nn | |||
from ..aggregator.attention import MultiHeadAttention | |||
from ..dropout import TimestepDropout | |||
__all__ = [ | |||
"TransformerEncoder" | |||
] | |||
class TransformerEncoder(nn.Module): | |||
""" | |||
@@ -1,6 +1,12 @@ | |||
""" | |||
Variational RNN 的 Pytorch 实现 | |||
""" | |||
__all__ = [ | |||
"VarRNN", | |||
"VarLSTM", | |||
"VarGRU" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |||
@@ -17,25 +23,19 @@ except ImportError: | |||
from ..utils import initial_parameter | |||
__all__ = [ | |||
"VarRNN", | |||
"VarLSTM", | |||
"VarGRU" | |||
] | |||
class VarRnnCellWrapper(nn.Module): | |||
""" | |||
Wrapper for normal RNN Cells, make it support variational dropout | |||
""" | |||
def __init__(self, cell, hidden_size, input_p, hidden_p): | |||
super(VarRnnCellWrapper, self).__init__() | |||
self.cell = cell | |||
self.hidden_size = hidden_size | |||
self.input_p = input_p | |||
self.hidden_p = hidden_p | |||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | |||
""" | |||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | |||
@@ -47,13 +47,13 @@ class VarRnnCellWrapper(nn.Module): | |||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||
for other RNN, h_n, [batch_size, hidden_size] | |||
""" | |||
def get_hi(hi, h0, size): | |||
h0_size = size - hi.size(0) | |||
if h0_size > 0: | |||
return torch.cat([hi, h0[:h0_size]], dim=0) | |||
return hi[:size] | |||
is_lstm = isinstance(hidden, tuple) | |||
input, batch_sizes = input_x.data, input_x.batch_sizes | |||
output = [] | |||
@@ -64,7 +64,7 @@ class VarRnnCellWrapper(nn.Module): | |||
else: | |||
batch_iter = batch_sizes | |||
idx = 0 | |||
if is_lstm: | |||
hn = (hidden[0].clone(), hidden[1].clone()) | |||
else: | |||
@@ -91,7 +91,7 @@ class VarRnnCellWrapper(nn.Module): | |||
hi = cell(input_i, hi) | |||
hn[:size] = hi | |||
output.append(hi) | |||
if is_reversed: | |||
output = list(reversed(output)) | |||
output = torch.cat(output, dim=0) | |||
@@ -117,7 +117,7 @@ class VarRNNBase(nn.Module): | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
""" | |||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | |||
bias=True, batch_first=False, | |||
input_dropout=0, hidden_dropout=0, bidirectional=False): | |||
@@ -141,7 +141,7 @@ class VarRNNBase(nn.Module): | |||
cell, self.hidden_size, input_dropout, hidden_dropout)) | |||
initial_parameter(self) | |||
self.is_lstm = (self.mode == "LSTM") | |||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | |||
is_lstm = self.is_lstm | |||
idx = self.num_directions * n_layer + n_direction | |||
@@ -150,7 +150,7 @@ class VarRNNBase(nn.Module): | |||
output_x, hidden_x = cell( | |||
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | |||
return output_x, hidden_x | |||
def forward(self, x, hx=None): | |||
""" | |||
@@ -170,13 +170,13 @@ class VarRNNBase(nn.Module): | |||
else: | |||
max_batch_size = int(x.batch_sizes[0]) | |||
x, batch_sizes = x.data, x.batch_sizes | |||
if hx is None: | |||
hx = x.new_zeros(self.num_layers * self.num_directions, | |||
max_batch_size, self.hidden_size, requires_grad=True) | |||
if is_lstm: | |||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | |||
mask_x = x.new_ones((max_batch_size, self.input_size)) | |||
mask_out = x.new_ones( | |||
(max_batch_size, self.hidden_size * self.num_directions)) | |||
@@ -185,7 +185,7 @@ class VarRNNBase(nn.Module): | |||
training=self.training, inplace=True) | |||
nn.functional.dropout(mask_out, p=self.hidden_dropout, | |||
training=self.training, inplace=True) | |||
hidden = x.new_zeros( | |||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||
if is_lstm: | |||
@@ -207,16 +207,16 @@ class VarRNNBase(nn.Module): | |||
else: | |||
hidden[idx] = hidden_x | |||
x = torch.cat(output_list, dim=-1) | |||
if is_lstm: | |||
hidden = (hidden, cellstate) | |||
if is_packed: | |||
output = PackedSequence(x, batch_sizes) | |||
else: | |||
x = PackedSequence(x, batch_sizes) | |||
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | |||
return output, hidden | |||
@@ -236,11 +236,11 @@ class VarLSTM(VarRNNBase): | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(VarLSTM, self).__init__( | |||
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
return super(VarLSTM, self).forward(x, hx) | |||
@@ -261,11 +261,11 @@ class VarRNN(VarRNNBase): | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(VarRNN, self).__init__( | |||
mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
return super(VarRNN, self).forward(x, hx) | |||
@@ -286,10 +286,10 @@ class VarGRU(VarRNNBase): | |||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super(VarGRU, self).__init__( | |||
mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | |||
def forward(self, x, hx=None): | |||
return super(VarGRU, self).forward(x, hx) |
@@ -1,5 +1,5 @@ | |||
from functools import reduce | |||
from collections import OrderedDict | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||