diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 24a1ab1d..c4c21832 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -47,7 +47,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.utils import load_url from fastNLP.api.processor import ModelProcessor -from fastNLP.io.dataset_loader import cut_long_sentence, ConllLoader +from fastNLP.io.dataset_loader import _cut_long_sentence, ConllLoader from fastNLP.core.instance import Instance from fastNLP.api.pipeline import Pipeline from fastNLP.core.metrics import SpanFPreRecMetric @@ -107,7 +107,7 @@ class ConllCWSReader(object): continue line = ' '.join(res) if cut_long_sent: - sents = cut_long_sentence(line) + sents = _cut_long_sentence(line) else: sents = [line] for raw_sentence in sents: diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index dbe86953..087882aa 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -5,7 +5,7 @@ from .instance import Instance from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward from .metrics import AccuracyMetric from .optimizer import Optimizer, SGD, Adam -from .sampler import SequentialSampler, BucketSampler, RandomSampler, BaseSampler +from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler from .tester import Tester from .trainer import Trainer from .vocabulary import Vocabulary diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py index 9d65ada8..3a62cefe 100644 --- a/fastNLP/core/batch.py +++ b/fastNLP/core/batch.py @@ -2,7 +2,7 @@ import numpy as np import torch import atexit -from fastNLP.core.sampler import RandomSampler +from fastNLP.core.sampler import RandomSampler, Sampler import torch.multiprocessing as mp _python_is_exit = False @@ -12,19 +12,25 @@ def _set_python_is_exit(): atexit.register(_set_python_is_exit) class Batch(object): - """Batch is an iterable object which iterates over mini-batches. - - Example:: - - for batch_x, batch_y in Batch(data_set, batch_size=16, sampler=SequentialSampler()): - # ... - - :param DataSet dataset: a DataSet object - :param int batch_size: the size of the batch - :param Sampler sampler: a Sampler object. If None, use fastNLP.sampler.RandomSampler - :param bool as_numpy: If True, return Numpy array. Otherwise, return torch tensors. - :param bool prefetch: If True, use multiprocessing to fetch next batch when training. - :param str or torch.device device: the batch's device, if as_numpy is True, device is ignored. + """ + Batch 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出. + 组成 `x` 和 `y` + + Example:: + + batch = Batch(data_set, batch_size=16, sampler=SequentialSampler()) + num_batch = len(batch) + for batch_x, batch_y in batch: + # do stuff ... + + :param DataSet dataset: `DataSet` 对象, 数据集 + :param int batch_size: 取出的batch大小 + :param Sampler sampler: 规定使用的 Sample 方式. 若为 ``None`` , 使用 RandomSampler. + Default: ``None`` + :param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 torch.Tensor. + Default: ``False`` + :param bool prefetch: 若为 ``True`` 使用多进程预先取出下一batch. + Default: ``False`` """ def __init__(self, dataset, batch_size, sampler=None, as_numpy=False, prefetch=False): @@ -41,7 +47,7 @@ class Batch(object): self.prefetch = prefetch self.lengths = 0 - def fetch_one(self): + def _fetch_one(self): if self.curidx >= len(self.idx_list): return None else: @@ -55,7 +61,7 @@ class Batch(object): if field.is_target or field.is_input: batch = field.get(indices) if not self.as_numpy and field.padder is not None: - batch = to_tensor(batch, field.dtype) + batch = _to_tensor(batch, field.dtype) if field.is_target: batch_y[field_name] = batch if field.is_input: @@ -70,17 +76,17 @@ class Batch(object): :return: """ if self.prefetch: - return run_batch_iter(self) + return _run_batch_iter(self) def batch_iter(): - self.init_iter() + self._init_iter() while 1: - res = self.fetch_one() + res = self._fetch_one() if res is None: break yield res return batch_iter() - def init_iter(self): + def _init_iter(self): self.idx_list = self.sampler(self.dataset) self.curidx = 0 self.lengths = self.dataset.get_length() @@ -89,10 +95,14 @@ class Batch(object): return self.num_batches def get_batch_indices(self): + """取得当前batch在DataSet中所在的index下标序列 + + :return list(int) indexes: 下标序列 + """ return self.cur_batch_indices -def to_tensor(batch, dtype): +def _to_tensor(batch, dtype): try: if dtype in (int, np.int8, np.int16, np.int32, np.int64): batch = torch.LongTensor(batch) @@ -103,12 +113,12 @@ def to_tensor(batch, dtype): return batch -def run_fetch(batch, q): +def _run_fetch(batch, q): global _python_is_exit - batch.init_iter() + batch._init_iter() # print('start fetch') while 1: - res = batch.fetch_one() + res = batch._fetch_one() # print('fetch one') while 1: try: @@ -124,9 +134,9 @@ def run_fetch(batch, q): # print('fetch exit') -def run_batch_iter(batch): +def _run_batch_iter(batch): q = mp.JoinableQueue(maxsize=10) - fetch_p = mp.Process(target=run_fetch, args=(batch, q)) + fetch_p = mp.Process(target=_run_fetch, args=(batch, q)) fetch_p.daemon = True fetch_p.start() # print('fork fetch process') diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 3a4dfa55..68dfcc51 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -482,7 +482,7 @@ class DataSet(object): """ import warnings - warnings.warn('read_csv is deprecated, use CSVLoader instead', + warnings.warn('DataSet.read_csv is deprecated, use CSVLoader instead', category=DeprecationWarning) with open(csv_path, "r", encoding='utf-8') as f: start_idx = 0 diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py index 4a523f10..080825df 100644 --- a/fastNLP/core/sampler.py +++ b/fastNLP/core/sampler.py @@ -3,72 +3,49 @@ from itertools import chain import numpy as np import torch +class Sampler(object): + """ `Sampler` 类的基类. 规定以何种顺序取出data中的元素 -def convert_to_torch_tensor(data_list, use_cuda): - """Convert lists into (cuda) Tensors. - - :param data_list: 2-level lists - :param use_cuda: bool, whether to use GPU or not - :return data_list: PyTorch Tensor of shape [batch_size, max_seq_len] - """ - data_list = torch.Tensor(data_list).long() - if torch.cuda.is_available() and use_cuda: - data_list = data_list.cuda() - return data_list - - -class BaseSampler(object): - """The base class of all samplers. - - Sub-classes must implement the ``__call__`` method. - ``__call__`` takes a DataSet object and returns a list of int - the sampling indices. + 子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 """ - def __call__(self, *args, **kwargs): + def __call__(self, data_set): + """ + :param DataSet data_set: `DataSet` 对象, 需要Sample的数据 + :return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 + """ raise NotImplementedError -class SequentialSampler(BaseSampler): - """Sample data in the original order. +class SequentialSampler(Sampler): + """顺序取出元素的 `Sampler` """ def __call__(self, data_set): - """ - - :param DataSet data_set: - :return result: a list of integers. - """ return list(range(len(data_set))) -class RandomSampler(BaseSampler): - """Sample data in random permutation order. +class RandomSampler(Sampler): + """随机化取元素的 `Sampler` """ def __call__(self, data_set): - """ - - :param DataSet data_set: - :return result: a list of integers. - """ return list(np.random.permutation(len(data_set))) -class BucketSampler(BaseSampler): - """ - - :param int num_buckets: the number of buckets to use. - :param int batch_size: batch size per epoch. - :param str seq_lens_field_name: the field name indicating the field about sequence length. +class BucketSampler(Sampler): + """带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 + :param int num_buckets: bucket的数量 + :param int batch_size: batch的大小 + :param str seq_lens_field_name: 对应序列长度的 `field` 的名字 """ - def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_lens'): + def __init__(self, num_buckets=10, batch_size=32, seq_lens_field_name='seq_len'): self.num_buckets = num_buckets self.batch_size = batch_size self.seq_lens_field_name = seq_lens_field_name def __call__(self, data_set): - seq_lens = data_set.get_all_fields()[self.seq_lens_field_name].content total_sample_num = len(seq_lens) diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index 67e7d2c0..867989bf 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -18,7 +18,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.core.losses import _prepare_losser from fastNLP.core.metrics import _prepare_metrics from fastNLP.core.optimizer import Adam -from fastNLP.core.sampler import BaseSampler +from fastNLP.core.sampler import Sampler from fastNLP.core.sampler import RandomSampler from fastNLP.core.sampler import SequentialSampler from fastNLP.core.tester import Tester @@ -57,7 +57,7 @@ class Trainer(object): smaller, add "-" in front of the string. For example:: metric_key="-PPL" # language model gets better as perplexity gets smaller - :param BaseSampler sampler: method used to generate batch data. + :param Sampler sampler: method used to generate batch data. :param prefetch: bool, 是否使用额外的进程对产生batch数据。 :param bool use_tqdm: whether to use tqdm to show train progress. :param callbacks: List[Callback]. 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 @@ -102,7 +102,7 @@ class Trainer(object): losser = _prepare_losser(loss) # sampler check - if sampler is not None and not isinstance(sampler, BaseSampler): + if sampler is not None and not isinstance(sampler, Sampler): raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) if check_code_level > -1: diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index c580dbec..6a1830ad 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -1,3 +1,4 @@ +from functools import wraps from collections import Counter from fastNLP.core.dataset import DataSet @@ -5,7 +6,7 @@ def check_build_vocab(func): """A decorator to make sure the indexing is built before used. """ - + @wraps(func) # to solve missing docstring def _wrapper(self, *args, **kwargs): if self.word2idx is None or self.rebuild is True: self.build_vocab() @@ -18,7 +19,7 @@ def check_build_status(func): """A decorator to check whether the vocabulary updates after the last build. """ - + @wraps(func) # to solve missing docstring def _wrapper(self, *args, **kwargs): if self.rebuild is False: self.rebuild = True @@ -32,23 +33,28 @@ def check_build_status(func): class Vocabulary(object): - """Use for word and index one to one mapping + """ + 用于构建, 存储和使用 `str` 到 `int` 的一一映射 Example:: vocab = Vocabulary() word_list = "this is a word list".split() vocab.update(word_list) - vocab["word"] - vocab.to_word(5) - - :param int max_size: set the max number of words in Vocabulary. Default: None - :param int min_freq: set the min occur frequency of words in Vocabulary. Default: None - :param padding: str, padding的字符,默认为。如果设置为None,则vocabulary中不考虑padding,为None的情况多在为label建立 - Vocabulary的情况。 - :param unknown: str, unknown的字符,默认为。如果设置为None,则vocabulary中不考虑unknown,为None的情况多在为label建立 - Vocabulary的情况。 - + vocab["word"] # str to int + vocab.to_word(5) # int to str + + :param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 + 若为 ``None`` , 则不限制大小. Default: ``None`` + :param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. + 若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` + :param str padding: padding的字符. 如果设置为 ``None`` , + 则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. + Default: '' + :param str unknow: unknow的字符,所有未被记录的词在转为 `int` 时将被视为unknown. + 如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. + 为 ``None`` 的情况多在为label建立Vocabulary的情况. + Default: '' """ def __init__(self, max_size=None, min_freq=None, padding='', unknown=''): @@ -63,7 +69,7 @@ class Vocabulary(object): @check_build_status def update(self, word_lst): - """Add a list of words into the vocabulary. + """依次增加序列中词在词典中的出现频率 :param list word_lst: a list of strings """ @@ -71,32 +77,35 @@ class Vocabulary(object): @check_build_status def add(self, word): - """Add a single word into the vocabulary. + """ + 增加一个新词在词典中的出现频率 - :param str word: a word or token. + :param str word: 新词 """ self.word_count[word] += 1 @check_build_status def add_word(self, word): - """Add a single word into the vocabulary. - - :param str word: a word or token. + """ + 增加一个新词在词典中的出现频率 + :param str word: 新词 """ self.add(word) @check_build_status def add_word_lst(self, word_lst): - """Add a list of words into the vocabulary. - - :param list word_lst: a list of strings + """ + 依次增加序列中词在词典中的出现频率 + :param list(str) word_lst: 词的序列 """ self.update(word_lst) def build_vocab(self): - """Build a mapping from word to index, and filter the word using ``max_size`` and ``min_freq``. + """ + 根据已经出现的词和出现频率构建词典. 注意: 重复构建可能会改变词典的大小, + 但已经记录在词典中的词, 不会改变对应的 `int` """ self.word2idx = {} @@ -117,7 +126,8 @@ class Vocabulary(object): self.rebuild = False def build_reverse_vocab(self): - """Build "index to word" dict based on "word to index" dict. + """ + 基于 "word to index" dict, 构建 "index to word" dict. """ self.idx2word = {i: w for w, i in self.word2idx.items()} @@ -128,7 +138,8 @@ class Vocabulary(object): @check_build_vocab def __contains__(self, item): - """Check if a word in vocabulary. + """ + 检查词是否被记录 :param item: the word :return: True or False @@ -136,11 +147,24 @@ class Vocabulary(object): return item in self.word2idx def has_word(self, w): + """ + 检查词是否被记录 + + Example:: + + has_abc = vocab.has_word('abc') + # equals to + has_abc = 'abc' in vocab + + :param item: the word + :return: ``True`` or ``False`` + """ return self.__contains__(w) @check_build_vocab def __getitem__(self, w): - """To support usage like:: + """ + To support usage like:: vocab[w] """ @@ -154,14 +178,19 @@ class Vocabulary(object): @check_build_vocab def index_dataset(self, *datasets, field_name, new_field_name=None): """ - example: - # remember to use `field_name` - vocab.index_dataset(tr_data, dev_data, te_data, field_name='words') + 将DataSet中对应field的词转为数字. + + Example:: - :param datasets: fastNLP Dataset type. you can pass multiple datasets - :param field_name: str, what field to index. Only support 0,1,2 dimension. - :param new_field_name: str. What the indexed field should be named, default is to overwrite field_name - :return: + # remember to use `field_name` + vocab.index_dataset(train_data, dev_data, test_data, field_name='words') + + :param DataSet datasets: 需要转index的 DataSet, 支持一个或多个 + :param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. + 目前仅支持 ``str`` , ``list(str)`` , ``list(list(str))`` + :param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. + Default: ``None`` + :return self: """ def index_instance(ins): """ @@ -194,11 +223,18 @@ class Vocabulary(object): def from_dataset(self, *datasets, field_name): """ - Construct vocab from dataset. + 使用dataset的对应field中词构建词典 + + Example:: + + # remember to use `field_name` + vocab.from_dataset(train_data1, train_data2, field_name='words') - :param datasets: DataSet. - :param field_name: str, what field is used to construct dataset. - :return: + :param DataSet datasets: 需要转index的 DataSet, 支持一个或多个. + :param str field_name: 构建词典所使用的 field. + 若有多个 DataSet, 每个DataSet都必须有此 field. + 目前仅支持 ``str`` , ``list(str)`` , ``list(list(str))`` + :return self: """ def construct_vocab(ins): field = ins[field_name] @@ -223,15 +259,27 @@ class Vocabulary(object): return self def to_index(self, w): - """ Turn a word to an index. If w is not in Vocabulary, return the unknown label. + """ + 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 + ``ValueError`` + + Example:: + + index = vocab.to_index('abc') + # equals to + index = vocab['abc'] :param str w: a word + :return int index: the number """ return self.__getitem__(w) @property @check_build_vocab def unknown_idx(self): + """ + unknown 对应的数字. + """ if self.unknown is None: return None return self.word2idx[self.unknown] @@ -239,16 +287,20 @@ class Vocabulary(object): @property @check_build_vocab def padding_idx(self): + """ + padding 对应的数字 + """ if self.padding is None: return None return self.word2idx[self.padding] @check_build_vocab def to_word(self, idx): - """given a word's index, return the word itself + """ + 给定一个数字, 将其转为对应的词. :param int idx: the index - :return str word: the indexed word + :return str word: the word """ return self.idx2word[idx] diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index 5657e194..039c4242 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -4,7 +4,7 @@ from nltk.tree import Tree from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance -from fastNLP.io.file_reader import read_csv, read_json, read_conll +from fastNLP.io.file_reader import _read_csv, _read_json, _read_conll def _download_from_url(url, path): @@ -55,12 +55,12 @@ def _uncompress(src, dst): class DataSetLoader: - """Interface for all DataSetLoaders. + """所有`DataSetLoader`的接口 """ def load(self, path): - """Load data from a given file. + """从指定 ``path`` 的文件中读取数据,返回DataSet :param str path: file path :return: a DataSet object @@ -68,7 +68,7 @@ class DataSetLoader: raise NotImplementedError def convert(self, data): - """Optional operation to build a DataSet. + """用Python数据对象创建DataSet :param data: inner data structure (user-defined) to represent the data. :return: a DataSet object @@ -77,7 +77,7 @@ class DataSetLoader: class PeopleDailyCorpusLoader(DataSetLoader): - """人民日报数据集 + """读取人民日报数据集 """ def __init__(self): super(PeopleDailyCorpusLoader, self).__init__() @@ -154,8 +154,35 @@ class PeopleDailyCorpusLoader(DataSetLoader): return data_set -class ConllLoader: +class ConllLoader(DataSetLoader): + """ + 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html + + 列号从0开始, 每列对应内容为:: + + Column Type + 0 Document ID + 1 Part number + 2 Word number + 3 Word itself + 4 Part-of-Speech + 5 Parse bit + 6 Predicate lemma + 7 Predicate Frameset ID + 8 Word sense + 9 Speaker/Author + 10 Named Entities + 11:N Predicate Arguments + N Coreference + + :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexs`` 一一对应 + :param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` + :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` + """ def __init__(self, headers, indexs=None, dropna=True): + super(ConllLoader, self).__init__() + if not isinstance(headers, (list, tuple)): + raise TypeError('invalid headers: {}, should be list of strings'.format(headers)) self.headers = headers self.dropna = dropna if indexs is None: @@ -167,24 +194,17 @@ class ConllLoader: def load(self, path): ds = DataSet() - for idx, data in read_conll(path, indexes=self.indexs, dropna=self.dropna): - ins = {h:data[idx] for h, idx in zip(self.headers, self.indexs)} + for idx, data in _read_conll(path, indexes=self.indexs, dropna=self.dropna): + ins = {h:data[i] for i, h in enumerate(self.headers)} ds.append(Instance(**ins)) return ds - def get_one(self, sample): - sample = list(map(list, zip(*sample))) - for field in sample: - if len(field) <= 0: - return None - return sample - class Conll2003Loader(ConllLoader): - """Loader for conll2003 dataset + """读取Conll2003数据 - More information about the given dataset cound be found on - https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data + 关于数据集的更多信息,参考: + https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ def __init__(self): headers = [ @@ -193,9 +213,10 @@ class Conll2003Loader(ConllLoader): super(Conll2003Loader, self).__init__(headers=headers) -def cut_long_sentence(sent, max_sample_length=200): +def _cut_long_sentence(sent, max_sample_length=200): """ - 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length + 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。 + 所以截取的句子可能长于或者短于max_sample_length :param sent: str. :param max_sample_length: int. @@ -223,8 +244,15 @@ def cut_long_sentence(sent, max_sample_length=200): class SSTLoader(DataSetLoader): - """load SST data in PTB tree format - data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + """读取SST数据集, DataSet包含fields:: + + words: list(str) 需要分类的文本 + target: str 文本的标签 + + 数据来源: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip + + :param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` + :param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` """ def __init__(self, subtree=False, fine_grained=False): self.subtree = subtree @@ -247,14 +275,14 @@ class SSTLoader(DataSetLoader): datas = [] for l in f: datas.extend([(s, self.tag_v[t]) - for s, t in self.get_one(l, self.subtree)]) + for s, t in self._get_one(l, self.subtree)]) ds = DataSet() for words, tag in datas: - ds.append(Instance(words=words, raw_tag=tag)) + ds.append(Instance(words=words, target=tag)) return ds @staticmethod - def get_one(data, subtree): + def _get_one(data, subtree): tree = Tree.fromstring(data) if subtree: return [(t.leaves(), t.label()) for t in tree.subtrees()] @@ -262,11 +290,17 @@ class SSTLoader(DataSetLoader): class JsonLoader(DataSetLoader): - """Load json-format data, - every line contains a json obj, like a dict - fields is the dict key that need to be load """ - def __init__(self, dropna=False, fields=None): + 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 + + :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name + ``fields`` 的`key`必须是json对象的属性名. ``fields`` 的`value`为读入后在DataSet存储的`field_name`, + `value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 + ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``True`` + """ + def __init__(self, fields=None, dropna=False): super(JsonLoader, self).__init__() self.dropna = dropna self.fields = None @@ -279,7 +313,7 @@ class JsonLoader(DataSetLoader): def load(self, path): ds = DataSet() - for idx, d in read_json(path, fields=self.fields_list, dropna=self.dropna): + for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): ins = {self.fields[k]:v for k,v in d.items()} ds.append(Instance(**ins)) return ds @@ -287,7 +321,13 @@ class JsonLoader(DataSetLoader): class SNLILoader(JsonLoader): """ - data source: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + 读取SNLI数据集,读取的DataSet包含fields:: + + words1: list(str),第一句文本, premise + words2: list(str), 第二句文本, hypothesis + target: str, 真实标签 + + 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip """ def __init__(self): fields = { @@ -309,14 +349,14 @@ class SNLILoader(JsonLoader): class CSVLoader(DataSetLoader): - """Load data from a CSV file and return a DataSet object. - - :param str csv_path: path to the CSV file - :param List[str] or Tuple[str] headers: headers of the CSV file - :param str sep: delimiter in CSV file. Default: "," - :param bool dropna: If True, drop rows that have less entries than headers. - :return dataset: the read data set + """ + 读取CSV格式的数据集。返回 ``DataSet`` + :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 + 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` + :param str sep: CSV文件中列与列之间的分隔符. Default: "," + :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . + Default: ``True`` """ def __init__(self, headers=None, sep=",", dropna=True): self.headers = headers @@ -325,8 +365,8 @@ class CSVLoader(DataSetLoader): def load(self, path): ds = DataSet() - for idx, data in read_csv(path, headers=self.headers, - sep=self.sep, dropna=self.dropna): + for idx, data in _read_csv(path, headers=self.headers, + sep=self.sep, dropna=self.dropna): ds.append(Instance(**data)) return ds diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 22766ebb..ffbab510 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -1,15 +1,16 @@ import json -def read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): +def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): """ - Construct a generator to read csv items + Construct a generator to read csv items. + :param path: file path :param encoding: file's encoding, default: utf-8 :param headers: file's headers, if None, make file's first line as headers. default: None :param sep: separator for each column. default: ',' :param dropna: weather to ignore and drop invalid data, - if False, raise ValueError when reading invalid data. default: True + :if False, raise ValueError when reading invalid data. default: True :return: generator, every time yield (line number, csv item) """ with open(path, 'r', encoding=encoding) as f: @@ -35,14 +36,15 @@ def read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): yield line_idx, _dict -def read_json(path, encoding='utf-8', fields=None, dropna=True): +def _read_json(path, encoding='utf-8', fields=None, dropna=True): """ - Construct a generator to read json items + Construct a generator to read json items. + :param path: file path :param encoding: file's encoding, default: utf-8 :param fields: json object's fields that needed, if None, all fields are needed. default: None :param dropna: weather to ignore and drop invalid data, - if False, raise ValueError when reading invalid data. default: True + :if False, raise ValueError when reading invalid data. default: True :return: generator, every time yield (line number, json item) """ if fields: @@ -65,14 +67,15 @@ def read_json(path, encoding='utf-8', fields=None, dropna=True): yield line_idx, _res -def read_conll(path, encoding='utf-8', indexes=None, dropna=True): +def _read_conll(path, encoding='utf-8', indexes=None, dropna=True): """ - Construct a generator to read conll items + Construct a generator to read conll items. + :param path: file path :param encoding: file's encoding, default: utf-8 :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None :param dropna: weather to ignore and drop invalid data, - if False, raise ValueError when reading invalid data. default: True + :if False, raise ValueError when reading invalid data. default: True :return: generator, every time yield (line number, conll item) """ def parse_conll(sample): diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py index dc294eb3..9a070c92 100644 --- a/fastNLP/models/biaffine_parser.py +++ b/fastNLP/models/biaffine_parser.py @@ -16,7 +16,7 @@ from fastNLP.modules.utils import initial_parameter from fastNLP.modules.utils import seq_mask -def mst(scores): +def _mst(scores): """ with some modification to support parser output for MST decoding https://github.com/tdozat/Parser/blob/0739216129cd39d69997d28cbc4133b360ea3934/lib/models/nn.py#L692 @@ -120,12 +120,22 @@ def _find_cycle(vertices, edges): class GraphParser(BaseModel): - """Graph based Parser helper class, support greedy decoding and MST(Maximum Spanning Tree) decoding + """ + 基于图的parser base class, 支持贪婪解码和最大生成树解码 """ def __init__(self): super(GraphParser, self).__init__() - def _greedy_decoder(self, arc_matrix, mask=None): + @staticmethod + def greedy_decoder(arc_matrix, mask=None): + """ + 贪心解码方式, 输入图, 输出贪心解码的parsing结果, 不保证合法的构成树 + + :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 + :param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. + 若为 ``None`` 时, 默认为全1向量. Default: ``None`` + :return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 + """ _, seq_len, _ = arc_matrix.shape matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf)) flip_mask = (mask == 0).byte() @@ -135,22 +145,34 @@ class GraphParser(BaseModel): heads *= mask.long() return heads - def _mst_decoder(self, arc_matrix, mask=None): + @staticmethod + def mst_decoder(arc_matrix, mask=None): + """ + 用最大生成树算法, 计算parsing结果, 保证输出合法的树结构 + + :param arc_matrix: [batch, seq_len, seq_len] 输入图矩阵 + :param mask: [batch, seq_len] 输入图的padding mask, 有内容的部分为 1, 否则为 0. + 若为 ``None`` 时, 默认为全1向量. Default: ``None`` + :return heads: [batch, seq_len] 每个元素在树中对应的head(parent)预测结果 + """ batch_size, seq_len, _ = arc_matrix.shape matrix = arc_matrix.clone() ans = matrix.new_zeros(batch_size, seq_len).long() lens = (mask.long()).sum(1) if mask is not None else torch.zeros(batch_size) + seq_len - batch_idx = torch.arange(batch_size, dtype=torch.long, device=lens.device) for i, graph in enumerate(matrix): len_i = lens[i] - ans[i, :len_i] = torch.as_tensor(mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) + ans[i, :len_i] = torch.as_tensor(_mst(graph.detach()[:len_i, :len_i].cpu().numpy()), device=ans.device) if mask is not None: ans *= mask.long() return ans class ArcBiaffine(nn.Module): - """helper module for Biaffine Dependency Parser predicting arc + """ + Biaffine Dependency Parser 的子模块, 用于构建预测边的图 + + :param hidden_size: 输入的特征维度 + :param bias: 是否使用bias. Default: ``True`` """ def __init__(self, hidden_size, bias=True): super(ArcBiaffine, self).__init__() @@ -164,10 +186,10 @@ class ArcBiaffine(nn.Module): def forward(self, head, dep): """ - :param head arc-head tensor = [batch, length, emb_dim] - :param dep arc-dependent tensor = [batch, length, emb_dim] - :return output tensor = [bacth, length, length] + :param head: arc-head tensor [batch, length, hidden] + :param dep: arc-dependent tensor [batch, length, hidden] + :return output: tensor [bacth, length, length] """ output = dep.matmul(self.U) output = output.bmm(head.transpose(-1, -2)) @@ -177,7 +199,13 @@ class ArcBiaffine(nn.Module): class LabelBilinear(nn.Module): - """helper module for Biaffine Dependency Parser predicting label + """ + Biaffine Dependency Parser 的子模块, 用于构建预测边类别的图 + + :param in1_features: 输入的特征1维度 + :param in2_features: 输入的特征2维度 + :param num_label: 边类别的个数 + :param bias: 是否使用bias. Default: ``True`` """ def __init__(self, in1_features, in2_features, num_label, bias=True): super(LabelBilinear, self).__init__() @@ -185,14 +213,34 @@ class LabelBilinear(nn.Module): self.lin = nn.Linear(in1_features + in2_features, num_label, bias=False) def forward(self, x1, x2): + """ + + :param x1: [batch, seq_len, hidden] 输入特征1, 即label-head + :param x2: [batch, seq_len, hidden] 输入特征2, 即label-dep + :return output: [batch, seq_len, num_cls] 每个元素对应类别的概率图 + """ output = self.bilinear(x1, x2) output += self.lin(torch.cat([x1, x2], dim=2)) return output class BiaffineParser(GraphParser): - """Biaffine Dependency Parser implemantation. - refer to ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) + """Biaffine Dependency Parser 实现. + 论文参考 ` Deep Biaffine Attention for Neural Dependency Parsing (Dozat and Manning, 2016) `_ . + + :param word_vocab_size: 单词词典大小 + :param word_emb_dim: 单词词嵌入向量的维度 + :param pos_vocab_size: part-of-speech 词典大小 + :param pos_emb_dim: part-of-speech 向量维度 + :param num_label: 边的类别个数 + :param rnn_layers: rnn encoder的层数 + :param rnn_hidden_size: rnn encoder 的隐状态维度 + :param arc_mlp_size: 边预测的MLP维度 + :param label_mlp_size: 类别预测的MLP维度 + :param dropout: dropout概率. + :param encoder: encoder类别, 可选 ('lstm', 'var-lstm', 'transformer'). Default: lstm + :param use_greedy_infer: 是否在inference时使用贪心算法. + 若 ``False`` , 使用更加精确但相对缓慢的MST算法. Default: ``False`` """ def __init__(self, word_vocab_size, @@ -207,7 +255,6 @@ class BiaffineParser(GraphParser): dropout=0.3, encoder='lstm', use_greedy_infer=False): - super(BiaffineParser, self).__init__() rnn_out_size = 2 * rnn_hidden_size word_hid_dim = pos_hid_dim = rnn_hidden_size @@ -275,27 +322,31 @@ class BiaffineParser(GraphParser): for p in m.parameters(): nn.init.normal_(p, 0, 0.1) - def forward(self, word_seq, pos_seq, seq_lens, gold_heads=None): - """ - :param word_seq: [batch_size, seq_len] sequence of word's indices - :param pos_seq: [batch_size, seq_len] sequence of word's indices - :param seq_lens: [batch_size, seq_len] sequence of length masks - :param gold_heads: [batch_size, seq_len] sequence of golden heads - :return dict: parsing results - arc_pred: [batch_size, seq_len, seq_len] - label_pred: [batch_size, seq_len, seq_len] - mask: [batch_size, seq_len] - head_pred: [batch_size, seq_len] if gold_heads is not provided, predicting the heads + def forward(self, words1, words2, seq_len, gold_heads=None): + """模型forward阶段 + + :param words1: [batch_size, seq_len] 输入word序列 + :param words2: [batch_size, seq_len] 输入pos序列 + :param seq_len: [batch_size, seq_len] 输入序列长度 + :param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, + 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 + Default: ``None`` + :return dict: parsing结果:: + + arc_pred: [batch_size, seq_len, seq_len] 边预测logits + label_pred: [batch_size, seq_len, num_label] label预测logits + mask: [batch_size, seq_len] 预测结果的mask + head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 """ # prepare embeddings - batch_size, seq_len = word_seq.shape + batch_size, length = words1.shape # print('forward {} {}'.format(batch_size, seq_len)) # get sequence mask - mask = seq_mask(seq_lens, seq_len).long() + mask = seq_mask(seq_len, length).long() - word = self.word_embedding(word_seq) # [N,L] -> [N,L,C_0] - pos = self.pos_embedding(pos_seq) # [N,L] -> [N,L,C_1] + word = self.word_embedding(words1) # [N,L] -> [N,L,C_0] + pos = self.pos_embedding(words2) # [N,L] -> [N,L,C_1] word, pos = self.word_fc(word), self.pos_fc(pos) word, pos = self.word_norm(word), self.pos_norm(pos) @@ -303,7 +354,7 @@ class BiaffineParser(GraphParser): # encoder, extract features if self.encoder_name.endswith('lstm'): - sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) + sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) x = x[sort_idx] x = nn.utils.rnn.pack_padded_sequence(x, sort_lens, batch_first=True) feat, _ = self.encoder(x) # -> [N,L,C] @@ -329,20 +380,20 @@ class BiaffineParser(GraphParser): if gold_heads is None or not self.training: # use greedy decoding in training if self.training or self.use_greedy_infer: - heads = self._greedy_decoder(arc_pred, mask) + heads = self.greedy_decoder(arc_pred, mask) else: - heads = self._mst_decoder(arc_pred, mask) + heads = self.mst_decoder(arc_pred, mask) head_pred = heads else: assert self.training # must be training mode if gold_heads is None: - heads = self._greedy_decoder(arc_pred, mask) + heads = self.greedy_decoder(arc_pred, mask) head_pred = heads else: head_pred = None heads = gold_heads - batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=word_seq.device).unsqueeze(1) + batch_range = torch.arange(start=0, end=batch_size, dtype=torch.long, device=words1.device).unsqueeze(1) label_head = label_head[batch_range, heads].contiguous() label_pred = self.label_predictor(label_head, label_dep) # [N, L, num_label] res_dict = {'arc_pred': arc_pred, 'label_pred': label_pred, 'mask': mask} @@ -355,11 +406,11 @@ class BiaffineParser(GraphParser): """ Compute loss. - :param arc_pred: [batch_size, seq_len, seq_len] - :param label_pred: [batch_size, seq_len, n_tags] - :param arc_true: [batch_size, seq_len] - :param label_true: [batch_size, seq_len] - :param mask: [batch_size, seq_len] + :param arc_pred: [batch_size, seq_len, seq_len] 边预测logits + :param label_pred: [batch_size, seq_len, num_label] label预测logits + :param arc_true: [batch_size, seq_len] 真实边的标注 + :param label_true: [batch_size, seq_len] 真实类别的标注 + :param mask: [batch_size, seq_len] 预测结果的mask :return: loss value """ @@ -381,16 +432,23 @@ class BiaffineParser(GraphParser): label_nll = -label_loss.mean() return arc_nll + label_nll - def predict(self, word_seq, pos_seq, seq_lens): - """ - - :param word_seq: - :param pos_seq: - :param seq_lens: - :return: arc_pred: [B, L] - label_pred: [B, L] + def predict(self, words1, words2, seq_len): + """模型预测API + + :param words1: [batch_size, seq_len] 输入word序列 + :param words2: [batch_size, seq_len] 输入pos序列 + :param seq_len: [batch_size, seq_len] 输入序列长度 + :param gold_heads: [batch_size, seq_len] 输入真实标注的heads, 仅在训练阶段有效, + 用于训练label分类器. 若为 ``None`` , 使用预测的heads输入到label分类器 + Default: ``None`` + :return dict: parsing结果:: + + arc_pred: [batch_size, seq_len, seq_len] 边预测logits + label_pred: [batch_size, seq_len, num_label] label预测logits + mask: [batch_size, seq_len] 预测结果的mask + head_pred: [batch_size, seq_len] heads的预测结果, 在 ``gold_heads=None`` 时预测 """ - res = self(word_seq, pos_seq, seq_lens) + res = self(words1, words2, seq_len) output = {} output['arc_pred'] = res.pop('head_pred') _, label_pred = res.pop('label_pred').max(2) @@ -399,6 +457,16 @@ class BiaffineParser(GraphParser): class ParserLoss(LossFunc): + """ + 计算parser的loss + + :param arc_pred: [batch_size, seq_len, seq_len] 边预测logits + :param label_pred: [batch_size, seq_len, num_label] label预测logits + :param arc_true: [batch_size, seq_len] 真实边的标注 + :param label_true: [batch_size, seq_len] 真实类别的标注 + :param mask: [batch_size, seq_len] 预测结果的mask + :return loss: scalar + """ def __init__(self, arc_pred=None, label_pred=None, arc_true=None, label_true=None): super(ParserLoss, self).__init__(BiaffineParser.loss, arc_pred=arc_pred, @@ -408,12 +476,26 @@ class ParserLoss(LossFunc): class ParserMetric(MetricBase): + """ + 评估parser的性能 + + :param arc_pred: 边预测logits + :param label_pred: label预测logits + :param arc_true: 真实边的标注 + :param label_true: 真实类别的标注 + :param seq_len: 序列长度 + :return dict: 评估结果:: + + UAS: 不带label时, 边预测的准确率 + LAS: 同时预测边和label的准确率 + """ def __init__(self, arc_pred=None, label_pred=None, - arc_true=None, label_true=None, seq_lens=None): + arc_true=None, label_true=None, seq_len=None): + super().__init__() self._init_param_map(arc_pred=arc_pred, label_pred=label_pred, arc_true=arc_true, label_true=label_true, - seq_lens=seq_lens) + seq_len=seq_len) self.num_arc = 0 self.num_label = 0 self.num_sample = 0 @@ -424,13 +506,13 @@ class ParserMetric(MetricBase): self.num_sample = self.num_label = self.num_arc = 0 return res - def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_lens=None): + def evaluate(self, arc_pred, label_pred, arc_true, label_true, seq_len=None): """Evaluate the performance of prediction. """ - if seq_lens is None: + if seq_len is None: seq_mask = arc_pred.new_ones(arc_pred.size(), dtype=torch.long) else: - seq_mask = seq_lens_to_masks(seq_lens.long(), float=False).long() + seq_mask = seq_lens_to_masks(seq_len.long(), float=False).long() # mask out tag seq_mask[:,0] = 0 head_pred_correct = (arc_pred == arc_true).long() * seq_mask diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py index 3af3fe19..4f4ed551 100644 --- a/fastNLP/models/star_transformer.py +++ b/fastNLP/models/star_transformer.py @@ -7,6 +7,21 @@ import torch.nn.functional as F class StarTransEnc(nn.Module): + """ + 带word embedding的Star-Transformer Encoder + + :param vocab_size: 词嵌入的词典大小 + :param emb_dim: 每个词嵌入的特征维度 + :param num_cls: 输出类别个数 + :param hidden_size: 模型中特征维度. + :param num_layers: 模型层数. + :param num_head: 模型中multi-head的head个数. + :param head_dim: 模型中multi-head中每个head特征维度. + :param max_len: 模型能接受的最大输入长度. + :param cls_hidden_size: 分类器隐层维度. + :param emb_dropout: 词嵌入的dropout概率. + :param dropout: 模型除词嵌入外的dropout概率. + """ def __init__(self, vocab_size, emb_dim, hidden_size, num_layers, @@ -27,15 +42,23 @@ class StarTransEnc(nn.Module): max_len=max_len) def forward(self, x, mask): + """ + :param FloatTensor data: [batch, length, hidden] 输入的序列 + :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, + 否则为 1 + :return: [batch, length, hidden] 编码后的输出序列 + + [batch, hidden] 全局 relay 节点, 详见论文 + """ x = self.embedding(x) x = self.emb_fc(self.emb_drop(x)) nodes, relay = self.encoder(x, mask) return nodes, relay -class Cls(nn.Module): +class _Cls(nn.Module): def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): - super(Cls, self).__init__() + super(_Cls, self).__init__() self.fc = nn.Sequential( nn.Linear(in_dim, hid_dim), nn.LeakyReLU(), @@ -48,9 +71,9 @@ class Cls(nn.Module): return h -class NLICls(nn.Module): +class _NLICls(nn.Module): def __init__(self, in_dim, num_cls, hid_dim, dropout=0.1): - super(NLICls, self).__init__() + super(_NLICls, self).__init__() self.fc = nn.Sequential( nn.Dropout(dropout), nn.Linear(in_dim*4, hid_dim), #4 @@ -65,7 +88,19 @@ class NLICls(nn.Module): return h class STSeqLabel(nn.Module): - """star-transformer model for sequence labeling + """用于序列标注的Star-Transformer模型 + + :param vocab_size: 词嵌入的词典大小 + :param emb_dim: 每个词嵌入的特征维度 + :param num_cls: 输出类别个数 + :param hidden_size: 模型中特征维度. Default: 300 + :param num_layers: 模型层数. Default: 4 + :param num_head: 模型中multi-head的head个数. Default: 8 + :param head_dim: 模型中multi-head中每个head特征维度. Default: 32 + :param max_len: 模型能接受的最大输入长度. Default: 512 + :param cls_hidden_size: 分类器隐层维度. Default: 600 + :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 + :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 """ def __init__(self, vocab_size, emb_dim, num_cls, hidden_size=300, @@ -86,23 +121,47 @@ class STSeqLabel(nn.Module): max_len=max_len, emb_dropout=emb_dropout, dropout=dropout) - self.cls = Cls(hidden_size, num_cls, cls_hidden_size) + self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) + + def forward(self, words, seq_len): + """ - def forward(self, word_seq, seq_lens): - mask = seq_lens_to_masks(seq_lens) - nodes, _ = self.enc(word_seq, mask) + :param words: [batch, seq_len] 输入序列 + :param seq_len: [batch,] 输入序列的长度 + :return output: [batch, num_cls, seq_len] 输出序列中每个元素的分类的概率 + """ + mask = seq_lens_to_masks(seq_len) + nodes, _ = self.enc(words, mask) output = self.cls(nodes) output = output.transpose(1,2) # make hidden to be dim 1 return {'output': output} # [bsz, n_cls, seq_len] - def predict(self, word_seq, seq_lens): - y = self.forward(word_seq, seq_lens) + def predict(self, words, seq_len): + """ + + :param words: [batch, seq_len] 输入序列 + :param seq_len: [batch,] 输入序列的长度 + :return output: [batch, seq_len] 输出序列中每个元素的分类 + """ + y = self.forward(words, seq_len) _, pred = y['output'].max(1) - return {'output': pred, 'seq_lens': seq_lens} + return {'output': pred} class STSeqCls(nn.Module): - """star-transformer model for sequence classification + """用于分类任务的Star-Transformer + + :param vocab_size: 词嵌入的词典大小 + :param emb_dim: 每个词嵌入的特征维度 + :param num_cls: 输出类别个数 + :param hidden_size: 模型中特征维度. Default: 300 + :param num_layers: 模型层数. Default: 4 + :param num_head: 模型中multi-head的head个数. Default: 8 + :param head_dim: 模型中multi-head中每个head特征维度. Default: 32 + :param max_len: 模型能接受的最大输入长度. Default: 512 + :param cls_hidden_size: 分类器隐层维度. Default: 600 + :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 + :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 """ def __init__(self, vocab_size, emb_dim, num_cls, @@ -124,23 +183,47 @@ class STSeqCls(nn.Module): max_len=max_len, emb_dropout=emb_dropout, dropout=dropout) - self.cls = Cls(hidden_size, num_cls, cls_hidden_size) + self.cls = _Cls(hidden_size, num_cls, cls_hidden_size) - def forward(self, word_seq, seq_lens): - mask = seq_lens_to_masks(seq_lens) - nodes, relay = self.enc(word_seq, mask) + def forward(self, words, seq_len): + """ + + :param words: [batch, seq_len] 输入序列 + :param seq_len: [batch,] 输入序列的长度 + :return output: [batch, num_cls] 输出序列的分类的概率 + """ + mask = seq_lens_to_masks(seq_len) + nodes, relay = self.enc(words, mask) y = 0.5 * (relay + nodes.max(1)[0]) output = self.cls(y) # [bsz, n_cls] return {'output': output} - def predict(self, word_seq, seq_lens): - y = self.forward(word_seq, seq_lens) + def predict(self, words, seq_len): + """ + + :param words: [batch, seq_len] 输入序列 + :param seq_len: [batch,] 输入序列的长度 + :return output: [batch, num_cls] 输出序列的分类 + """ + y = self.forward(words, seq_len) _, pred = y['output'].max(1) return {'output': pred} class STNLICls(nn.Module): - """star-transformer model for NLI + """用于自然语言推断(NLI)的Star-Transformer + + :param vocab_size: 词嵌入的词典大小 + :param emb_dim: 每个词嵌入的特征维度 + :param num_cls: 输出类别个数 + :param hidden_size: 模型中特征维度. Default: 300 + :param num_layers: 模型层数. Default: 4 + :param num_head: 模型中multi-head的head个数. Default: 8 + :param head_dim: 模型中multi-head中每个head特征维度. Default: 32 + :param max_len: 模型能接受的最大输入长度. Default: 512 + :param cls_hidden_size: 分类器隐层维度. Default: 600 + :param emb_dropout: 词嵌入的dropout概率. Default: 0.1 + :param dropout: 模型除词嵌入外的dropout概率. Default: 0.1 """ def __init__(self, vocab_size, emb_dim, num_cls, @@ -162,20 +245,36 @@ class STNLICls(nn.Module): max_len=max_len, emb_dropout=emb_dropout, dropout=dropout) - self.cls = NLICls(hidden_size, num_cls, cls_hidden_size) + self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size) + + def forward(self, words1, words2, seq_len1, seq_len2): + """ - def forward(self, word_seq1, word_seq2, seq_lens1, seq_lens2): - mask1 = seq_lens_to_masks(seq_lens1) - mask2 = seq_lens_to_masks(seq_lens2) + :param words1: [batch, seq_len] 输入序列1 + :param words2: [batch, seq_len] 输入序列2 + :param seq_len1: [batch,] 输入序列1的长度 + :param seq_len2: [batch,] 输入序列2的长度 + :return output: [batch, num_cls] 输出分类的概率 + """ + mask1 = seq_lens_to_masks(seq_len1) + mask2 = seq_lens_to_masks(seq_len2) def enc(seq, mask): nodes, relay = self.enc(seq, mask) return 0.5 * (relay + nodes.max(1)[0]) - y1 = enc(word_seq1, mask1) - y2 = enc(word_seq2, mask2) + y1 = enc(words1, mask1) + y2 = enc(words2, mask2) output = self.cls(y1, y2) # [bsz, n_cls] return {'output': output} - def predict(self, word_seq1, word_seq2, seq_lens1, seq_lens2): - y = self.forward(word_seq1, word_seq2, seq_lens1, seq_lens2) + def predict(self, words1, words2, seq_len1, seq_len2): + """ + + :param words1: [batch, seq_len] 输入序列1 + :param words2: [batch, seq_len] 输入序列2 + :param seq_len1: [batch,] 输入序列1的长度 + :param seq_len2: [batch,] 输入序列2的长度 + :return output: [batch, num_cls] 输出分类的概率 + """ + y = self.forward(words1, words2, seq_len1, seq_len2) _, pred = y['output'].max(1) return {'output': pred} diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 04f331f7..fe740c70 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -6,17 +6,17 @@ from fastNLP.modules.utils import initial_parameter class LSTM(nn.Module): - """Long Short Term Memory + """LSTM 模块, 轻量封装的Pytorch LSTM - :param int input_size: - :param int hidden_size: - :param int num_layers: - :param float dropout: - :param bool batch_first: - :param bool bidirectional: - :param bool bias: - :param str initial_method: - :param bool get_hidden: + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param dropout: 层间dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param get_hidden: 是否返回隐状态 `h` . Default: ``False`` """ def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, bidirectional=False, bias=True, initial_method=None, get_hidden=False): @@ -27,14 +27,24 @@ class LSTM(nn.Module): self.get_hidden = get_hidden initial_parameter(self, initial_method) - def forward(self, x, seq_lens=None, h0=None, c0=None): + def forward(self, x, seq_len=None, h0=None, c0=None): + """ + + :param x: [batch, seq_len, input_size] 输入序列 + :param seq_len: [batch, ] 序列长度, 若为 ``None``, 所有输入看做一样长. Default: ``None`` + :param h0: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` + :param c0: [batch, hidden_size] 初始Cell状态, 若为 ``None`` , 设为全1向量. Default: ``None`` + :return (output, ht) 或 output: 若 ``get_hidden=True`` [batch, seq_len, hidden_size*num_direction] 输出序列 + :和 [batch, hidden_size*num_direction] 最后时刻隐状态. + :若 ``get_hidden=False`` 仅返回输出序列. + """ if h0 is not None and c0 is not None: hx = (h0, c0) else: hx = None - if seq_lens is not None and not isinstance(x, rnn.PackedSequence): + if seq_len is not None and not isinstance(x, rnn.PackedSequence): print('padding') - sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) + sort_lens, sort_idx = torch.sort(seq_len, dim=0, descending=True) if self.batch_first: x = x[sort_idx] else: diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py index 1618c8ee..034cfa96 100644 --- a/fastNLP/modules/encoder/star_transformer.py +++ b/fastNLP/modules/encoder/star_transformer.py @@ -5,16 +5,19 @@ import numpy as NP class StarTransformer(nn.Module): - """Star-Transformer Encoder part。 + """ + Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码 + paper: https://arxiv.org/abs/1902.09113 - :param hidden_size: int, 输入维度的大小。同时也是输出维度的大小。 - :param num_layers: int, star-transformer的层数 - :param num_head: int,head的数量。 - :param head_dim: int, 每个head的维度大小。 - :param dropout: float dropout 概率 - :param max_len: int or None, 如果为int,输入序列的最大长度, - 模型会为属于序列加上position embedding。 - 若为None,忽略加上position embedding的步骤 + + :param int hidden_size: 输入维度的大小。同时也是输出维度的大小。 + :param int num_layers: star-transformer的层数 + :param int num_head: head的数量。 + :param int head_dim: 每个head的维度大小。 + :param float dropout: dropout 概率. Default: 0.1 + :param int max_len: int or None, 如果为int,输入序列的最大长度, + 模型会为输入序列加上position embedding。 + 若为`None`,忽略加上position embedding的步骤. Default: `None` """ def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None): super(StarTransformer, self).__init__() @@ -22,11 +25,11 @@ class StarTransformer(nn.Module): self.norm = nn.ModuleList([nn.LayerNorm(hidden_size) for _ in range(self.iters)]) self.ring_att = nn.ModuleList( - [MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) - for _ in range(self.iters)]) + [_MSA1(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) + for _ in range(self.iters)]) self.star_att = nn.ModuleList( - [MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) - for _ in range(self.iters)]) + [_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=dropout) + for _ in range(self.iters)]) if max_len is not None: self.pos_emb = self.pos_emb = nn.Embedding(max_len, hidden_size) @@ -35,10 +38,12 @@ class StarTransformer(nn.Module): def forward(self, data, mask): """ - :param FloatTensor data: [batch, length, hidden] the input sequence - :param ByteTensor mask: [batch, length] the padding mask for input, in which padding pos is 0 - :return: [batch, length, hidden] the output sequence - [batch, hidden] the global relay node + :param FloatTensor data: [batch, length, hidden] 输入的序列 + :param ByteTensor mask: [batch, length] 输入序列的padding mask, 在没有内容(padding 部分) 为 0, + 否则为 1 + :return: [batch, length, hidden] 编码后的输出序列 + + [batch, hidden] 全局 relay 节点, 详见论文 """ def norm_func(f, x): # B, H, L, 1 @@ -70,9 +75,9 @@ class StarTransformer(nn.Module): return nodes, relay.view(B, H) -class MSA1(nn.Module): +class _MSA1(nn.Module): def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): - super(MSA1, self).__init__() + super(_MSA1, self).__init__() # Multi-head Self Attention Case 1, doing self-attention for small regions # Due to the architecture of GPU, using hadamard production and summation are faster than dot production when unfold_size is very small self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) @@ -113,10 +118,10 @@ class MSA1(nn.Module): return ret -class MSA2(nn.Module): +class _MSA2(nn.Module): def __init__(self, nhid, nhead=10, head_dim=10, dropout=0.1): # Multi-head Self Attention Case 2, a broadcastable query for a sequence key and value - super(MSA2, self).__init__() + super(_MSA2, self).__init__() self.WQ = nn.Conv2d(nhid, nhead * head_dim, 1) self.WK = nn.Conv2d(nhid, nhead * head_dim, 1) self.WV = nn.Conv2d(nhid, nhead * head_dim, 1) diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index d1262141..60216c2b 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -7,13 +7,13 @@ from ..dropout import TimestepDropout class TransformerEncoder(nn.Module): """transformer的encoder模块,不包含embedding层 - :param num_layers: int, transformer的层数 - :param model_size: int, 输入维度的大小。同时也是输出维度的大小。 - :param inner_size: int, FFN层的hidden大小 - :param key_size: int, 每个head的维度大小。 - :param value_size: int,每个head中value的维度。 - :param num_head: int,head的数量。 - :param dropout: float。 + :param int num_layers: transformer的层数 + :param int model_size: 输入维度的大小。同时也是输出维度的大小。 + :param int inner_size: FFN层的hidden大小 + :param int key_size: 每个head的维度大小。 + :param int value_size: 每个head中value的维度。 + :param int num_head: head的数量。 + :param float dropout: dropout概率. Default: 0.1 """ class SubLayer(nn.Module): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): @@ -48,7 +48,8 @@ class TransformerEncoder(nn.Module): def forward(self, x, seq_mask=None): """ :param x: [batch, seq_len, model_size] 输入序列 - :param seq_mask: [batch, seq_len] 输入序列的padding mask + :param seq_mask: [batch, seq_len] 输入序列的padding mask, 若为 ``None`` , 生成全1向量. + Default: ``None`` :return: [batch, seq_len, model_size] 输出序列 """ output = x diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py index a7902813..d63aa6e7 100644 --- a/fastNLP/modules/encoder/variational_rnn.py +++ b/fastNLP/modules/encoder/variational_rnn.py @@ -28,11 +28,11 @@ class VarRnnCellWrapper(nn.Module): """ :param PackedSequence input_x: [seq_len, batch_size, input_size] :param hidden: for LSTM, tuple of (h_0, c_0), [batch_size, hidden_size] - for other RNN, h_0, [batch_size, hidden_size] + :for other RNN, h_0, [batch_size, hidden_size] :param mask_x: [batch_size, input_size] dropout mask for input :param mask_h: [batch_size, hidden_size] dropout mask for hidden :return PackedSequence output: [seq_len, bacth_size, hidden_size] - hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] + :hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size] for other RNN, h_n, [batch_size, hidden_size] """ def get_hi(hi, h0, size): @@ -84,9 +84,21 @@ class VarRnnCellWrapper(nn.Module): class VarRNNBase(nn.Module): - """Implementation of Variational Dropout RNN network. - refer to `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) + """Variational Dropout RNN 实现. + 论文参考: `A Theoretically Grounded Application of Dropout in Recurrent Neural Networks (Yarin Gal and Zoubin Ghahramani, 2016) https://arxiv.org/abs/1512.05287`. + + :param mode: rnn 模式, (lstm or not) + :param Cell: rnn cell 类型, (lstm, gru, etc) + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` """ def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1, @@ -120,36 +132,43 @@ class VarRNNBase(nn.Module): output_x, hidden_x = cell(input, hi, mask_x, mask_h, is_reversed=(n_direction == 1)) return output_x, hidden_x - def forward(self, input, hx=None): + def forward(self, x, hx=None): + """ + + :param x: [batch, seq_len, input_size] 输入序列 + :param hx: [batch, hidden_size] 初始隐状态, 若为 ``None`` , 设为全1向量. Default: ``None`` + :return (output, ht): [batch, seq_len, hidden_size*num_direction] 输出序列 + :和 [batch, hidden_size*num_direction] 最后时刻隐状态 + """ is_lstm = self.is_lstm - is_packed = isinstance(input, PackedSequence) + is_packed = isinstance(x, PackedSequence) if not is_packed: - seq_len = input.size(1) if self.batch_first else input.size(0) - max_batch_size = input.size(0) if self.batch_first else input.size(1) + seq_len = x.size(1) if self.batch_first else x.size(0) + max_batch_size = x.size(0) if self.batch_first else x.size(1) seq_lens = torch.LongTensor([seq_len for _ in range(max_batch_size)]) - input, batch_sizes = pack_padded_sequence(input, seq_lens, batch_first=self.batch_first) + x, batch_sizes = pack_padded_sequence(x, seq_lens, batch_first=self.batch_first) else: - max_batch_size = int(input.batch_sizes[0]) - input, batch_sizes = input + max_batch_size = int(x.batch_sizes[0]) + x, batch_sizes = x if hx is None: - hx = input.new_zeros(self.num_layers * self.num_directions, - max_batch_size, self.hidden_size, requires_grad=True) + hx = x.new_zeros(self.num_layers * self.num_directions, + max_batch_size, self.hidden_size, requires_grad=True) if is_lstm: hx = (hx, hx.new_zeros(hx.size(), requires_grad=True)) - mask_x = input.new_ones((max_batch_size, self.input_size)) - mask_out = input.new_ones((max_batch_size, self.hidden_size * self.num_directions)) - mask_h_ones = input.new_ones((max_batch_size, self.hidden_size)) + mask_x = x.new_ones((max_batch_size, self.input_size)) + mask_out = x.new_ones((max_batch_size, self.hidden_size * self.num_directions)) + mask_h_ones = x.new_ones((max_batch_size, self.hidden_size)) nn.functional.dropout(mask_x, p=self.input_dropout, training=self.training, inplace=True) nn.functional.dropout(mask_out, p=self.hidden_dropout, training=self.training, inplace=True) - hidden = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) + hidden = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) if is_lstm: - cellstate = input.new_zeros((self.num_layers*self.num_directions, max_batch_size, self.hidden_size)) + cellstate = x.new_zeros((self.num_layers * self.num_directions, max_batch_size, self.hidden_size)) for layer in range(self.num_layers): output_list = [] - input_seq = PackedSequence(input, batch_sizes) + input_seq = PackedSequence(x, batch_sizes) mask_h = nn.functional.dropout(mask_h_ones, p=self.hidden_dropout, training=self.training, inplace=False) for direction in range(self.num_directions): output_x, hidden_x = self._forward_one(layer, direction, input_seq, hx, @@ -161,22 +180,32 @@ class VarRNNBase(nn.Module): cellstate[idx] = hidden_x[1] else: hidden[idx] = hidden_x - input = torch.cat(output_list, dim=-1) + x = torch.cat(output_list, dim=-1) if is_lstm: hidden = (hidden, cellstate) if is_packed: - output = PackedSequence(input, batch_sizes) + output = PackedSequence(x, batch_sizes) else: - input = PackedSequence(input, batch_sizes) - output, _ = pad_packed_sequence(input, batch_first=self.batch_first) + x = PackedSequence(x, batch_sizes) + output, _ = pad_packed_sequence(x, batch_first=self.batch_first) return output, hidden class VarLSTM(VarRNNBase): """Variational Dropout LSTM. + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False`` """ def __init__(self, *args, **kwargs): @@ -185,6 +214,16 @@ class VarLSTM(VarRNNBase): class VarRNN(VarRNNBase): """Variational Dropout RNN. + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False`` """ def __init__(self, *args, **kwargs): @@ -193,6 +232,16 @@ class VarRNN(VarRNNBase): class VarGRU(VarRNNBase): """Variational Dropout GRU. + + :param input_size: 输入 `x` 的特征维度 + :param hidden_size: 隐状态 `h` 的特征维度 + :param num_layers: rnn的层数. Default: 1 + :param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True`` + :param batch_first: 若为 ``True``, 输入和输出 ``Tensor`` 形状为 + :(batch, seq, feature). Default: ``False`` + :param input_dropout: 对输入的dropout概率. Default: 0 + :param hidden_dropout: 对每个隐状态的dropout概率. Default: 0 + :param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False`` """ def __init__(self, *args, **kwargs): diff --git a/test/core/test_sampler.py b/test/core/test_sampler.py index b23af470..f3cbb77f 100644 --- a/test/core/test_sampler.py +++ b/test/core/test_sampler.py @@ -4,17 +4,11 @@ import unittest import torch from fastNLP.core.dataset import DataSet -from fastNLP.core.sampler import convert_to_torch_tensor, SequentialSampler, RandomSampler, \ +from fastNLP.core.sampler import SequentialSampler, RandomSampler, \ k_means_1d, k_means_bucketing, simple_sort_bucketing, BucketSampler class TestSampler(unittest.TestCase): - def test_convert_to_torch_tensor(self): - data = [[1, 2, 3, 4, 5], [5, 4, 3, 2, 1], [1, 3, 4, 5, 2]] - ans = convert_to_torch_tensor(data, False) - assert isinstance(ans, torch.Tensor) - assert tuple(ans.shape) == (3, 5) - def test_sequential_sampler(self): sampler = SequentialSampler() data = [1, 3, 5, 7, 9, 2, 4, 6, 8, 10] diff --git a/test/models/test_biaffine_parser.py b/test/models/test_biaffine_parser.py index 88ba09b8..5d6c2102 100644 --- a/test/models/test_biaffine_parser.py +++ b/test/models/test_biaffine_parser.py @@ -44,34 +44,34 @@ data_file = """ def init_data(): ds = fastNLP.DataSet() - v = {'word_seq': fastNLP.Vocabulary(), - 'pos_seq': fastNLP.Vocabulary(), + v = {'words1': fastNLP.Vocabulary(), + 'words2': fastNLP.Vocabulary(), 'label_true': fastNLP.Vocabulary()} data = [] for line in data_file.split('\n'): line = line.split() if len(line) == 0 and len(data) > 0: data = list(zip(*data)) - ds.append(fastNLP.Instance(word_seq=data[1], - pos_seq=data[4], + ds.append(fastNLP.Instance(words1=data[1], + words2=data[4], arc_true=data[6], label_true=data[7])) data = [] elif len(line) > 0: data.append(line) - for name in ['word_seq', 'pos_seq', 'label_true']: + for name in ['words1', 'words2', 'label_true']: ds.apply(lambda x: [''] + list(x[name]), new_field_name=name) ds.apply(lambda x: v[name].add_word_lst(x[name])) - for name in ['word_seq', 'pos_seq', 'label_true']: + for name in ['words1', 'words2', 'label_true']: ds.apply(lambda x: [v[name].to_index(w) for w in x[name]], new_field_name=name) ds.apply(lambda x: [0] + list(map(int, x['arc_true'])), new_field_name='arc_true') - ds.apply(lambda x: len(x['word_seq']), new_field_name='seq_lens') - ds.set_input('word_seq', 'pos_seq', 'seq_lens', flag=True) - ds.set_target('arc_true', 'label_true', 'seq_lens', flag=True) - return ds, v['word_seq'], v['pos_seq'], v['label_true'] + ds.apply(lambda x: len(x['words1']), new_field_name='seq_len') + ds.set_input('words1', 'words2', 'seq_len', flag=True) + ds.set_target('arc_true', 'label_true', 'seq_len', flag=True) + return ds, v['words1'], v['words2'], v['label_true'] class TestBiaffineParser(unittest.TestCase): diff --git a/test/test_tutorials.py b/test/test_tutorials.py index 600699a3..eb77321c 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -437,4 +437,10 @@ class TestTutorial(unittest.TestCase): ) tester.test() - os.chdir("../..") + def setUp(self): + import os + self._init_wd = os.path.abspath(os.curdir) + + def tearDown(self): + import os + os.chdir(self._init_wd)