From 46ea42498d669957ec155477010771f4318ec3ac Mon Sep 17 00:00:00 2001 From: Yige Xu Date: Sun, 12 Jan 2020 10:40:33 +0800 Subject: [PATCH] [update] add Loader and Pipe for AG's News dataset --- fastNLP/io/loader/classification.py | 15 ++++++++++ fastNLP/io/loader/matching.py | 1 + fastNLP/io/pipe/classification.py | 45 +++++++++++++++++++++++++++++ 3 files changed, 61 insertions(+) diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index 12b10541..196673b2 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -4,6 +4,7 @@ __all__ = [ "YelpLoader", "YelpFullLoader", "YelpPolarityLoader", + "AGsNewsLoader", "IMDBLoader", "SSTLoader", "SST2Loader", @@ -161,6 +162,20 @@ class YelpPolarityLoader(YelpLoader): return data_dir +class AGsNewsLoader(YelpLoader): + def download(self): + """ + 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + + Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances + in Neural Information Processing Systems 28 (NIPS 2015) + + :return: str, 数据集的目录地址 + """ + + return self._get_dataset_path(dataset_name='ag-news') + + class IMDBLoader(Loader): """ 原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index 9c4c90d9..854ac7a8 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -57,6 +57,7 @@ class MNLILoader(Loader): f.readline() # 跳过header if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): warnings.warn("RTE's test file has no target.") + warnings.warn("MNLI's test file has no target.") for line in f: line = line.strip() if line: diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index ab31c9de..70fd8042 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -21,6 +21,8 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta from ..data_bundle import DataBundle from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader +from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \ + AGsNewsLoader from ...core._logger import logger from ...core.const import Const from ...core.dataset import DataSet @@ -272,6 +274,49 @@ class YelpPolarityPipe(_CLSPipe): return self.process(data_bundle=data_bundle) +class AGsNewsPipe(YelpFullPipe): + """ + 处理AG's News的数据, 处理之后DataSet中的内容如下 + + .. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field + :header: "raw_words", "target", "words", "seq_len" + + "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 + " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 + "...", ., "[...]", . + + dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: + + +-------------+-----------+--------+-------+---------+ + | field_names | raw_words | target | words | seq_len | + +-------------+-----------+--------+-------+---------+ + | is_input | False | False | True | True | + | is_target | False | True | False | False | + | ignore_type | | False | False | False | + | pad_value | | 0 | 0 | 0 | + +-------------+-----------+--------+-------+---------+ + + """ + + def __init__(self, lower: bool = False, tokenizer: str = 'spacy'): + """ + + :param bool lower: 是否对输入进行小写化。 + :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 + """ + super().__init__(lower=lower, tokenizer=tokenizer) + self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3} + + def process_from_file(self, paths=None): + """ + + :param str paths: + :return: DataBundle + """ + data_bundle = AGsNewsLoader().load(paths) + return self.process(data_bundle=data_bundle) + + class SSTPipe(_CLSPipe): """ 经过该Pipe之后,DataSet中具备的field如下所示