Browse Source

[update] add Loader and Pipe for AG's News dataset

tags/v0.5.5
Yige Xu 5 years ago
parent
commit
46ea42498d
3 changed files with 61 additions and 0 deletions
  1. +15
    -0
      fastNLP/io/loader/classification.py
  2. +1
    -0
      fastNLP/io/loader/matching.py
  3. +45
    -0
      fastNLP/io/pipe/classification.py

+ 15
- 0
fastNLP/io/loader/classification.py View File

@@ -4,6 +4,7 @@ __all__ = [
"YelpLoader", "YelpLoader",
"YelpFullLoader", "YelpFullLoader",
"YelpPolarityLoader", "YelpPolarityLoader",
"AGsNewsLoader",
"IMDBLoader", "IMDBLoader",
"SSTLoader", "SSTLoader",
"SST2Loader", "SST2Loader",
@@ -161,6 +162,20 @@ class YelpPolarityLoader(YelpLoader):
return data_dir return data_dir




class AGsNewsLoader(YelpLoader):
def download(self):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章

Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
in Neural Information Processing Systems 28 (NIPS 2015)

:return: str, 数据集的目录地址
"""

return self._get_dataset_path(dataset_name='ag-news')


class IMDBLoader(Loader): class IMDBLoader(Loader):
""" """
原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。


+ 1
- 0
fastNLP/io/loader/matching.py View File

@@ -57,6 +57,7 @@ class MNLILoader(Loader):
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'): if path.endswith("test_matched.tsv") or path.endswith('test_mismatched.tsv'):
warnings.warn("RTE's test file has no target.") warnings.warn("RTE's test file has no target.")
warnings.warn("MNLI's test file has no target.")
for line in f: for line in f:
line = line.strip() line = line.strip()
if line: if line:


+ 45
- 0
fastNLP/io/pipe/classification.py View File

@@ -21,6 +21,8 @@ from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_insta
from ..data_bundle import DataBundle from ..data_bundle import DataBundle
from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader from ..loader.classification import ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader, \
AGsNewsLoader
from ...core._logger import logger from ...core._logger import logger
from ...core.const import Const from ...core.const import Const
from ...core.dataset import DataSet from ...core.dataset import DataSet
@@ -272,6 +274,49 @@ class YelpPolarityPipe(_CLSPipe):
return self.process(data_bundle=data_bundle) return self.process(data_bundle=data_bundle)




class AGsNewsPipe(YelpFullPipe):
"""
处理AG's News的数据, 处理之后DataSet中的内容如下

.. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field
:header: "raw_words", "target", "words", "seq_len"

"I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160
" Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40
"...", ., "[...]", .

dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为::

+-------------+-----------+--------+-------+---------+
| field_names | raw_words | target | words | seq_len |
+-------------+-----------+--------+-------+---------+
| is_input | False | False | True | True |
| is_target | False | True | False | False |
| ignore_type | | False | False | False |
| pad_value | | 0 | 0 | 0 |
+-------------+-----------+--------+-------+---------+

"""

def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
"""

:param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
"""
super().__init__(lower=lower, tokenizer=tokenizer)
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3}

def process_from_file(self, paths=None):
"""

:param str paths:
:return: DataBundle
"""
data_bundle = AGsNewsLoader().load(paths)
return self.process(data_bundle=data_bundle)


class SSTPipe(_CLSPipe): class SSTPipe(_CLSPipe):
""" """
经过该Pipe之后,DataSet中具备的field如下所示 经过该Pipe之后,DataSet中具备的field如下所示


Loading…
Cancel
Save