diff --git a/docs/Makefile b/docs/Makefile index e978dfe6..6f2f2821 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,6 +3,7 @@ # You can set these variables from the command line. SPHINXOPTS = +SPHINXAPIDOC = sphinx-apidoc SPHINXBUILD = sphinx-build SPHINXPROJ = fastNLP SOURCEDIR = source @@ -12,6 +13,12 @@ BUILDDIR = build help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +apidoc: + @$(SPHINXAPIDOC) -f -o source ../fastNLP + +server: + cd build/html && python -m http.server + .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new diff --git a/docs/source/conf.py b/docs/source/conf.py index e449a9f8..96f7f437 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -23,9 +23,9 @@ copyright = '2018, xpqiu' author = 'xpqiu' # The short X.Y version -version = '0.2' +version = '0.4' # The full version, including alpha/beta/rc tags -release = '0.2' +release = '0.4' # -- General configuration --------------------------------------------------- @@ -67,7 +67,7 @@ language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . -exclude_patterns = [] +exclude_patterns = ['modules.rst'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' diff --git a/docs/source/fastNLP.api.rst b/docs/source/fastNLP.api.rst index eb9192da..ee2413fb 100644 --- a/docs/source/fastNLP.api.rst +++ b/docs/source/fastNLP.api.rst @@ -1,36 +1,62 @@ -fastNLP.api -============ +fastNLP.api package +=================== -fastNLP.api.api ----------------- +Submodules +---------- + +fastNLP.api.api module +---------------------- .. automodule:: fastNLP.api.api :members: + :undoc-members: + :show-inheritance: -fastNLP.api.converter ----------------------- +fastNLP.api.converter module +---------------------------- .. automodule:: fastNLP.api.converter :members: + :undoc-members: + :show-inheritance: -fastNLP.api.model\_zoo ------------------------ +fastNLP.api.examples module +--------------------------- -.. automodule:: fastNLP.api.model_zoo +.. automodule:: fastNLP.api.examples :members: + :undoc-members: + :show-inheritance: -fastNLP.api.pipeline ---------------------- +fastNLP.api.pipeline module +--------------------------- .. automodule:: fastNLP.api.pipeline :members: + :undoc-members: + :show-inheritance: -fastNLP.api.processor ----------------------- +fastNLP.api.processor module +---------------------------- .. automodule:: fastNLP.api.processor :members: + :undoc-members: + :show-inheritance: + +fastNLP.api.utils module +------------------------ + +.. automodule:: fastNLP.api.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.api :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst index b9f6c89f..79d26c76 100644 --- a/docs/source/fastNLP.core.rst +++ b/docs/source/fastNLP.core.rst @@ -1,84 +1,126 @@ -fastNLP.core -============= +fastNLP.core package +==================== -fastNLP.core.batch -------------------- +Submodules +---------- + +fastNLP.core.batch module +------------------------- .. automodule:: fastNLP.core.batch :members: + :undoc-members: + :show-inheritance: + +fastNLP.core.callback module +---------------------------- -fastNLP.core.dataset ---------------------- +.. automodule:: fastNLP.core.callback + :members: + :undoc-members: + :show-inheritance: + +fastNLP.core.dataset module +--------------------------- .. automodule:: fastNLP.core.dataset :members: + :undoc-members: + :show-inheritance: -fastNLP.core.fieldarray ------------------------- +fastNLP.core.fieldarray module +------------------------------ .. automodule:: fastNLP.core.fieldarray :members: + :undoc-members: + :show-inheritance: -fastNLP.core.instance ----------------------- +fastNLP.core.instance module +---------------------------- .. automodule:: fastNLP.core.instance :members: + :undoc-members: + :show-inheritance: -fastNLP.core.losses --------------------- +fastNLP.core.losses module +-------------------------- .. automodule:: fastNLP.core.losses :members: + :undoc-members: + :show-inheritance: -fastNLP.core.metrics ---------------------- +fastNLP.core.metrics module +--------------------------- .. automodule:: fastNLP.core.metrics :members: + :undoc-members: + :show-inheritance: -fastNLP.core.optimizer ------------------------ +fastNLP.core.optimizer module +----------------------------- .. automodule:: fastNLP.core.optimizer :members: + :undoc-members: + :show-inheritance: -fastNLP.core.predictor ------------------------ +fastNLP.core.predictor module +----------------------------- .. automodule:: fastNLP.core.predictor :members: + :undoc-members: + :show-inheritance: -fastNLP.core.sampler ---------------------- +fastNLP.core.sampler module +--------------------------- .. automodule:: fastNLP.core.sampler :members: + :undoc-members: + :show-inheritance: -fastNLP.core.tester --------------------- +fastNLP.core.tester module +-------------------------- .. automodule:: fastNLP.core.tester :members: + :undoc-members: + :show-inheritance: -fastNLP.core.trainer ---------------------- +fastNLP.core.trainer module +--------------------------- .. automodule:: fastNLP.core.trainer :members: + :undoc-members: + :show-inheritance: -fastNLP.core.utils -------------------- +fastNLP.core.utils module +------------------------- .. automodule:: fastNLP.core.utils :members: + :undoc-members: + :show-inheritance: -fastNLP.core.vocabulary ------------------------- +fastNLP.core.vocabulary module +------------------------------ .. automodule:: fastNLP.core.vocabulary :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.core :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index d91e0d1c..e73f27d3 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -1,42 +1,62 @@ -fastNLP.io -=========== +fastNLP.io package +================== -fastNLP.io.base\_loader ------------------------- +Submodules +---------- + +fastNLP.io.base\_loader module +------------------------------ .. automodule:: fastNLP.io.base_loader :members: + :undoc-members: + :show-inheritance: -fastNLP.io.config\_io ----------------------- +fastNLP.io.config\_io module +---------------------------- .. automodule:: fastNLP.io.config_io :members: + :undoc-members: + :show-inheritance: -fastNLP.io.dataset\_loader ---------------------------- +fastNLP.io.dataset\_loader module +--------------------------------- .. automodule:: fastNLP.io.dataset_loader :members: + :undoc-members: + :show-inheritance: -fastNLP.io.embed\_loader -------------------------- +fastNLP.io.embed\_loader module +------------------------------- .. automodule:: fastNLP.io.embed_loader :members: + :undoc-members: + :show-inheritance: -fastNLP.io.logger ------------------- +fastNLP.io.file\_reader module +------------------------------ -.. automodule:: fastNLP.io.logger +.. automodule:: fastNLP.io.file_reader :members: + :undoc-members: + :show-inheritance: -fastNLP.io.model\_io ---------------------- +fastNLP.io.model\_io module +--------------------------- .. automodule:: fastNLP.io.model_io :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.io :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst index 7452fdf6..3ebf9608 100644 --- a/docs/source/fastNLP.models.rst +++ b/docs/source/fastNLP.models.rst @@ -1,42 +1,110 @@ -fastNLP.models -=============== +fastNLP.models package +====================== -fastNLP.models.base\_model ---------------------------- +Submodules +---------- + +fastNLP.models.base\_model module +--------------------------------- .. automodule:: fastNLP.models.base_model :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.bert module +-------------------------- -fastNLP.models.biaffine\_parser --------------------------------- +.. automodule:: fastNLP.models.bert + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.biaffine\_parser module +-------------------------------------- .. automodule:: fastNLP.models.biaffine_parser :members: + :undoc-members: + :show-inheritance: -fastNLP.models.char\_language\_model -------------------------------------- +fastNLP.models.char\_language\_model module +------------------------------------------- .. automodule:: fastNLP.models.char_language_model :members: + :undoc-members: + :show-inheritance: -fastNLP.models.cnn\_text\_classification ------------------------------------------ +fastNLP.models.cnn\_text\_classification module +----------------------------------------------- .. automodule:: fastNLP.models.cnn_text_classification :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.enas\_controller module +-------------------------------------- + +.. automodule:: fastNLP.models.enas_controller + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.enas\_model module +--------------------------------- + +.. automodule:: fastNLP.models.enas_model + :members: + :undoc-members: + :show-inheritance: -fastNLP.models.sequence\_modeling ----------------------------------- +fastNLP.models.enas\_trainer module +----------------------------------- + +.. automodule:: fastNLP.models.enas_trainer + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.enas\_utils module +--------------------------------- + +.. automodule:: fastNLP.models.enas_utils + :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.sequence\_modeling module +---------------------------------------- .. automodule:: fastNLP.models.sequence_modeling :members: + :undoc-members: + :show-inheritance: -fastNLP.models.snli --------------------- +fastNLP.models.snli module +-------------------------- .. automodule:: fastNLP.models.snli :members: + :undoc-members: + :show-inheritance: + +fastNLP.models.star\_transformer module +--------------------------------------- + +.. automodule:: fastNLP.models.star_transformer + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.models :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.aggregator.rst b/docs/source/fastNLP.modules.aggregator.rst index 073da4a5..63d351e4 100644 --- a/docs/source/fastNLP.modules.aggregator.rst +++ b/docs/source/fastNLP.modules.aggregator.rst @@ -1,36 +1,54 @@ -fastNLP.modules.aggregator -=========================== +fastNLP.modules.aggregator package +================================== -fastNLP.modules.aggregator.attention -------------------------------------- +Submodules +---------- + +fastNLP.modules.aggregator.attention module +------------------------------------------- .. automodule:: fastNLP.modules.aggregator.attention :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.avg\_pool -------------------------------------- +fastNLP.modules.aggregator.avg\_pool module +------------------------------------------- .. automodule:: fastNLP.modules.aggregator.avg_pool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.kmax\_pool --------------------------------------- +fastNLP.modules.aggregator.kmax\_pool module +-------------------------------------------- .. automodule:: fastNLP.modules.aggregator.kmax_pool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.max\_pool -------------------------------------- +fastNLP.modules.aggregator.max\_pool module +------------------------------------------- .. automodule:: fastNLP.modules.aggregator.max_pool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.aggregator.self\_attention -------------------------------------------- +fastNLP.modules.aggregator.self\_attention module +------------------------------------------------- .. automodule:: fastNLP.modules.aggregator.self_attention :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules.aggregator :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.decoder.rst b/docs/source/fastNLP.modules.decoder.rst index 6844543a..60706b06 100644 --- a/docs/source/fastNLP.modules.decoder.rst +++ b/docs/source/fastNLP.modules.decoder.rst @@ -1,18 +1,38 @@ -fastNLP.modules.decoder -======================== +fastNLP.modules.decoder package +=============================== -fastNLP.modules.decoder.CRF ----------------------------- +Submodules +---------- + +fastNLP.modules.decoder.CRF module +---------------------------------- .. automodule:: fastNLP.modules.decoder.CRF :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.decoder.MLP ----------------------------- +fastNLP.modules.decoder.MLP module +---------------------------------- .. automodule:: fastNLP.modules.decoder.MLP :members: + :undoc-members: + :show-inheritance: + +fastNLP.modules.decoder.utils module +------------------------------------ + +.. automodule:: fastNLP.modules.decoder.utils + :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules.decoder :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst index ea8fc699..ab93a169 100644 --- a/docs/source/fastNLP.modules.encoder.rst +++ b/docs/source/fastNLP.modules.encoder.rst @@ -1,60 +1,94 @@ -fastNLP.modules.encoder -======================== +fastNLP.modules.encoder package +=============================== -fastNLP.modules.encoder.char\_embedding ----------------------------------------- +Submodules +---------- + +fastNLP.modules.encoder.char\_embedding module +---------------------------------------------- .. automodule:: fastNLP.modules.encoder.char_embedding :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.conv ------------------------------ +fastNLP.modules.encoder.conv module +----------------------------------- .. automodule:: fastNLP.modules.encoder.conv :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.conv\_maxpool --------------------------------------- +fastNLP.modules.encoder.conv\_maxpool module +-------------------------------------------- .. automodule:: fastNLP.modules.encoder.conv_maxpool :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.embedding ----------------------------------- +fastNLP.modules.encoder.embedding module +---------------------------------------- .. automodule:: fastNLP.modules.encoder.embedding :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.linear -------------------------------- +fastNLP.modules.encoder.linear module +------------------------------------- .. automodule:: fastNLP.modules.encoder.linear :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.lstm ------------------------------ +fastNLP.modules.encoder.lstm module +----------------------------------- .. automodule:: fastNLP.modules.encoder.lstm :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.masked\_rnn ------------------------------------- +fastNLP.modules.encoder.masked\_rnn module +------------------------------------------ .. automodule:: fastNLP.modules.encoder.masked_rnn :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.transformer ------------------------------------- +fastNLP.modules.encoder.star\_transformer module +------------------------------------------------ + +.. automodule:: fastNLP.modules.encoder.star_transformer + :members: + :undoc-members: + :show-inheritance: + +fastNLP.modules.encoder.transformer module +------------------------------------------ .. automodule:: fastNLP.modules.encoder.transformer :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.encoder.variational\_rnn ------------------------------------------ +fastNLP.modules.encoder.variational\_rnn module +----------------------------------------------- .. automodule:: fastNLP.modules.encoder.variational_rnn :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules.encoder :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst index 965fb27d..57858176 100644 --- a/docs/source/fastNLP.modules.rst +++ b/docs/source/fastNLP.modules.rst @@ -1,5 +1,8 @@ -fastNLP.modules -================ +fastNLP.modules package +======================= + +Subpackages +----------- .. toctree:: @@ -7,24 +10,38 @@ fastNLP.modules fastNLP.modules.decoder fastNLP.modules.encoder -fastNLP.modules.dropout ------------------------- +Submodules +---------- + +fastNLP.modules.dropout module +------------------------------ .. automodule:: fastNLP.modules.dropout :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.other\_modules -------------------------------- +fastNLP.modules.other\_modules module +------------------------------------- .. automodule:: fastNLP.modules.other_modules :members: + :undoc-members: + :show-inheritance: -fastNLP.modules.utils ----------------------- +fastNLP.modules.utils module +---------------------------- .. automodule:: fastNLP.modules.utils :members: + :undoc-members: + :show-inheritance: + +Module contents +--------------- .. automodule:: fastNLP.modules :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index 61882359..6348c9a6 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -1,13 +1,22 @@ -fastNLP -======== +fastNLP package +=============== + +Subpackages +----------- .. toctree:: fastNLP.api + fastNLP.automl fastNLP.core fastNLP.io fastNLP.models fastNLP.modules +Module contents +--------------- + .. automodule:: fastNLP :members: + :undoc-members: + :show-inheritance: diff --git a/fastNLP/api/api.py b/fastNLP/api/api.py index 53a80131..24a1ab1d 100644 --- a/fastNLP/api/api.py +++ b/fastNLP/api/api.py @@ -1,3 +1,41 @@ +""" +api.api的介绍文档 + 直接缩进会把上面的文字变成标题 + +空行缩进的写法比较合理 + + 比较合理 + +*这里是斜体内容* + +**这里是粗体内容** + +数学公式块 + +.. math:: + E = mc^2 + +.. note:: + 注解型提示。 + +.. warning:: + 警告型提示。 + +.. seealso:: + `参考与超链接 `_ + +普通代码块需要空一行, Example:: + + from fitlog import fitlog + fitlog.commit() + +普通下标和上标: + +H\ :sub:`2`\ O + +E = mc\ :sup:`2` + +""" import warnings import torch @@ -9,7 +47,7 @@ from fastNLP.core.dataset import DataSet from fastNLP.api.utils import load_url from fastNLP.api.processor import ModelProcessor -from fastNLP.io.dataset_loader import ConllCWSReader, ConllxDataLoader +from fastNLP.io.dataset_loader import cut_long_sentence, ConllLoader from fastNLP.core.instance import Instance from fastNLP.api.pipeline import Pipeline from fastNLP.core.metrics import SpanFPreRecMetric @@ -23,7 +61,89 @@ model_urls = { } +class ConllCWSReader(object): + """Deprecated. Use ConllLoader for all types of conll-format files.""" + def __init__(self): + pass + + def load(self, path, cut_long_sent=False): + """ + 返回的DataSet只包含raw_sentence这个field,内容为str。 + 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 + :: + + 1 编者按 编者按 NN O 11 nmod:topic + 2 : : PU O 11 punct + 3 7月 7月 NT DATE 4 compound:nn + 4 12日 12日 NT DATE 11 nmod:tmod + 5 , , PU O 11 punct + + 1 这 这 DT O 3 det + 2 款 款 M O 1 mark:clf + 3 飞行 飞行 NN O 8 nsubj + 4 从 从 P O 5 case + 5 外型 外型 NN O 8 nmod:prep + + """ + datalist = [] + with open(path, 'r', encoding='utf-8') as f: + sample = [] + for line in f: + if line.startswith('\n'): + datalist.append(sample) + sample = [] + elif line.startswith('#'): + continue + else: + sample.append(line.strip().split()) + if len(sample) > 0: + datalist.append(sample) + + ds = DataSet() + for sample in datalist: + # print(sample) + res = self.get_char_lst(sample) + if res is None: + continue + line = ' '.join(res) + if cut_long_sent: + sents = cut_long_sentence(line) + else: + sents = [line] + for raw_sentence in sents: + ds.append(Instance(raw_sentence=raw_sentence)) + return ds + + def get_char_lst(self, sample): + if len(sample) == 0: + return None + text = [] + for w in sample: + t1, t2, t3, t4 = w[1], w[3], w[6], w[7] + if t3 == '_': + return None + text.append(t1) + return text + +class ConllxDataLoader(ConllLoader): + """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 + + Deprecated. Use ConllLoader for all types of conll-format files. + """ + def __init__(self): + headers = [ + 'words', 'pos_tags', 'heads', 'labels', + ] + indexs = [ + 1, 3, 6, 7, + ] + super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) + + class API: + """ + 这是 API 类的文档 + """ def __init__(self): self.pipeline = None self._dict = None @@ -69,8 +189,9 @@ class POS(API): self.load(model_path, device) def predict(self, content): - """ - + """predict函数的介绍, + 函数介绍的第二句,这句话不会换行 + :param content: list of list of str. Each string is a token(word). :return answer: list of list of str. Each string is a tag. """ @@ -136,13 +257,14 @@ class POS(API): class CWS(API): - def __init__(self, model_path=None, device='cpu'): - """ - 中文分词高级接口。 + """ + 中文分词高级接口。 - :param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 - :param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 - """ + :param model_path: 当model_path为None,使用默认位置的model。如果默认位置不存在,则自动下载模型 + :param device: str,可以为'cpu', 'cuda'或'cuda:0'等。会将模型load到相应device进行推断。 + """ + def __init__(self, model_path=None, device='cpu'): + super(CWS, self).__init__() if model_path is None: model_path = model_urls['cws'] @@ -183,18 +305,20 @@ class CWS(API): def test(self, filepath): """ 传入一个分词文件路径,返回该数据集上分词f1, precision, recall。 - 分词文件应该为: + 分词文件应该为:: + 1 编者按 编者按 NN O 11 nmod:topic 2 : : PU O 11 punct 3 7月 7月 NT DATE 4 compound:nn 4 12日 12日 NT DATE 11 nmod:tmod 5 , , PU O 11 punct - + 1 这 这 DT O 3 det 2 款 款 M O 1 mark:clf 3 飞行 飞行 NN O 8 nsubj 4 从 从 P O 5 case 5 外型 外型 NN O 8 nmod:prep + 以空行分割两个句子,有内容的每行有7列。 :param filepath: str, 文件路径路径。 diff --git a/fastNLP/automl/enas_trainer.py b/fastNLP/automl/enas_trainer.py index 7c0da752..061d604c 100644 --- a/fastNLP/automl/enas_trainer.py +++ b/fastNLP/automl/enas_trainer.py @@ -62,13 +62,14 @@ class ENASTrainer(fastNLP.Trainer): """ :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 最好的模型参数。 - :return results: 返回一个字典类型的数据, 内含以下内容:: - - seconds: float, 表示训练时长 - 以下三个内容只有在提供了dev_data的情况下会有。 - best_eval: Dict of Dict, 表示evaluation的结果 - best_epoch: int,在第几个epoch取得的最佳值 - best_step: int, 在第几个step(batch)更新取得的最佳值 + :return results: 返回一个字典类型的数据, + 内含以下内容:: + + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果 + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 """ results = {} diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py index 76a34655..3a4dfa55 100644 --- a/fastNLP/core/dataset.py +++ b/fastNLP/core/dataset.py @@ -1,3 +1,18 @@ +""" +fastNLP.core.DataSet的介绍文档 + +DataSet是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,每一行是一个instance(或sample),每一列是一个feature。 + +csv-table:: +:header: "Field1", "Field2", "Field3" +:widths:20, 10, 10 + +"This is the first instance", ['This', 'is', 'the', 'first', 'instance'], 5 +"Second instance", ['Second', 'instance'], 2 + +""" + + import _pickle as pickle import numpy as np @@ -31,7 +46,7 @@ class DataSet(object): length_set.add(len(value)) assert len(length_set) == 1, "Arrays must all be same length." for key, value in data.items(): - self.add_field(name=key, fields=value) + self.add_field(field_name=key, fields=value) elif isinstance(data, list): for ins in data: assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) @@ -88,7 +103,7 @@ class DataSet(object): raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self)-1}") data_set = DataSet() for field in self.field_arrays.values(): - data_set.add_field(name=field.name, fields=field.content[idx], padder=field.padder, + data_set.add_field(field_name=field.name, fields=field.content[idx], padder=field.padder, is_input=field.is_input, is_target=field.is_target, ignore_type=field.ignore_type) return data_set elif isinstance(idx, str): @@ -131,7 +146,7 @@ class DataSet(object): return "DataSet(" + self.__inner_repr__() + ")" def append(self, ins): - """Add an instance to the DataSet. + """将一个instance对象append到DataSet后面。 If the DataSet is not empty, the instance must have the same field names as the rest instances in the DataSet. :param ins: an Instance object @@ -151,54 +166,60 @@ class DataSet(object): assert name in self.field_arrays self.field_arrays[name].append(field) - def add_field(self, name, fields, padder=AutoPadder(pad_val=0), is_input=False, is_target=False, ignore_type=False): - """Add a new field to the DataSet. + def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): + """新增一个field - :param str name: the name of the field. - :param fields: a list of int, float, or other objects. - :param int padder: PadBase对象,如何对该Field进行padding。大部分情况使用默认值即可 - :param bool is_input: whether this field is model input. - :param bool is_target: whether this field is label or target. - :param bool ignore_type: If True, do not perform type check. (Default: False) + :param str field_name: 新增的field的名称 + :param list fields: 需要新增的field的内容 + :param None, Padder padder: 如果为None,则不进行pad。 + :param bool is_input: 新加入的field是否是input + :param bool is_target: 新加入的field是否是target + :param bool ignore_type: 是否忽略对新加入的field的类型检查 """ + if len(self.field_arrays) != 0: if len(self) != len(fields): raise RuntimeError(f"The field to append must have the same size as dataset. " f"Dataset size {len(self)} != field size {len(fields)}") - self.field_arrays[name] = FieldArray(name, fields, is_target=is_target, is_input=is_input, - padder=padder, ignore_type=ignore_type) + self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, + padder=padder, ignore_type=ignore_type) - def delete_field(self, name): - """Delete a field based on the field name. + def delete_field(self, field_name): + """删除field - :param name: the name of the field to be deleted. + :param str field_name: 需要删除的field的名称. """ - self.field_arrays.pop(name) + self.field_arrays.pop(field_name) def get_field(self, field_name): + """获取field_name这个field + + :param str field_name: field的名称 + :return: FieldArray + """ if field_name not in self.field_arrays: raise KeyError("Field name {} not found in DataSet".format(field_name)) return self.field_arrays[field_name] def get_all_fields(self): - """Return all the fields with their names. + """返回一个dict,key为field_name, value为对应的FieldArray - :return field_arrays: the internal data structure of DataSet. + :return: dict: """ return self.field_arrays def get_length(self): - """Fetch the length of the dataset. + """获取DataSet的元素数量 - :return length: + :return: int length: """ return len(self) def rename_field(self, old_name, new_name): - """Rename a field. + """将某个field重新命名. - :param str old_name: - :param str new_name: + :param str old_name: 原来的field名称 + :param str new_name: 修改为new_name """ if old_name in self.field_arrays: self.field_arrays[new_name] = self.field_arrays.pop(old_name) @@ -207,34 +228,62 @@ class DataSet(object): raise KeyError("DataSet has no field named {}.".format(old_name)) def set_target(self, *field_names, flag=True): - """Change the target flag of these fields. + """将field_names的target设置为flag状态 + Example:: - :param field_names: a sequence of str, indicating field names - :param bool flag: Set these fields as target if True. Unset them if False. + dataset.set_target('labels', 'seq_len') # 将labels和seq_len这两个field的target属性设置为True + dataset.set_target('labels', 'seq_lens', flag=False) # 将labels和seq_len的target属性设置为False + + :param str field_names: field的名称 + :param bool flag: 将field_name的target状态设置为flag """ + assert isinstance(flag, bool), "Only bool type supported." for name in field_names: if name in self.field_arrays: self.field_arrays[name].is_target = flag else: raise KeyError("{} is not a valid field name.".format(name)) - def set_input(self, *field_name, flag=True): - """Set the input flag of these fields. + def set_input(self, *field_names, flag=True): + """将field_name的input设置为flag状态 + Example:: + + dataset.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True + dataset.set_input('words', flag=False) # 将words这个field的input属性设置为False - :param field_name: a sequence of str, indicating field names. - :param bool flag: Set these fields as input if True. Unset them if False. + :param str field_names: field的名称 + :param bool flag: 将field_name的input状态设置为flag """ - for name in field_name: + for name in field_names: if name in self.field_arrays: self.field_arrays[name].is_input = flag else: raise KeyError("{} is not a valid field name.".format(name)) - def set_padder(self, field_name, padder): + def set_ignore_type(self, *field_names, flag=True): + """将field_names的ignore_type设置为flag状态 + + :param str field_names: field的名称 + :param bool flag: 将field_name的ignore_type状态设置为flag + :return: """ - 为field_name设置padder - :param field_name: str, 设置field的padding方式为padder - :param padder: PadderBase类型或None. 设置为None即删除padder。即对该field不进行padding操作. + assert isinstance(flag, bool), "Only bool type supported." + for name in field_names: + if name in self.field_arrays: + self.field_arrays[name].ignore_type = flag + else: + raise KeyError("{} is not a valid field name.".format(name)) + + def set_padder(self, field_name, padder): + """为field_name设置padder + Example:: + + from fastNLP import EngChar2DPadder + padder = EngChar2DPadder() + dataset.set_padder('chars', padder) # 则chars这个field会使用EngChar2DPadder进行pad操作 + + :param str field_name: 设置field的padding方式为padder + :param None, Padder padder: 设置为None即删除padder, 即对该field不进行pad操作. :return: """ if field_name not in self.field_arrays: @@ -242,11 +291,10 @@ class DataSet(object): self.field_arrays[field_name].set_padder(padder) def set_pad_val(self, field_name, pad_val): - """ - 为某个 + """为某个field设置对应的pad_val. - :param field_name: str,修改该field的pad_val - :param pad_val: int,该field的padder会以pad_val作为padding index + :param str field_name: 修改该field的pad_val + :param int pad_val: 该field的padder会以pad_val作为padding index :return: """ if field_name not in self.field_arrays: @@ -254,43 +302,68 @@ class DataSet(object): self.field_arrays[field_name].set_pad_val(pad_val) def get_input_name(self): - """Get all field names with `is_input` as True. + """返回所有is_input被设置为True的field名称 - :return field_names: a list of str + :return: list, 里面的元素为被设置为input的field名称 """ return [name for name, field in self.field_arrays.items() if field.is_input] def get_target_name(self): - """Get all field names with `is_target` as True. + """返回所有is_target被设置为True的field名称 - :return field_names: a list of str + :return list, 里面的元素为被设置为target的field名称 """ return [name for name, field in self.field_arrays.items() if field.is_target] - def apply(self, func, new_field_name=None, **kwargs): - """Apply a function to every instance of the DataSet. - - :param func: a function that takes an instance as input. - :param str new_field_name: If not None, results of the function will be stored as a new field. - :param **kwargs: Accept parameters will be - (1) is_input: boolean, will be ignored if new_field is None. If True, the new field will be as input. - (2) is_target: boolean, will be ignored if new_field is None. If True, the new field will be as target. - :return results: if new_field_name is not passed, returned values of the function over all instances. + def apply_field(self, func, field_name, new_field_name=None, **kwargs): + """将DataSet中的每个instance中的`field_name`这个field传给func,并获取它的返回值. + + :param callable func: input是instance的`field_name`这个field. + :param str field_name: 传入func的是哪个field. + :param str, None new_field_name: 将func返回的内容放入到什么field中 + + 1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 + 同,则覆盖之前的field + + 2. None, 不创建新的field + :param kwargs: 合法的参数有以下三个 + + 1. is_input: bool, 如果为True则将`new_field_name`的field设置为input + + 2. is_target: bool, 如果为True则将`new_field_name`的field设置为target + + 3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 + :return: list(Any), 里面的元素为func的返回值,所以list长度为DataSet的长度 + """ - assert len(self)!=0, "Null dataset cannot use .apply()." + assert len(self)!=0, "Null DataSet cannot use apply()." + if field_name not in self: + raise KeyError("DataSet has no field named `{}`.".format(field_name)) results = [] idx = -1 try: for idx, ins in enumerate(self._inner_iter()): - results.append(func(ins)) + results.append(func(ins[field_name])) except Exception as e: if idx!=-1: print("Exception happens at the `{}`th instance.".format(idx)) raise e - # results = [func(ins) for ins in self._inner_iter()] if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None raise ValueError("{} always return None.".format(get_func_signature(func=func))) + if new_field_name is not None: + self._add_apply_field(results, new_field_name, kwargs) + + return results + + def _add_apply_field(self, results, new_field_name, kwargs): + """将results作为加入到新的field中,field名称为new_field_name + + :param list(str) results: 一般是apply*()之后的结果 + :param str new_field_name: 新加入的field的名称 + :param dict kwargs: 用户apply*()时传入的自定义参数 + :return: + """ extra_param = {} if 'is_input' in kwargs: extra_param['is_input'] = kwargs['is_input'] @@ -298,56 +371,91 @@ class DataSet(object): extra_param['is_target'] = kwargs['is_target'] if 'ignore_type' in kwargs: extra_param['ignore_type'] = kwargs['ignore_type'] - if new_field_name is not None: - if new_field_name in self.field_arrays: - # overwrite the field, keep same attributes - old_field = self.field_arrays[new_field_name] - if 'is_input' not in extra_param: - extra_param['is_input'] = old_field.is_input - if 'is_target' not in extra_param: - extra_param['is_target'] = old_field.is_target - if 'ignore_type' not in extra_param: - extra_param['ignore_type'] = old_field.ignore_type - self.add_field(name=new_field_name, fields=results, is_input=extra_param["is_input"], - is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) - else: - self.add_field(name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), - is_target=extra_param.get("is_target", None), - ignore_type=extra_param.get("ignore_type", False)) + if new_field_name in self.field_arrays: + # overwrite the field, keep same attributes + old_field = self.field_arrays[new_field_name] + if 'is_input' not in extra_param: + extra_param['is_input'] = old_field.is_input + if 'is_target' not in extra_param: + extra_param['is_target'] = old_field.is_target + if 'ignore_type' not in extra_param: + extra_param['ignore_type'] = old_field.ignore_type + self.add_field(field_name=new_field_name, fields=results, is_input=extra_param["is_input"], + is_target=extra_param["is_target"], ignore_type=extra_param['ignore_type']) else: - return results + self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), + is_target=extra_param.get("is_target", None), + ignore_type=extra_param.get("ignore_type", False)) + + def apply(self, func, new_field_name=None, **kwargs): + """将DataSet中每个instance传入到func中,并获取它的返回值. + + :param callable func: 参数是DataSet中的instance + :param str, None new_field_name: 将func返回的内容放入到什么field中 + + 1. str, 将func的返回值放入这个名为`new_field_name`的新field中,如果名称与已有的field相 + 同,则覆盖之前的field + + 2. None, 不创建新的field + :param kwargs: 合法的参数有以下三个 + + 1. is_input: bool, 如果为True则将`new_field_name`的field设置为input + + 2. is_target: bool, 如果为True则将`new_field_name`的field设置为target + + 3. ignore_type: bool, 如果为True则将`new_field_name`的field的ignore_type设置为true, 忽略其类型 + :return: List[], 里面的元素为func的返回值,所以list长度为DataSet的长度 + """ + assert len(self)!=0, "Null DataSet cannot use apply()." + idx = -1 + try: + results = [] + for idx, ins in enumerate(self._inner_iter()): + results.append(func(ins)) + except Exception as e: + if idx!=-1: + print("Exception happens at the `{}`th instance.".format(idx)) + raise e + # results = [func(ins) for ins in self._inner_iter()] + if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None + raise ValueError("{} always return None.".format(get_func_signature(func=func))) + + if new_field_name is not None: + self._add_apply_field(results, new_field_name, kwargs) + + return results def drop(self, func, inplace=True): - """Drop instances if a condition holds. + """func接受一个instance,返回bool值,返回值为True时,该instance会被删除。 - :param func: a function that takes an Instance object as input, and returns bool. - The instance will be dropped if the function returns True. - :param inplace: bool, whether to drop inpalce. Otherwise a new dataset will be returned. + :param callable func: 接受一个instance作为参数,返回bool值。为True时删除该instance + :param bool inplace: 是否在当前DataSet中直接删除instance。如果为False,返回值为一个删除了相应instance的新的DataSet + :return: DataSet """ if inplace: results = [ins for ins in self._inner_iter() if not func(ins)] for name, old_field in self.field_arrays.items(): self.field_arrays[name].content = [ins[name] for ins in results] + return self else: results = [ins for ins in self if not func(ins)] - data = DataSet(results) + dataset = DataSet(results) for field_name, field in self.field_arrays.items(): - data.field_arrays[field_name].to(field) + dataset.field_arrays[field_name].to(field) + return dataset - def split(self, dev_ratio): - """Split the dataset into training and development(validation) set. + def split(self, ratio): + """将DataSet按照ratio的比例拆分,返回两个DataSet - :param float dev_ratio: the ratio of test set in all data. - :return (train_set, dev_set): - train_set: the training set - dev_set: the development set + :param float ratio: 0>[1, 1, 1] - ins.add_field("field_3", [3, 3, 3]) - - :param fields: a dict of (str: list). - + Example:: + + ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2]) + ins["field_1"] + >>[1, 1, 1] + ins.add_field("field_3", [3, 3, 3]) + """ def __init__(self, **fields): diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py index b52244e5..6b0b4460 100644 --- a/fastNLP/core/losses.py +++ b/fastNLP/core/losses.py @@ -272,7 +272,7 @@ def squash(predict, truth, **kwargs): :param predict: Tensor, model output :param truth: Tensor, truth from dataset - :param **kwargs: extra arguments + :param kwargs: extra arguments :return predict , truth: predict & truth after processing """ return predict.view(-1, predict.size()[-1]), truth.view(-1, ) @@ -316,7 +316,7 @@ def mask(predict, truth, **kwargs): :param predict: Tensor, [batch_size , max_len , tag_size] :param truth: Tensor, [batch_size , max_len] - :param **kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. + :param kwargs: extra arguments, kwargs["mask"]: ByteTensor, [batch_size , max_len], the mask Tensor. The position that is 1 will be selected. :return predict , truth: predict & truth after processing """ diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py index 5687cc85..314be0d9 100644 --- a/fastNLP/core/metrics.py +++ b/fastNLP/core/metrics.py @@ -17,66 +17,72 @@ class MetricBase(object): """Base class for all metrics. 所有的传入到Trainer, Tester的Metric需要继承自该对象。需要覆盖写入evaluate(), get_metric()方法。 + evaluate(xxx)中传入的是一个batch的数据。 + get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值 + 以分类问题中,Accuracy计算为例 - 假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy - class Model(nn.Module): - def __init__(xxx): - # do something - def forward(self, xxx): - # do something - return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes + 假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy:: + + class Model(nn.Module): + def __init__(xxx): + # do something + def forward(self, xxx): + # do something + return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes + 假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target - 对应的AccMetric可以按如下的定义 - # version1, 只使用这一次 - class AccMetric(MetricBase): - def __init__(self): - super().__init__() - - # 根据你的情况自定义指标 - self.corr_num = 0 - self.total = 0 - - def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value - # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric - self.total += label.size(0) - self.corr_num += label.eq(pred).sum().item() - - def get_metric(self, reset=True): # 在这里定义如何计算metric - acc = self.corr_num/self.total - if reset: # 是否清零以便重新计算 + 对应的AccMetric可以按如下的定义, version1, 只使用这一次:: + + class AccMetric(MetricBase): + def __init__(self): + super().__init__() + + # 根据你的情况自定义指标 self.corr_num = 0 self.total = 0 - return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 - - - # version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred - class AccMetric(MetricBase): - def __init__(self, label=None, pred=None): - # 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, - # acc_metric = AccMetric(label='y', pred='pred_y')即可。 - # 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 - # 应的的值 - super().__init__() - self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 - # 如果没有注册该则效果与version1就是一样的 - - # 根据你的情况自定义指标 - self.corr_num = 0 - self.total = 0 - - def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 - # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric - self.total += label.size(0) - self.corr_num += label.eq(pred).sum().item() - - def get_metric(self, reset=True): # 在这里定义如何计算metric - acc = self.corr_num/self.total - if reset: # 是否清零以便重新计算 + + def evaluate(self, label, pred): # 这里的名称需要和dataset中target field与model返回的key是一样的,不然找不到对应的value + # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric + self.total += label.size(0) + self.corr_num += label.eq(pred).sum().item() + + def get_metric(self, reset=True): # 在这里定义如何计算metric + acc = self.corr_num/self.total + if reset: # 是否清零以便重新计算 + self.corr_num = 0 + self.total = 0 + return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 + + + version2,如果需要复用Metric,比如下一次使用AccMetric时,dataset中目标field不叫label而叫y,或者model的输出不是pred:: + + class AccMetric(MetricBase): + def __init__(self, label=None, pred=None): + # 假设在另一场景使用时,目标field叫y,model给出的key为pred_y。则只需要在初始化AccMetric时, + # acc_metric = AccMetric(label='y', pred='pred_y')即可。 + # 当初始化为acc_metric = AccMetric(),即label=None, pred=None, fastNLP会直接使用'label', 'pred'作为key去索取对 + # 应的的值 + super().__init__() + self._init_param_map(label=label, pred=pred) # 该方法会注册label和pred. 仅需要注册evaluate()方法会用到的参数名即可 + # 如果没有注册该则效果与version1就是一样的 + + # 根据你的情况自定义指标 self.corr_num = 0 self.total = 0 - return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 + + def evaluate(self, label, pred): # 这里的参数名称需要和self._init_param_map()注册时一致。 + # dev或test时,每个batch结束会调用一次该方法,需要实现如何根据每个batch累加metric + self.total += label.size(0) + self.corr_num += label.eq(pred).sum().item() + + def get_metric(self, reset=True): # 在这里定义如何计算metric + acc = self.corr_num/self.total + if reset: # 是否清零以便重新计算 + self.corr_num = 0 + self.total = 0 + return {'acc': acc} # 需要返回一个dict,key为该metric的名称,该名称会显示到Trainer的progress bar中 ``MetricBase`` handles validity check of its input dictionaries - ``pred_dict`` and ``target_dict``. @@ -84,12 +90,12 @@ class MetricBase(object): ``target_dict`` is the ground truth from DataSet where ``is_target`` is set ``True``. ``MetricBase`` will do the following type checks: - 1. whether self.evaluate has varargs, which is not supported. - 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. - 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. + 1. whether self.evaluate has varargs, which is not supported. + 2. whether params needed by self.evaluate is not included in ``pred_dict``, ``target_dict``. + 3. whether params needed by self.evaluate duplicate in ``pred_dict``, ``target_dict``. Besides, before passing params into self.evaluate, this function will filter out params from output_dict and - target_dict which are not used in self.evaluate. (but if **kwargs presented in self.evaluate, no filtering + target_dict which are not used in self.evaluate. (but if kwargs presented in self.evaluate, no filtering will be conducted.) """ @@ -388,23 +394,26 @@ class SpanFPreRecMetric(MetricBase): """ 在序列标注问题中,以span的方式计算F, pre, rec. 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例) - ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 - 最后得到的metric结果为 - { - 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 - 'pre': xxx, - 'rec':xxx - } - 若only_gross=False, 即还会返回各个label的metric统计值 + ['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 + 最后得到的metric结果为:: + { - 'f': xxx, - 'pre': xxx, - 'rec':xxx, - 'f-label': xxx, - 'pre-label': xxx, - 'rec-label':xxx, - ... - } + 'f': xxx, # 这里使用f考虑以后可以计算f_beta值 + 'pre': xxx, + 'rec':xxx + } + + 若only_gross=False, 即还会返回各个label的metric统计值:: + + { + 'f': xxx, + 'pre': xxx, + 'rec':xxx, + 'f-label': xxx, + 'pre-label': xxx, + 'rec-label':xxx, + ... + } """ def __init__(self, tag_vocab, pred=None, target=None, seq_lens=None, encoding_type='bio', ignore_labels=None, @@ -573,13 +582,21 @@ class BMESF1PreRecMetric(MetricBase): """ 按照BMES标注方式计算f1, precision, recall。由于可能存在非法tag,比如"BS",所以需要用以下的表格做转换,cur_B意思是当前tag是B, next_B意思是后一个tag是B。则cur_B=S,即将当前被predict是B的tag标为S;next_M=B, 即将后一个被predict是M的tag标为B + + +-------+---------+----------+----------+---------+---------+ | | next_B | next_M | next_E | next_S | end | - |:-----:|:-------:|:--------:|:--------:|:-------:|:-------:| - | start | 合法 | next_M=B | next_E=S | 合法 | - | + +=======+=========+==========+==========+=========+=========+ + | start | 合法 | next_M=B | next_E=S | 合法 | -- | + +-------+---------+----------+----------+---------+---------+ | cur_B | cur_B=S | 合法 | 合法 | cur_B=S | cur_B=S | + +-------+---------+----------+----------+---------+---------+ | cur_M | cur_M=E | 合法 | 合法 | cur_M=E | cur_M=E | + +-------+---------+----------+----------+---------+---------+ | cur_E | 合法 | next_M=B | next_E=S | 合法 | 合法 | + +-------+---------+----------+----------+---------+---------+ | cur_S | 合法 | next_M=B | next_E=S | 合法 | 合法 | + +-------+---------+----------+----------+---------+---------+ + 举例: prediction为BSEMS,会被认为是SSSSS. diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py index b45dd148..67e7d2c0 100644 --- a/fastNLP/core/trainer.py +++ b/fastNLP/core/trainer.py @@ -79,7 +79,7 @@ class Trainer(object): raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") # check update every - assert update_every>=1, "update_every must be no less than 1." + assert update_every >= 1, "update_every must be no less than 1." self.update_every = int(update_every) # check save_path @@ -120,7 +120,7 @@ class Trainer(object): self.use_cuda = bool(use_cuda) self.save_path = save_path self.print_every = int(print_every) - self.validate_every = int(validate_every) if validate_every!=0 else -1 + self.validate_every = int(validate_every) if validate_every != 0 else -1 self.best_metric_indicator = None self.best_dev_epoch = None self.best_dev_step = None @@ -129,7 +129,7 @@ class Trainer(object): self.prefetch = prefetch self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) self.n_steps = (len(self.train_data) // self.batch_size + int( - len(self.train_data) % self.batch_size != 0)) * self.n_epochs + len(self.train_data) % self.batch_size != 0)) * self.n_epochs if isinstance(optimizer, torch.optim.Optimizer): self.optimizer = optimizer @@ -156,7 +156,6 @@ class Trainer(object): self.callback_manager = CallbackManager(env={"trainer": self}, callbacks=callbacks) - def train(self, load_best_model=True): """ @@ -185,14 +184,15 @@ class Trainer(object): 根据metrics进行evaluation,并根据是否提供了save_path判断是否存储模型 :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 - 最好的模型参数。 - :return results: 返回一个字典类型的数据, 内含以下内容:: + 最好的模型参数。 + :return results: 返回一个字典类型的数据, + 内含以下内容:: - seconds: float, 表示训练时长 - 以下三个内容只有在提供了dev_data的情况下会有。 - best_eval: Dict of Dict, 表示evaluation的结果 - best_epoch: int,在第几个epoch取得的最佳值 - best_step: int, 在第几个step(batch)更新取得的最佳值 + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果 + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 """ results = {} @@ -218,8 +218,9 @@ class Trainer(object): self.callback_manager.on_exception(e) if self.dev_data is not None and hasattr(self, 'best_dev_perf'): - print("\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + - self.tester._format_eval_results(self.best_dev_perf),) + print( + "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + + self.tester._format_eval_results(self.best_dev_perf), ) results['best_eval'] = self.best_dev_perf results['best_epoch'] = self.best_dev_epoch results['best_step'] = self.best_dev_step @@ -250,7 +251,7 @@ class Trainer(object): avg_loss = 0 data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False, prefetch=self.prefetch) - for epoch in range(1, self.n_epochs+1): + for epoch in range(1, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping @@ -267,7 +268,7 @@ class Trainer(object): self.callback_manager.on_loss_begin(batch_y, prediction) loss = self._compute_loss(prediction, batch_y).mean() avg_loss += loss.item() - loss = loss/self.update_every + loss = loss / self.update_every # Is loss NaN or inf? requires_grad = False self.callback_manager.on_backward_begin(loss) @@ -277,8 +278,8 @@ class Trainer(object): self._update() self.callback_manager.on_step_end() - if (self.step+1) % self.print_every == 0: - avg_loss = avg_loss / self.print_every + if self.step % self.print_every == 0: + avg_loss = float(avg_loss) / self.print_every if self.use_tqdm: print_output = "loss:{0:<6.5f}".format(avg_loss) pbar.update(self.print_every) @@ -297,7 +298,7 @@ class Trainer(object): eval_res = self._do_validation(epoch=epoch, step=self.step) eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, self.n_steps) + \ - self.tester._format_eval_results(eval_res) + self.tester._format_eval_results(eval_res) pbar.write(eval_str + '\n') # ================= mini-batch end ==================== # @@ -317,7 +318,7 @@ class Trainer(object): if self._better_eval_result(res): if self.save_path is not None: self._save_model(self.model, - "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) + "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) else: self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} self.best_dev_perf = res @@ -344,7 +345,7 @@ class Trainer(object): """Perform weight update on a model. """ - if (self.step+1)%self.update_every==0: + if (self.step + 1) % self.update_every == 0: self.optimizer.step() def _data_forward(self, network, x): @@ -361,7 +362,7 @@ class Trainer(object): For PyTorch, just do "loss.backward()" """ - if self.step%self.update_every==0: + if self.step % self.update_every == 0: self.model.zero_grad() loss.backward() @@ -437,6 +438,7 @@ class Trainer(object): DEFAULT_CHECK_BATCH_SIZE = 2 DEFAULT_CHECK_NUM_BATCH = 2 + def _get_value_info(_dict): # given a dict value, return information about this dict's value. Return list of str strs = [] @@ -453,6 +455,7 @@ def _get_value_info(_dict): strs.append(_str) return strs + def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, dev_data=None, metric_key=None, check_level=0): @@ -463,17 +466,17 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ for batch_count, (batch_x, batch_y) in enumerate(batch): _move_dict_value_to_device(batch_x, batch_y, device=model_devcie) # forward check - if batch_count==0: + if batch_count == 0: info_str = "" input_fields = _get_value_info(batch_x) target_fields = _get_value_info(batch_y) - if len(input_fields)>0: + if len(input_fields) > 0: info_str += "input fields after batch(if batch size is {}):\n".format(batch_size) info_str += "\n".join(input_fields) info_str += '\n' else: raise RuntimeError("There is no input field.") - if len(target_fields)>0: + if len(target_fields) > 0: info_str += "target fields after batch(if batch size is {}):\n".format(batch_size) info_str += "\n".join(target_fields) info_str += '\n' @@ -481,7 +484,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ info_str += 'There is no target field.' print(info_str) _check_forward_error(forward_func=model.forward, dataset=dataset, - batch_x=batch_x, check_level=check_level) + batch_x=batch_x, check_level=check_level) refined_batch_x = _build_args(model.forward, **batch_x) pred_dict = model(**refined_batch_x) diff --git a/fastNLP/core/utils.py b/fastNLP/core/utils.py index d9141412..b2c10fb4 100644 --- a/fastNLP/core/utils.py +++ b/fastNLP/core/utils.py @@ -24,7 +24,7 @@ def _prepare_cache_filepath(filepath): if not os.path.exists(cache_dir): os.makedirs(cache_dir) - +# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 def cache_results(cache_filepath, refresh=False, verbose=1): def wrapper_(func): signature = inspect.signature(func) @@ -197,17 +197,22 @@ def get_func_signature(func): Given a function or method, return its signature. For example: - (1) function + + 1 function:: + def func(a, b='a', *args): xxxx get_func_signature(func) # 'func(a, b='a', *args)' - (2) method + + 2 method:: + class Demo: def __init__(self): xxx def forward(self, a, b='a', **args) demo = Demo() get_func_signature(demo.forward) # 'Demo.forward(self, a, b='a', **args)' + :param func: a function or a method :return: str or None """ diff --git a/fastNLP/io/config_io.py b/fastNLP/io/config_io.py index 5a64b96c..c0ffe53e 100644 --- a/fastNLP/io/config_io.py +++ b/fastNLP/io/config_io.py @@ -26,10 +26,10 @@ class ConfigLoader(BaseLoader): :param str file_path: the path of config file :param dict sections: the dict of ``{section_name(string): ConfigSection object}`` - Example:: - - test_args = ConfigSection() - ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) + Example:: + + test_args = ConfigSection() + ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) """ assert isinstance(sections, dict) diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py index e33384a8..5657e194 100644 --- a/fastNLP/io/dataset_loader.py +++ b/fastNLP/io/dataset_loader.py @@ -1,71 +1,13 @@ import os import json +from nltk.tree import Tree from fastNLP.core.dataset import DataSet from fastNLP.core.instance import Instance -from fastNLP.io.base_loader import DataLoaderRegister +from fastNLP.io.file_reader import read_csv, read_json, read_conll -def convert_seq_dataset(data): - """Create an DataSet instance that contains no labels. - - :param data: list of list of strings, [num_examples, *]. - Example:: - - [ - [word_11, word_12, ...], - ... - ] - - :return: a DataSet. - """ - dataset = DataSet() - for word_seq in data: - dataset.append(Instance(word_seq=word_seq)) - return dataset - - -def convert_seq2tag_dataset(data): - """Convert list of data into DataSet. - - :param data: list of list of strings, [num_examples, *]. - Example:: - - [ - [ [word_11, word_12, ...], label_1 ], - [ [word_21, word_22, ...], label_2 ], - ... - ] - - :return: a DataSet. - """ - dataset = DataSet() - for sample in data: - dataset.append(Instance(word_seq=sample[0], label=sample[1])) - return dataset - - -def convert_seq2seq_dataset(data): - """Convert list of data into DataSet. - - :param data: list of list of strings, [num_examples, *]. - Example:: - - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - - :return: a DataSet. - """ - dataset = DataSet() - for sample in data: - dataset.append(Instance(word_seq=sample[0], label_seq=sample[1])) - return dataset - - -def download_from_url(url, path): +def _download_from_url(url, path): from tqdm import tqdm import requests @@ -81,7 +23,7 @@ def download_from_url(url, path): t.update(len(chunk)) return -def uncompress(src, dst): +def _uncompress(src, dst): import zipfile, gzip, tarfile, os def unzip(src, dst): @@ -134,241 +76,6 @@ class DataSetLoader: raise NotImplementedError -class NativeDataSetLoader(DataSetLoader): - """A simple example of DataSetLoader - - """ - - def __init__(self): - super(NativeDataSetLoader, self).__init__() - - def load(self, path): - ds = DataSet.read_csv(path, headers=("raw_sentence", "label"), sep="\t") - ds.set_input("raw_sentence") - ds.set_target("label") - return ds - - -DataLoaderRegister.set_reader(NativeDataSetLoader, 'read_naive') - - -class RawDataSetLoader(DataSetLoader): - """A simple example of raw data reader - - """ - - def __init__(self): - super(RawDataSetLoader, self).__init__() - - def load(self, data_path, split=None): - with open(data_path, "r", encoding="utf-8") as f: - lines = f.readlines() - lines = lines if split is None else [l.split(split) for l in lines] - lines = list(filter(lambda x: len(x) > 0, lines)) - return self.convert(lines) - - def convert(self, data): - return convert_seq_dataset(data) - - -DataLoaderRegister.set_reader(RawDataSetLoader, 'read_rawdata') - - -class DummyPOSReader(DataSetLoader): - """A simple reader for a dummy POS tagging dataset. - - In these datasets, each line are divided by "\t". The first Col is the vocabulary and the second - Col is the label. Different sentence are divided by an empty line. - E.g:: - - Tom label1 - and label2 - Jerry label1 - . label3 - (separated by an empty line) - Hello label4 - world label5 - ! label3 - - In this example, there are two sentences "Tom and Jerry ." and "Hello world !". Each word has its own label. - """ - - def __init__(self): - super(DummyPOSReader, self).__init__() - - def load(self, data_path): - """ - :return data: three-level list - Example:: - [ - [ [word_11, word_12, ...], [label_1, label_1, ...] ], - [ [word_21, word_22, ...], [label_2, label_1, ...] ], - ... - ] - """ - with open(data_path, "r", encoding="utf-8") as f: - lines = f.readlines() - data = self.parse(lines) - return self.convert(data) - - @staticmethod - def parse(lines): - data = [] - sentence = [] - for line in lines: - line = line.strip() - if len(line) > 1: - sentence.append(line.split('\t')) - else: - words = [] - labels = [] - for tokens in sentence: - words.append(tokens[0]) - labels.append(tokens[1]) - data.append([words, labels]) - sentence = [] - if len(sentence) != 0: - words = [] - labels = [] - for tokens in sentence: - words.append(tokens[0]) - labels.append(tokens[1]) - data.append([words, labels]) - return data - - def convert(self, data): - """Convert lists of strings into Instances with Fields. - """ - return convert_seq2seq_dataset(data) - - -DataLoaderRegister.set_reader(DummyPOSReader, 'read_pos') - - -class DummyCWSReader(DataSetLoader): - """Load pku dataset for Chinese word segmentation. - """ - def __init__(self): - super(DummyCWSReader, self).__init__() - - def load(self, data_path, max_seq_len=32): - """Load pku dataset for Chinese word segmentation. - CWS (Chinese Word Segmentation) pku training dataset format: - 1. Each line is a sentence. - 2. Each word in a sentence is separated by space. - This function convert the pku dataset into three-level lists with labels . - B: beginning of a word - M: middle of a word - E: ending of a word - S: single character - - :param str data_path: path to the data set. - :param max_seq_len: int, the maximum length of a sequence. If a sequence is longer than it, split it into - several sequences. - :return: three-level lists - """ - assert isinstance(max_seq_len, int) and max_seq_len > 0 - with open(data_path, "r", encoding="utf-8") as f: - sentences = f.readlines() - data = [] - for sent in sentences: - tokens = sent.strip().split() - words = [] - labels = [] - for token in tokens: - if len(token) == 1: - words.append(token) - labels.append("S") - else: - words.append(token[0]) - labels.append("B") - for idx in range(1, len(token) - 1): - words.append(token[idx]) - labels.append("M") - words.append(token[-1]) - labels.append("E") - num_samples = len(words) // max_seq_len - if len(words) % max_seq_len != 0: - num_samples += 1 - for sample_idx in range(num_samples): - start = sample_idx * max_seq_len - end = (sample_idx + 1) * max_seq_len - seq_words = words[start:end] - seq_labels = labels[start:end] - data.append([seq_words, seq_labels]) - return self.convert(data) - - def convert(self, data): - return convert_seq2seq_dataset(data) - - -class DummyClassificationReader(DataSetLoader): - """Loader for a dummy classification data set""" - - def __init__(self): - super(DummyClassificationReader, self).__init__() - - def load(self, data_path): - assert os.path.exists(data_path) - with open(data_path, "r", encoding="utf-8") as f: - lines = f.readlines() - data = self.parse(lines) - return self.convert(data) - - @staticmethod - def parse(lines): - """每行第一个token是标签,其余是字/词;由空格分隔。 - - :param lines: lines from dataset - :return: list(list(list())): the three level of lists are words, sentence, and dataset - """ - dataset = list() - for line in lines: - line = line.strip().split() - label = line[0] - words = line[1:] - if len(words) <= 1: - continue - - sentence = [words, label] - dataset.append(sentence) - return dataset - - def convert(self, data): - return convert_seq2tag_dataset(data) - - -class DummyLMReader(DataSetLoader): - """A Dummy Language Model Dataset Reader - """ - def __init__(self): - super(DummyLMReader, self).__init__() - - def load(self, data_path): - if not os.path.exists(data_path): - raise FileNotFoundError("file {} not found.".format(data_path)) - with open(data_path, "r", encoding="utf=8") as f: - text = " ".join(f.readlines()) - tokens = text.strip().split() - data = self.sentence_cut(tokens) - return self.convert(data) - - def sentence_cut(self, tokens, sentence_length=15): - start_idx = 0 - data_set = [] - for idx in range(len(tokens) // sentence_length): - x = tokens[start_idx * idx: start_idx * idx + sentence_length] - y = tokens[start_idx * idx + 1: start_idx * idx + sentence_length + 1] - if start_idx * idx + sentence_length + 1 >= len(tokens): - # ad hoc - y.extend([""]) - data_set.append([x, y]) - return data_set - - def convert(self, data): - pass - - class PeopleDailyCorpusLoader(DataSetLoader): """人民日报数据集 """ @@ -448,8 +155,9 @@ class PeopleDailyCorpusLoader(DataSetLoader): class ConllLoader: - def __init__(self, headers, indexs=None): + def __init__(self, headers, indexs=None, dropna=True): self.headers = headers + self.dropna = dropna if indexs is None: self.indexs = list(range(len(self.headers))) else: @@ -458,33 +166,10 @@ class ConllLoader: self.indexs = indexs def load(self, path): - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - sample = [] - start = next(f) - if '-DOCSTART-' not in start: - sample.append(start.split()) - for line in f: - if line.startswith('\n'): - if len(sample): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.split()) - if len(sample) > 0: - datalist.append(sample) - - data = [self.get_one(sample) for sample in datalist] - data = filter(lambda x: x is not None, data) - ds = DataSet() - for sample in data: - ins = Instance() - for name, idx in zip(self.headers, self.indexs): - ins.add_field(field_name=name, field=sample[idx]) - ds.append(ins) + for idx, data in read_conll(path, indexes=self.indexs, dropna=self.dropna): + ins = {h:data[idx] for h, idx in zip(self.headers, self.indexs)} + ds.append(Instance(**ins)) return ds def get_one(self, sample): @@ -499,9 +184,7 @@ class Conll2003Loader(ConllLoader): """Loader for conll2003 dataset More information about the given dataset cound be found on - https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data - - Deprecated. Use ConllLoader for all types of conll-format files. + https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data """ def __init__(self): headers = [ @@ -510,194 +193,6 @@ class Conll2003Loader(ConllLoader): super(Conll2003Loader, self).__init__(headers=headers) -class SNLIDataSetReader(DataSetLoader): - """A data set loader for SNLI data set. - - """ - def __init__(self): - super(SNLIDataSetReader, self).__init__() - - def load(self, path_list): - """ - - :param list path_list: A list of file name, in the order of premise file, hypothesis file, and label file. - :return: A DataSet object. - """ - assert len(path_list) == 3 - line_set = [] - for file in path_list: - if not os.path.exists(file): - raise FileNotFoundError("file {} NOT found".format(file)) - - with open(file, 'r', encoding='utf-8') as f: - lines = f.readlines() - line_set.append(lines) - - premise_lines, hypothesis_lines, label_lines = line_set - assert len(premise_lines) == len(hypothesis_lines) and len(premise_lines) == len(label_lines) - - data_set = [] - for premise, hypothesis, label in zip(premise_lines, hypothesis_lines, label_lines): - p = premise.strip().split() - h = hypothesis.strip().split() - l = label.strip() - data_set.append([p, h, l]) - - return self.convert(data_set) - - def convert(self, data): - """Convert a 3D list to a DataSet object. - - :param data: A 3D tensor. - Example:: - [ - [ [premise_word_11, premise_word_12, ...], [hypothesis_word_11, hypothesis_word_12, ...], [label_1] ], - [ [premise_word_21, premise_word_22, ...], [hypothesis_word_21, hypothesis_word_22, ...], [label_2] ], - ... - ] - - :return: A DataSet object. - """ - - data_set = DataSet() - - for example in data: - p, h, l = example - # list, list, str - instance = Instance() - instance.add_field("premise", p) - instance.add_field("hypothesis", h) - instance.add_field("truth", l) - data_set.append(instance) - data_set.apply(lambda ins: len(ins["premise"]), new_field_name="premise_len") - data_set.apply(lambda ins: len(ins["hypothesis"]), new_field_name="hypothesis_len") - data_set.set_input("premise", "hypothesis", "premise_len", "hypothesis_len") - data_set.set_target("truth") - return data_set - - -class ConllCWSReader(object): - """Deprecated. Use ConllLoader for all types of conll-format files.""" - def __init__(self): - pass - - def load(self, path, cut_long_sent=False): - """ - 返回的DataSet只包含raw_sentence这个field,内容为str。 - 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 - :: - - 1 编者按 编者按 NN O 11 nmod:topic - 2 : : PU O 11 punct - 3 7月 7月 NT DATE 4 compound:nn - 4 12日 12日 NT DATE 11 nmod:tmod - 5 , , PU O 11 punct - - 1 这 这 DT O 3 det - 2 款 款 M O 1 mark:clf - 3 飞行 飞行 NN O 8 nsubj - 4 从 从 P O 5 case - 5 外型 外型 NN O 8 nmod:prep - - """ - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - sample = [] - for line in f: - if line.startswith('\n'): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.strip().split()) - if len(sample) > 0: - datalist.append(sample) - - ds = DataSet() - for sample in datalist: - # print(sample) - res = self.get_char_lst(sample) - if res is None: - continue - line = ' '.join(res) - if cut_long_sent: - sents = cut_long_sentence(line) - else: - sents = [line] - for raw_sentence in sents: - ds.append(Instance(raw_sentence=raw_sentence)) - return ds - - def get_char_lst(self, sample): - if len(sample) == 0: - return None - text = [] - for w in sample: - t1, t2, t3, t4 = w[1], w[3], w[6], w[7] - if t3 == '_': - return None - text.append(t1) - return text - - -class NaiveCWSReader(DataSetLoader): - """ - 这个reader假设了分词数据集为以下形式, 即已经用空格分割好内容了 - 例如:: - - 这是 fastNLP , 一个 非常 good 的 包 . - - 或者,即每个part后面还有一个pos tag - 例如:: - - 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY - - """ - - def __init__(self, in_word_splitter=None): - super(NaiveCWSReader, self).__init__() - self.in_word_splitter = in_word_splitter - - def load(self, filepath, in_word_splitter=None, cut_long_sent=False): - """ - 允许使用的情况有(默认以\t或空格作为seg) - 这是 fastNLP , 一个 非常 good 的 包 . - 和 - 也/D 在/P 團員/Na 之中/Ng ,/COMMACATEGORY - 如果splitter不为None则认为是第二种情况, 且我们会按splitter分割"也/D", 然后取第一部分. 例如"也/D".split('/')[0] - - :param filepath: - :param in_word_splitter: - :param cut_long_sent: - :return: - """ - if in_word_splitter == None: - in_word_splitter = self.in_word_splitter - dataset = DataSet() - with open(filepath, 'r') as f: - for line in f: - line = line.strip() - if len(line.replace(' ', '')) == 0: # 不能接受空行 - continue - - if not in_word_splitter is None: - words = [] - for part in line.split(): - word = part.split(in_word_splitter)[0] - words.append(word) - line = ' '.join(words) - if cut_long_sent: - sents = cut_long_sentence(line) - else: - sents = [line] - for sent in sents: - instance = Instance(raw_sentence=sent) - dataset.append(instance) - - return dataset - - def cut_long_sentence(sent, max_sample_length=200): """ 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。所以截取的句子可能长于或者短于max_sample_length @@ -727,103 +222,6 @@ def cut_long_sentence(sent, max_sample_length=200): return cutted_sentence -class ZhConllPOSReader(object): - """读取中文Conll格式。返回“字级别”的标签,使用BMES记号扩展原来的词级别标签。 - - Deprecated. Use ConllLoader for all types of conll-format files. - """ - def __init__(self): - pass - - def load(self, path): - """ - 返回的DataSet, 包含以下的field - words:list of str, - tag: list of str, 被加入了BMES tag, 比如原来的序列为['VP', 'NN', 'NN', ..],会被认为是["S-VP", "B-NN", "M-NN",..] - 假定了输入为conll的格式,以空行隔开两个句子,每行共7列,即 - :: - - 1 编者按 编者按 NN O 11 nmod:topic - 2 : : PU O 11 punct - 3 7月 7月 NT DATE 4 compound:nn - 4 12日 12日 NT DATE 11 nmod:tmod - 5 , , PU O 11 punct - - 1 这 这 DT O 3 det - 2 款 款 M O 1 mark:clf - 3 飞行 飞行 NN O 8 nsubj - 4 从 从 P O 5 case - 5 外型 外型 NN O 8 nmod:prep - - """ - datalist = [] - with open(path, 'r', encoding='utf-8') as f: - sample = [] - for line in f: - if line.startswith('\n'): - datalist.append(sample) - sample = [] - elif line.startswith('#'): - continue - else: - sample.append(line.split('\t')) - if len(sample) > 0: - datalist.append(sample) - - ds = DataSet() - for sample in datalist: - # print(sample) - res = self.get_one(sample) - if res is None: - continue - char_seq = [] - pos_seq = [] - for word, tag in zip(res[0], res[1]): - char_seq.extend(list(word)) - if len(word) == 1: - pos_seq.append('S-{}'.format(tag)) - elif len(word) > 1: - pos_seq.append('B-{}'.format(tag)) - for _ in range(len(word) - 2): - pos_seq.append('M-{}'.format(tag)) - pos_seq.append('E-{}'.format(tag)) - else: - raise ValueError("Zero length of word detected.") - - ds.append(Instance(words=char_seq, - tag=pos_seq)) - - return ds - - def get_one(self, sample): - if len(sample) == 0: - return None - text = [] - pos_tags = [] - for w in sample: - t1, t2, t3, t4 = w[1], w[3], w[6], w[7] - if t3 == '_': - return None - text.append(t1) - pos_tags.append(t2) - return text, pos_tags - - -class ConllxDataLoader(ConllLoader): - """返回“词级别”的标签信息,包括词、词性、(句法)头依赖、(句法)边标签。跟``ZhConllPOSReader``完全不同。 - - Deprecated. Use ConllLoader for all types of conll-format files. - """ - def __init__(self): - headers = [ - 'words', 'pos_tags', 'heads', 'labels', - ] - indexs = [ - 1, 3, 6, 7, - ] - super(ConllxDataLoader, self).__init__(headers=headers, indexs=indexs) - - class SSTLoader(DataSetLoader): """load SST data in PTB tree format data source: https://nlp.stanford.edu/sentiment/trainDevTestTrees_PTB.zip @@ -842,10 +240,7 @@ class SSTLoader(DataSetLoader): """ :param path: str,存储数据的路径 - :return: DataSet。内含field有'words', 'pos_tags', 'heads', 'labels'(parser的label) - 类似于拥有以下结构, 一行为一个instance(sample) - words pos_tags heads labels - ['some', ..] ['NN', ...] [2, 3...] ['nn', 'nn'...] + :return: DataSet。 """ datalist = [] with open(path, 'r', encoding='utf-8') as f: @@ -860,7 +255,6 @@ class SSTLoader(DataSetLoader): @staticmethod def get_one(data, subtree): - from nltk.tree import Tree tree = Tree.fromstring(data) if subtree: return [(t.leaves(), t.label()) for t in tree.subtrees()] @@ -872,26 +266,72 @@ class JsonLoader(DataSetLoader): every line contains a json obj, like a dict fields is the dict key that need to be load """ - def __init__(self, **fields): + def __init__(self, dropna=False, fields=None): super(JsonLoader, self).__init__() - self.fields = {} - for k, v in fields.items(): - self.fields[k] = k if v is None else v + self.dropna = dropna + self.fields = None + self.fields_list = None + if fields: + self.fields = {} + for k, v in fields.items(): + self.fields[k] = k if v is None else v + self.fields_list = list(self.fields.keys()) + + def load(self, path): + ds = DataSet() + for idx, d in read_json(path, fields=self.fields_list, dropna=self.dropna): + ins = {self.fields[k]:v for k,v in d.items()} + ds.append(Instance(**ins)) + return ds + + +class SNLILoader(JsonLoader): + """ + data source: https://nlp.stanford.edu/projects/snli/snli_1.0.zip + """ + def __init__(self): + fields = { + 'sentence1_parse': 'words1', + 'sentence2_parse': 'words2', + 'gold_label': 'target', + } + super(SNLILoader, self).__init__(fields=fields) + + def load(self, path): + ds = super(SNLILoader, self).load(path) + def parse_tree(x): + t = Tree.fromstring(x) + return t.leaves() + ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') + ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') + ds.drop(lambda x: x['target'] == '-') + return ds + + +class CSVLoader(DataSetLoader): + """Load data from a CSV file and return a DataSet object. + + :param str csv_path: path to the CSV file + :param List[str] or Tuple[str] headers: headers of the CSV file + :param str sep: delimiter in CSV file. Default: "," + :param bool dropna: If True, drop rows that have less entries than headers. + :return dataset: the read data set + + """ + def __init__(self, headers=None, sep=",", dropna=True): + self.headers = headers + self.sep = sep + self.dropna = dropna def load(self, path): - with open(path, 'r', encoding='utf-8') as f: - datas = [json.loads(l) for l in f] ds = DataSet() - for d in datas: - ins = Instance() - for k, v in d.items(): - if k in self.fields: - ins.add_field(self.fields[k], v) - ds.append(ins) + for idx, data in read_csv(path, headers=self.headers, + sep=self.sep, dropna=self.dropna): + ds.append(Instance(**data)) return ds -def add_seg_tag(data): +def _add_seg_tag(data): """ :param data: list of ([word], [pos], [heads], [head_tags]) diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index c80e8b5f..258e8595 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -132,7 +132,7 @@ class EmbedLoader(BaseLoader): def load_with_vocab(embed_filepath, vocab, dtype=np.float32, normalize=True, error='ignore'): """ load pretraining embedding in {embed_file} based on words in vocab. Words in vocab but not in the pretraining - embedding are initialized from a normal distribution which has the mean and std of the found words vectors. + embedding are initialized from a normal distribution which has the mean and std of the found words vectors. The embedding type is determined automatically, support glove and word2vec(the first line only has two elements). :param embed_filepath: str, where to read pretrain embedding diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py new file mode 100644 index 00000000..22766ebb --- /dev/null +++ b/fastNLP/io/file_reader.py @@ -0,0 +1,112 @@ +import json + + +def read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): + """ + Construct a generator to read csv items + :param path: file path + :param encoding: file's encoding, default: utf-8 + :param headers: file's headers, if None, make file's first line as headers. default: None + :param sep: separator for each column. default: ',' + :param dropna: weather to ignore and drop invalid data, + if False, raise ValueError when reading invalid data. default: True + :return: generator, every time yield (line number, csv item) + """ + with open(path, 'r', encoding=encoding) as f: + start_idx = 0 + if headers is None: + headers = f.readline().rstrip('\r\n') + headers = headers.split(sep) + start_idx += 1 + elif not isinstance(headers, (list, tuple)): + raise TypeError("headers should be list or tuple, not {}." \ + .format(type(headers))) + for line_idx, line in enumerate(f, start_idx): + contents = line.rstrip('\r\n').split(sep) + if len(contents) != len(headers): + if dropna: + continue + else: + raise ValueError("Line {} has {} parts, while header has {} parts." \ + .format(line_idx, len(contents), len(headers))) + _dict = {} + for header, content in zip(headers, contents): + _dict[header] = content + yield line_idx, _dict + + +def read_json(path, encoding='utf-8', fields=None, dropna=True): + """ + Construct a generator to read json items + :param path: file path + :param encoding: file's encoding, default: utf-8 + :param fields: json object's fields that needed, if None, all fields are needed. default: None + :param dropna: weather to ignore and drop invalid data, + if False, raise ValueError when reading invalid data. default: True + :return: generator, every time yield (line number, json item) + """ + if fields: + fields = set(fields) + with open(path, 'r', encoding=encoding) as f: + for line_idx, line in enumerate(f): + data = json.loads(line) + if fields is None: + yield line_idx, data + continue + _res = {} + for k, v in data.items(): + if k in fields: + _res[k] = v + if len(_res) < len(fields): + if dropna: + continue + else: + raise ValueError('invalid instance at line: {}'.format(line_idx)) + yield line_idx, _res + + +def read_conll(path, encoding='utf-8', indexes=None, dropna=True): + """ + Construct a generator to read conll items + :param path: file path + :param encoding: file's encoding, default: utf-8 + :param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None + :param dropna: weather to ignore and drop invalid data, + if False, raise ValueError when reading invalid data. default: True + :return: generator, every time yield (line number, conll item) + """ + def parse_conll(sample): + sample = list(map(list, zip(*sample))) + sample = [sample[i] for i in indexes] + for f in sample: + if len(f) <= 0: + raise ValueError('empty field') + return sample + with open(path, 'r', encoding=encoding) as f: + sample = [] + start = next(f) + if '-DOCSTART-' not in start: + sample.append(start.split()) + for line_idx, line in enumerate(f, 1): + if line.startswith('\n'): + if len(sample): + try: + res = parse_conll(sample) + sample = [] + yield line_idx, res + except Exception as e: + if dropna: + continue + raise ValueError('invalid instance at line: {}'.format(line_idx)) + elif line.startswith('#'): + continue + else: + sample.append(line.split()) + if len(sample) > 0: + try: + res = parse_conll(sample) + yield line_idx, res + except Exception as e: + if dropna: + return + raise ValueError('invalid instance at line: {}'.format(line_idx)) diff --git a/fastNLP/io/model_io.py b/fastNLP/io/model_io.py index 422eb919..53bdc7ce 100644 --- a/fastNLP/io/model_io.py +++ b/fastNLP/io/model_io.py @@ -31,16 +31,18 @@ class ModelLoader(BaseLoader): class ModelSaver(object): """Save a model + Example:: - :param str save_path: the path to the saving directory. - Example:: - - saver = ModelSaver("./save/model_ckpt_100.pkl") - saver.save_pytorch(model) + saver = ModelSaver("./save/model_ckpt_100.pkl") + saver.save_pytorch(model) """ def __init__(self, save_path): + """ + + :param save_path: the path to the saving directory. + """ self.save_path = save_path def save_pytorch(self, model, param_only=True): diff --git a/fastNLP/models/char_language_model.py b/fastNLP/models/char_language_model.py index 5fbde3cc..d5e3359d 100644 --- a/fastNLP/models/char_language_model.py +++ b/fastNLP/models/char_language_model.py @@ -20,16 +20,23 @@ class Highway(nn.Module): class CharLM(nn.Module): """CNN + highway network + LSTM - # Input: + + # Input:: + 4D tensor with shape [batch_size, in_channel, height, width] - # Output: + + # Output:: + 2D Tensor with shape [batch_size, vocab_size] - # Arguments: + + # Arguments:: + char_emb_dim: the size of each character's attention word_emb_dim: the size of each word's attention vocab_size: num of unique words num_char: num of characters use_gpu: True or False + """ def __init__(self, char_emb_dim, word_emb_dim, diff --git a/fastNLP/models/enas_trainer.py b/fastNLP/models/enas_trainer.py index 6b51c897..26b7cd49 100644 --- a/fastNLP/models/enas_trainer.py +++ b/fastNLP/models/enas_trainer.py @@ -65,13 +65,14 @@ class ENASTrainer(fastNLP.Trainer): """ :param bool load_best_model: 该参数只有在初始化提供了dev_data的情况下有效,如果True, trainer将在返回之前重新加载dev表现 最好的模型参数。 - :return results: 返回一个字典类型的数据, 内含以下内容:: - - seconds: float, 表示训练时长 - 以下三个内容只有在提供了dev_data的情况下会有。 - best_eval: Dict of Dict, 表示evaluation的结果 - best_epoch: int,在第几个epoch取得的最佳值 - best_step: int, 在第几个step(batch)更新取得的最佳值 + :return results: 返回一个字典类型的数据, + 内含以下内容:: + + seconds: float, 表示训练时长 + 以下三个内容只有在提供了dev_data的情况下会有。 + best_eval: Dict of Dict, 表示evaluation的结果 + best_epoch: int,在第几个epoch取得的最佳值 + best_step: int, 在第几个step(batch)更新取得的最佳值 """ results = {} diff --git a/fastNLP/models/sequence_modeling.py b/fastNLP/models/sequence_modeling.py index cb9e9478..cb615daf 100644 --- a/fastNLP/models/sequence_modeling.py +++ b/fastNLP/models/sequence_modeling.py @@ -79,7 +79,7 @@ class SeqLabeling(BaseModel): :return prediction: list of [decode path(list)] """ max_len = x.shape[1] - tag_seq = self.Crf.viterbi_decode(x, self.mask) + tag_seq, _ = self.Crf.viterbi_decode(x, self.mask) # pad prediction to equal length if pad is True: for pred in tag_seq: diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py index 6a7d8d84..901f2dd4 100644 --- a/fastNLP/models/snli.py +++ b/fastNLP/models/snli.py @@ -1,6 +1,5 @@ import torch import torch.nn as nn -import torch.nn.functional as F from fastNLP.models.base_model import BaseModel from fastNLP.modules import decoder as Decoder @@ -40,7 +39,7 @@ class ESIM(BaseModel): batch_first=self.batch_first, bidirectional=True ) - self.bi_attention = Aggregator.Bi_Attention() + self.bi_attention = Aggregator.BiAttention() self.mean_pooling = Aggregator.MeanPoolWithMask() self.max_pooling = Aggregator.MaxPoolWithMask() @@ -53,23 +52,23 @@ class ESIM(BaseModel): self.output = Decoder.MLP([4 * self.hidden_size, self.hidden_size, self.n_labels], 'tanh', dropout=self.dropout) - def forward(self, premise, hypothesis, premise_len, hypothesis_len): + def forward(self, words1, words2, seq_len1, seq_len2): """ Forward function - :param premise: A Tensor represents premise: [batch size(B), premise seq len(PL)]. - :param hypothesis: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. - :param premise_len: A Tensor record which is a real word and which is a padding word in premise: [B, PL]. - :param hypothesis_len: A Tensor record which is a real word and which is a padding word in hypothesis: [B, HL]. + :param words1: A Tensor represents premise: [batch size(B), premise seq len(PL)]. + :param words2: A Tensor represents hypothesis: [B, hypothesis seq len(HL)]. + :param seq_len1: A Tensor record which is a real word and which is a padding word in premise: [B]. + :param seq_len2: A Tensor record which is a real word and which is a padding word in hypothesis: [B]. :return: prediction: A Dict with Tensor of classification result: [B, n_labels(N)]. """ - premise0 = self.embedding_layer(self.embedding(premise)) - hypothesis0 = self.embedding_layer(self.embedding(hypothesis)) + premise0 = self.embedding_layer(self.embedding(words1)) + hypothesis0 = self.embedding_layer(self.embedding(words2)) _BP, _PSL, _HP = premise0.size() _BH, _HSL, _HH = hypothesis0.size() - _BPL, _PLL = premise_len.size() - _HPL, _HLL = hypothesis_len.size() + _BPL, _PLL = seq_len1.size() + _HPL, _HLL = seq_len2.size() assert _BP == _BH and _BPL == _HPL and _BP == _BPL assert _HP == _HH @@ -84,7 +83,7 @@ class ESIM(BaseModel): a = torch.mean(a0.view(B, PL, -1, H), dim=2) # a: [B, PL, H] b = torch.mean(b0.view(B, HL, -1, H), dim=2) # b: [B, HL, H] - ai, bi = self.bi_attention(a, b, premise_len, hypothesis_len) + ai, bi = self.bi_attention(a, b, seq_len1, seq_len2) ma = torch.cat((a, ai, a - ai, a * ai), dim=2) # ma: [B, PL, 4 * H] mb = torch.cat((b, bi, b - bi, b * bi), dim=2) # mb: [B, HL, 4 * H] @@ -98,17 +97,18 @@ class ESIM(BaseModel): va = torch.mean(vat.view(B, PL, -1, H), dim=2) # va: [B, PL, H] vb = torch.mean(vbt.view(B, HL, -1, H), dim=2) # vb: [B, HL, H] - va_ave = self.mean_pooling(va, premise_len, dim=1) # va_ave: [B, H] - va_max, va_arg_max = self.max_pooling(va, premise_len, dim=1) # va_max: [B, H] - vb_ave = self.mean_pooling(vb, hypothesis_len, dim=1) # vb_ave: [B, H] - vb_max, vb_arg_max = self.max_pooling(vb, hypothesis_len, dim=1) # vb_max: [B, H] + va_ave = self.mean_pooling(va, seq_len1, dim=1) # va_ave: [B, H] + va_max, va_arg_max = self.max_pooling(va, seq_len1, dim=1) # va_max: [B, H] + vb_ave = self.mean_pooling(vb, seq_len2, dim=1) # vb_ave: [B, H] + vb_max, vb_arg_max = self.max_pooling(vb, seq_len2, dim=1) # vb_max: [B, H] v = torch.cat((va_ave, va_max, vb_ave, vb_max), dim=1) # v: [B, 4 * H] - prediction = F.tanh(self.output(v)) # prediction: [B, N] + prediction = torch.tanh(self.output(v)) # prediction: [B, N] return {'pred': prediction} - def predict(self, premise, hypothesis, premise_len, hypothesis_len): - return self.forward(premise, hypothesis, premise_len, hypothesis_len) + def predict(self, words1, words2, seq_len1, seq_len2): + prediction = self.forward(words1, words2, seq_len1, seq_len2)['pred'] + return {'pred': torch.argmax(prediction, dim=-1)} diff --git a/fastNLP/modules/aggregator/__init__.py b/fastNLP/modules/aggregator/__init__.py index 2fabb89e..43d60cac 100644 --- a/fastNLP/modules/aggregator/__init__.py +++ b/fastNLP/modules/aggregator/__init__.py @@ -5,6 +5,6 @@ from .avg_pool import MeanPoolWithMask from .kmax_pool import KMaxPool from .attention import Attention -from .attention import Bi_Attention +from .attention import BiAttention from .self_attention import SelfAttention diff --git a/fastNLP/modules/aggregator/attention.py b/fastNLP/modules/aggregator/attention.py index ef9d159d..4155fdd6 100644 --- a/fastNLP/modules/aggregator/attention.py +++ b/fastNLP/modules/aggregator/attention.py @@ -23,9 +23,9 @@ class Attention(torch.nn.Module): raise NotImplementedError -class DotAtte(nn.Module): +class DotAttention(nn.Module): def __init__(self, key_size, value_size, dropout=0.1): - super(DotAtte, self).__init__() + super(DotAttention, self).__init__() self.key_size = key_size self.value_size = value_size self.scale = math.sqrt(key_size) @@ -48,7 +48,7 @@ class DotAtte(nn.Module): return torch.matmul(output, V) -class MultiHeadAtte(nn.Module): +class MultiHeadAttention(nn.Module): def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1): """ @@ -58,7 +58,7 @@ class MultiHeadAtte(nn.Module): :param num_head: int,head的数量。 :param dropout: float。 """ - super(MultiHeadAtte, self).__init__() + super(MultiHeadAttention, self).__init__() self.input_size = input_size self.key_size = key_size self.value_size = value_size @@ -68,7 +68,7 @@ class MultiHeadAtte(nn.Module): self.q_in = nn.Linear(input_size, in_size) self.k_in = nn.Linear(input_size, in_size) self.v_in = nn.Linear(input_size, in_size) - self.attention = DotAtte(key_size=key_size, value_size=value_size) + self.attention = DotAttention(key_size=key_size, value_size=value_size) self.out = nn.Linear(value_size * num_head, input_size) self.drop = TimestepDropout(dropout) self.reset_parameters() @@ -109,16 +109,34 @@ class MultiHeadAtte(nn.Module): return output -class Bi_Attention(nn.Module): +class BiAttention(nn.Module): + """Bi Attention module + Calculate Bi Attention matrix `e` + + .. math:: + + \begin{array}{ll} \\ + e_ij = {a}^{\mathbf{T}}_{i}{b}_{j} \\ + a_i = + b_j = + \end{array} + + """ + def __init__(self): - super(Bi_Attention, self).__init__() + super(BiAttention, self).__init__() self.inf = 10e12 def forward(self, in_x1, in_x2, x1_len, x2_len): - # in_x1: [batch_size, x1_seq_len, hidden_size] - # in_x2: [batch_size, x2_seq_len, hidden_size] - # x1_len: [batch_size, x1_seq_len] - # x2_len: [batch_size, x2_seq_len] + """ + :param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示 + :param torch.Tensor in_x2: [batch_size, x2_seq_len, hidden_size] 第二句的特征表示 + :param torch.Tensor x1_len: [batch_size, x1_seq_len] 第一句的0/1mask矩阵 + :param torch.Tensor x2_len: [batch_size, x2_seq_len] 第二句的0/1mask矩阵 + :return: torch.Tensor out_x1: [batch_size, x1_seq_len, hidden_size] 第一句attend到的特征表示 + torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示 + + """ assert in_x1.size()[0] == in_x2.size()[0] assert in_x1.size()[2] == in_x2.size()[2] diff --git a/fastNLP/modules/decoder/CRF.py b/fastNLP/modules/decoder/CRF.py index df004224..99e7a9c2 100644 --- a/fastNLP/modules/decoder/CRF.py +++ b/fastNLP/modules/decoder/CRF.py @@ -2,12 +2,7 @@ import torch from torch import nn from fastNLP.modules.utils import initial_parameter - - -def log_sum_exp(x, dim=-1): - max_value, _ = x.max(dim=dim, keepdim=True) - res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value - return res.squeeze(dim) +from fastNLP.modules.decoder.utils import log_sum_exp def seq_len_to_byte_mask(seq_lens): @@ -20,22 +15,27 @@ def seq_len_to_byte_mask(seq_lens): return mask -def allowed_transitions(id2label, encoding_type='bio'): +def allowed_transitions(id2label, encoding_type='bio', include_start_end=True): """ + 给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。 - :param dict id2label: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 - "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()id2label。 + :param id2label: Dict, key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是 + "B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.get_id2word()得到id2label。 :param encoding_type: str, 支持"bio", "bmes", "bmeso"。 - :return: List[Tuple(int, int)]], 内部的Tuple是(from_tag_id, to_tag_id)。 返回的结果考虑了start和end,比如"BIO"中,B、O可以 - 位于序列的开端,而I不行。所以返回的结果中会包含(start_idx, B_idx), (start_idx, O_idx), 但是不包含(start_idx, I_idx). - start_idx=len(id2label), end_idx=len(id2label)+1。 + :param include_start_end: bool, 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头; + 为True,返回的结果中会包含(start_idx, b_idx), (start_idx, o_idx), 但是不包含(start_idx, i_idx); + start_idx=len(id2label), end_idx=len(id2label)+1。 + 为False, 返回的结果中不含与开始结尾相关的内容 + :return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。 """ num_tags = len(id2label) start_idx = num_tags end_idx = num_tags + 1 encoding_type = encoding_type.lower() allowed_trans = [] - id_label_lst = list(id2label.items()) + [(start_idx, 'start'), (end_idx, 'end')] + id_label_lst = list(id2label.items()) + if include_start_end: + id_label_lst += [(start_idx, 'start'), (end_idx, 'end')] def split_tag_label(from_label): from_label = from_label.lower() if from_label in ['start', 'end']: @@ -54,12 +54,12 @@ def allowed_transitions(id2label, encoding_type='bio'): if to_label in ['', '']: continue to_tag, to_label = split_tag_label(to_label) - if is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): + if _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): allowed_trans.append((from_id, to_id)) return allowed_trans -def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): +def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label): """ :param encoding_type: str, 支持"BIO", "BMES", "BEMSO"。 @@ -140,20 +140,22 @@ def is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label) raise ValueError("Unexpect tag type {}. Expect only 'B', 'M', 'E', 'S', 'O'.".format(from_tag)) else: - raise ValueError("Only support BIO, BMES encoding type, got {}.".format(encoding_type)) + raise ValueError("Only support BIO, BMES, BMESO encoding type, got {}.".format(encoding_type)) class ConditionalRandomField(nn.Module): - """ - - :param int num_tags: 标签的数量。 - :param bool include_start_end_trans: 是否包含起始tag - :param list allowed_transitions: ``List[Tuple[from_tag_id(int), to_tag_id(int)]]``. 允许的跃迁,可以通过allowed_transitions()得到。 - 如果为None,则所有跃迁均为合法 - :param str initial_method: - """ - - def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, initial_method=None): + def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None, + initial_method=None): + """条件随机场。 + 提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。 + + :param num_tags: int, 标签的数量 + :param include_start_end_trans: bool, 是否考虑各个tag作为开始以及结尾的分数。 + :param allowed_transitions: List[Tuple[from_tag_id(int), to_tag_id(int)]], 内部的Tuple[from_tag_id(int), + to_tag_id(int)]视为允许发生的跃迁,其他没有包含的跃迁认为是禁止跃迁,可以通过 + allowed_transitions()函数得到;如果为None,则所有跃迁均为合法 + :param initial_method: str, 初始化方法。见initial_parameter + """ super(ConditionalRandomField, self).__init__() self.include_start_end_trans = include_start_end_trans @@ -168,18 +170,12 @@ class ConditionalRandomField(nn.Module): if allowed_transitions is None: constrain = torch.zeros(num_tags + 2, num_tags + 2) else: - constrain = torch.ones(num_tags + 2, num_tags + 2) * -1000 + constrain = torch.new_full((num_tags+2, num_tags+2), fill_value=-10000.0, dtype=torch.float) for from_tag_id, to_tag_id in allowed_transitions: constrain[from_tag_id, to_tag_id] = 0 self._constrain = nn.Parameter(constrain, requires_grad=False) - # self.reset_parameter() initial_parameter(self, initial_method) - def reset_parameter(self): - nn.init.xavier_normal_(self.trans_m) - if self.include_start_end_trans: - nn.init.normal_(self.start_scores) - nn.init.normal_(self.end_scores) def _normalizer_likelihood(self, logits, mask): """Computes the (batch_size,) denominator term for the log-likelihood, which is the @@ -239,10 +235,11 @@ class ConditionalRandomField(nn.Module): def forward(self, feats, tags, mask): """ - Calculate the neg log likelihood - :param feats:FloatTensor, batch_size x max_len x num_tags - :param tags:LongTensor, batch_size x max_len - :param mask:ByteTensor batch_size x max_len + 用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。 + + :param feats:FloatTensor, batch_size x max_len x num_tags,特征矩阵。 + :param tags:LongTensor, batch_size x max_len,标签矩阵。 + :param mask:ByteTensor batch_size x max_len,为0的位置认为是padding。 :return:FloatTensor, batch_size """ feats = feats.transpose(0, 1) @@ -253,28 +250,27 @@ class ConditionalRandomField(nn.Module): return all_path_score - gold_path_score - def viterbi_decode(self, data, mask, get_score=False, unpad=False): - """Given a feats matrix, return best decode path and best score. + def viterbi_decode(self, feats, mask, unpad=False): + """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 - :param data:FloatTensor, batch_size x max_len x num_tags - :param mask:ByteTensor batch_size x max_len - :param get_score: bool, whether to output the decode score. - :param unpad: bool, 是否将结果unpad, - 如果False, 返回的是batch_size x max_len的tensor, - 如果True,返回的是List[List[int]], List[int]为每个sequence的label,已经unpadding了,即每个 - List[int]的长度是这个sample的有效长度 - :return: 如果get_score为False,返回结果根据unpadding变动 - 如果get_score为True, 返回 (paths, List[float], )。第一个仍然是解码后的路径(根据unpad变化),第二个List[Float] - 为每个seqence的解码分数。 + :param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 + :param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 + :param unpad: bool, 是否将结果删去padding, + False, 返回的是batch_size x max_len的tensor, + True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int] + 的长度是这个sample的有效长度。 + :return: 返回 (paths, scores)。 + paths: 是解码后的路径, 其值参照unpad参数. + scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 """ - batch_size, seq_len, n_tags = data.size() - data = data.transpose(0, 1).data # L, B, H + batch_size, seq_len, n_tags = feats.size() + feats = feats.transpose(0, 1).data # L, B, H mask = mask.transpose(0, 1).data.byte() # L, B # dp - vpath = data.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) - vscore = data[0] + vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) + vscore = feats[0] transitions = self._constrain.data.clone() transitions[:n_tags, :n_tags] += self.trans_m.data if self.include_start_end_trans: @@ -285,23 +281,24 @@ class ConditionalRandomField(nn.Module): trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data for i in range(1, seq_len): prev_score = vscore.view(batch_size, n_tags, 1) - cur_score = data[i].view(batch_size, 1, n_tags) + cur_score = feats[i].view(batch_size, 1, n_tags) score = prev_score + trans_score + cur_score best_score, best_dst = score.max(1) vpath[i] = best_dst vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ vscore.masked_fill(mask[i].view(batch_size, 1), 0) - vscore += transitions[:n_tags, n_tags+1].view(1, -1) + if self.include_start_end_trans: + vscore += transitions[:n_tags, n_tags+1].view(1, -1) # backtrace - batch_idx = torch.arange(batch_size, dtype=torch.long, device=data.device) - seq_idx = torch.arange(seq_len, dtype=torch.long, device=data.device) + batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) + seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) lens = (mask.long().sum(0) - 1) # idxes [L, B], batched idx from seq_len-1 to 0 idxes = (lens.view(1,-1) - seq_idx.view(-1,1)) % seq_len - ans = data.new_empty((seq_len, batch_size), dtype=torch.long) + ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) ans_score, last_tags = vscore.max(1) ans[idxes[0], batch_idx] = last_tags for i in range(seq_len - 1): diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py new file mode 100644 index 00000000..6e35af9a --- /dev/null +++ b/fastNLP/modules/decoder/utils.py @@ -0,0 +1,70 @@ + +import torch + + +def log_sum_exp(x, dim=-1): + max_value, _ = x.max(dim=dim, keepdim=True) + res = torch.log(torch.sum(torch.exp(x - max_value), dim=dim, keepdim=True)) + max_value + return res.squeeze(dim) + + +def viterbi_decode(feats, transitions, mask=None, unpad=False): + """给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 + + :param feats: FloatTensor, batch_size x max_len x num_tags,特征矩阵。 + :param transitions: FloatTensor, n_tags x n_tags。[i, j]位置的值认为是从tag i到tag j的转换。 + :param mask: ByteTensor, batch_size x max_len, 为0的位置认为是pad;如果为None,则认为没有padding。 + :param unpad: bool, 是否将结果删去padding, + False, 返回的是batch_size x max_len的tensor, + True,返回的是List[List[int]], 内部的List[int]为每个sequence的label,已经除去pad部分,即每个List[int]的长度是 + 这个sample的有效长度。 + :return: 返回 (paths, scores)。 + paths: 是解码后的路径, 其值参照unpad参数. + scores: torch.FloatTensor, size为(batch_size,), 对应每个最优路径的分数。 + + """ + batch_size, seq_len, n_tags = feats.size() + assert n_tags==transitions.size(0) and n_tags==transitions.size(1), "The shapes of transitions and feats are not " \ + "compatible." + feats = feats.transpose(0, 1).data # L, B, H + if mask is not None: + mask = mask.transpose(0, 1).data.byte() # L, B + else: + mask = feats.new_ones((seq_len, batch_size), dtype=torch.uint8) + + # dp + vpath = feats.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long) + vscore = feats[0] + + vscore += transitions[n_tags, :n_tags] + trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data + for i in range(1, seq_len): + prev_score = vscore.view(batch_size, n_tags, 1) + cur_score = feats[i].view(batch_size, 1, n_tags) + score = prev_score + trans_score + cur_score + best_score, best_dst = score.max(1) + vpath[i] = best_dst + vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \ + vscore.masked_fill(mask[i].view(batch_size, 1), 0) + + # backtrace + batch_idx = torch.arange(batch_size, dtype=torch.long, device=feats.device) + seq_idx = torch.arange(seq_len, dtype=torch.long, device=feats.device) + lens = (mask.long().sum(0) - 1) + # idxes [L, B], batched idx from seq_len-1 to 0 + idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len + + ans = feats.new_empty((seq_len, batch_size), dtype=torch.long) + ans_score, last_tags = vscore.max(1) + ans[idxes[0], batch_idx] = last_tags + for i in range(seq_len - 1): + last_tags = vpath[idxes[i], batch_idx, last_tags] + ans[idxes[i + 1], batch_idx] = last_tags + ans = ans.transpose(0, 1) + if unpad: + paths = [] + for idx, seq_len in enumerate(lens): + paths.append(ans[idx, :seq_len + 1].tolist()) + else: + paths = ans + return paths, ans_score \ No newline at end of file diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py index 48c67a64..04f331f7 100644 --- a/fastNLP/modules/encoder/lstm.py +++ b/fastNLP/modules/encoder/lstm.py @@ -1,4 +1,6 @@ +import torch import torch.nn as nn +import torch.nn.utils.rnn as rnn from fastNLP.modules.utils import initial_parameter @@ -19,21 +21,44 @@ class LSTM(nn.Module): def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True, bidirectional=False, bias=True, initial_method=None, get_hidden=False): super(LSTM, self).__init__() + self.batch_first = batch_first self.lstm = nn.LSTM(input_size, hidden_size, num_layers, bias=bias, batch_first=batch_first, dropout=dropout, bidirectional=bidirectional) self.get_hidden = get_hidden initial_parameter(self, initial_method) - def forward(self, x, h0=None, c0=None): + def forward(self, x, seq_lens=None, h0=None, c0=None): if h0 is not None and c0 is not None: - x, (ht, ct) = self.lstm(x, (h0, c0)) + hx = (h0, c0) else: - x, (ht, ct) = self.lstm(x) - if self.get_hidden: - return x, (ht, ct) + hx = None + if seq_lens is not None and not isinstance(x, rnn.PackedSequence): + print('padding') + sort_lens, sort_idx = torch.sort(seq_lens, dim=0, descending=True) + if self.batch_first: + x = x[sort_idx] + else: + x = x[:, sort_idx] + x = rnn.pack_padded_sequence(x, sort_lens, batch_first=self.batch_first) + output, hx = self.lstm(x, hx) # -> [N,L,C] + output, _ = rnn.pad_packed_sequence(output, batch_first=self.batch_first) + _, unsort_idx = torch.sort(sort_idx, dim=0, descending=False) + if self.batch_first: + output = output[unsort_idx] + else: + output = output[:, unsort_idx] else: - return x + output, hx = self.lstm(x, hx) + if self.get_hidden: + return output, hx + return output if __name__ == "__main__": - lstm = LSTM(10) + lstm = LSTM(input_size=2, hidden_size=2, get_hidden=False) + x = torch.randn((3, 5, 2)) + seq_lens = torch.tensor([5,1,2]) + y = lstm(x, seq_lens) + print(x) + print(y) + print(x.size(), y.size(), ) diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py index d7b8c544..d1262141 100644 --- a/fastNLP/modules/encoder/transformer.py +++ b/fastNLP/modules/encoder/transformer.py @@ -1,6 +1,6 @@ from torch import nn -from ..aggregator.attention import MultiHeadAtte +from ..aggregator.attention import MultiHeadAttention from ..dropout import TimestepDropout @@ -18,7 +18,7 @@ class TransformerEncoder(nn.Module): class SubLayer(nn.Module): def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1): super(TransformerEncoder.SubLayer, self).__init__() - self.atte = MultiHeadAtte(model_size, key_size, value_size, num_head, dropout) + self.atte = MultiHeadAttention(model_size, key_size, value_size, num_head, dropout) self.norm1 = nn.LayerNorm(model_size) self.ffn = nn.Sequential(nn.Linear(model_size, inner_size), nn.ReLU(), diff --git a/reproduction/Chinese_word_segmentation/models/cws_model.py b/reproduction/Chinese_word_segmentation/models/cws_model.py index daefc380..13632207 100644 --- a/reproduction/Chinese_word_segmentation/models/cws_model.py +++ b/reproduction/Chinese_word_segmentation/models/cws_model.py @@ -183,7 +183,7 @@ class CWSBiLSTMCRF(BaseModel): masks = seq_lens_to_mask(seq_lens) feats = self.encoder_model(chars, bigrams, seq_lens) feats = self.decoder_model(feats) - probs = self.crf.viterbi_decode(feats, masks, get_score=False) + paths, _ = self.crf.viterbi_decode(feats, masks) - return {'pred': probs, 'seq_lens':seq_lens} + return {'pred': paths, 'seq_lens':seq_lens} diff --git a/reproduction/Chinese_word_segmentation/models/cws_transformer.py b/reproduction/Chinese_word_segmentation/models/cws_transformer.py index 375eaa14..f6c2dab6 100644 --- a/reproduction/Chinese_word_segmentation/models/cws_transformer.py +++ b/reproduction/Chinese_word_segmentation/models/cws_transformer.py @@ -145,9 +145,9 @@ class TransformerDilatedCWS(nn.Module): feats = self.transformer(x, masks) feats = self.fc2(feats) - probs = self.crf.viterbi_decode(feats, masks, get_score=False) + paths, _ = self.crf.viterbi_decode(feats, masks) - return {'pred': probs, 'seq_lens':seq_lens} + return {'pred': paths, 'seq_lens':seq_lens} diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py index 5ed1a711..833ee9ce 100644 --- a/test/core/test_dataset.py +++ b/test/core/test_dataset.py @@ -163,6 +163,11 @@ class TestDataSetMethods(unittest.TestCase): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target]) + def test_split(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) + d1, d2 = ds.split(0.1) + + def test_apply2(self): def split_sent(ins): return ins['raw_sentence'].split() @@ -202,20 +207,6 @@ class TestDataSetMethods(unittest.TestCase): self.assertTrue(isinstance(ans, FieldArray)) self.assertEqual(ans.content, [[5, 6]] * 10) - def test_reader(self): - # 跑通即可 - ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv") - self.assertTrue(isinstance(ds, DataSet)) - self.assertTrue(len(ds) > 0) - - ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt") - self.assertTrue(isinstance(ds, DataSet)) - self.assertTrue(len(ds) > 0) - - ds = DataSet().read_pos("test/data_for_tests/people.txt") - self.assertTrue(isinstance(ds, DataSet)) - self.assertTrue(len(ds) > 0) - def test_add_null(self): # TODO test failed because 'fastNLP\core\fieldarray.py:143: RuntimeError' ds = DataSet() diff --git a/test/data_for_tests/sample_snli.jsonl b/test/data_for_tests/sample_snli.jsonl new file mode 100644 index 00000000..e62856ac --- /dev/null +++ b/test/data_for_tests/sample_snli.jsonl @@ -0,0 +1,3 @@ +{"annotator_labels": ["neutral"], "captionID": "3416050480.jpg#4", "gold_label": "neutral", "pairID": "3416050480.jpg#4r1n", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is training his horse for a competition.", "sentence2_binary_parse": "( ( A person ) ( ( is ( ( training ( his horse ) ) ( for ( a competition ) ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (VP (VBG training) (NP (PRP$ his) (NN horse)) (PP (IN for) (NP (DT a) (NN competition))))) (. .)))"} +{"annotator_labels": ["contradiction"], "captionID": "3416050480.jpg#4", "gold_label": "contradiction", "pairID": "3416050480.jpg#4r1c", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is at a diner, ordering an omelette.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is ( at ( a diner ) ) ) , ) ( ordering ( an omelette ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (PP (IN at) (NP (DT a) (NN diner))) (, ,) (S (VP (VBG ordering) (NP (DT an) (NN omelette))))) (. .)))"} +{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} \ No newline at end of file diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py index 4dddc5d0..97379a7d 100644 --- a/test/io/test_dataset_loader.py +++ b/test/io/test_dataset_loader.py @@ -1,8 +1,7 @@ import unittest -from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, ConllCWSReader, \ - ZhConllPOSReader, ConllxDataLoader - +from fastNLP.io.dataset_loader import Conll2003Loader, PeopleDailyCorpusLoader, \ + CSVLoader, SNLILoader class TestDatasetLoader(unittest.TestCase): @@ -17,3 +16,11 @@ class TestDatasetLoader(unittest.TestCase): def test_PeopleDailyCorpusLoader(self): data_set = PeopleDailyCorpusLoader().load("test/data_for_tests/people_daily_raw.txt") + def test_CSVLoader(self): + ds = CSVLoader(sep='\t', headers=['words', 'label'])\ + .load('test/data_for_tests/tutorial_sample_dataset.csv') + assert len(ds) > 0 + + def test_SNLILoader(self): + ds = SNLILoader().load('test/data_for_tests/sample_snli.jsonl') + assert len(ds) == 3 diff --git a/test/modules/test_utils.py b/test/modules/test_utils.py deleted file mode 100644 index 1d3cfcac..00000000 --- a/test/modules/test_utils.py +++ /dev/null @@ -1,9 +0,0 @@ -import unittest - - -class TestUtils(unittest.TestCase): - def test_case_1(self): - pass - - def test_case_2(self): - pass diff --git a/test/test_tutorials.py b/test/test_tutorials.py index bc0b5d2b..600699a3 100644 --- a/test/test_tutorials.py +++ b/test/test_tutorials.py @@ -379,6 +379,14 @@ class TestTutorial(unittest.TestCase): dev_data_2.apply(lambda x: [vocab_bert.to_index(word) for word in x['hypothesis']], new_field_name='hypothesis') train_data_2[-1], dev_data_2[-1] + for data in [train_data, dev_data, test_data]: + data.rename_field('premise', 'words1') + data.rename_field('hypothesis', 'words2') + data.rename_field('premise_len', 'seq_len1') + data.rename_field('hypothesis_len', 'seq_len2') + data.set_input('words1', 'words2', 'seq_len1', 'seq_len2') + + # step 1:加载模型参数(非必选) from fastNLP.io.config_io import ConfigSection, ConfigLoader args = ConfigSection()