|
- """
- dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` ,
- 得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。
- 以SNLI数据集为例::
-
- loader = SNLILoader()
- train_ds = loader.load('path/to/train')
- dev_ds = loader.load('path/to/dev')
- test_ds = loader.load('path/to/test')
-
- # ... do stuff
-
- 为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。
- """
- __all__ = [
- 'CSVLoader',
- 'JsonLoader',
- 'ConllLoader',
- 'SNLILoader',
- 'SSTLoader',
- 'PeopleDailyCorpusLoader',
- 'Conll2003Loader',
- ]
-
- import os
- from nltk import Tree
- from typing import Union, Dict
- from ..core.vocabulary import Vocabulary
- from ..core.dataset import DataSet
- from ..core.instance import Instance
- from .file_reader import _read_csv, _read_json, _read_conll
- from .base_loader import DataSetLoader, DataInfo
- from .data_loader.sst import SSTLoader
- from ..core.const import Const
- from ..modules.encoder._bert import BertTokenizer
-
-
- class PeopleDailyCorpusLoader(DataSetLoader):
- """
- 别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader`
-
- 读取人民日报数据集
- """
-
- def __init__(self, pos=True, ner=True):
- super(PeopleDailyCorpusLoader, self).__init__()
- self.pos = pos
- self.ner = ner
-
- def _load(self, data_path):
- with open(data_path, "r", encoding="utf-8") as f:
- sents = f.readlines()
- examples = []
- for sent in sents:
- if len(sent) <= 2:
- continue
- inside_ne = False
- sent_pos_tag = []
- sent_words = []
- sent_ner = []
- words = sent.strip().split()[1:]
- for word in words:
- if "[" in word and "]" in word:
- ner_tag = "U"
- print(word)
- elif "[" in word:
- inside_ne = True
- ner_tag = "B"
- word = word[1:]
- elif "]" in word:
- ner_tag = "L"
- word = word[:word.index("]")]
- if inside_ne is True:
- inside_ne = False
- else:
- raise RuntimeError("only ] appears!")
- else:
- if inside_ne is True:
- ner_tag = "I"
- else:
- ner_tag = "O"
- tmp = word.split("/")
- token, pos = tmp[0], tmp[1]
- sent_ner.append(ner_tag)
- sent_pos_tag.append(pos)
- sent_words.append(token)
- example = [sent_words]
- if self.pos is True:
- example.append(sent_pos_tag)
- if self.ner is True:
- example.append(sent_ner)
- examples.append(example)
- return self.convert(examples)
-
- def convert(self, data):
- """
-
- :param data: python 内置对象
- :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
- """
- data_set = DataSet()
- for item in data:
- sent_words = item[0]
- if self.pos is True and self.ner is True:
- instance = Instance(
- words=sent_words, pos_tags=item[1], ner=item[2])
- elif self.pos is True:
- instance = Instance(words=sent_words, pos_tags=item[1])
- elif self.ner is True:
- instance = Instance(words=sent_words, ner=item[1])
- else:
- instance = Instance(words=sent_words)
- data_set.append(instance)
- data_set.apply(lambda ins: len(ins["words"]), new_field_name="seq_len")
- return data_set
-
-
- class ConllLoader(DataSetLoader):
- """
- 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader`
-
- 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为
- 该符号在conll 2003中被用为文档分割符。
-
- 列号从0开始, 每列对应内容为::
-
- Column Type
- 0 Document ID
- 1 Part number
- 2 Word number
- 3 Word itself
- 4 Part-of-Speech
- 5 Parse bit
- 6 Predicate lemma
- 7 Predicate Frameset ID
- 8 Word sense
- 9 Speaker/Author
- 10 Named Entities
- 11:N Predicate Arguments
- N Coreference
-
- :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
- :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
- :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
- """
-
- def __init__(self, headers, indexes=None, dropna=False):
- super(ConllLoader, self).__init__()
- if not isinstance(headers, (list, tuple)):
- raise TypeError(
- 'invalid headers: {}, should be list of strings'.format(headers))
- self.headers = headers
- self.dropna = dropna
- if indexes is None:
- self.indexes = list(range(len(self.headers)))
- else:
- if len(indexes) != len(headers):
- raise ValueError
- self.indexes = indexes
-
- def _load(self, path):
- ds = DataSet()
- for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
- ins = {h: data[i] for i, h in enumerate(self.headers)}
- ds.append(Instance(**ins))
- return ds
-
-
- class Conll2003Loader(ConllLoader):
- """
- 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader`
-
- 读取Conll2003数据
-
- 关于数据集的更多信息,参考:
- https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
- """
-
- def __init__(self):
- headers = [
- 'tokens', 'pos', 'chunks', 'ner',
- ]
- super(Conll2003Loader, self).__init__(headers=headers)
-
-
- def _cut_long_sentence(sent, max_sample_length=200):
- """
- 将长于max_sample_length的sentence截成多段,只会在有空格的地方发生截断。
- 所以截取的句子可能长于或者短于max_sample_length
-
- :param sent: str.
- :param max_sample_length: int.
- :return: list of str.
- """
- sent_no_space = sent.replace(' ', '')
- cutted_sentence = []
- if len(sent_no_space) > max_sample_length:
- parts = sent.strip().split()
- new_line = ''
- length = 0
- for part in parts:
- length += len(part)
- new_line += part + ' '
- if length > max_sample_length:
- new_line = new_line[:-1]
- cutted_sentence.append(new_line)
- length = 0
- new_line = ''
- if new_line != '':
- cutted_sentence.append(new_line[:-1])
- else:
- cutted_sentence.append(sent)
- return cutted_sentence
-
-
- class JsonLoader(DataSetLoader):
- """
- 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader`
-
- 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象
-
- :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
- ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
- `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
- ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
- :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
- Default: ``False``
- """
-
- def __init__(self, fields=None, dropna=False):
- super(JsonLoader, self).__init__()
- self.dropna = dropna
- self.fields = None
- self.fields_list = None
- if fields:
- self.fields = {}
- for k, v in fields.items():
- self.fields[k] = k if v is None else v
- self.fields_list = list(self.fields.keys())
-
- def _load(self, path):
- ds = DataSet()
- for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
- if self.fields:
- ins = {self.fields[k]: v for k, v in d.items()}
- else:
- ins = d
- ds.append(Instance(**ins))
- return ds
-
-
- class SNLILoader(JsonLoader):
- """
- 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
-
- 读取SNLI数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
- """
-
- def __init__(self):
- fields = {
- 'sentence1_parse': Const.INPUTS(0),
- 'sentence2_parse': Const.INPUTS(1),
- 'gold_label': Const.TARGET,
- }
- super(SNLILoader, self).__init__(fields=fields)
-
- def _load(self, path):
- ds = super(SNLILoader, self)._load(path)
-
- def parse_tree(x):
- t = Tree.fromstring(x)
- return t.leaves()
-
- ds.apply(lambda ins: parse_tree(
- ins[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0))
- ds.apply(lambda ins: parse_tree(
- ins[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1))
- ds.drop(lambda x: x[Const.TARGET] == '-')
- return ds
-
-
- class CSVLoader(DataSetLoader):
- """
- 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader`
-
- 读取CSV格式的数据集。返回 ``DataSet``
-
- :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称
- 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None``
- :param str sep: CSV文件中列与列之间的分隔符. Default: ","
- :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
- Default: ``False``
- """
-
- def __init__(self, headers=None, sep=",", dropna=False):
- self.headers = headers
- self.sep = sep
- self.dropna = dropna
-
- def _load(self, path):
- ds = DataSet()
- for idx, data in _read_csv(path, headers=self.headers,
- sep=self.sep, dropna=self.dropna):
- ds.append(Instance(**data))
- return ds
-
-
- def _add_seg_tag(data):
- """
-
- :param data: list of ([word], [pos], [heads], [head_tags])
- :return: list of ([word], [pos])
- """
-
- _processed = []
- for word_list, pos_list, _, _ in data:
- new_sample = []
- for word, pos in zip(word_list, pos_list):
- if len(word) == 1:
- new_sample.append((word, 'S-' + pos))
- else:
- new_sample.append((word[0], 'B-' + pos))
- for c in word[1:-1]:
- new_sample.append((c, 'M-' + pos))
- new_sample.append((word[-1], 'E-' + pos))
- _processed.append(list(map(list, zip(*new_sample))))
- return _processed
|