@@ -47,7 +47,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.api.utils import load_url | from fastNLP.api.utils import load_url | ||||
from fastNLP.api.processor import ModelProcessor | from fastNLP.api.processor import ModelProcessor | ||||
from fastNLP.io.dataset_loader import cut_long_sentence, ConllLoader | |||||
from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader | |||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.api.pipeline import Pipeline | from fastNLP.api.pipeline import Pipeline | ||||
from fastNLP.core.metrics import SpanFPreRecMetric | from fastNLP.core.metrics import SpanFPreRecMetric | ||||
@@ -107,7 +107,7 @@ class ConllCWSReader(object): | |||||
continue | continue | ||||
line = ' '.join(res) | line = ' '.join(res) | ||||
if cut_long_sent: | if cut_long_sent: | ||||
sents = cut_long_sentence(line) | |||||
sents = _cut_long_sentence(line) | |||||
else: | else: | ||||
sents = [line] | sents = [line] | ||||
for raw_sentence in sents: | for raw_sentence in sents: | ||||
@@ -5,7 +5,7 @@ from .instance import Instance | |||||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | ||||
from .metrics import AccuracyMetric | from .metrics import AccuracyMetric | ||||
from .optimizer import Optimizer, SGD, Adam | from .optimizer import Optimizer, SGD, Adam | ||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler | |||||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||||
from .tester import Tester | from .tester import Tester | ||||
from .trainer import Trainer | from .trainer import Trainer | ||||
from .vocabulary import Vocabulary | from .vocabulary import Vocabulary | ||||
@@ -2,7 +2,7 @@ import numpy as np | |||||
import torch | import torch | ||||
import atexit | import atexit | ||||
from fastNLP.core.sampler import RandomSampler | |||||
from fastNLP.core.sampler import RandomSampler, Sampler | |||||
import torch.multiprocessing as mp | import torch.multiprocessing as mp | ||||
_python_is_exit = False | _python_is_exit = False | ||||
@@ -12,19 +12,25 @@ def _set_python_is_exit(): | |||||
atexit.register(_set_python_is_exit) | atexit.register(_set_python_is_exit) | ||||
class Batch(object): | class Batch(object): | ||||
"""Batch is an iterable object which iterates over mini-batches. | |||||
Example:: | |||||
for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): | |||||
# ... | |||||
:param DataSet dataset: a DataSet object | |||||
:param int batch_size: the size of the batch | |||||
:param Sampler sampler: a Sampler object. If None, use fastNLP.sampler.RandomSampler | |||||
:param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. | |||||
:param bool prefetch: If True, use multiprocessing to fetch next batch when training. | |||||
:param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. | |||||
""" | |||||
Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出. | |||||
组成 `x` 和 `y` | |||||
Example:: | |||||
batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) | |||||
num_batch = len(batch) | |||||
for batch_x, batch_y in batch: | |||||
# do stuff ... | |||||
:param DataSet dataset: `DataSet` 对象, 数据集 | |||||
:param int batch_size: 取出的batch大小 | |||||
:param Sampler sampler: 规定使用的 Sample 方式. 若为 ``None`` , 使用 RandomSampler. | |||||
Default: ``None`` | |||||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 torch.Tensor. | |||||
Default: ``False`` | |||||
:param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. | |||||
Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): | ||||
@@ -41,7 +47,7 @@ class Batch(object): | |||||
self.prefetch = prefetch | self.prefetch = prefetch | ||||
self.lengths = 0 | self.lengths = 0 | ||||
def fetch_one(self): | |||||
def _fetch_one(self): | |||||
if self.curidx >= len(self.idx_list): | if self.curidx >= len(self.idx_list): | ||||
return None | return None | ||||
else: | else: | ||||
@@ -55,7 +61,7 @@ class Batch(object): | |||||
if field.is_target or field.is_input: | if field.is_target or field.is_input: | ||||
batch = field.get(indices) | batch = field.get(indices) | ||||
if not self.as_numpy and field.padder is not None: | if not self.as_numpy and field.padder is not None: | ||||
batch = to_tensor(batch, field.dtype) | |||||
batch = _to_tensor(batch, field.dtype) | |||||
if field.is_target: | if field.is_target: | ||||
batch_y[field_name] = batch | batch_y[field_name] = batch | ||||
if field.is_input: | if field.is_input: | ||||
@@ -70,17 +76,17 @@ class Batch(object): | |||||
:return: | :return: | ||||
""" | """ | ||||
if self.prefetch: | if self.prefetch: | ||||
return run_batch_iter(self) | |||||
return _run_batch_iter(self) | |||||
def batch_iter(): | def batch_iter(): | ||||
self.init_iter() | |||||
self._init_iter() | |||||
while 1: | while 1: | ||||
res = self.fetch_one() | |||||
res = self._fetch_one() | |||||
if res is None: | if res is None: | ||||
break | break | ||||
yield res | yield res | ||||
return batch_iter() | return batch_iter() | ||||
def init_iter(self): | |||||
def _init_iter(self): | |||||
self.idx_list = self.sampler(self.dataset) | self.idx_list = self.sampler(self.dataset) | ||||
self.curidx = 0 | self.curidx = 0 | ||||
self.lengths = self.dataset.get_length() | self.lengths = self.dataset.get_length() | ||||
@@ -89,10 +95,14 @@ class Batch(object): | |||||
return self.num_batches | return self.num_batches | ||||
def get_batch_indices(self): | def get_batch_indices(self): | ||||
"""取得当前batch在DataSet中所在的index下标序列 | |||||
:return list(int) indexes: 下标序列 | |||||
""" | |||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def to_tensor(batch, dtype): | |||||
def _to_tensor(batch, dtype): | |||||
try: | try: | ||||
if dtype in (int, np.int8, np.int16, np.int32, np.int64): | if dtype in (int, np.int8, np.int16, np.int32, np.int64): | ||||
batch = torch.LongTensor(batch) | batch = torch.LongTensor(batch) | ||||
@@ -103,12 +113,12 @@ def to_tensor(batch, dtype): | |||||
return batch | return batch | ||||
def run_fetch(batch, q): | |||||
def _run_fetch(batch, q): | |||||
global _python_is_exit | global _python_is_exit | ||||
batch.init_iter() | |||||
batch._init_iter() | |||||
# print('start fetch') | # print('start fetch') | ||||
while 1: | while 1: | ||||
res = batch.fetch_one() | |||||
res = batch._fetch_one() | |||||
# print('fetch one') | # print('fetch one') | ||||
while 1: | while 1: | ||||
try: | try: | ||||
@@ -124,9 +134,9 @@ def run_fetch(batch, q): | |||||
# print('fetch exit') | # print('fetch exit') | ||||
def run_batch_iter(batch): | |||||
def _run_batch_iter(batch): | |||||
q = mp.JoinableQueue(maxsize=10) | q = mp.JoinableQueue(maxsize=10) | ||||
fetch_p = mp.Process(target=run_fetch, args=(batch, q)) | |||||
fetch_p = mp.Process(target=_run_fetch, args=(batch, q)) | |||||
fetch_p.daemon = True | fetch_p.daemon = True | ||||
fetch_p.start() | fetch_p.start() | ||||
# print('fork fetch process') | # print('fork fetch process') | ||||
@@ -482,7 +482,7 @@ class DataSet(object): | |||||
""" | """ | ||||
import warnings | import warnings | ||||
warnings.warn('read_csv is deprecated, use CSVLoader instead', | |||||
warnings.warn('DataSet.read_csv is deprecated, use CSVLoader instead', | |||||
category=DeprecationWarning) | category=DeprecationWarning) | ||||
with open(csv_path, "r", encoding='utf-8') as f: | with open(csv_path, "r", encoding='utf-8') as f: | ||||
start_idx = 0 | start_idx = 0 | ||||
@@ -3,72 +3,49 @@ from itertools import chain | |||||
import numpy as np | import numpy as np | ||||
import torch | import torch | ||||
class Sampler(object): | |||||
""" `Sampler` 类的基类. 规定以何种顺序取出data中的元素 | |||||
def convert_to_torch_tensor(data_list, use_cuda): | |||||
"""Convert lists into (cuda) Tensors. | |||||
:param data_list: 2-level lists | |||||
:param use_cuda: bool, whether to use GPU or not | |||||
:return data_list: PyTorch Tensor of shape [batch_size, max_seq_len] | |||||
""" | |||||
data_list = torch.Tensor(data_list).long() | |||||
if torch.cuda.is_available() and use_cuda: | |||||
data_list = data_list.cuda() | |||||
return data_list | |||||
class BaseSampler(object): | |||||
"""The base class of all samplers. | |||||
Sub-classes must implement the ``__call__`` method. | |||||
``__call__`` takes a DataSet object and returns a list of int - the sampling indices. | |||||
子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | |||||
""" | """ | ||||
def __call__(self, *args, **kwargs): | |||||
def __call__(self, data_set): | |||||
""" | |||||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | |||||
:return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | |||||
""" | |||||
raise NotImplementedError | raise NotImplementedError | ||||
class SequentialSampler(BaseSampler): | |||||
"""Sample data in the original order. | |||||
class SequentialSampler(Sampler): | |||||
"""顺序取出元素的 `Sampler` | |||||
""" | """ | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
""" | |||||
:param DataSet data_set: | |||||
:return result: a list of integers. | |||||
""" | |||||
return list(range(len(data_set))) | return list(range(len(data_set))) | ||||
class RandomSampler(BaseSampler): | |||||
"""Sample data in random permutation order. | |||||
class RandomSampler(Sampler): | |||||
"""随机化取元素的 `Sampler` | |||||
""" | """ | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
""" | |||||
:param DataSet data_set: | |||||
:return result: a list of integers. | |||||
""" | |||||
return list(np.random.permutation(len(data_set))) | return list(np.random.permutation(len(data_set))) | ||||
class BucketSampler(BaseSampler): | |||||
""" | |||||
:param int num_buckets: the number of buckets to use. | |||||
:param int batch_size: batch size per epoch. | |||||
:param str seq_lens_field_name: the field name indicating the field about sequence length. | |||||
class BucketSampler(Sampler): | |||||
"""带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | |||||
:param int num_buckets: bucket的数量 | |||||
:param int batch_size: batch的大小 | |||||
:param str seq_lens_field_name: 对应序列长度的 `field` 的名字 | |||||
""" | """ | ||||
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'): | |||||
def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_len'): | |||||
self.num_buckets = num_buckets | self.num_buckets = num_buckets | ||||
self.batch_size = batch_size | self.batch_size = batch_size | ||||
self.seq_lens_field_name = seq_lens_field_name | self.seq_lens_field_name = seq_lens_field_name | ||||
def __call__(self, data_set): | def __call__(self, data_set): | ||||
seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content | seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content | ||||
total_sample_num = len(seq_lens) | total_sample_num = len(seq_lens) | ||||
@@ -18,7 +18,7 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.losses import _prepare_losser | from fastNLP.core.losses import _prepare_losser | ||||
from fastNLP.core.metrics import _prepare_metrics | from fastNLP.core.metrics import _prepare_metrics | ||||
from fastNLP.core.optimizer import Adam | from fastNLP.core.optimizer import Adam | ||||
from fastNLP.core.sampler import BaseSampler | |||||
from fastNLP.core.sampler import Sampler | |||||
from fastNLP.core.sampler import RandomSampler | from fastNLP.core.sampler import RandomSampler | ||||
from fastNLP.core.sampler import SequentialSampler | from fastNLP.core.sampler import SequentialSampler | ||||
from fastNLP.core.tester import Tester | from fastNLP.core.tester import Tester | ||||
@@ -57,7 +57,7 @@ class Trainer(object): | |||||
smaller, add "-" in front of the string. For example:: | smaller, add "-" in front of the string. For example:: | ||||
metric_key="-PPL" # language model gets better as perplexity gets smaller | metric_key="-PPL" # language model gets better as perplexity gets smaller | ||||
:param BaseSampler sampler: method used to generate batch data. | |||||
:param Sampler sampler: method used to generate batch data. | |||||
:param prefetch: bool, 是否使用额外的进程对产生batch数据。 | :param prefetch: bool, 是否使用额外的进程对产生batch数据。 | ||||
:param bool use_tqdm: whether to use tqdm to show train progress. | :param bool use_tqdm: whether to use tqdm to show train progress. | ||||
:param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | ||||
@@ -102,7 +102,7 @@ class Trainer(object): | |||||
losser = _prepare_losser(loss) | losser = _prepare_losser(loss) | ||||
# sampler check | # sampler check | ||||
if sampler is not None and not isinstance(sampler, BaseSampler): | |||||
if sampler is not None and not isinstance(sampler, Sampler): | |||||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | ||||
if check_code_level > -1: | if check_code_level > -1: | ||||
@@ -1,3 +1,4 @@ | |||||
from functools import wraps | |||||
from collections import Counter | from collections import Counter | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
@@ -5,7 +6,7 @@ def check_build_vocab(func): | |||||
"""A decorator to make sure the indexing is built before used. | """A decorator to make sure the indexing is built before used. | ||||
""" | """ | ||||
@wraps(func) # to solve missing docstring | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.word2idx is None or self.rebuild is True: | if self.word2idx is None or self.rebuild is True: | ||||
self.build_vocab() | self.build_vocab() | ||||
@@ -18,7 +19,7 @@ def check_build_status(func): | |||||
"""A decorator to check whether the vocabulary updates after the last build. | """A decorator to check whether the vocabulary updates after the last build. | ||||
""" | """ | ||||
@wraps(func) # to solve missing docstring | |||||
def _wrapper(self, *args, **kwargs): | def _wrapper(self, *args, **kwargs): | ||||
if self.rebuild is False: | if self.rebuild is False: | ||||
self.rebuild = True | self.rebuild = True | ||||
@@ -32,23 +33,28 @@ def check_build_status(func): | |||||
class Vocabulary(object): | class Vocabulary(object): | ||||
"""Use for word and index one to one mapping | |||||
""" | |||||
用于构建, 存储和使用 `str` 到 `int` 的一一映射 | |||||
Example:: | Example:: | ||||
vocab = Vocabulary() | vocab = Vocabulary() | ||||
word_list = "this is a word list".split() | word_list = "this is a word list".split() | ||||
vocab.update(word_list) | vocab.update(word_list) | ||||
vocab["word"] | |||||
vocab.to_word(5) | |||||
:param int max_size: set the max number of words in Vocabulary. Default: None | |||||
:param int min_freq: set the min occur frequency of words in Vocabulary. Default: None | |||||
:param padding: str, padding的字符,默认为<pad>。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 | |||||
Vocabulary的情况。 | |||||
:param unknown: str, unknown的字符,默认为<unk>。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 | |||||
Vocabulary的情况。 | |||||
vocab["word"] # str to int | |||||
vocab.to_word(5) # int to str | |||||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||||
:param str padding: padding的字符. 如果设置为 ``None`` , | |||||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||||
Default: '<pad>' | |||||
:param str unknow: unknow的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||||
Default: '<unk>' | |||||
""" | """ | ||||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | ||||
@@ -63,7 +69,7 @@ class Vocabulary(object): | |||||
@check_build_status | @check_build_status | ||||
def update(self, word_lst): | def update(self, word_lst): | ||||
"""Add a list of words into the vocabulary. | |||||
"""依次增加序列中词在词典中的出现频率 | |||||
:param list word_lst: a list of strings | :param list word_lst: a list of strings | ||||
""" | """ | ||||
@@ -71,32 +77,35 @@ class Vocabulary(object): | |||||
@check_build_status | @check_build_status | ||||
def add(self, word): | def add(self, word): | ||||
"""Add a single word into the vocabulary. | |||||
""" | |||||
增加一个新词在词典中的出现频率 | |||||
:param str word: a word or token. | |||||
:param str word: 新词 | |||||
""" | """ | ||||
self.word_count[word] += 1 | self.word_count[word] += 1 | ||||
@check_build_status | @check_build_status | ||||
def add_word(self, word): | def add_word(self, word): | ||||
"""Add a single word into the vocabulary. | |||||
:param str word: a word or token. | |||||
""" | |||||
增加一个新词在词典中的出现频率 | |||||
:param str word: 新词 | |||||
""" | """ | ||||
self.add(word) | self.add(word) | ||||
@check_build_status | @check_build_status | ||||
def add_word_lst(self, word_lst): | def add_word_lst(self, word_lst): | ||||
"""Add a list of words into the vocabulary. | |||||
:param list word_lst: a list of strings | |||||
""" | |||||
依次增加序列中词在词典中的出现频率 | |||||
:param list(str) word_lst: 词的序列 | |||||
""" | """ | ||||
self.update(word_lst) | self.update(word_lst) | ||||
def build_vocab(self): | def build_vocab(self): | ||||
"""Build a mapping from word to index, and filter the word using ``max_size`` and ``min_freq``. | |||||
""" | |||||
根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, | |||||
但已经记录在词典中的词, 不会改变对应的 `int` | |||||
""" | """ | ||||
self.word2idx = {} | self.word2idx = {} | ||||
@@ -117,7 +126,8 @@ class Vocabulary(object): | |||||
self.rebuild = False | self.rebuild = False | ||||
def build_reverse_vocab(self): | def build_reverse_vocab(self): | ||||
"""Build "index to word" dict based on "word to index" dict. | |||||
""" | |||||
基于 "word to index" dict, 构建 "index to word" dict. | |||||
""" | """ | ||||
self.idx2word = {i: w for w, i in self.word2idx.items()} | self.idx2word = {i: w for w, i in self.word2idx.items()} | ||||
@@ -128,7 +138,8 @@ class Vocabulary(object): | |||||
@check_build_vocab | @check_build_vocab | ||||
def __contains__(self, item): | def __contains__(self, item): | ||||
"""Check if a word in vocabulary. | |||||
""" | |||||
检查词是否被记录 | |||||
:param item: the word | :param item: the word | ||||
:return: True or False | :return: True or False | ||||
@@ -136,11 +147,24 @@ class Vocabulary(object): | |||||
return item in self.word2idx | return item in self.word2idx | ||||
def has_word(self, w): | def has_word(self, w): | ||||
""" | |||||
检查词是否被记录 | |||||
Example:: | |||||
has_abc = vocab.has_word('abc') | |||||
# equals to | |||||
has_abc = 'abc' in vocab | |||||
:param item: the word | |||||
:return: ``True`` or ``False`` | |||||
""" | |||||
return self.__contains__(w) | return self.__contains__(w) | ||||
@check_build_vocab | @check_build_vocab | ||||
def __getitem__(self, w): | def __getitem__(self, w): | ||||
"""To support usage like:: | |||||
""" | |||||
To support usage like:: | |||||
vocab[w] | vocab[w] | ||||
""" | """ | ||||
@@ -154,14 +178,19 @@ class Vocabulary(object): | |||||
@check_build_vocab | @check_build_vocab | ||||
def index_dataset(self, *datasets, field_name, new_field_name=None): | def index_dataset(self, *datasets, field_name, new_field_name=None): | ||||
""" | """ | ||||
example: | |||||
# remember to use `field_name` | |||||
vocab.index_dataset(tr_data, dev_data, te_data, field_name='words') | |||||
将DataSet中对应field的词转为数字. | |||||
Example:: | |||||
:param datasets: fastNLP Dataset type. you can pass multiple datasets | |||||
:param field_name: str, what field to index. Only support 0,1,2 dimension. | |||||
:param new_field_name: str. What the indexed field should be named, default is to overwrite field_name | |||||
:return: | |||||
# remember to use `field_name` | |||||
vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | |||||
:param DataSet datasets: 需要转index的 DataSet, 支持一个或多个 | |||||
:param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||||
目前仅支持 ``str`` , ``list(str)`` , ``list(list(str))`` | |||||
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||||
Default: ``None`` | |||||
:return self: | |||||
""" | """ | ||||
def index_instance(ins): | def index_instance(ins): | ||||
""" | """ | ||||
@@ -194,11 +223,18 @@ class Vocabulary(object): | |||||
def from_dataset(self, *datasets, field_name): | def from_dataset(self, *datasets, field_name): | ||||
""" | """ | ||||
Construct vocab from dataset. | |||||
使用dataset的对应field中词构建词典 | |||||
Example:: | |||||
# remember to use `field_name` | |||||
vocab.from_dataset(train_data1, train_data2, field_name='words') | |||||
:param datasets: DataSet. | |||||
:param field_name: str, what field is used to construct dataset. | |||||
:return: | |||||
:param DataSet datasets: 需要转index的 DataSet, 支持一个或多个. | |||||
:param str field_name: 构建词典所使用的 field. | |||||
若有多个 DataSet, 每个DataSet都必须有此 field. | |||||
目前仅支持 ``str`` , ``list(str)`` , ``list(list(str))`` | |||||
:return self: | |||||
""" | """ | ||||
def construct_vocab(ins): | def construct_vocab(ins): | ||||
field = ins[field_name] | field = ins[field_name] | ||||
@@ -223,15 +259,27 @@ class Vocabulary(object): | |||||
return self | return self | ||||
def to_index(self, w): | def to_index(self, w): | ||||
""" Turn a word to an index. If w is not in Vocabulary, return the unknown label. | |||||
""" | |||||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 | |||||
``ValueError`` | |||||
Example:: | |||||
index = vocab.to_index('abc') | |||||
# equals to | |||||
index = vocab['abc'] | |||||
:param str w: a word | :param str w: a word | ||||
:return int index: the number | |||||
""" | """ | ||||
return self.__getitem__(w) | return self.__getitem__(w) | ||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def unknown_idx(self): | def unknown_idx(self): | ||||
""" | |||||
unknown 对应的数字. | |||||
""" | |||||
if self.unknown is None: | if self.unknown is None: | ||||
return None | return None | ||||
return self.word2idx[self.unknown] | return self.word2idx[self.unknown] | ||||
@@ -239,16 +287,20 @@ class Vocabulary(object): | |||||
@property | @property | ||||
@check_build_vocab | @check_build_vocab | ||||
def padding_idx(self): | def padding_idx(self): | ||||
""" | |||||
padding 对应的数字 | |||||
""" | |||||
if self.padding is None: | if self.padding is None: | ||||
return None | return None | ||||
return self.word2idx[self.padding] | return self.word2idx[self.padding] | ||||
@check_build_vocab | @check_build_vocab | ||||
def to_word(self, idx): | def to_word(self, idx): | ||||
"""given a word's index, return the word itself | |||||
""" | |||||
给定一个数字, 将其转为对应的词. | |||||
:param int idx: the index | :param int idx: the index | ||||
:return str word: the indexed word | |||||
:return str word: the word | |||||
""" | """ | ||||
return self.idx2word[idx] | return self.idx2word[idx] | ||||
@@ -4,7 +4,7 @@ from nltk.tree import Tree | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.instance import Instance | from fastNLP.core.instance import Instance | ||||
from fastNLP.io.file_reader import read_csv, read_json, read_conll | |||||
from fastNLP.io.file_reader import _read_csv, _read_json, _read_conll | |||||
def _download_from_url(url, path): | def _download_from_url(url, path): | ||||
@@ -55,12 +55,12 @@ def _uncompress(src, dst): | |||||
class DataSetLoader: | class DataSetLoader: | ||||
"""Interface for all DataSetLoaders. | |||||
"""所有`DataSetLoader`的接口 | |||||
""" | """ | ||||
def load(self, path): | def load(self, path): | ||||
"""Load data from a given file. | |||||
"""从指定 ``path`` 的文件中读取数据,返回DataSet | |||||
:param str path: file path | :param str path: file path | ||||
:return: a DataSet object | :return: a DataSet object | ||||
@@ -68,7 +68,7 @@ class DataSetLoader: | |||||
raise NotImplementedError | raise NotImplementedError | ||||
def convert(self, data): | def convert(self, data): | ||||
"""Optional operation to build a DataSet. | |||||
"""用Python数据对象创建DataSet | |||||
:param data: inner data structure (user-defined) to represent the data. | :param data: inner data structure (user-defined) to represent the data. | ||||
:return: a DataSet object | :return: a DataSet object | ||||
@@ -77,7 +77,7 @@ class DataSetLoader: | |||||
class PeopleDailyCorpusLoader(DataSetLoader): | class PeopleDailyCorpusLoader(DataSetLoader): | ||||
"""人民日报数据集 | |||||
"""读取人民日报数据集 | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(PeopleDailyCorpusLoader, self).__init__() | super(PeopleDailyCorpusLoader, self).__init__() | ||||
@@ -154,8 +154,35 @@ class PeopleDailyCorpusLoader(DataSetLoader): | |||||
return data_set | return data_set | ||||
class ConllLoader: | |||||
class ConllLoader(DataSetLoader): | |||||
""" | |||||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html | |||||
列号从0开始, 每列对应内容为:: | |||||
Column Type | |||||
0 Document ID | |||||
1 Part number | |||||
2 Word number | |||||
3 Word itself | |||||
4 Part-of-Speech | |||||
5 Parse bit | |||||
6 Predicate lemma | |||||
7 Predicate Frameset ID | |||||
8 Word sense | |||||
9 Speaker/Author | |||||
10 Named Entities | |||||
11:N Predicate Arguments | |||||
N Coreference | |||||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应 | |||||
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||||
""" | |||||
def __init__(self, headers, indexs=None, dropna=True): | def __init__(self, headers, indexs=None, dropna=True): | ||||
super(ConllLoader, self).__init__() | |||||
if not isinstance(headers, (list, tuple)): | |||||
raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) | |||||
self.headers = headers | self.headers = headers | ||||
self.dropna = dropna | self.dropna = dropna | ||||
if indexs is None: | if indexs is None: | ||||
@@ -167,24 +194,17 @@ class ConllLoader: | |||||
def load(self, path): | def load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, data in read_conll(path, indexes=self.indexs, dropna=self.dropna): | |||||
ins = {h:data[idx] for h, idx in zip(self.headers, self.indexs)} | |||||
for idx, data in _read_conll(path, indexes=self.indexs, dropna=self.dropna): | |||||
ins = {h:data[i] for i, h in enumerate(self.headers)} | |||||
ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
return ds | return ds | ||||
def get_one(self, sample): | |||||
sample = list(map(list, zip(*sample))) | |||||
for field in sample: | |||||
if len(field) <= 0: | |||||
return None | |||||
return sample | |||||
class Conll2003Loader(ConllLoader): | class Conll2003Loader(ConllLoader): | ||||
"""Loader for conll2003 dataset | |||||
"""读取Conll2003数据 | |||||
More information about the given dataset cound be found on | |||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||||
关于数据集的更多信息,参考: | |||||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
headers = [ | headers = [ | ||||
@@ -193,9 +213,10 @@ class Conll2003Loader(ConllLoader): | |||||
super(Conll2003Loader, self).__init__(headers=headers) | super(Conll2003Loader, self).__init__(headers=headers) | ||||
def cut_long_sentence(sent, max_sample_length=200): | |||||
def _cut_long_sentence(sent, max_sample_length=200): | |||||
""" | """ | ||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length | |||||
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。 | |||||
所以截取的句子可能长于或者短于max_sample_length | |||||
:param sent: str. | :param sent: str. | ||||
:param max_sample_length: int. | :param max_sample_length: int. | ||||
@@ -223,8 +244,15 @@ def cut_long_sentence(sent, max_sample_length=200): | |||||
class SSTLoader(DataSetLoader): | class SSTLoader(DataSetLoader): | ||||
"""load SST data in PTB tree format | |||||
data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
"""读取SST数据集, DataSet包含fields:: | |||||
words: list(str) 需要分类的文本 | |||||
target: str 文本的标签 | |||||
数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip | |||||
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` | |||||
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, subtree=False, fine_grained=False): | def __init__(self, subtree=False, fine_grained=False): | ||||
self.subtree = subtree | self.subtree = subtree | ||||
@@ -247,14 +275,14 @@ class SSTLoader(DataSetLoader): | |||||
datas = [] | datas = [] | ||||
for l in f: | for l in f: | ||||
datas.extend([(s, self.tag_v[t]) | datas.extend([(s, self.tag_v[t]) | ||||
for s, t in self.get_one(l, self.subtree)]) | |||||
for s, t in self._get_one(l, self.subtree)]) | |||||
ds = DataSet() | ds = DataSet() | ||||
for words, tag in datas: | for words, tag in datas: | ||||
ds.append(Instance(words=words, raw_tag=tag)) | |||||
ds.append(Instance(words=words, target=tag)) | |||||
return ds | return ds | ||||
@staticmethod | @staticmethod | ||||
def get_one(data, subtree): | |||||
def _get_one(data, subtree): | |||||
tree = Tree.fromstring(data) | tree = Tree.fromstring(data) | ||||
if subtree: | if subtree: | ||||
return [(t.leaves(), t.label()) for t in tree.subtrees()] | return [(t.leaves(), t.label()) for t in tree.subtrees()] | ||||
@@ -262,11 +290,17 @@ class SSTLoader(DataSetLoader): | |||||
class JsonLoader(DataSetLoader): | class JsonLoader(DataSetLoader): | ||||
"""Load json-format data, | |||||
every line contains a json obj, like a dict | |||||
fields is the dict key that need to be load | |||||
""" | """ | ||||
def __init__(self, dropna=False, fields=None): | |||||
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 | |||||
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name | |||||
``fields`` 的`key`必须是json对象的属性名. ``fields`` 的`value`为读入后在DataSet存储的`field_name`, | |||||
`value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 | |||||
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||||
Default: ``True`` | |||||
""" | |||||
def __init__(self, fields=None, dropna=False): | |||||
super(JsonLoader, self).__init__() | super(JsonLoader, self).__init__() | ||||
self.dropna = dropna | self.dropna = dropna | ||||
self.fields = None | self.fields = None | ||||
@@ -279,7 +313,7 @@ class JsonLoader(DataSetLoader): | |||||
def load(self, path): | def load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, d in read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): | |||||
ins = {self.fields[k]:v for k,v in d.items()} | ins = {self.fields[k]:v for k,v in d.items()} | ||||
ds.append(Instance(**ins)) | ds.append(Instance(**ins)) | ||||
return ds | return ds | ||||
@@ -287,7 +321,13 @@ class JsonLoader(DataSetLoader): | |||||
class SNLILoader(JsonLoader): | class SNLILoader(JsonLoader): | ||||
""" | """ | ||||
data source: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
读取SNLI数据集,读取的DataSet包含fields:: | |||||
words1: list(str),第一句文本, premise | |||||
words2: list(str), 第二句文本, hypothesis | |||||
target: str, 真实标签 | |||||
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
fields = { | fields = { | ||||
@@ -309,14 +349,14 @@ class SNLILoader(JsonLoader): | |||||
class CSVLoader(DataSetLoader): | class CSVLoader(DataSetLoader): | ||||
"""Load data from a CSV file and return a DataSet object. | |||||
:param str csv_path: path to the CSV file | |||||
:param List[str] or Tuple[str] headers: headers of the CSV file | |||||
:param str sep: delimiter in CSV file. Default: "," | |||||
:param bool dropna: If True, drop rows that have less entries than headers. | |||||
:return dataset: the read data set | |||||
""" | |||||
读取CSV格式的数据集。返回 ``DataSet`` | |||||
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 | |||||
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` | |||||
:param str sep: CSV文件中列与列之间的分隔符. Default: "," | |||||
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . | |||||
Default: ``True`` | |||||
""" | """ | ||||
def __init__(self, headers=None, sep=",", dropna=True): | def __init__(self, headers=None, sep=",", dropna=True): | ||||
self.headers = headers | self.headers = headers | ||||
@@ -325,8 +365,8 @@ class CSVLoader(DataSetLoader): | |||||
def load(self, path): | def load(self, path): | ||||
ds = DataSet() | ds = DataSet() | ||||
for idx, data in read_csv(path, headers=self.headers, | |||||
sep=self.sep, dropna=self.dropna): | |||||
for idx, data in _read_csv(path, headers=self.headers, | |||||
sep=self.sep, dropna=self.dropna): | |||||
ds.append(Instance(**data)) | ds.append(Instance(**data)) | ||||
return ds | return ds | ||||
@@ -1,15 +1,16 @@ | |||||
import json | import json | ||||
def read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||||
""" | """ | ||||
Construct a generator to read csv items | |||||
Construct a generator to read csv items. | |||||
:param path: file path | :param path: file path | ||||
:param encoding: file's encoding, default: utf-8 | :param encoding: file's encoding, default: utf-8 | ||||
:param headers: file's headers, if None, make file's first line as headers. default: None | :param headers: file's headers, if None, make file's first line as headers. default: None | ||||
:param sep: separator for each column. default: ',' | :param sep: separator for each column. default: ',' | ||||
:param dropna: weather to ignore and drop invalid data, | :param dropna: weather to ignore and drop invalid data, | ||||
if False, raise ValueError when reading invalid data. default: True | |||||
:if False, raise ValueError when reading invalid data. default: True | |||||
:return: generator, every time yield (line number, csv item) | :return: generator, every time yield (line number, csv item) | ||||
""" | """ | ||||
with open(path, 'r', encoding=encoding) as f: | with open(path, 'r', encoding=encoding) as f: | ||||
@@ -35,14 +36,15 @@ def read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||||
yield line_idx, _dict | yield line_idx, _dict | ||||
def read_json(path, encoding='utf-8', fields=None, dropna=True): | |||||
def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||||
""" | """ | ||||
Construct a generator to read json items | |||||
Construct a generator to read json items. | |||||
:param path: file path | :param path: file path | ||||
:param encoding: file's encoding, default: utf-8 | :param encoding: file's encoding, default: utf-8 | ||||
:param fields: json object's fields that needed, if None, all fields are needed. default: None | :param fields: json object's fields that needed, if None, all fields are needed. default: None | ||||
:param dropna: weather to ignore and drop invalid data, | :param dropna: weather to ignore and drop invalid data, | ||||
if False, raise ValueError when reading invalid data. default: True | |||||
:if False, raise ValueError when reading invalid data. default: True | |||||
:return: generator, every time yield (line number, json item) | :return: generator, every time yield (line number, json item) | ||||
""" | """ | ||||
if fields: | if fields: | ||||
@@ -65,14 +67,15 @@ def read_json(path, encoding='utf-8', fields=None, dropna=True): | |||||
yield line_idx, _res | yield line_idx, _res | ||||
def read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): | |||||
""" | """ | ||||
Construct a generator to read conll items | |||||
Construct a generator to read conll items. | |||||
:param path: file path | :param path: file path | ||||
:param encoding: file's encoding, default: utf-8 | :param encoding: file's encoding, default: utf-8 | ||||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | ||||
:param dropna: weather to ignore and drop invalid data, | :param dropna: weather to ignore and drop invalid data, | ||||
if False, raise ValueError when reading invalid data. default: True | |||||
:if False, raise ValueError when reading invalid data. default: True | |||||
:return: generator, every time yield (line number, conll item) | :return: generator, every time yield (line number, conll item) | ||||
""" | """ | ||||
def parse_conll(sample): | def parse_conll(sample): | ||||
@@ -16,7 +16,7 @@ from fastNLP.modules.utils import initial_parameter | |||||
from fastNLP.modules.utils import seq_mask | from fastNLP.modules.utils import seq_mask | ||||
def mst(scores): | |||||
def _mst(scores): | |||||
""" | """ | ||||
with some modification to support parser output for MST decoding | with some modification to support parser output for MST decoding | ||||
https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 | ||||
@@ -120,12 +120,22 @@ def _find_cycle(vertices, edges): | |||||
class GraphParser(BaseModel): | class GraphParser(BaseModel): | ||||
"""Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding | |||||
""" | |||||
基于图的parser base class, 支持贪婪解码和最大生成树解码 | |||||
""" | """ | ||||
def __init__(self): | def __init__(self): | ||||
super(GraphParser, self).__init__() | super(GraphParser, self).__init__() | ||||
def _greedy_decoder(self, arc_matrix, mask=None): | |||||
@staticmethod | |||||
def greedy_decoder(arc_matrix, mask=None): | |||||
""" | |||||
贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 | |||||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||||
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||||
若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||||
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||||
""" | |||||
_, seq_len, _ = arc_matrix.shape | _, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) | ||||
flip_mask = (mask == 0).byte() | flip_mask = (mask == 0).byte() | ||||
@@ -135,22 +145,34 @@ class GraphParser(BaseModel): | |||||
heads *= mask.long() | heads *= mask.long() | ||||
return heads | return heads | ||||
def _mst_decoder(self, arc_matrix, mask=None): | |||||
@staticmethod | |||||
def mst_decoder(arc_matrix, mask=None): | |||||
""" | |||||
用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 | |||||
:param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 | |||||
:param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. | |||||
若为 ``None`` 时, 默认为全1向量. Default: ``None`` | |||||
:return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 | |||||
""" | |||||
batch_size, seq_len, _ = arc_matrix.shape | batch_size, seq_len, _ = arc_matrix.shape | ||||
matrix = arc_matrix.clone() | matrix = arc_matrix.clone() | ||||
ans = matrix.new_zeros(batch_size, seq_len).long() | ans = matrix.new_zeros(batch_size, seq_len).long() | ||||
lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len | ||||
batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) | |||||
for i, graph in enumerate(matrix): | for i, graph in enumerate(matrix): | ||||
len_i = lens[i] | len_i = lens[i] | ||||
ans[i, :len_i] = torch.as_tensor(mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
ans[i, :len_i] = torch.as_tensor(_mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) | |||||
if mask is not None: | if mask is not None: | ||||
ans *= mask.long() | ans *= mask.long() | ||||
return ans | return ans | ||||
class ArcBiaffine(nn.Module): | class ArcBiaffine(nn.Module): | ||||
"""helper module for Biaffine Dependency Parser predicting arc | |||||
""" | |||||
Biaffine Dependency Parser 的子模块, 用于构建预测边的图 | |||||
:param hidden_size: 输入的特征维度 | |||||
:param bias: 是否使用bias. Default: ``True`` | |||||
""" | """ | ||||
def __init__(self, hidden_size, bias=True): | def __init__(self, hidden_size, bias=True): | ||||
super(ArcBiaffine, self).__init__() | super(ArcBiaffine, self).__init__() | ||||
@@ -164,10 +186,10 @@ class ArcBiaffine(nn.Module): | |||||
def forward(self, head, dep): | def forward(self, head, dep): | ||||
""" | """ | ||||
:param head arc-head tensor = [batch, length, emb_dim] | |||||
:param dep arc-dependent tensor = [batch, length, emb_dim] | |||||
:return output tensor = [bacth, length, length] | |||||
:param head: arc-head tensor [batch, length, hidden] | |||||
:param dep: arc-dependent tensor [batch, length, hidden] | |||||
:return output: tensor [bacth, length, length] | |||||
""" | """ | ||||
output = dep.matmul(self.U) | output = dep.matmul(self.U) | ||||
output = output.bmm(head.transpose(-1, -2)) | output = output.bmm(head.transpose(-1, -2)) | ||||
@@ -177,7 +199,13 @@ class ArcBiaffine(nn.Module): | |||||
class LabelBilinear(nn.Module): | class LabelBilinear(nn.Module): | ||||
"""helper module for Biaffine Dependency Parser predicting label | |||||
""" | |||||
Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 | |||||
:param in1_features: 输入的特征1维度 | |||||
:param in2_features: 输入的特征2维度 | |||||
:param num_label: 边类别的个数 | |||||
:param bias: 是否使用bias. Default: ``True`` | |||||
""" | """ | ||||
def __init__(self, in1_features, in2_features, num_label, bias=True): | def __init__(self, in1_features, in2_features, num_label, bias=True): | ||||
super(LabelBilinear, self).__init__() | super(LabelBilinear, self).__init__() | ||||
@@ -185,14 +213,34 @@ class LabelBilinear(nn.Module): | |||||
self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) | ||||
def forward(self, x1, x2): | def forward(self, x1, x2): | ||||
""" | |||||
:param x1: [batch, seq_len, hidden] 输入特征1, 即label-head | |||||
:param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep | |||||
:return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 | |||||
""" | |||||
output = self.bilinear(x1, x2) | output = self.bilinear(x1, x2) | ||||
output += self.lin(torch.cat([x1, x2], dim=2)) | output += self.lin(torch.cat([x1, x2], dim=2)) | ||||
return output | return output | ||||
class BiaffineParser(GraphParser): | class BiaffineParser(GraphParser): | ||||
"""Biaffine Dependency Parser implemantation. | |||||
refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | |||||
"""Biaffine Dependency Parser 实现. | |||||
论文参考 ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) | |||||
<https://arxiv.org/abs/1611.01734>`_ . | <https://arxiv.org/abs/1611.01734>`_ . | ||||
:param word_vocab_size: 单词词典大小 | |||||
:param word_emb_dim: 单词词嵌入向量的维度 | |||||
:param pos_vocab_size: part-of-speech 词典大小 | |||||
:param pos_emb_dim: part-of-speech 向量维度 | |||||
:param num_label: 边的类别个数 | |||||
:param rnn_layers: rnn encoder的层数 | |||||
:param rnn_hidden_size: rnn encoder 的隐状态维度 | |||||
:param arc_mlp_size: 边预测的MLP维度 | |||||
:param label_mlp_size: 类别预测的MLP维度 | |||||
:param dropout: dropout概率. | |||||
:param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm | |||||
:param use_greedy_infer: 是否在inference时使用贪心算法. | |||||
若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, | def __init__(self, | ||||
word_vocab_size, | word_vocab_size, | ||||
@@ -207,7 +255,6 @@ class BiaffineParser(GraphParser): | |||||
dropout=0.3, | dropout=0.3, | ||||
encoder='lstm', | encoder='lstm', | ||||
use_greedy_infer=False): | use_greedy_infer=False): | ||||
super(BiaffineParser, self).__init__() | super(BiaffineParser, self).__init__() | ||||
rnn_out_size = 2 * rnn_hidden_size | rnn_out_size = 2 * rnn_hidden_size | ||||
word_hid_dim = pos_hid_dim = rnn_hidden_size | word_hid_dim = pos_hid_dim = rnn_hidden_size | ||||
@@ -275,27 +322,31 @@ class BiaffineParser(GraphParser): | |||||
for p in m.parameters(): | for p in m.parameters(): | ||||
nn.init.normal_(p, 0, 0.1) | nn.init.normal_(p, 0, 0.1) | ||||
def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None): | |||||
""" | |||||
:param word_seq: [batch_size, seq_len] sequence of word's indices | |||||
:param pos_seq: [batch_size, seq_len] sequence of word's indices | |||||
:param seq_lens: [batch_size, seq_len] sequence of length masks | |||||
:param gold_heads: [batch_size, seq_len] sequence of golden heads | |||||
:return dict: parsing results | |||||
arc_pred: [batch_size, seq_len, seq_len] | |||||
label_pred: [batch_size, seq_len, seq_len] | |||||
mask: [batch_size, seq_len] | |||||
head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads | |||||
def forward(self, words1, words2, seq_len, gold_heads=None): | |||||
"""模型forward阶段 | |||||
:param words1: [batch_size, seq_len] 输入word序列 | |||||
:param words2: [batch_size, seq_len] 输入pos序列 | |||||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||||
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||||
Default: ``None`` | |||||
:return dict: parsing结果:: | |||||
arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
mask: [batch_size, seq_len] 预测结果的mask | |||||
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 | |||||
""" | """ | ||||
# prepare embeddings | # prepare embeddings | ||||
batch_size, seq_len = word_seq.shape | |||||
batch_size, length = words1.shape | |||||
# print('forward {} {}'.format(batch_size, seq_len)) | # print('forward {} {}'.format(batch_size, seq_len)) | ||||
# get sequence mask | # get sequence mask | ||||
mask = seq_mask(seq_lens, seq_len).long() | |||||
mask = seq_mask(seq_len, length).long() | |||||
word = self.word_embedding(word_seq) # [N,L] -> [N,L,C_0] | |||||
pos = self.pos_embedding(pos_seq) # [N,L] -> [N,L,C_1] | |||||
word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] | |||||
pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] | |||||
word, pos = self.word_fc(word), self.pos_fc(pos) | word, pos = self.word_fc(word), self.pos_fc(pos) | ||||
word, pos = self.word_norm(word), self.pos_norm(pos) | word, pos = self.word_norm(word), self.pos_norm(pos) | ||||
@@ -303,7 +354,7 @@ class BiaffineParser(GraphParser): | |||||
# encoder, extract features | # encoder, extract features | ||||
if self.encoder_name.endswith('lstm'): | if self.encoder_name.endswith('lstm'): | ||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||||
x = x[sort_idx] | x = x[sort_idx] | ||||
x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) | ||||
feat, _ = self.encoder(x) # -> [N,L,C] | feat, _ = self.encoder(x) # -> [N,L,C] | ||||
@@ -329,20 +380,20 @@ class BiaffineParser(GraphParser): | |||||
if gold_heads is None or not self.training: | if gold_heads is None or not self.training: | ||||
# use greedy decoding in training | # use greedy decoding in training | ||||
if self.training or self.use_greedy_infer: | if self.training or self.use_greedy_infer: | ||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
else: | else: | ||||
heads = self._mst_decoder(arc_pred, mask) | |||||
heads = self.mst_decoder(arc_pred, mask) | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
assert self.training # must be training mode | assert self.training # must be training mode | ||||
if gold_heads is None: | if gold_heads is None: | ||||
heads = self._greedy_decoder(arc_pred, mask) | |||||
heads = self.greedy_decoder(arc_pred, mask) | |||||
head_pred = heads | head_pred = heads | ||||
else: | else: | ||||
head_pred = None | head_pred = None | ||||
heads = gold_heads | heads = gold_heads | ||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) | |||||
batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) | |||||
label_head = label_head[batch_range, heads].contiguous() | label_head = label_head[batch_range, heads].contiguous() | ||||
label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] | ||||
res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} | ||||
@@ -355,11 +406,11 @@ class BiaffineParser(GraphParser): | |||||
""" | """ | ||||
Compute loss. | Compute loss. | ||||
:param arc_pred: [batch_size, seq_len, seq_len] | |||||
:param label_pred: [batch_size, seq_len, n_tags] | |||||
:param arc_true: [batch_size, seq_len] | |||||
:param label_true: [batch_size, seq_len] | |||||
:param mask: [batch_size, seq_len] | |||||
:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
:param label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
:param arc_true: [batch_size, seq_len] 真实边的标注 | |||||
:param label_true: [batch_size, seq_len] 真实类别的标注 | |||||
:param mask: [batch_size, seq_len] 预测结果的mask | |||||
:return: loss value | :return: loss value | ||||
""" | """ | ||||
@@ -381,16 +432,23 @@ class BiaffineParser(GraphParser): | |||||
label_nll = -label_loss.mean() | label_nll = -label_loss.mean() | ||||
return arc_nll + label_nll | return arc_nll + label_nll | ||||
def predict(self, word_seq, pos_seq, seq_lens): | |||||
""" | |||||
:param word_seq: | |||||
:param pos_seq: | |||||
:param seq_lens: | |||||
:return: arc_pred: [B, L] | |||||
label_pred: [B, L] | |||||
def predict(self, words1, words2, seq_len): | |||||
"""模型预测API | |||||
:param words1: [batch_size, seq_len] 输入word序列 | |||||
:param words2: [batch_size, seq_len] 输入pos序列 | |||||
:param seq_len: [batch_size, seq_len] 输入序列长度 | |||||
:param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, | |||||
用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 | |||||
Default: ``None`` | |||||
:return dict: parsing结果:: | |||||
arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
mask: [batch_size, seq_len] 预测结果的mask | |||||
head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 | |||||
""" | """ | ||||
res = self(word_seq, pos_seq, seq_lens) | |||||
res = self(words1, words2, seq_len) | |||||
output = {} | output = {} | ||||
output['arc_pred'] = res.pop('head_pred') | output['arc_pred'] = res.pop('head_pred') | ||||
_, label_pred = res.pop('label_pred').max(2) | _, label_pred = res.pop('label_pred').max(2) | ||||
@@ -399,6 +457,16 @@ class BiaffineParser(GraphParser): | |||||
class ParserLoss(LossFunc): | class ParserLoss(LossFunc): | ||||
""" | |||||
计算parser的loss | |||||
:param arc_pred: [batch_size, seq_len, seq_len] 边预测logits | |||||
:param label_pred: [batch_size, seq_len, num_label] label预测logits | |||||
:param arc_true: [batch_size, seq_len] 真实边的标注 | |||||
:param label_true: [batch_size, seq_len] 真实类别的标注 | |||||
:param mask: [batch_size, seq_len] 预测结果的mask | |||||
:return loss: scalar | |||||
""" | |||||
def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): | def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): | ||||
super(ParserLoss, self).__init__(BiaffineParser.loss, | super(ParserLoss, self).__init__(BiaffineParser.loss, | ||||
arc_pred=arc_pred, | arc_pred=arc_pred, | ||||
@@ -408,12 +476,26 @@ class ParserLoss(LossFunc): | |||||
class ParserMetric(MetricBase): | class ParserMetric(MetricBase): | ||||
""" | |||||
评估parser的性能 | |||||
:param arc_pred: 边预测logits | |||||
:param label_pred: label预测logits | |||||
:param arc_true: 真实边的标注 | |||||
:param label_true: 真实类别的标注 | |||||
:param seq_len: 序列长度 | |||||
:return dict: 评估结果:: | |||||
UAS: 不带label时, 边预测的准确率 | |||||
LAS: 同时预测边和label的准确率 | |||||
""" | |||||
def __init__(self, arc_pred=None, label_pred=None, | def __init__(self, arc_pred=None, label_pred=None, | ||||
arc_true=None, label_true=None, seq_lens=None): | |||||
arc_true=None, label_true=None, seq_len=None): | |||||
super().__init__() | super().__init__() | ||||
self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, | self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, | ||||
arc_true=arc_true, label_true=label_true, | arc_true=arc_true, label_true=label_true, | ||||
seq_lens=seq_lens) | |||||
seq_len=seq_len) | |||||
self.num_arc = 0 | self.num_arc = 0 | ||||
self.num_label = 0 | self.num_label = 0 | ||||
self.num_sample = 0 | self.num_sample = 0 | ||||
@@ -424,13 +506,13 @@ class ParserMetric(MetricBase): | |||||
self.num_sample = self.num_label = self.num_arc = 0 | self.num_sample = self.num_label = self.num_arc = 0 | ||||
return res | return res | ||||
def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_lens=None): | |||||
def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_len=None): | |||||
"""Evaluate the performance of prediction. | """Evaluate the performance of prediction. | ||||
""" | """ | ||||
if seq_lens is None: | |||||
if seq_len is None: | |||||
seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) | ||||
else: | else: | ||||
seq_mask = seq_lens_to_masks(seq_lens.long(), float=False).long() | |||||
seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() | |||||
# mask out <root> tag | # mask out <root> tag | ||||
seq_mask[:,0] = 0 | seq_mask[:,0] = 0 | ||||
head_pred_correct = (arc_pred == arc_true).long() * seq_mask | head_pred_correct = (arc_pred == arc_true).long() * seq_mask | ||||
@@ -7,6 +7,21 @@ import torch.nn.functional as F | |||||
class StarTransEnc(nn.Module): | class StarTransEnc(nn.Module): | ||||
""" | |||||
带word embedding的Star-Transformer Encoder | |||||
:param vocab_size: 词嵌入的词典大小 | |||||
:param emb_dim: 每个词嵌入的特征维度 | |||||
:param num_cls: 输出类别个数 | |||||
:param hidden_size: 模型中特征维度. | |||||
:param num_layers: 模型层数. | |||||
:param num_head: 模型中multi-head的head个数. | |||||
:param head_dim: 模型中multi-head中每个head特征维度. | |||||
:param max_len: 模型能接受的最大输入长度. | |||||
:param cls_hidden_size: 分类器隐层维度. | |||||
:param emb_dropout: 词嵌入的dropout概率. | |||||
:param dropout: 模型除词嵌入外的dropout概率. | |||||
""" | |||||
def __init__(self, vocab_size, emb_dim, | def __init__(self, vocab_size, emb_dim, | ||||
hidden_size, | hidden_size, | ||||
num_layers, | num_layers, | ||||
@@ -27,15 +42,23 @@ class StarTransEnc(nn.Module): | |||||
max_len=max_len) | max_len=max_len) | ||||
def forward(self, x, mask): | def forward(self, x, mask): | ||||
""" | |||||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | |||||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||||
否则为 1 | |||||
:return: [batch, length, hidden] 编码后的输出序列 | |||||
[batch, hidden] 全局 relay 节点, 详见论文 | |||||
""" | |||||
x = self.embedding(x) | x = self.embedding(x) | ||||
x = self.emb_fc(self.emb_drop(x)) | x = self.emb_fc(self.emb_drop(x)) | ||||
nodes, relay = self.encoder(x, mask) | nodes, relay = self.encoder(x, mask) | ||||
return nodes, relay | return nodes, relay | ||||
class Cls(nn.Module): | |||||
class _Cls(nn.Module): | |||||
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | ||||
super(Cls, self).__init__() | |||||
super(_Cls, self).__init__() | |||||
self.fc = nn.Sequential( | self.fc = nn.Sequential( | ||||
nn.Linear(in_dim, hid_dim), | nn.Linear(in_dim, hid_dim), | ||||
nn.LeakyReLU(), | nn.LeakyReLU(), | ||||
@@ -48,9 +71,9 @@ class Cls(nn.Module): | |||||
return h | return h | ||||
class NLICls(nn.Module): | |||||
class _NLICls(nn.Module): | |||||
def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): | ||||
super(NLICls, self).__init__() | |||||
super(_NLICls, self).__init__() | |||||
self.fc = nn.Sequential( | self.fc = nn.Sequential( | ||||
nn.Dropout(dropout), | nn.Dropout(dropout), | ||||
nn.Linear(in_dim*4, hid_dim), #4 | nn.Linear(in_dim*4, hid_dim), #4 | ||||
@@ -65,7 +88,19 @@ class NLICls(nn.Module): | |||||
return h | return h | ||||
class STSeqLabel(nn.Module): | class STSeqLabel(nn.Module): | ||||
"""star-transformer model for sequence labeling | |||||
"""用于序列标注的Star-Transformer模型 | |||||
:param vocab_size: 词嵌入的词典大小 | |||||
:param emb_dim: 每个词嵌入的特征维度 | |||||
:param num_cls: 输出类别个数 | |||||
:param hidden_size: 模型中特征维度. Default: 300 | |||||
:param num_layers: 模型层数. Default: 4 | |||||
:param num_head: 模型中multi-head的head个数. Default: 8 | |||||
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32 | |||||
:param max_len: 模型能接受的最大输入长度. Default: 512 | |||||
:param cls_hidden_size: 分类器隐层维度. Default: 600 | |||||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | |||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | |||||
""" | """ | ||||
def __init__(self, vocab_size, emb_dim, num_cls, | def __init__(self, vocab_size, emb_dim, num_cls, | ||||
hidden_size=300, | hidden_size=300, | ||||
@@ -86,23 +121,47 @@ class STSeqLabel(nn.Module): | |||||
max_len=max_len, | max_len=max_len, | ||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = Cls(hidden_size, num_cls, cls_hidden_size) | |||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, words, seq_len): | |||||
""" | |||||
def forward(self, word_seq, seq_lens): | |||||
mask = seq_lens_to_masks(seq_lens) | |||||
nodes, _ = self.enc(word_seq, mask) | |||||
:param words: [batch, seq_len] 输入序列 | |||||
:param seq_len: [batch,] 输入序列的长度 | |||||
:return output: [batch, num_cls, seq_len] 输出序列中每个元素的分类的概率 | |||||
""" | |||||
mask = seq_lens_to_masks(seq_len) | |||||
nodes, _ = self.enc(words, mask) | |||||
output = self.cls(nodes) | output = self.cls(nodes) | ||||
output = output.transpose(1,2) # make hidden to be dim 1 | output = output.transpose(1,2) # make hidden to be dim 1 | ||||
return {'output': output} # [bsz, n_cls, seq_len] | return {'output': output} # [bsz, n_cls, seq_len] | ||||
def predict(self, word_seq, seq_lens): | |||||
y = self.forward(word_seq, seq_lens) | |||||
def predict(self, words, seq_len): | |||||
""" | |||||
:param words: [batch, seq_len] 输入序列 | |||||
:param seq_len: [batch,] 输入序列的长度 | |||||
:return output: [batch, seq_len] 输出序列中每个元素的分类 | |||||
""" | |||||
y = self.forward(words, seq_len) | |||||
_, pred = y['output'].max(1) | _, pred = y['output'].max(1) | ||||
return {'output': pred, 'seq_lens': seq_lens} | |||||
return {'output': pred} | |||||
class STSeqCls(nn.Module): | class STSeqCls(nn.Module): | ||||
"""star-transformer model for sequence classification | |||||
"""用于分类任务的Star-Transformer | |||||
:param vocab_size: 词嵌入的词典大小 | |||||
:param emb_dim: 每个词嵌入的特征维度 | |||||
:param num_cls: 输出类别个数 | |||||
:param hidden_size: 模型中特征维度. Default: 300 | |||||
:param num_layers: 模型层数. Default: 4 | |||||
:param num_head: 模型中multi-head的head个数. Default: 8 | |||||
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32 | |||||
:param max_len: 模型能接受的最大输入长度. Default: 512 | |||||
:param cls_hidden_size: 分类器隐层维度. Default: 600 | |||||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | |||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | |||||
""" | """ | ||||
def __init__(self, vocab_size, emb_dim, num_cls, | def __init__(self, vocab_size, emb_dim, num_cls, | ||||
@@ -124,23 +183,47 @@ class STSeqCls(nn.Module): | |||||
max_len=max_len, | max_len=max_len, | ||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = Cls(hidden_size, num_cls, cls_hidden_size) | |||||
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, word_seq, seq_lens): | |||||
mask = seq_lens_to_masks(seq_lens) | |||||
nodes, relay = self.enc(word_seq, mask) | |||||
def forward(self, words, seq_len): | |||||
""" | |||||
:param words: [batch, seq_len] 输入序列 | |||||
:param seq_len: [batch,] 输入序列的长度 | |||||
:return output: [batch, num_cls] 输出序列的分类的概率 | |||||
""" | |||||
mask = seq_lens_to_masks(seq_len) | |||||
nodes, relay = self.enc(words, mask) | |||||
y = 0.5 * (relay + nodes.max(1)[0]) | y = 0.5 * (relay + nodes.max(1)[0]) | ||||
output = self.cls(y) # [bsz, n_cls] | output = self.cls(y) # [bsz, n_cls] | ||||
return {'output': output} | return {'output': output} | ||||
def predict(self, word_seq, seq_lens): | |||||
y = self.forward(word_seq, seq_lens) | |||||
def predict(self, words, seq_len): | |||||
""" | |||||
:param words: [batch, seq_len] 输入序列 | |||||
:param seq_len: [batch,] 输入序列的长度 | |||||
:return output: [batch, num_cls] 输出序列的分类 | |||||
""" | |||||
y = self.forward(words, seq_len) | |||||
_, pred = y['output'].max(1) | _, pred = y['output'].max(1) | ||||
return {'output': pred} | return {'output': pred} | ||||
class STNLICls(nn.Module): | class STNLICls(nn.Module): | ||||
"""star-transformer model for NLI | |||||
"""用于自然语言推断(NLI)的Star-Transformer | |||||
:param vocab_size: 词嵌入的词典大小 | |||||
:param emb_dim: 每个词嵌入的特征维度 | |||||
:param num_cls: 输出类别个数 | |||||
:param hidden_size: 模型中特征维度. Default: 300 | |||||
:param num_layers: 模型层数. Default: 4 | |||||
:param num_head: 模型中multi-head的head个数. Default: 8 | |||||
:param head_dim: 模型中multi-head中每个head特征维度. Default: 32 | |||||
:param max_len: 模型能接受的最大输入长度. Default: 512 | |||||
:param cls_hidden_size: 分类器隐层维度. Default: 600 | |||||
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1 | |||||
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 | |||||
""" | """ | ||||
def __init__(self, vocab_size, emb_dim, num_cls, | def __init__(self, vocab_size, emb_dim, num_cls, | ||||
@@ -162,20 +245,36 @@ class STNLICls(nn.Module): | |||||
max_len=max_len, | max_len=max_len, | ||||
emb_dropout=emb_dropout, | emb_dropout=emb_dropout, | ||||
dropout=dropout) | dropout=dropout) | ||||
self.cls = NLICls(hidden_size, num_cls, cls_hidden_size) | |||||
self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) | |||||
def forward(self, words1, words2, seq_len1, seq_len2): | |||||
""" | |||||
def forward(self, word_seq1, word_seq2, seq_lens1, seq_lens2): | |||||
mask1 = seq_lens_to_masks(seq_lens1) | |||||
mask2 = seq_lens_to_masks(seq_lens2) | |||||
:param words1: [batch, seq_len] 输入序列1 | |||||
:param words2: [batch, seq_len] 输入序列2 | |||||
:param seq_len1: [batch,] 输入序列1的长度 | |||||
:param seq_len2: [batch,] 输入序列2的长度 | |||||
:return output: [batch, num_cls] 输出分类的概率 | |||||
""" | |||||
mask1 = seq_lens_to_masks(seq_len1) | |||||
mask2 = seq_lens_to_masks(seq_len2) | |||||
def enc(seq, mask): | def enc(seq, mask): | ||||
nodes, relay = self.enc(seq, mask) | nodes, relay = self.enc(seq, mask) | ||||
return 0.5 * (relay + nodes.max(1)[0]) | return 0.5 * (relay + nodes.max(1)[0]) | ||||
y1 = enc(word_seq1, mask1) | |||||
y2 = enc(word_seq2, mask2) | |||||
y1 = enc(words1, mask1) | |||||
y2 = enc(words2, mask2) | |||||
output = self.cls(y1, y2) # [bsz, n_cls] | output = self.cls(y1, y2) # [bsz, n_cls] | ||||
return {'output': output} | return {'output': output} | ||||
def predict(self, word_seq1, word_seq2, seq_lens1, seq_lens2): | |||||
y = self.forward(word_seq1, word_seq2, seq_lens1, seq_lens2) | |||||
def predict(self, words1, words2, seq_len1, seq_len2): | |||||
""" | |||||
:param words1: [batch, seq_len] 输入序列1 | |||||
:param words2: [batch, seq_len] 输入序列2 | |||||
:param seq_len1: [batch,] 输入序列1的长度 | |||||
:param seq_len2: [batch,] 输入序列2的长度 | |||||
:return output: [batch, num_cls] 输出分类的概率 | |||||
""" | |||||
y = self.forward(words1, words2, seq_len1, seq_len2) | |||||
_, pred = y['output'].max(1) | _, pred = y['output'].max(1) | ||||
return {'output': pred} | return {'output': pred} |
@@ -6,17 +6,17 @@ from fastNLP.modules.utils import initial_parameter | |||||
class LSTM(nn.Module): | class LSTM(nn.Module): | ||||
"""Long Short Term Memory | |||||
"""LSTM 模块, 轻量封装的Pytorch LSTM | |||||
:param int input_size: | |||||
:param int hidden_size: | |||||
:param int num_layers: | |||||
:param float dropout: | |||||
:param bool batch_first: | |||||
:param bool bidirectional: | |||||
:param bool bias: | |||||
:param str initial_method: | |||||
:param bool get_hidden: | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param dropout: 层间dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
:(batch, seq, feature). Default: ``False`` | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param get_hidden: 是否返回隐状态 `h` . Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, | ||||
bidirectional=False, bias=True, initial_method=None, get_hidden=False): | bidirectional=False, bias=True, initial_method=None, get_hidden=False): | ||||
@@ -27,14 +27,24 @@ class LSTM(nn.Module): | |||||
self.get_hidden = get_hidden | self.get_hidden = get_hidden | ||||
initial_parameter(self, initial_method) | initial_parameter(self, initial_method) | ||||
def forward(self, x, seq_lens=None, h0=None, c0=None): | |||||
def forward(self, x, seq_len=None, h0=None, c0=None): | |||||
""" | |||||
:param x: [batch, seq_len, input_size] 输入序列 | |||||
:param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` | |||||
:param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||||
:param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||||
:return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 | |||||
:和 [batch, hidden_size*num_direction] 最后时刻隐状态. | |||||
:若 ``get_hidden=False`` 仅返回输出序列. | |||||
""" | |||||
if h0 is not None and c0 is not None: | if h0 is not None and c0 is not None: | ||||
hx = (h0, c0) | hx = (h0, c0) | ||||
else: | else: | ||||
hx = None | hx = None | ||||
if seq_lens is not None and not isinstance(x, rnn.PackedSequence): | |||||
if seq_len is not None and not isinstance(x, rnn.PackedSequence): | |||||
print('padding') | print('padding') | ||||
sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) | |||||
sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) | |||||
if self.batch_first: | if self.batch_first: | ||||
x = x[sort_idx] | x = x[sort_idx] | ||||
else: | else: | ||||
@@ -5,16 +5,19 @@ import numpy as NP | |||||
class StarTransformer(nn.Module): | class StarTransformer(nn.Module): | ||||
"""Star-Transformer Encoder part。 | |||||
""" | |||||
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 | |||||
paper: https://arxiv.org/abs/1902.09113 | paper: https://arxiv.org/abs/1902.09113 | ||||
:param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param num_layers: int, star-transformer的层数 | |||||
:param num_head: int,head的数量。 | |||||
:param head_dim: int, 每个head的维度大小。 | |||||
:param dropout: float dropout 概率 | |||||
:param max_len: int or None, 如果为int,输入序列的最大长度, | |||||
模型会为属于序列加上position embedding。 | |||||
若为None,忽略加上position embedding的步骤 | |||||
:param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 | |||||
:param int num_layers: star-transformer的层数 | |||||
:param int num_head: head的数量。 | |||||
:param int head_dim: 每个head的维度大小。 | |||||
:param float dropout: dropout 概率. Default: 0.1 | |||||
:param int max_len: int or None, 如果为int,输入序列的最大长度, | |||||
模型会为输入序列加上position embedding。 | |||||
若为`None`,忽略加上position embedding的步骤. Default: `None` | |||||
""" | """ | ||||
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): | ||||
super(StarTransformer, self).__init__() | super(StarTransformer, self).__init__() | ||||
@@ -22,11 +25,11 @@ class StarTransformer(nn.Module): | |||||
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) | ||||
self.ring_att = nn.ModuleList( | self.ring_att = nn.ModuleList( | ||||
[MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
for _ in range(self.iters)]) | |||||
[_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
for _ in range(self.iters)]) | |||||
self.star_att = nn.ModuleList( | self.star_att = nn.ModuleList( | ||||
[MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
for _ in range(self.iters)]) | |||||
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) | |||||
for _ in range(self.iters)]) | |||||
if max_len is not None: | if max_len is not None: | ||||
self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) | ||||
@@ -35,10 +38,12 @@ class StarTransformer(nn.Module): | |||||
def forward(self, data, mask): | def forward(self, data, mask): | ||||
""" | """ | ||||
:param FloatTensor data: [batch, length, hidden] the input sequence | |||||
:param ByteTensor mask: [batch, length] the padding mask for input, in which padding pos is 0 | |||||
:return: [batch, length, hidden] the output sequence | |||||
[batch, hidden] the global relay node | |||||
:param FloatTensor data: [batch, length, hidden] 输入的序列 | |||||
:param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, | |||||
否则为 1 | |||||
:return: [batch, length, hidden] 编码后的输出序列 | |||||
[batch, hidden] 全局 relay 节点, 详见论文 | |||||
""" | """ | ||||
def norm_func(f, x): | def norm_func(f, x): | ||||
# B, H, L, 1 | # B, H, L, 1 | ||||
@@ -70,9 +75,9 @@ class StarTransformer(nn.Module): | |||||
return nodes, relay.view(B, H) | return nodes, relay.view(B, H) | ||||
class MSA1(nn.Module): | |||||
class _MSA1(nn.Module): | |||||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | ||||
super(MSA1, self).__init__() | |||||
super(_MSA1, self).__init__() | |||||
# Multi-head Self Attention Case 1, doing self-attention for small regions | # Multi-head Self Attention Case 1, doing self-attention for small regions | ||||
# Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small | # Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small | ||||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
@@ -113,10 +118,10 @@ class MSA1(nn.Module): | |||||
return ret | return ret | ||||
class MSA2(nn.Module): | |||||
class _MSA2(nn.Module): | |||||
def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): | ||||
# Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value | # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value | ||||
super(MSA2, self).__init__() | |||||
super(_MSA2, self).__init__() | |||||
self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) | ||||
@@ -7,13 +7,13 @@ from ..dropout import TimestepDropout | |||||
class TransformerEncoder(nn.Module): | class TransformerEncoder(nn.Module): | ||||
"""transformer的encoder模块,不包含embedding层 | """transformer的encoder模块,不包含embedding层 | ||||
:param num_layers: int, transformer的层数 | |||||
:param model_size: int, 输入维度的大小。同时也是输出维度的大小。 | |||||
:param inner_size: int, FFN层的hidden大小 | |||||
:param key_size: int, 每个head的维度大小。 | |||||
:param value_size: int,每个head中value的维度。 | |||||
:param num_head: int,head的数量。 | |||||
:param dropout: float。 | |||||
:param int num_layers: transformer的层数 | |||||
:param int model_size: 输入维度的大小。同时也是输出维度的大小。 | |||||
:param int inner_size: FFN层的hidden大小 | |||||
:param int key_size: 每个head的维度大小。 | |||||
:param int value_size: 每个head中value的维度。 | |||||
:param int num_head: head的数量。 | |||||
:param float dropout: dropout概率. Default: 0.1 | |||||
""" | """ | ||||
class SubLayer(nn.Module): | class SubLayer(nn.Module): | ||||
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): | ||||
@@ -48,7 +48,8 @@ class TransformerEncoder(nn.Module): | |||||
def forward(self, x, seq_mask=None): | def forward(self, x, seq_mask=None): | ||||
""" | """ | ||||
:param x: [batch, seq_len, model_size] 输入序列 | :param x: [batch, seq_len, model_size] 输入序列 | ||||
:param seq_mask: [batch, seq_len] 输入序列的padding mask | |||||
:param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. | |||||
Default: ``None`` | |||||
:return: [batch, seq_len, model_size] 输出序列 | :return: [batch, seq_len, model_size] 输出序列 | ||||
""" | """ | ||||
output = x | output = x | ||||
@@ -28,11 +28,11 @@ class VarRnnCellWrapper(nn.Module): | |||||
""" | """ | ||||
:param PackedSequence input_x: [seq_len, batch_size, input_size] | :param PackedSequence input_x: [seq_len, batch_size, input_size] | ||||
:param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] | ||||
for other RNN, h_0, [batch_size, hidden_size] | |||||
:for other RNN, h_0, [batch_size, hidden_size] | |||||
:param mask_x: [batch_size, input_size] dropout mask for input | :param mask_x: [batch_size, input_size] dropout mask for input | ||||
:param mask_h: [batch_size, hidden_size] dropout mask for hidden | :param mask_h: [batch_size, hidden_size] dropout mask for hidden | ||||
:return PackedSequence output: [seq_len, bacth_size, hidden_size] | :return PackedSequence output: [seq_len, bacth_size, hidden_size] | ||||
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||||
:hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] | |||||
for other RNN, h_n, [batch_size, hidden_size] | for other RNN, h_n, [batch_size, hidden_size] | ||||
""" | """ | ||||
def get_hi(hi, h0, size): | def get_hi(hi, h0, size): | ||||
@@ -84,9 +84,21 @@ class VarRnnCellWrapper(nn.Module): | |||||
class VarRNNBase(nn.Module): | class VarRNNBase(nn.Module): | ||||
"""Implementation of Variational Dropout RNN network. | |||||
refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||||
"""Variational Dropout RNN 实现. | |||||
论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) | |||||
https://arxiv.org/abs/1512.05287`. | https://arxiv.org/abs/1512.05287`. | ||||
:param mode: rnn 模式, (lstm or not) | |||||
:param Cell: rnn cell 类型, (lstm, gru, etc) | |||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
:(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, | ||||
@@ -120,36 +132,43 @@ class VarRNNBase(nn.Module): | |||||
output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) | ||||
return output_x, hidden_x | return output_x, hidden_x | ||||
def forward(self, input, hx=None): | |||||
def forward(self, x, hx=None): | |||||
""" | |||||
:param x: [batch, seq_len, input_size] 输入序列 | |||||
:param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` | |||||
:return (output, ht): [batch, seq_len, hidden_size*num_direction] 输出序列 | |||||
:和 [batch, hidden_size*num_direction] 最后时刻隐状态 | |||||
""" | |||||
is_lstm = self.is_lstm | is_lstm = self.is_lstm | ||||
is_packed = isinstance(input, PackedSequence) | |||||
is_packed = isinstance(x, PackedSequence) | |||||
if not is_packed: | if not is_packed: | ||||
seq_len = input.size(1) if self.batch_first else input.size(0) | |||||
max_batch_size = input.size(0) if self.batch_first else input.size(1) | |||||
seq_len = x.size(1) if self.batch_first else x.size(0) | |||||
max_batch_size = x.size(0) if self.batch_first else x.size(1) | |||||
seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) | seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) | ||||
input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) | |||||
x, batch_sizes = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) | |||||
else: | else: | ||||
max_batch_size = int(input.batch_sizes[0]) | |||||
input, batch_sizes = input | |||||
max_batch_size = int(x.batch_sizes[0]) | |||||
x, batch_sizes = x | |||||
if hx is None: | if hx is None: | ||||
hx = input.new_zeros(self.num_layers * self.num_directions, | |||||
max_batch_size, self.hidden_size, requires_grad=True) | |||||
hx = x.new_zeros(self.num_layers * self.num_directions, | |||||
max_batch_size, self.hidden_size, requires_grad=True) | |||||
if is_lstm: | if is_lstm: | ||||
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) | ||||
mask_x = input.new_ones((max_batch_size, self.input_size)) | |||||
mask_out = input.new_ones((max_batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = input.new_ones((max_batch_size, self.hidden_size)) | |||||
mask_x = x.new_ones((max_batch_size, self.input_size)) | |||||
mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions)) | |||||
mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) | |||||
nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) | ||||
nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) | ||||
hidden = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) | |||||
hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
if is_lstm: | if is_lstm: | ||||
cellstate = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) | |||||
cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) | |||||
for layer in range(self.num_layers): | for layer in range(self.num_layers): | ||||
output_list = [] | output_list = [] | ||||
input_seq = PackedSequence(input, batch_sizes) | |||||
input_seq = PackedSequence(x, batch_sizes) | |||||
mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) | ||||
for direction in range(self.num_directions): | for direction in range(self.num_directions): | ||||
output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, | ||||
@@ -161,22 +180,32 @@ class VarRNNBase(nn.Module): | |||||
cellstate[idx] = hidden_x[1] | cellstate[idx] = hidden_x[1] | ||||
else: | else: | ||||
hidden[idx] = hidden_x | hidden[idx] = hidden_x | ||||
input = torch.cat(output_list, dim=-1) | |||||
x = torch.cat(output_list, dim=-1) | |||||
if is_lstm: | if is_lstm: | ||||
hidden = (hidden, cellstate) | hidden = (hidden, cellstate) | ||||
if is_packed: | if is_packed: | ||||
output = PackedSequence(input, batch_sizes) | |||||
output = PackedSequence(x, batch_sizes) | |||||
else: | else: | ||||
input = PackedSequence(input, batch_sizes) | |||||
output, _ = pad_packed_sequence(input, batch_first=self.batch_first) | |||||
x = PackedSequence(x, batch_sizes) | |||||
output, _ = pad_packed_sequence(x, batch_first=self.batch_first) | |||||
return output, hidden | return output, hidden | ||||
class VarLSTM(VarRNNBase): | class VarLSTM(VarRNNBase): | ||||
"""Variational Dropout LSTM. | """Variational Dropout LSTM. | ||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
:(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
@@ -185,6 +214,16 @@ class VarLSTM(VarRNNBase): | |||||
class VarRNN(VarRNNBase): | class VarRNN(VarRNNBase): | ||||
"""Variational Dropout RNN. | """Variational Dropout RNN. | ||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
:(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
@@ -193,6 +232,16 @@ class VarRNN(VarRNNBase): | |||||
class VarGRU(VarRNNBase): | class VarGRU(VarRNNBase): | ||||
"""Variational Dropout GRU. | """Variational Dropout GRU. | ||||
:param input_size: 输入 `x` 的特征维度 | |||||
:param hidden_size: 隐状态 `h` 的特征维度 | |||||
:param num_layers: rnn的层数. Default: 1 | |||||
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` | |||||
:param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 | |||||
:(batch, seq, feature). Default: ``False`` | |||||
:param input_dropout: 对输入的dropout概率. Default: 0 | |||||
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 | |||||
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` | |||||
""" | """ | ||||
def __init__(self, *args, **kwargs): | def __init__(self, *args, **kwargs): | ||||
@@ -4,17 +4,11 @@ import unittest | |||||
import torch | import torch | ||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ | |||||
from fastNLP.core.sampler import SequentialSampler, RandomSampler, \ | |||||
k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler | k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler | ||||
class TestSampler(unittest.TestCase): | class TestSampler(unittest.TestCase): | ||||
def test_convert_to_torch_tensor(self): | |||||
data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] | |||||
ans = convert_to_torch_tensor(data, False) | |||||
assert isinstance(ans, torch.Tensor) | |||||
assert tuple(ans.shape) == (3, 5) | |||||
def test_sequential_sampler(self): | def test_sequential_sampler(self): | ||||
sampler = SequentialSampler() | sampler = SequentialSampler() | ||||
data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] | ||||
@@ -44,34 +44,34 @@ data_file = """ | |||||
def init_data(): | def init_data(): | ||||
ds = fastNLP.DataSet() | ds = fastNLP.DataSet() | ||||
v = {'word_seq': fastNLP.Vocabulary(), | |||||
'pos_seq': fastNLP.Vocabulary(), | |||||
v = {'words1': fastNLP.Vocabulary(), | |||||
'words2': fastNLP.Vocabulary(), | |||||
'label_true': fastNLP.Vocabulary()} | 'label_true': fastNLP.Vocabulary()} | ||||
data = [] | data = [] | ||||
for line in data_file.split('\n'): | for line in data_file.split('\n'): | ||||
line = line.split() | line = line.split() | ||||
if len(line) == 0 and len(data) > 0: | if len(line) == 0 and len(data) > 0: | ||||
data = list(zip(*data)) | data = list(zip(*data)) | ||||
ds.append(fastNLP.Instance(word_seq=data[1], | |||||
pos_seq=data[4], | |||||
ds.append(fastNLP.Instance(words1=data[1], | |||||
words2=data[4], | |||||
arc_true=data[6], | arc_true=data[6], | ||||
label_true=data[7])) | label_true=data[7])) | ||||
data = [] | data = [] | ||||
elif len(line) > 0: | elif len(line) > 0: | ||||
data.append(line) | data.append(line) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | |||||
for name in ['words1', 'words2', 'label_true']: | |||||
ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name) | ds.apply(lambda x: ['<st>'] + list(x[name]), new_field_name=name) | ||||
ds.apply(lambda x: v[name].add_word_lst(x[name])) | ds.apply(lambda x: v[name].add_word_lst(x[name])) | ||||
for name in ['word_seq', 'pos_seq', 'label_true']: | |||||
for name in ['words1', 'words2', 'label_true']: | |||||
ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) | ||||
ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') | ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') | ||||
ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') | |||||
ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) | |||||
ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) | |||||
return ds, v['word_seq'], v['pos_seq'], v['label_true'] | |||||
ds.apply(lambda x: len(x['words1']), new_field_name='seq_len') | |||||
ds.set_input('words1', 'words2', 'seq_len', flag=True) | |||||
ds.set_target('arc_true', 'label_true', 'seq_len', flag=True) | |||||
return ds, v['words1'], v['words2'], v['label_true'] | |||||
class TestBiaffineParser(unittest.TestCase): | class TestBiaffineParser(unittest.TestCase): | ||||
@@ -437,4 +437,10 @@ class TestTutorial(unittest.TestCase): | |||||
) | ) | ||||
tester.test() | tester.test() | ||||
os.chdir("../..") | |||||
def setUp(self): | |||||
import os | |||||
self._init_wd = os.path.abspath(os.curdir) | |||||
def tearDown(self): | |||||
import os | |||||
os.chdir(self._init_wd) |