|
@@ -1,8 +1,6 @@ |
|
|
""" |
|
|
""" |
|
|
.. _dataset-loader: |
|
|
|
|
|
|
|
|
|
|
|
DataSetLoader 的 API, 用于读取不同格式的数据, 并返回 `DataSet` , |
|
|
|
|
|
得到的 `DataSet` 对象可以直接传入 `Trainer`, `Tester`, 用于模型的训练和测试 |
|
|
|
|
|
|
|
|
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` , |
|
|
|
|
|
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer`, :class:`~fastNLP.Tester`, 用于模型的训练和测试 |
|
|
|
|
|
|
|
|
Example:: |
|
|
Example:: |
|
|
|
|
|
|
|
@@ -13,50 +11,50 @@ Example:: |
|
|
|
|
|
|
|
|
# ... do stuff |
|
|
# ... do stuff |
|
|
""" |
|
|
""" |
|
|
import os |
|
|
|
|
|
import json |
|
|
|
|
|
|
|
|
|
|
|
from nltk.tree import Tree |
|
|
from nltk.tree import Tree |
|
|
|
|
|
|
|
|
from fastNLP.core.dataset import DataSet |
|
|
|
|
|
from fastNLP.core.instance import Instance |
|
|
|
|
|
from fastNLP.io.file_reader import _read_csv, _read_json, _read_conll |
|
|
|
|
|
|
|
|
from ..core.dataset import DataSet |
|
|
|
|
|
from ..core.instance import Instance |
|
|
|
|
|
from .file_reader import _read_csv, _read_json, _read_conll |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download_from_url(url, path): |
|
|
def _download_from_url(url, path): |
|
|
from tqdm import tqdm |
|
|
from tqdm import tqdm |
|
|
import requests |
|
|
import requests |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Download file""" |
|
|
"""Download file""" |
|
|
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) |
|
|
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) |
|
|
chunk_size = 16 * 1024 |
|
|
chunk_size = 16 * 1024 |
|
|
total_size = int(r.headers.get('Content-length', 0)) |
|
|
total_size = int(r.headers.get('Content-length', 0)) |
|
|
with open(path, "wb") as file ,\ |
|
|
|
|
|
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: |
|
|
|
|
|
|
|
|
with open(path, "wb") as file, \ |
|
|
|
|
|
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: |
|
|
for chunk in r.iter_content(chunk_size): |
|
|
for chunk in r.iter_content(chunk_size): |
|
|
if chunk: |
|
|
if chunk: |
|
|
file.write(chunk) |
|
|
file.write(chunk) |
|
|
t.update(len(chunk)) |
|
|
t.update(len(chunk)) |
|
|
return |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _uncompress(src, dst): |
|
|
def _uncompress(src, dst): |
|
|
import zipfile, gzip, tarfile, os |
|
|
import zipfile, gzip, tarfile, os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unzip(src, dst): |
|
|
def unzip(src, dst): |
|
|
with zipfile.ZipFile(src, 'r') as f: |
|
|
with zipfile.ZipFile(src, 'r') as f: |
|
|
f.extractall(dst) |
|
|
f.extractall(dst) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ungz(src, dst): |
|
|
def ungz(src, dst): |
|
|
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: |
|
|
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: |
|
|
length = 16 * 1024 # 16KB |
|
|
|
|
|
|
|
|
length = 16 * 1024 # 16KB |
|
|
buf = f.read(length) |
|
|
buf = f.read(length) |
|
|
while buf: |
|
|
while buf: |
|
|
uf.write(buf) |
|
|
uf.write(buf) |
|
|
buf = f.read(length) |
|
|
buf = f.read(length) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def untar(src, dst): |
|
|
def untar(src, dst): |
|
|
with tarfile.open(src, 'r:gz') as f: |
|
|
with tarfile.open(src, 'r:gz') as f: |
|
|
f.extractall(dst) |
|
|
f.extractall(dst) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn, ext = os.path.splitext(src) |
|
|
fn, ext = os.path.splitext(src) |
|
|
_, ext_2 = os.path.splitext(fn) |
|
|
_, ext_2 = os.path.splitext(fn) |
|
|
if ext == '.zip': |
|
|
if ext == '.zip': |
|
@@ -71,42 +69,48 @@ def _uncompress(src, dst): |
|
|
|
|
|
|
|
|
class DataSetLoader: |
|
|
class DataSetLoader: |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` |
|
|
|
|
|
|
|
|
所有`DataSetLoader`的接口 |
|
|
|
|
|
|
|
|
所有 DataSetLoader 的 API 接口,你可以继承它实现自己的 DataSetLoader |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path): |
|
|
def load(self, path): |
|
|
"""从指定 ``path`` 的文件中读取数据,返回DataSet |
|
|
"""从指定 ``path`` 的文件中读取数据,返回DataSet |
|
|
|
|
|
|
|
|
:param str path: file path |
|
|
|
|
|
:return: a DataSet object |
|
|
|
|
|
|
|
|
:param str path: 文件路径 |
|
|
|
|
|
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
""" |
|
|
""" |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert(self, data): |
|
|
def convert(self, data): |
|
|
"""用Python数据对象创建DataSet |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
用Python数据对象创建DataSet,各个子类需要自行实现这个方法 |
|
|
|
|
|
|
|
|
:param data: inner data structure (user-defined) to represent the data. |
|
|
|
|
|
:return: a DataSet object |
|
|
|
|
|
|
|
|
:param data: Python 内置的数据结构 |
|
|
|
|
|
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
""" |
|
|
""" |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
"""读取人民日报数据集 |
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader` |
|
|
|
|
|
|
|
|
|
|
|
读取人民日报数据集 |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super(PeopleDailyCorpusLoader, self).__init__() |
|
|
super(PeopleDailyCorpusLoader, self).__init__() |
|
|
self.pos = True |
|
|
self.pos = True |
|
|
self.ner = True |
|
|
self.ner = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, data_path, pos=True, ner=True): |
|
|
def load(self, data_path, pos=True, ner=True): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param str data_path: 数据路径 |
|
|
:param str data_path: 数据路径 |
|
|
:param bool pos: 是否使用词性标签 |
|
|
:param bool pos: 是否使用词性标签 |
|
|
:param bool ner: 是否使用命名实体标签 |
|
|
:param bool ner: 是否使用命名实体标签 |
|
|
:return: a DataSet object |
|
|
|
|
|
|
|
|
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
""" |
|
|
""" |
|
|
self.pos, self.ner = pos, ner |
|
|
self.pos, self.ner = pos, ner |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
@@ -152,8 +156,13 @@ class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
example.append(sent_ner) |
|
|
example.append(sent_ner) |
|
|
examples.append(example) |
|
|
examples.append(example) |
|
|
return self.convert(examples) |
|
|
return self.convert(examples) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert(self, data): |
|
|
def convert(self, data): |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
:param data: python 内置对象 |
|
|
|
|
|
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
|
|
|
""" |
|
|
data_set = DataSet() |
|
|
data_set = DataSet() |
|
|
for item in data: |
|
|
for item in data: |
|
|
sent_words = item[0] |
|
|
sent_words = item[0] |
|
@@ -172,6 +181,8 @@ class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
|
|
|
|
|
|
class ConllLoader(DataSetLoader): |
|
|
class ConllLoader(DataSetLoader): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.dataset_loader.ConllLoader` |
|
|
|
|
|
|
|
|
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html |
|
|
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html |
|
|
|
|
|
|
|
|
列号从0开始, 每列对应内容为:: |
|
|
列号从0开始, 每列对应内容为:: |
|
@@ -195,6 +206,7 @@ class ConllLoader(DataSetLoader): |
|
|
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` |
|
|
:param indexs: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` |
|
|
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` |
|
|
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, headers, indexs=None, dropna=False): |
|
|
def __init__(self, headers, indexs=None, dropna=False): |
|
|
super(ConllLoader, self).__init__() |
|
|
super(ConllLoader, self).__init__() |
|
|
if not isinstance(headers, (list, tuple)): |
|
|
if not isinstance(headers, (list, tuple)): |
|
@@ -207,21 +219,25 @@ class ConllLoader(DataSetLoader): |
|
|
if len(indexs) != len(headers): |
|
|
if len(indexs) != len(headers): |
|
|
raise ValueError |
|
|
raise ValueError |
|
|
self.indexs = indexs |
|
|
self.indexs = indexs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path): |
|
|
def load(self, path): |
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
|
for idx, data in _read_conll(path, indexes=self.indexs, dropna=self.dropna): |
|
|
for idx, data in _read_conll(path, indexes=self.indexs, dropna=self.dropna): |
|
|
ins = {h:data[i] for i, h in enumerate(self.headers)} |
|
|
|
|
|
|
|
|
ins = {h: data[i] for i, h in enumerate(self.headers)} |
|
|
ds.append(Instance(**ins)) |
|
|
ds.append(Instance(**ins)) |
|
|
return ds |
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Conll2003Loader(ConllLoader): |
|
|
class Conll2003Loader(ConllLoader): |
|
|
"""读取Conll2003数据 |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader` |
|
|
|
|
|
|
|
|
|
|
|
读取Conll2003数据 |
|
|
|
|
|
|
|
|
关于数据集的更多信息,参考: |
|
|
关于数据集的更多信息,参考: |
|
|
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data |
|
|
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
headers = [ |
|
|
headers = [ |
|
|
'tokens', 'pos', 'chunks', 'ner', |
|
|
'tokens', 'pos', 'chunks', 'ner', |
|
@@ -260,7 +276,10 @@ def _cut_long_sentence(sent, max_sample_length=200): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SSTLoader(DataSetLoader): |
|
|
class SSTLoader(DataSetLoader): |
|
|
"""读取SST数据集, DataSet包含fields:: |
|
|
|
|
|
|
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.dataset_loader.SSTLoader` |
|
|
|
|
|
|
|
|
|
|
|
读取SST数据集, DataSet包含fields:: |
|
|
|
|
|
|
|
|
words: list(str) 需要分类的文本 |
|
|
words: list(str) 需要分类的文本 |
|
|
target: str 文本的标签 |
|
|
target: str 文本的标签 |
|
@@ -270,21 +289,22 @@ class SSTLoader(DataSetLoader): |
|
|
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` |
|
|
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` |
|
|
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` |
|
|
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, subtree=False, fine_grained=False): |
|
|
def __init__(self, subtree=False, fine_grained=False): |
|
|
self.subtree = subtree |
|
|
self.subtree = subtree |
|
|
|
|
|
|
|
|
tag_v = {'0':'very negative', '1':'negative', '2':'neutral', |
|
|
|
|
|
'3':'positive', '4':'very positive'} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', |
|
|
|
|
|
'3': 'positive', '4': 'very positive'} |
|
|
if not fine_grained: |
|
|
if not fine_grained: |
|
|
tag_v['0'] = tag_v['1'] |
|
|
tag_v['0'] = tag_v['1'] |
|
|
tag_v['4'] = tag_v['3'] |
|
|
tag_v['4'] = tag_v['3'] |
|
|
self.tag_v = tag_v |
|
|
self.tag_v = tag_v |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path): |
|
|
def load(self, path): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
:param path: str,存储数据的路径 |
|
|
|
|
|
:return: DataSet。 |
|
|
|
|
|
|
|
|
:param str path: 存储数据的路径 |
|
|
|
|
|
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
""" |
|
|
""" |
|
|
datalist = [] |
|
|
datalist = [] |
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
|
with open(path, 'r', encoding='utf-8') as f: |
|
@@ -296,7 +316,7 @@ class SSTLoader(DataSetLoader): |
|
|
for words, tag in datas: |
|
|
for words, tag in datas: |
|
|
ds.append(Instance(words=words, target=tag)) |
|
|
ds.append(Instance(words=words, target=tag)) |
|
|
return ds |
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def _get_one(data, subtree): |
|
|
def _get_one(data, subtree): |
|
|
tree = Tree.fromstring(data) |
|
|
tree = Tree.fromstring(data) |
|
@@ -307,15 +327,18 @@ class SSTLoader(DataSetLoader): |
|
|
|
|
|
|
|
|
class JsonLoader(DataSetLoader): |
|
|
class JsonLoader(DataSetLoader): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.dataset_loader.JsonLoader` |
|
|
|
|
|
|
|
|
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 |
|
|
读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 |
|
|
|
|
|
|
|
|
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name |
|
|
:param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name |
|
|
``fields`` 的`key`必须是json对象的属性名. ``fields`` 的`value`为读入后在DataSet存储的`field_name`, |
|
|
|
|
|
`value`也可为 ``None`` , 这时读入后的`field_name`与json对象对应属性同名 |
|
|
|
|
|
|
|
|
``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , |
|
|
|
|
|
`value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 |
|
|
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` |
|
|
``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
Default: ``False`` |
|
|
Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, fields=None, dropna=False): |
|
|
def __init__(self, fields=None, dropna=False): |
|
|
super(JsonLoader, self).__init__() |
|
|
super(JsonLoader, self).__init__() |
|
|
self.dropna = dropna |
|
|
self.dropna = dropna |
|
@@ -326,12 +349,12 @@ class JsonLoader(DataSetLoader): |
|
|
for k, v in fields.items(): |
|
|
for k, v in fields.items(): |
|
|
self.fields[k] = k if v is None else v |
|
|
self.fields[k] = k if v is None else v |
|
|
self.fields_list = list(self.fields.keys()) |
|
|
self.fields_list = list(self.fields.keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path): |
|
|
def load(self, path): |
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
|
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): |
|
|
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): |
|
|
if self.fields: |
|
|
if self.fields: |
|
|
ins = {self.fields[k]:v for k,v in d.items()} |
|
|
|
|
|
|
|
|
ins = {self.fields[k]: v for k, v in d.items()} |
|
|
else: |
|
|
else: |
|
|
ins = d |
|
|
ins = d |
|
|
ds.append(Instance(**ins)) |
|
|
ds.append(Instance(**ins)) |
|
@@ -340,6 +363,8 @@ class JsonLoader(DataSetLoader): |
|
|
|
|
|
|
|
|
class SNLILoader(JsonLoader): |
|
|
class SNLILoader(JsonLoader): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader` |
|
|
|
|
|
|
|
|
读取SNLI数据集,读取的DataSet包含fields:: |
|
|
读取SNLI数据集,读取的DataSet包含fields:: |
|
|
|
|
|
|
|
|
words1: list(str),第一句文本, premise |
|
|
words1: list(str),第一句文本, premise |
|
@@ -348,6 +373,7 @@ class SNLILoader(JsonLoader): |
|
|
|
|
|
|
|
|
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip |
|
|
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
fields = { |
|
|
fields = { |
|
|
'sentence1_parse': 'words1', |
|
|
'sentence1_parse': 'words1', |
|
@@ -355,12 +381,14 @@ class SNLILoader(JsonLoader): |
|
|
'gold_label': 'target', |
|
|
'gold_label': 'target', |
|
|
} |
|
|
} |
|
|
super(SNLILoader, self).__init__(fields=fields) |
|
|
super(SNLILoader, self).__init__(fields=fields) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path): |
|
|
def load(self, path): |
|
|
ds = super(SNLILoader, self).load(path) |
|
|
ds = super(SNLILoader, self).load(path) |
|
|
|
|
|
|
|
|
def parse_tree(x): |
|
|
def parse_tree(x): |
|
|
t = Tree.fromstring(x) |
|
|
t = Tree.fromstring(x) |
|
|
return t.leaves() |
|
|
return t.leaves() |
|
|
|
|
|
|
|
|
ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') |
|
|
ds.apply(lambda ins: parse_tree(ins['words1']), new_field_name='words1') |
|
|
ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') |
|
|
ds.apply(lambda ins: parse_tree(ins['words2']), new_field_name='words2') |
|
|
ds.drop(lambda x: x['target'] == '-') |
|
|
ds.drop(lambda x: x['target'] == '-') |
|
@@ -369,6 +397,8 @@ class SNLILoader(JsonLoader): |
|
|
|
|
|
|
|
|
class CSVLoader(DataSetLoader): |
|
|
class CSVLoader(DataSetLoader): |
|
|
""" |
|
|
""" |
|
|
|
|
|
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader` |
|
|
|
|
|
|
|
|
读取CSV格式的数据集。返回 ``DataSet`` |
|
|
读取CSV格式的数据集。返回 ``DataSet`` |
|
|
|
|
|
|
|
|
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 |
|
|
:param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 |
|
@@ -377,11 +407,12 @@ class CSVLoader(DataSetLoader): |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
Default: ``False`` |
|
|
Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, headers=None, sep=",", dropna=False): |
|
|
def __init__(self, headers=None, sep=",", dropna=False): |
|
|
self.headers = headers |
|
|
self.headers = headers |
|
|
self.sep = sep |
|
|
self.sep = sep |
|
|
self.dropna = dropna |
|
|
self.dropna = dropna |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, path): |
|
|
def load(self, path): |
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
|
for idx, data in _read_csv(path, headers=self.headers, |
|
|
for idx, data in _read_csv(path, headers=self.headers, |
|
@@ -396,7 +427,7 @@ def _add_seg_tag(data): |
|
|
:param data: list of ([word], [pos], [heads], [head_tags]) |
|
|
:param data: list of ([word], [pos], [heads], [head_tags]) |
|
|
:return: list of ([word], [pos]) |
|
|
:return: list of ([word], [pos]) |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_processed = [] |
|
|
_processed = [] |
|
|
for word_list, pos_list, _, _ in data: |
|
|
for word_list, pos_list, _, _ in data: |
|
|
new_sample = [] |
|
|
new_sample = [] |
|
@@ -410,4 +441,3 @@ def _add_seg_tag(data): |
|
|
new_sample.append((word[-1], 'E-' + pos)) |
|
|
new_sample.append((word[-1], 'E-' + pos)) |
|
|
_processed.append(list(map(list, zip(*new_sample)))) |
|
|
_processed.append(list(map(list, zip(*new_sample)))) |
|
|
return _processed |
|
|
return _processed |
|
|
|
|
|
|