Browse Source

add original dataset loader

tags/v0.4.10
Danqing Wang 6 years ago
parent
commit
84e659b720
1 changed files with 333 additions and 0 deletions
  1. +333
    -0
      fastNLP/io/dataset_loader.py

+ 333
- 0
fastNLP/io/dataset_loader.py View File

@@ -0,0 +1,333 @@
"""
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` ,
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。
以SNLI数据集为例::

loader = SNLILoader()
train_ds = loader.load('path/to/train')
dev_ds = loader.load('path/to/dev')
test_ds = loader.load('path/to/test')

# ... do stuff
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。
"""
__all__ = [
'CSVLoader',
'JsonLoader',
'ConllLoader',
'SNLILoader',
'SSTLoader',
'PeopleDailyCorpusLoader',
'Conll2003Loader',
]

import os
from nltk import Tree
from typing import Union, Dict
from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json, _read_conll
from .base_loader import DataSetLoader, DataInfo
from .data_loader.sst import SSTLoader
from ..core.const import Const
from ..modules.encoder._bert import BertTokenizer


class PeopleDailyCorpusLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader`

读取人民日报数据集
"""

def __init__(self, pos=True, ner=True):
super(PeopleDailyCorpusLoader, self).__init__()
self.pos = pos
self.ner = ner

def _load(self, data_path):
with open(data_path, "r", encoding="utf-8") as f:
sents = f.readlines()
examples = []
for sent in sents:
if len(sent) <= 2:
continue
inside_ne = False
sent_pos_tag = []
sent_words = []
sent_ner = []
words = sent.strip().split()[1:]
for word in words:
if "[" in word and "]" in word:
ner_tag = "U"
print(word)
elif "[" in word:
inside_ne = True
ner_tag = "B"
word = word[1:]
elif "]" in word:
ner_tag = "L"
word = word[:word.index("]")]
if inside_ne is True:
inside_ne = False
else:
raise RuntimeError("only ] appears!")
else:
if inside_ne is True:
ner_tag = "I"
else:
ner_tag = "O"
tmp = word.split("/")
token, pos = tmp[0], tmp[1]
sent_ner.append(ner_tag)
sent_pos_tag.append(pos)
sent_words.append(token)
example = [sent_words]
if self.pos is True:
example.append(sent_pos_tag)
if self.ner is True:
example.append(sent_ner)
examples.append(example)
return self.convert(examples)

def convert(self, data):
"""

:param data: python 内置对象
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象
"""
data_set = DataSet()
for item in data:
sent_words = item[0]
if self.pos is True and self.ner is True:
instance = Instance(
words=sent_words, pos_tags=item[1], ner=item[2])
elif self.pos is True:
instance = Instance(words=sent_words, pos_tags=item[1])
elif self.ner is True:
instance = Instance(words=sent_words, ner=item[1])
else:
instance = Instance(words=sent_words)
data_set.append(instance)
data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
return data_set


class ConllLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader`

读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为
该符号在conll 2003中被用为文档分割符。

列号从0开始, 每列对应内容为::

Column Type
0 Document ID
1 Part number
2 Word number
3 Word itself
4 Part-of-Speech
5 Parse bit
6 Predicate lemma
7 Predicate Frameset ID
8 Word sense
9 Speaker/Author
10 Named Entities
11:N Predicate Arguments
N Coreference

:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
"""

def __init__(self, headers, indexes=None, dropna=False):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
raise TypeError(
'invalid headers: {}, should be list of strings'.format(headers))
self.headers = headers
self.dropna = dropna
if indexes is None:
self.indexes = list(range(len(self.headers)))
else:
if len(indexes) != len(headers):
raise ValueError
self.indexes = indexes

def _load(self, path):
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds


class Conll2003Loader(ConllLoader):
"""
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader`

读取Conll2003数据

关于数据集的更多信息,参考:
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
"""

def __init__(self):
headers = [
'tokens', 'pos', 'chunks', 'ner',
]
super(Conll2003Loader, self).__init__(headers=headers)


def _cut_long_sentence(sent, max_sample_length=200):
"""
将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。
所以截取的句子可能长于或者短于max_sample_length

:param sent: str.
:param max_sample_length: int.
:return: list of str.
"""
sent_no_space = sent.replace(' ', '')
cutted_sentence = []
if len(sent_no_space) > max_sample_length:
parts = sent.strip().split()
new_line = ''
length = 0
for part in parts:
length += len(part)
new_line += part + ' '
if length > max_sample_length:
new_line = new_line[:-1]
cutted_sentence.append(new_line)
length = 0
new_line = ''
if new_line != '':
cutted_sentence.append(new_line[:-1])
else:
cutted_sentence.append(sent)
return cutted_sentence


class JsonLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader`

读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象

:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""

def __init__(self, fields=None, dropna=False):
super(JsonLoader, self).__init__()
self.dropna = dropna
self.fields = None
self.fields_list = None
if fields:
self.fields = {}
for k, v in fields.items():
self.fields[k] = k if v is None else v
self.fields_list = list(self.fields.keys())

def _load(self, path):
ds = DataSet()
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
if self.fields:
ins = {self.fields[k]: v for k, v in d.items()}
else:
ins = d
ds.append(Instance(**ins))
return ds


class SNLILoader(JsonLoader):
"""
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`

读取SNLI数据集,读取的DataSet包含fields::

words1: list(str),第一句文本, premise
words2: list(str), 第二句文本, hypothesis
target: str, 真实标签

数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
"""

def __init__(self):
fields = {
'sentence1_parse': Const.INPUTS(0),
'sentence2_parse': Const.INPUTS(1),
'gold_label': Const.TARGET,
}
super(SNLILoader, self).__init__(fields=fields)

def _load(self, path):
ds = super(SNLILoader, self)._load(path)

def parse_tree(x):
t = Tree.fromstring(x)
return t.leaves()

ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
ds.apply(lambda ins: parse_tree(
ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
ds.drop(lambda x: x[Const.TARGET] == '-')
return ds


class CSVLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader`

读取CSV格式的数据集。返回 ``DataSet``

:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称
若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None``
:param str sep: CSV文件中列与列之间的分隔符. Default: ","
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
Default: ``False``
"""

def __init__(self, headers=None, sep=",", dropna=False):
self.headers = headers
self.sep = sep
self.dropna = dropna

def _load(self, path):
ds = DataSet()
for idx, data in _read_csv(path, headers=self.headers,
sep=self.sep, dropna=self.dropna):
ds.append(Instance(**data))
return ds


def _add_seg_tag(data):
"""

:param data: list of ([word], [pos], [heads], [head_tags])
:return: list of ([word], [pos])
"""

_processed = []
for word_list, pos_list, _, _ in data:
new_sample = []
for word, pos in zip(word_list, pos_list):
if len(word) == 1:
new_sample.append((word, 'S-' + pos))
else:
new_sample.append((word[0], 'B-' + pos))
for c in word[1:-1]:
new_sample.append((c, 'M-' + pos))
new_sample.append((word[-1], 'E-' + pos))
_processed.append(list(map(list, zip(*new_sample))))
return _processed

Loading…
Cancel
Save