@@ -52,8 +52,8 @@ __all__ = [ | |||||
"cache_results" | "cache_results" | ||||
] | ] | ||||
__version__ = '0.4.0' | |||||
from .core import * | from .core import * | ||||
from . import models | from . import models | ||||
from . import modules | from . import modules | ||||
__version__ = '0.4.0' |
@@ -2,6 +2,10 @@ | |||||
batch 模块实现了 fastNLP 所需的 Batch 类。 | batch 模块实现了 fastNLP 所需的 Batch 类。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"Batch" | |||||
] | |||||
import atexit | import atexit | ||||
from queue import Empty, Full | from queue import Empty, Full | ||||
@@ -11,10 +15,6 @@ import torch.multiprocessing as mp | |||||
from .sampler import RandomSampler | from .sampler import RandomSampler | ||||
__all__ = [ | |||||
"Batch" | |||||
] | |||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -49,6 +49,18 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: | |||||
trainer.train() | trainer.train() | ||||
""" | """ | ||||
__all__ = [ | |||||
"Callback", | |||||
"GradientClipCallback", | |||||
"EarlyStopCallback", | |||||
"TensorboardCallback", | |||||
"LRScheduler", | |||||
"ControlC", | |||||
"CallbackException", | |||||
"EarlyStopError" | |||||
] | |||||
import os | import os | ||||
import torch | import torch | ||||
@@ -62,18 +74,6 @@ except: | |||||
from ..io.model_io import ModelSaver, ModelLoader | from ..io.model_io import ModelSaver, ModelLoader | ||||
__all__ = [ | |||||
"Callback", | |||||
"GradientClipCallback", | |||||
"EarlyStopCallback", | |||||
"TensorboardCallback", | |||||
"LRScheduler", | |||||
"ControlC", | |||||
"CallbackException", | |||||
"EarlyStopError" | |||||
] | |||||
class Callback(object): | class Callback(object): | ||||
""" | """ | ||||
@@ -272,6 +272,10 @@ | |||||
""" | """ | ||||
__all__ = [ | |||||
"DataSet" | |||||
] | |||||
import _pickle as pickle | import _pickle as pickle | ||||
import warnings | import warnings | ||||
@@ -282,10 +286,6 @@ from .field import FieldArray | |||||
from .instance import Instance | from .instance import Instance | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
__all__ = [ | |||||
"DataSet" | |||||
] | |||||
class DataSet(object): | class DataSet(object): | ||||
""" | """ | ||||
@@ -3,10 +3,6 @@ field模块实现了 FieldArray 和若干 Padder。 FieldArray 是 :class:`~fas | |||||
原理部分请参考 :doc:`fastNLP.core.dataset` | 原理部分请参考 :doc:`fastNLP.core.dataset` | ||||
""" | """ | ||||
from copy import deepcopy | |||||
import numpy as np | |||||
__all__ = [ | __all__ = [ | ||||
"FieldArray", | "FieldArray", | ||||
"Padder", | "Padder", | ||||
@@ -14,6 +10,10 @@ __all__ = [ | |||||
"EngChar2DPadder" | "EngChar2DPadder" | ||||
] | ] | ||||
from copy import deepcopy | |||||
import numpy as np | |||||
class FieldArray(object): | class FieldArray(object): | ||||
""" | """ | ||||
@@ -2,6 +2,18 @@ | |||||
losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | losses 模块定义了 fastNLP 中所需的各种损失函数,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"LossBase", | |||||
"LossFunc", | |||||
"LossInForward", | |||||
"CrossEntropyLoss", | |||||
"BCELoss", | |||||
"L1Loss", | |||||
"NLLLoss" | |||||
] | |||||
import inspect | import inspect | ||||
from collections import defaultdict | from collections import defaultdict | ||||
@@ -15,18 +27,6 @@ from .utils import _check_arg_dict_list | |||||
from .utils import _check_function_or_method | from .utils import _check_function_or_method | ||||
from .utils import _get_func_signature | from .utils import _get_func_signature | ||||
__all__ = [ | |||||
"LossBase", | |||||
"LossFunc", | |||||
"LossInForward", | |||||
"CrossEntropyLoss", | |||||
"BCELoss", | |||||
"L1Loss", | |||||
"NLLLoss" | |||||
] | |||||
class LossBase(object): | class LossBase(object): | ||||
""" | """ | ||||
@@ -2,6 +2,13 @@ | |||||
metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | metrics 模块实现了 fastNLP 所需的各种常用衡量指标,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"MetricBase", | |||||
"AccuracyMetric", | |||||
"SpanFPreRecMetric", | |||||
"SQuADMetric" | |||||
] | |||||
import inspect | import inspect | ||||
from collections import defaultdict | from collections import defaultdict | ||||
@@ -16,13 +23,6 @@ from .utils import _get_func_signature | |||||
from .utils import seq_len_to_mask | from .utils import seq_len_to_mask | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
__all__ = [ | |||||
"MetricBase", | |||||
"AccuracyMetric", | |||||
"SpanFPreRecMetric", | |||||
"SQuADMetric" | |||||
] | |||||
class MetricBase(object): | class MetricBase(object): | ||||
""" | """ | ||||
@@ -2,14 +2,14 @@ | |||||
optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :class:`~fastNLP.Trainer` 的参数使用。 | ||||
""" | """ | ||||
import torch | |||||
__all__ = [ | __all__ = [ | ||||
"Optimizer", | "Optimizer", | ||||
"SGD", | "SGD", | ||||
"Adam" | "Adam" | ||||
] | ] | ||||
import torch | |||||
class Optimizer(object): | class Optimizer(object): | ||||
""" | """ | ||||
@@ -1,10 +1,6 @@ | |||||
""" | """ | ||||
sampler 子类实现了 fastNLP 所需的各种采样器。 | sampler 子类实现了 fastNLP 所需的各种采样器。 | ||||
""" | """ | ||||
from itertools import chain | |||||
import numpy as np | |||||
__all__ = [ | __all__ = [ | ||||
"Sampler", | "Sampler", | ||||
"BucketSampler", | "BucketSampler", | ||||
@@ -12,6 +8,10 @@ __all__ = [ | |||||
"RandomSampler" | "RandomSampler" | ||||
] | ] | ||||
from itertools import chain | |||||
import numpy as np | |||||
class Sampler(object): | class Sampler(object): | ||||
""" | """ | ||||
@@ -295,6 +295,9 @@ Example2.3 | |||||
fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。 | fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"Trainer" | |||||
] | |||||
import os | import os | ||||
import time | import time | ||||
@@ -1,6 +1,11 @@ | |||||
""" | """ | ||||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"cache_results", | |||||
"seq_len_to_mask" | |||||
] | |||||
import _pickle | import _pickle | ||||
import inspect | import inspect | ||||
import os | import os | ||||
@@ -11,10 +16,6 @@ import numpy as np | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
__all__ = [ | |||||
"cache_results", | |||||
"seq_len_to_mask" | |||||
] | |||||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | ||||
'varargs']) | 'varargs']) | ||||
@@ -1,12 +1,12 @@ | |||||
__all__ = [ | |||||
"Vocabulary" | |||||
] | |||||
from functools import wraps | from functools import wraps | ||||
from collections import Counter | from collections import Counter | ||||
from .dataset import DataSet | from .dataset import DataSet | ||||
__all__ = [ | |||||
"Vocabulary" | |||||
] | |||||
def _check_build_vocab(func): | def _check_build_vocab(func): | ||||
"""A decorator to make sure the indexing is built before used. | """A decorator to make sure the indexing is built before used. | ||||
@@ -322,7 +322,7 @@ class Vocabulary(object): | |||||
:return str word: the word | :return str word: the word | ||||
""" | """ | ||||
return self.idx2word[idx] | return self.idx2word[idx] | ||||
def clear(self): | def clear(self): | ||||
""" | """ | ||||
删除Vocabulary中的词表数据。相当于重新初始化一下。 | 删除Vocabulary中的词表数据。相当于重新初始化一下。 | ||||
@@ -333,7 +333,7 @@ class Vocabulary(object): | |||||
self.word2idx = None | self.word2idx = None | ||||
self.idx2word = None | self.idx2word = None | ||||
self.rebuild = True | self.rebuild = True | ||||
def __getstate__(self): | def __getstate__(self): | ||||
"""Use to prepare data for pickle. | """Use to prepare data for pickle. | ||||
@@ -9,11 +9,6 @@ | |||||
这些类的使用方法如下: | 这些类的使用方法如下: | ||||
""" | """ | ||||
from .embed_loader import EmbedLoader | |||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||||
PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .model_io import ModelLoader, ModelSaver | |||||
__all__ = [ | __all__ = [ | ||||
'EmbedLoader', | 'EmbedLoader', | ||||
@@ -29,3 +24,8 @@ __all__ = [ | |||||
'ModelLoader', | 'ModelLoader', | ||||
'ModelSaver', | 'ModelSaver', | ||||
] | ] | ||||
from .embed_loader import EmbedLoader | |||||
from .dataset_loader import DataSetLoader, CSVLoader, JsonLoader, ConllLoader, SNLILoader, SSTLoader, \ | |||||
PeopleDailyCorpusLoader, Conll2003Loader | |||||
from .model_io import ModelLoader, ModelSaver |
@@ -1,10 +1,10 @@ | |||||
import _pickle as pickle | |||||
import os | |||||
__all__ = [ | __all__ = [ | ||||
"BaseLoader" | "BaseLoader" | ||||
] | ] | ||||
import _pickle as pickle | |||||
import os | |||||
class BaseLoader(object): | class BaseLoader(object): | ||||
""" | """ | ||||
@@ -3,18 +3,18 @@ | |||||
.. todo:: | .. todo:: | ||||
这个模块中的类可能被抛弃? | 这个模块中的类可能被抛弃? | ||||
""" | """ | ||||
import configparser | |||||
import json | |||||
import os | |||||
from .base_loader import BaseLoader | |||||
__all__ = [ | __all__ = [ | ||||
"ConfigLoader", | "ConfigLoader", | ||||
"ConfigSection", | "ConfigSection", | ||||
"ConfigSaver" | "ConfigSaver" | ||||
] | ] | ||||
import configparser | |||||
import json | |||||
import os | |||||
from .base_loader import BaseLoader | |||||
class ConfigLoader(BaseLoader): | class ConfigLoader(BaseLoader): | ||||
""" | """ | ||||
@@ -10,12 +10,6 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的 | |||||
# ... do stuff | # ... do stuff | ||||
""" | """ | ||||
from nltk.tree import Tree | |||||
from ..core.dataset import DataSet | |||||
from ..core.instance import Instance | |||||
from .file_reader import _read_csv, _read_json, _read_conll | |||||
__all__ = [ | __all__ = [ | ||||
'DataSetLoader', | 'DataSetLoader', | ||||
'CSVLoader', | 'CSVLoader', | ||||
@@ -27,6 +21,12 @@ __all__ = [ | |||||
'Conll2003Loader', | 'Conll2003Loader', | ||||
] | ] | ||||
from nltk.tree import Tree | |||||
from ..core.dataset import DataSet | |||||
from ..core.instance import Instance | |||||
from .file_reader import _read_csv, _read_json, _read_conll | |||||
def _download_from_url(url, path): | def _download_from_url(url, path): | ||||
try: | try: | ||||
@@ -1,3 +1,7 @@ | |||||
__all__ = [ | |||||
"EmbedLoader" | |||||
] | |||||
import os | import os | ||||
import warnings | import warnings | ||||
@@ -6,10 +10,6 @@ import numpy as np | |||||
from ..core.vocabulary import Vocabulary | from ..core.vocabulary import Vocabulary | ||||
from .base_loader import BaseLoader | from .base_loader import BaseLoader | ||||
__all__ = [ | |||||
"EmbedLoader" | |||||
] | |||||
class EmbedLoader(BaseLoader): | class EmbedLoader(BaseLoader): | ||||
""" | """ | ||||
@@ -1,15 +1,15 @@ | |||||
""" | """ | ||||
用于载入和保存模型 | 用于载入和保存模型 | ||||
""" | """ | ||||
import torch | |||||
from .base_loader import BaseLoader | |||||
__all__ = [ | __all__ = [ | ||||
"ModelLoader", | "ModelLoader", | ||||
"ModelSaver" | "ModelSaver" | ||||
] | ] | ||||
import torch | |||||
from .base_loader import BaseLoader | |||||
class ModelLoader(BaseLoader): | class ModelLoader(BaseLoader): | ||||
""" | """ | ||||
@@ -7,15 +7,6 @@ fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models | |||||
""" | """ | ||||
from .base_model import BaseModel | |||||
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | |||||
BertForTokenClassification | |||||
from .biaffine_parser import BiaffineParser, GraphParser | |||||
from .cnn_text_classification import CNNText | |||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||||
from .snli import ESIM | |||||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel | |||||
__all__ = [ | __all__ = [ | ||||
"CNNText", | "CNNText", | ||||
@@ -32,3 +23,12 @@ __all__ = [ | |||||
"BiaffineParser", | "BiaffineParser", | ||||
"GraphParser" | "GraphParser" | ||||
] | ] | ||||
from .base_model import BaseModel | |||||
from .bert import BertForMultipleChoice, BertForQuestionAnswering, BertForSequenceClassification, \ | |||||
BertForTokenClassification | |||||
from .biaffine_parser import BiaffineParser, GraphParser | |||||
from .cnn_text_classification import CNNText | |||||
from .sequence_labeling import SeqLabeling, AdvSeqLabel | |||||
from .snli import ESIM | |||||
from .star_transformer import StarTransEnc, STSeqCls, STNLICls, STSeqLabel |
@@ -1,6 +1,11 @@ | |||||
""" | """ | ||||
Biaffine Dependency Parser 的 Pytorch 实现. | Biaffine Dependency Parser 的 Pytorch 实现. | ||||
""" | """ | ||||
__all__ = [ | |||||
"BiaffineParser", | |||||
"GraphParser" | |||||
] | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -19,11 +24,6 @@ from ..modules.utils import get_embeddings | |||||
from .base_model import BaseModel | from .base_model import BaseModel | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
__all__ = [ | |||||
"BiaffineParser", | |||||
"GraphParser" | |||||
] | |||||
def _mst(scores): | def _mst(scores): | ||||
""" | """ | ||||
@@ -1,13 +1,13 @@ | |||||
__all__ = [ | |||||
"CNNText" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
from ..modules import encoder | from ..modules import encoder | ||||
__all__ = [ | |||||
"CNNText" | |||||
] | |||||
class CNNText(torch.nn.Module): | class CNNText(torch.nn.Module): | ||||
""" | """ | ||||
@@ -1,6 +1,5 @@ | |||||
# Code Modified from https://github.com/carpedm20/ENAS-pytorch | # Code Modified from https://github.com/carpedm20/ENAS-pytorch | ||||
from __future__ import print_function | |||||
from collections import defaultdict | from collections import defaultdict | ||||
import collections | import collections | ||||
@@ -1,6 +1,11 @@ | |||||
""" | """ | ||||
本模块实现了两种序列标注模型 | 本模块实现了两种序列标注模型 | ||||
""" | """ | ||||
__all__ = [ | |||||
"SeqLabeling", | |||||
"AdvSeqLabel" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -10,11 +15,6 @@ from ..modules.decoder.crf import allowed_transitions | |||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
from ..core.const import Const as C | from ..core.const import Const as C | ||||
__all__ = [ | |||||
"SeqLabeling", | |||||
"AdvSeqLabel" | |||||
] | |||||
class SeqLabeling(BaseModel): | class SeqLabeling(BaseModel): | ||||
""" | """ | ||||
@@ -1,3 +1,7 @@ | |||||
__all__ = [ | |||||
"ESIM" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -8,10 +12,6 @@ from ..modules import encoder as Encoder | |||||
from ..modules import aggregator as Aggregator | from ..modules import aggregator as Aggregator | ||||
from ..core.utils import seq_len_to_mask | from ..core.utils import seq_len_to_mask | ||||
__all__ = [ | |||||
"ESIM" | |||||
] | |||||
my_inf = 10e12 | my_inf = 10e12 | ||||
@@ -1,6 +1,13 @@ | |||||
""" | """ | ||||
Star-Transformer 的 Pytorch 实现。 | Star-Transformer 的 Pytorch 实现。 | ||||
""" | """ | ||||
__all__ = [ | |||||
"StarTransEnc", | |||||
"STNLICls", | |||||
"STSeqCls", | |||||
"STSeqLabel", | |||||
] | |||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
@@ -9,13 +16,6 @@ from ..core.utils import seq_len_to_mask | |||||
from ..modules.utils import get_embeddings | from ..modules.utils import get_embeddings | ||||
from ..core.const import Const | from ..core.const import Const | ||||
__all__ = [ | |||||
"StarTransEnc", | |||||
"STNLICls", | |||||
"STSeqCls", | |||||
"STSeqLabel", | |||||
] | |||||
class StarTransEnc(nn.Module): | class StarTransEnc(nn.Module): | ||||
""" | """ | ||||
@@ -22,15 +22,6 @@ | |||||
+-----------------------+-----------------------+-----------------------+ | +-----------------------+-----------------------+-----------------------+ | ||||
""" | """ | ||||
from . import aggregator | |||||
from . import decoder | |||||
from . import encoder | |||||
from .aggregator import * | |||||
from .decoder import * | |||||
from .dropout import TimestepDropout | |||||
from .encoder import * | |||||
from .utils import get_embeddings | |||||
__all__ = [ | __all__ = [ | ||||
# "BertModel", | # "BertModel", | ||||
"ConvolutionCharEncoder", | "ConvolutionCharEncoder", | ||||
@@ -54,3 +45,12 @@ __all__ = [ | |||||
"viterbi_decode", | "viterbi_decode", | ||||
"allowed_transitions", | "allowed_transitions", | ||||
] | ] | ||||
from . import aggregator | |||||
from . import decoder | |||||
from . import encoder | |||||
from .aggregator import * | |||||
from .decoder import * | |||||
from .dropout import TimestepDropout | |||||
from .encoder import * | |||||
from .utils import get_embeddings |
@@ -1,10 +1,3 @@ | |||||
from .pooling import MaxPool | |||||
from .pooling import MaxPoolWithMask | |||||
from .pooling import AvgPool | |||||
from .pooling import AvgPoolWithMask | |||||
from .attention import MultiHeadAttention | |||||
__all__ = [ | __all__ = [ | ||||
"MaxPool", | "MaxPool", | ||||
"MaxPoolWithMask", | "MaxPoolWithMask", | ||||
@@ -12,3 +5,10 @@ __all__ = [ | |||||
"MultiHeadAttention", | "MultiHeadAttention", | ||||
] | ] | ||||
from .pooling import MaxPool | |||||
from .pooling import MaxPoolWithMask | |||||
from .pooling import AvgPool | |||||
from .pooling import AvgPoolWithMask | |||||
from .attention import MultiHeadAttention |
@@ -1,3 +1,7 @@ | |||||
__all__ = [ | |||||
"MultiHeadAttention" | |||||
] | |||||
import math | import math | ||||
import torch | import torch | ||||
@@ -8,10 +12,6 @@ from ..dropout import TimestepDropout | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"MultiHeadAttention" | |||||
] | |||||
class DotAttention(nn.Module): | class DotAttention(nn.Module): | ||||
""" | """ | ||||
@@ -1,4 +1,8 @@ | |||||
__all__ = ["MaxPool", "MaxPoolWithMask", "AvgPool"] | |||||
__all__ = [ | |||||
"MaxPool", | |||||
"MaxPoolWithMask", | |||||
"AvgPool" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
@@ -16,6 +20,7 @@ class MaxPool(nn.Module): | |||||
:param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension | :param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension | ||||
:param ceil_mode: | :param ceil_mode: | ||||
""" | """ | ||||
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False): | ||||
super(MaxPool, self).__init__() | super(MaxPool, self).__init__() | ||||
@@ -125,7 +130,7 @@ class AvgPoolWithMask(nn.Module): | |||||
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | 给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling | ||||
的时候只会考虑mask为1的位置 | 的时候只会考虑mask为1的位置 | ||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(AvgPoolWithMask, self).__init__() | super(AvgPoolWithMask, self).__init__() | ||||
self.inf = 10e12 | self.inf = 10e12 | ||||
@@ -1,11 +1,11 @@ | |||||
from .crf import ConditionalRandomField | |||||
from .mlp import MLP | |||||
from .utils import viterbi_decode | |||||
from .crf import allowed_transitions | |||||
__all__ = [ | __all__ = [ | ||||
"MLP", | "MLP", | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"viterbi_decode", | "viterbi_decode", | ||||
"allowed_transitions" | "allowed_transitions" | ||||
] | ] | ||||
from .crf import ConditionalRandomField | |||||
from .mlp import MLP | |||||
from .utils import viterbi_decode | |||||
from .crf import allowed_transitions |
@@ -1,13 +1,13 @@ | |||||
import torch | |||||
from torch import nn | |||||
from ..utils import initial_parameter | |||||
__all__ = [ | __all__ = [ | ||||
"ConditionalRandomField", | "ConditionalRandomField", | ||||
"allowed_transitions" | "allowed_transitions" | ||||
] | ] | ||||
import torch | |||||
from torch import nn | |||||
from ..utils import initial_parameter | |||||
def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | def allowed_transitions(id2target, encoding_type='bio', include_start_end=True): | ||||
""" | """ | ||||
@@ -1,12 +1,12 @@ | |||||
__all__ = [ | |||||
"MLP" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"MLP" | |||||
] | |||||
class MLP(nn.Module): | class MLP(nn.Module): | ||||
""" | """ | ||||
@@ -1,8 +1,7 @@ | |||||
import torch | |||||
__all__ = [ | __all__ = [ | ||||
"viterbi_decode" | "viterbi_decode" | ||||
] | ] | ||||
import torch | |||||
def viterbi_decode(logits, transitions, mask=None, unpad=False): | def viterbi_decode(logits, transitions, mask=None, unpad=False): | ||||
@@ -1,6 +1,8 @@ | |||||
import torch | |||||
__all__ = [] | __all__ = [] | ||||
import torch | |||||
class TimestepDropout(torch.nn.Dropout): | class TimestepDropout(torch.nn.Dropout): | ||||
""" | """ | ||||
别名::class:`fastNLP.modules.TimestepDropout` | 别名::class:`fastNLP.modules.TimestepDropout` | ||||
@@ -8,7 +10,7 @@ class TimestepDropout(torch.nn.Dropout): | |||||
接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``) | 接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``) | ||||
在每个timestamp上做dropout。 | 在每个timestamp上做dropout。 | ||||
""" | """ | ||||
def forward(self, x): | def forward(self, x): | ||||
dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | dropout_mask = x.new_ones(x.shape[0], x.shape[-1]) | ||||
torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | torch.nn.functional.dropout(dropout_mask, self.p, self.training, inplace=True) | ||||
@@ -1,12 +1,3 @@ | |||||
from .bert import BertModel | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||||
from .conv_maxpool import ConvMaxpool | |||||
from .embedding import Embedding | |||||
from .lstm import LSTM | |||||
from .star_transformer import StarTransformer | |||||
from .transformer import TransformerEncoder | |||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU | |||||
__all__ = [ | __all__ = [ | ||||
# "BertModel", | # "BertModel", | ||||
@@ -27,3 +18,11 @@ __all__ = [ | |||||
"VarLSTM", | "VarLSTM", | ||||
"VarGRU" | "VarGRU" | ||||
] | ] | ||||
from .bert import BertModel | |||||
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder | |||||
from .conv_maxpool import ConvMaxpool | |||||
from .embedding import Embedding | |||||
from .lstm import LSTM | |||||
from .star_transformer import StarTransformer | |||||
from .transformer import TransformerEncoder | |||||
from .variational_rnn import VarRNN, VarLSTM, VarGRU |
@@ -1,12 +1,11 @@ | |||||
import torch | |||||
import torch.nn as nn | |||||
from ..utils import initial_parameter | |||||
__all__ = [ | __all__ = [ | ||||
"ConvolutionCharEncoder", | "ConvolutionCharEncoder", | ||||
"LSTMCharEncoder" | "LSTMCharEncoder" | ||||
] | ] | ||||
import torch | |||||
import torch.nn as nn | |||||
from ..utils import initial_parameter | |||||
# from torch.nn.init import xavier_uniform | # from torch.nn.init import xavier_uniform | ||||
@@ -1,13 +1,12 @@ | |||||
__all__ = [ | |||||
"ConvMaxpool" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"ConvMaxpool" | |||||
] | |||||
class ConvMaxpool(nn.Module): | class ConvMaxpool(nn.Module): | ||||
""" | """ | ||||
@@ -1,9 +1,8 @@ | |||||
import torch.nn as nn | |||||
from ..utils import get_embeddings | |||||
__all__ = [ | __all__ = [ | ||||
"Embedding" | "Embedding" | ||||
] | ] | ||||
import torch.nn as nn | |||||
from ..utils import get_embeddings | |||||
class Embedding(nn.Embedding): | class Embedding(nn.Embedding): | ||||
@@ -2,16 +2,16 @@ | |||||
轻量封装的 Pytorch LSTM 模块. | 轻量封装的 Pytorch LSTM 模块. | ||||
可在 forward 时传入序列的长度, 自动对padding做合适的处理. | 可在 forward 时传入序列的长度, 自动对padding做合适的处理. | ||||
""" | """ | ||||
__all__ = [ | |||||
"LSTM" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
import torch.nn.utils.rnn as rnn | import torch.nn.utils.rnn as rnn | ||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"LSTM" | |||||
] | |||||
class LSTM(nn.Module): | class LSTM(nn.Module): | ||||
""" | """ | ||||
@@ -1,15 +1,15 @@ | |||||
""" | """ | ||||
Star-Transformer 的encoder部分的 Pytorch 实现 | Star-Transformer 的encoder部分的 Pytorch 实现 | ||||
""" | """ | ||||
__all__ = [ | |||||
"StarTransformer" | |||||
] | |||||
import numpy as NP | import numpy as NP | ||||
import torch | import torch | ||||
from torch import nn | from torch import nn | ||||
from torch.nn import functional as F | from torch.nn import functional as F | ||||
__all__ = [ | |||||
"StarTransformer" | |||||
] | |||||
class StarTransformer(nn.Module): | class StarTransformer(nn.Module): | ||||
""" | """ | ||||
@@ -1,12 +1,11 @@ | |||||
__all__ = [ | |||||
"TransformerEncoder" | |||||
] | |||||
from torch import nn | from torch import nn | ||||
from ..aggregator.attention import MultiHeadAttention | from ..aggregator.attention import MultiHeadAttention | ||||
from ..dropout import TimestepDropout | from ..dropout import TimestepDropout | ||||
__all__ = [ | |||||
"TransformerEncoder" | |||||
] | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
""" | """ | ||||
@@ -1,6 +1,12 @@ | |||||
""" | """ | ||||
Variational RNN 的 Pytorch 实现 | Variational RNN 的 Pytorch 实现 | ||||
""" | """ | ||||
__all__ = [ | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | |||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | ||||
@@ -17,25 +23,19 @@ except ImportError: | |||||
from ..utils import initial_parameter | from ..utils import initial_parameter | ||||
__all__ = [ | |||||
"VarRNN", | |||||
"VarLSTM", | |||||
"VarGRU" | |||||
] | |||||
class VarRnnCellWrapper(nn.Module): | class VarRnnCellWrapper(nn.Module): | ||||
""" | """ | ||||
Wrapper for normal RNN Cells, make it support variational dropout | Wrapper for normal RNN Cells, make it support variational dropout | ||||
""" | """ | ||||
def __init__(self, cell, hidden_size, input_p, hidden_p): | def __init__(self, cell, hidden_size, input_p, hidden_p): | ||||
super(VarRnnCellWrapper, self).__init__() | super(VarRnnCellWrapper, self).__init__() | ||||
self.cell = cell | self.cell = cell | ||||
self.hidden_size = hidden_size | self.hidden_size = hidden_size | ||||
self.input_p = input_p | self.input_p = input_p | ||||
self.hidden_p = hidden_p | self.hidden_p = hidden_p | ||||
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False): | ||||
""" | """ | ||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | :param PackedSequence input_x: [seq_len, batch_size, input_size] | ||||
@@ -47,13 +47,13 @@ class VarRnnCellWrapper(nn.Module): | |||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | ||||
for other RNN, h_n, [batch_size, hidden_size] | for other RNN, h_n, [batch_size, hidden_size] | ||||
""" | """ | ||||
def get_hi(hi, h0, size): | def get_hi(hi, h0, size): | ||||
h0_size = size - hi.size(0) | h0_size = size - hi.size(0) | ||||
if h0_size > 0: | if h0_size > 0: | ||||
return torch.cat([hi, h0[:h0_size]], dim=0) | return torch.cat([hi, h0[:h0_size]], dim=0) | ||||
return hi[:size] | return hi[:size] | ||||
is_lstm = isinstance(hidden, tuple) | is_lstm = isinstance(hidden, tuple) | ||||
input, batch_sizes = input_x.data, input_x.batch_sizes | input, batch_sizes = input_x.data, input_x.batch_sizes | ||||
output = [] | output = [] | ||||
@@ -64,7 +64,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
else: | else: | ||||
batch_iter = batch_sizes | batch_iter = batch_sizes | ||||
idx = 0 | idx = 0 | ||||
if is_lstm: | if is_lstm: | ||||
hn = (hidden[0].clone(), hidden[1].clone()) | hn = (hidden[0].clone(), hidden[1].clone()) | ||||
else: | else: | ||||
@@ -91,7 +91,7 @@ class VarRnnCellWrapper(nn.Module): | |||||
hi = cell(input_i, hi) | hi = cell(input_i, hi) | ||||
hn[:size] = hi | hn[:size] = hi | ||||
output.append(hi) | output.append(hi) | ||||
if is_reversed: | if is_reversed: | ||||
output = list(reversed(output)) | output = list(reversed(output)) | ||||
output = torch.cat(output, dim=0) | output = torch.cat(output, dim=0) | ||||
@@ -117,7 +117,7 @@ class VarRNNBase(nn.Module): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | ||||
bias=True, batch_first=False, | bias=True, batch_first=False, | ||||
input_dropout=0, hidden_dropout=0, bidirectional=False): | input_dropout=0, hidden_dropout=0, bidirectional=False): | ||||
@@ -141,7 +141,7 @@ class VarRNNBase(nn.Module): | |||||
cell, self.hidden_size, input_dropout, hidden_dropout)) | cell, self.hidden_size, input_dropout, hidden_dropout)) | ||||
initial_parameter(self) | initial_parameter(self) | ||||
self.is_lstm = (self.mode == "LSTM") | self.is_lstm = (self.mode == "LSTM") | ||||
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h): | ||||
is_lstm = self.is_lstm | is_lstm = self.is_lstm | ||||
idx = self.num_directions * n_layer + n_direction | idx = self.num_directions * n_layer + n_direction | ||||
@@ -150,7 +150,7 @@ class VarRNNBase(nn.Module): | |||||
output_x, hidden_x = cell( | output_x, hidden_x = cell( | ||||
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | ||||
return output_x, hidden_x | return output_x, hidden_x | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
""" | """ | ||||
@@ -170,13 +170,13 @@ class VarRNNBase(nn.Module): | |||||
else: | else: | ||||
max_batch_size = int(x.batch_sizes[0]) | max_batch_size = int(x.batch_sizes[0]) | ||||
x, batch_sizes = x.data, x.batch_sizes | x, batch_sizes = x.data, x.batch_sizes | ||||
if hx is None: | if hx is None: | ||||
hx = x.new_zeros(self.num_layers * self.num_directions, | hx = x.new_zeros(self.num_layers * self.num_directions, | ||||
max_batch_size, self.hidden_size, requires_grad=True) | max_batch_size, self.hidden_size, requires_grad=True) | ||||
if is_lstm: | if is_lstm: | ||||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | ||||
mask_x = x.new_ones((max_batch_size, self.input_size)) | mask_x = x.new_ones((max_batch_size, self.input_size)) | ||||
mask_out = x.new_ones( | mask_out = x.new_ones( | ||||
(max_batch_size, self.hidden_size * self.num_directions)) | (max_batch_size, self.hidden_size * self.num_directions)) | ||||
@@ -185,7 +185,7 @@ class VarRNNBase(nn.Module): | |||||
training=self.training, inplace=True) | training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, | nn.functional.dropout(mask_out, p=self.hidden_dropout, | ||||
training=self.training, inplace=True) | training=self.training, inplace=True) | ||||
hidden = x.new_zeros( | hidden = x.new_zeros( | ||||
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | (self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | ||||
if is_lstm: | if is_lstm: | ||||
@@ -207,16 +207,16 @@ class VarRNNBase(nn.Module): | |||||
else: | else: | ||||
hidden[idx] = hidden_x | hidden[idx] = hidden_x | ||||
x = torch.cat(output_list, dim=-1) | x = torch.cat(output_list, dim=-1) | ||||
if is_lstm: | if is_lstm: | ||||
hidden = (hidden, cellstate) | hidden = (hidden, cellstate) | ||||
if is_packed: | if is_packed: | ||||
output = PackedSequence(x, batch_sizes) | output = PackedSequence(x, batch_sizes) | ||||
else: | else: | ||||
x = PackedSequence(x, batch_sizes) | x = PackedSequence(x, batch_sizes) | ||||
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | ||||
return output, hidden | return output, hidden | ||||
@@ -236,11 +236,11 @@ class VarLSTM(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarLSTM, self).__init__( | super(VarLSTM, self).__init__( | ||||
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs) | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarLSTM, self).forward(x, hx) | return super(VarLSTM, self).forward(x, hx) | ||||
@@ -261,11 +261,11 @@ class VarRNN(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarRNN, self).__init__( | super(VarRNN, self).__init__( | ||||
mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | mode="RNN", Cell=nn.RNNCell, *args, **kwargs) | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarRNN, self).forward(x, hx) | return super(VarRNN, self).forward(x, hx) | ||||
@@ -286,10 +286,10 @@ class VarGRU(VarRNNBase): | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | ||||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | ||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
super(VarGRU, self).__init__( | super(VarGRU, self).__init__( | ||||
mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | mode="GRU", Cell=nn.GRUCell, *args, **kwargs) | ||||
def forward(self, x, hx=None): | def forward(self, x, hx=None): | ||||
return super(VarGRU, self).forward(x, hx) | return super(VarGRU, self).forward(x, hx) |
@@ -1,5 +1,5 @@ | |||||
from functools import reduce | from functools import reduce | ||||
from collections import OrderedDict | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
import torch.nn as nn | import torch.nn as nn | ||||