|
@@ -29,8 +29,12 @@ from nltk.tree import Tree |
|
|
from ..core.dataset import DataSet |
|
|
from ..core.dataset import DataSet |
|
|
from ..core.instance import Instance |
|
|
from ..core.instance import Instance |
|
|
from .file_reader import _read_csv, _read_json, _read_conll |
|
|
from .file_reader import _read_csv, _read_json, _read_conll |
|
|
from typing import Union, Dict |
|
|
|
|
|
|
|
|
from typing import Union, Dict, Iterable |
|
|
import os |
|
|
import os |
|
|
|
|
|
from ..core.utils import Example |
|
|
|
|
|
from ..core import Vocabulary |
|
|
|
|
|
from ..io import EmbedLoader |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download_from_url(url, path): |
|
|
def _download_from_url(url, path): |
|
@@ -39,7 +43,7 @@ def _download_from_url(url, path): |
|
|
except: |
|
|
except: |
|
|
from ..core.utils import _pseudo_tqdm as tqdm |
|
|
from ..core.utils import _pseudo_tqdm as tqdm |
|
|
import requests |
|
|
import requests |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Download file""" |
|
|
"""Download file""" |
|
|
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) |
|
|
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) |
|
|
chunk_size = 16 * 1024 |
|
|
chunk_size = 16 * 1024 |
|
@@ -58,11 +62,11 @@ def _uncompress(src, dst): |
|
|
import gzip |
|
|
import gzip |
|
|
import tarfile |
|
|
import tarfile |
|
|
import os |
|
|
import os |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unzip(src, dst): |
|
|
def unzip(src, dst): |
|
|
with zipfile.ZipFile(src, 'r') as f: |
|
|
with zipfile.ZipFile(src, 'r') as f: |
|
|
f.extractall(dst) |
|
|
f.extractall(dst) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def ungz(src, dst): |
|
|
def ungz(src, dst): |
|
|
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: |
|
|
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: |
|
|
length = 16 * 1024 # 16KB |
|
|
length = 16 * 1024 # 16KB |
|
@@ -70,11 +74,11 @@ def _uncompress(src, dst): |
|
|
while buf: |
|
|
while buf: |
|
|
uf.write(buf) |
|
|
uf.write(buf) |
|
|
buf = f.read(length) |
|
|
buf = f.read(length) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def untar(src, dst): |
|
|
def untar(src, dst): |
|
|
with tarfile.open(src, 'r:gz') as f: |
|
|
with tarfile.open(src, 'r:gz') as f: |
|
|
f.extractall(dst) |
|
|
f.extractall(dst) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
fn, ext = os.path.splitext(src) |
|
|
fn, ext = os.path.splitext(src) |
|
|
_, ext_2 = os.path.splitext(fn) |
|
|
_, ext_2 = os.path.splitext(fn) |
|
|
if ext == '.zip': |
|
|
if ext == '.zip': |
|
@@ -87,6 +91,34 @@ def _uncompress(src, dst): |
|
|
raise ValueError('unsupported file {}'.format(src)) |
|
|
raise ValueError('unsupported file {}'.format(src)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class VocabularyOption(Example): |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
max_size=None, |
|
|
|
|
|
min_freq=None, |
|
|
|
|
|
padding='<pad>', |
|
|
|
|
|
unknown='<unk>'): |
|
|
|
|
|
super().__init__( |
|
|
|
|
|
max_size=max_size, |
|
|
|
|
|
min_freq=min_freq, |
|
|
|
|
|
padding=padding, |
|
|
|
|
|
unknown=unknown |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class EmbeddingOption(Example): |
|
|
|
|
|
def __init__(self, |
|
|
|
|
|
embed_filepath=None, |
|
|
|
|
|
dtype=np.float32, |
|
|
|
|
|
normalize=True, |
|
|
|
|
|
error='ignore'): |
|
|
|
|
|
super().__init__( |
|
|
|
|
|
embed_filepath=embed_filepath, |
|
|
|
|
|
dtype=dtype, |
|
|
|
|
|
normalize=normalize, |
|
|
|
|
|
error=error |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataInfo: |
|
|
class DataInfo: |
|
|
""" |
|
|
""" |
|
|
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 |
|
|
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。 |
|
@@ -95,7 +127,7 @@ class DataInfo: |
|
|
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` |
|
|
:param embeddings: 从名称(字符串)到一系列 embedding 的dict,参考 :class:`~fastNLP.io.EmbedLoader` |
|
|
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict |
|
|
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): |
|
|
def __init__(self, vocabs: dict = None, embeddings: dict = None, datasets: dict = None): |
|
|
self.vocabs = vocabs or {} |
|
|
self.vocabs = vocabs or {} |
|
|
self.embeddings = embeddings or {} |
|
|
self.embeddings = embeddings or {} |
|
@@ -106,21 +138,21 @@ class DataSetLoader: |
|
|
""" |
|
|
""" |
|
|
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` |
|
|
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` |
|
|
|
|
|
|
|
|
定义了各种 DataSetLoader (针对特定数据上的特定任务) 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 |
|
|
|
|
|
|
|
|
开发者至少应该编写如下内容: |
|
|
开发者至少应该编写如下内容: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` |
|
|
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` |
|
|
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` |
|
|
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` |
|
|
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` |
|
|
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
**process 函数中可以 调用load 函数或 _load 函数** |
|
|
**process 函数中可以 调用load 函数或 _load 函数** |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _download(self, url: str, path: str, uncompress=True) -> str: |
|
|
def _download(self, url: str, path: str, uncompress=True) -> str: |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 |
|
|
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 |
|
|
|
|
|
|
|
|
:param url: 下载的网站 |
|
|
:param url: 下载的网站 |
|
@@ -136,7 +168,7 @@ class DataSetLoader: |
|
|
_uncompress(path, dst) |
|
|
_uncompress(path, dst) |
|
|
return dst |
|
|
return dst |
|
|
return path |
|
|
return path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: |
|
|
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: |
|
|
""" |
|
|
""" |
|
|
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 |
|
|
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 |
|
@@ -148,7 +180,7 @@ class DataSetLoader: |
|
|
if isinstance(paths, str): |
|
|
if isinstance(paths, str): |
|
|
return self._load(paths) |
|
|
return self._load(paths) |
|
|
return {name: self._load(path) for name, path in paths.items()} |
|
|
return {name: self._load(path) for name, path in paths.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path: str) -> DataSet: |
|
|
def _load(self, path: str) -> DataSet: |
|
|
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
|
|
|
|
|
@@ -156,16 +188,16 @@ class DataSetLoader: |
|
|
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 |
|
|
""" |
|
|
""" |
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: |
|
|
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataInfo: |
|
|
""" |
|
|
""" |
|
|
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 |
|
|
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 |
|
|
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 |
|
|
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 |
|
|
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 |
|
|
|
|
|
|
|
|
返回的 :class:`DataInfo` 对象有如下属性: |
|
|
返回的 :class:`DataInfo` 对象有如下属性: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 |
|
|
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 |
|
|
- embeddings: (可选) 数据集对应的词嵌入 |
|
|
- embeddings: (可选) 数据集对应的词嵌入 |
|
|
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` |
|
|
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` |
|
@@ -183,12 +215,12 @@ class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
|
|
|
|
|
|
读取人民日报数据集 |
|
|
读取人民日报数据集 |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, pos=True, ner=True): |
|
|
def __init__(self, pos=True, ner=True): |
|
|
super(PeopleDailyCorpusLoader, self).__init__() |
|
|
super(PeopleDailyCorpusLoader, self).__init__() |
|
|
self.pos = pos |
|
|
self.pos = pos |
|
|
self.ner = ner |
|
|
self.ner = ner |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, data_path): |
|
|
def _load(self, data_path): |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
sents = f.readlines() |
|
|
sents = f.readlines() |
|
@@ -233,7 +265,7 @@ class PeopleDailyCorpusLoader(DataSetLoader): |
|
|
example.append(sent_ner) |
|
|
example.append(sent_ner) |
|
|
examples.append(example) |
|
|
examples.append(example) |
|
|
return self.convert(examples) |
|
|
return self.convert(examples) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def convert(self, data): |
|
|
def convert(self, data): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
@@ -284,7 +316,7 @@ class ConllLoader(DataSetLoader): |
|
|
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` |
|
|
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` |
|
|
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` |
|
|
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, headers, indexes=None, dropna=False): |
|
|
def __init__(self, headers, indexes=None, dropna=False): |
|
|
super(ConllLoader, self).__init__() |
|
|
super(ConllLoader, self).__init__() |
|
|
if not isinstance(headers, (list, tuple)): |
|
|
if not isinstance(headers, (list, tuple)): |
|
@@ -298,7 +330,7 @@ class ConllLoader(DataSetLoader): |
|
|
if len(indexes) != len(headers): |
|
|
if len(indexes) != len(headers): |
|
|
raise ValueError |
|
|
raise ValueError |
|
|
self.indexes = indexes |
|
|
self.indexes = indexes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
def _load(self, path): |
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
|
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): |
|
|
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): |
|
@@ -316,7 +348,7 @@ class Conll2003Loader(ConllLoader): |
|
|
关于数据集的更多信息,参考: |
|
|
关于数据集的更多信息,参考: |
|
|
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data |
|
|
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
headers = [ |
|
|
headers = [ |
|
|
'tokens', 'pos', 'chunks', 'ner', |
|
|
'tokens', 'pos', 'chunks', 'ner', |
|
@@ -368,17 +400,17 @@ class SSTLoader(DataSetLoader): |
|
|
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` |
|
|
:param subtree: 是否将数据展开为子树,扩充数据量. Default: ``False`` |
|
|
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` |
|
|
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, subtree=False, fine_grained=False): |
|
|
def __init__(self, subtree=False, fine_grained=False): |
|
|
self.subtree = subtree |
|
|
self.subtree = subtree |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', |
|
|
tag_v = {'0': 'very negative', '1': 'negative', '2': 'neutral', |
|
|
'3': 'positive', '4': 'very positive'} |
|
|
'3': 'positive', '4': 'very positive'} |
|
|
if not fine_grained: |
|
|
if not fine_grained: |
|
|
tag_v['0'] = tag_v['1'] |
|
|
tag_v['0'] = tag_v['1'] |
|
|
tag_v['4'] = tag_v['3'] |
|
|
tag_v['4'] = tag_v['3'] |
|
|
self.tag_v = tag_v |
|
|
self.tag_v = tag_v |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
def _load(self, path): |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
@@ -395,7 +427,7 @@ class SSTLoader(DataSetLoader): |
|
|
for words, tag in datas: |
|
|
for words, tag in datas: |
|
|
ds.append(Instance(words=words, target=tag)) |
|
|
ds.append(Instance(words=words, target=tag)) |
|
|
return ds |
|
|
return ds |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
|
@staticmethod |
|
|
def _get_one(data, subtree): |
|
|
def _get_one(data, subtree): |
|
|
tree = Tree.fromstring(data) |
|
|
tree = Tree.fromstring(data) |
|
@@ -403,6 +435,40 @@ class SSTLoader(DataSetLoader): |
|
|
return [(t.leaves(), t.label()) for t in tree.subtrees()] |
|
|
return [(t.leaves(), t.label()) for t in tree.subtrees()] |
|
|
return [(tree.leaves(), tree.label())] |
|
|
return [(tree.leaves(), tree.label())] |
|
|
|
|
|
|
|
|
|
|
|
def process(self, |
|
|
|
|
|
paths, |
|
|
|
|
|
train_ds: Iterable[str] = None, |
|
|
|
|
|
src_vocab_op: VocabularyOption = None, |
|
|
|
|
|
tgt_vocab_op: VocabularyOption = None, |
|
|
|
|
|
embed_op: EmbeddingOption = None): |
|
|
|
|
|
input_name, target_name = 'words', 'target' |
|
|
|
|
|
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op) |
|
|
|
|
|
tgt_vocab = Vocabulary() if tgt_vocab_op is None else Vocabulary(**tgt_vocab_op) |
|
|
|
|
|
|
|
|
|
|
|
info = DataInfo(datasets=self.load(paths)) |
|
|
|
|
|
_train_ds = [info.datasets[name] |
|
|
|
|
|
for name in train_ds] if train_ds else info.datasets.values() |
|
|
|
|
|
|
|
|
|
|
|
src_vocab.from_dataset(_train_ds, field_name=input_name) |
|
|
|
|
|
tgt_vocab.from_dataset(_train_ds, field_name=target_name) |
|
|
|
|
|
src_vocab.index_dataset( |
|
|
|
|
|
info.datasets.values(), |
|
|
|
|
|
field_name=input_name, new_field_name=input_name) |
|
|
|
|
|
tgt_vocab.index_dataset( |
|
|
|
|
|
info.datasets.values(), |
|
|
|
|
|
field_name=target_name, new_field_name=target_name) |
|
|
|
|
|
info.vocabs = { |
|
|
|
|
|
input_name: src_vocab, |
|
|
|
|
|
target_name: tgt_vocab |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if embed_op is not None: |
|
|
|
|
|
embed_op.vocab = src_vocab |
|
|
|
|
|
init_emb = EmbedLoader.load_with_vocab(**embed_op) |
|
|
|
|
|
info.embeddings[input_name] = init_emb |
|
|
|
|
|
|
|
|
|
|
|
return info |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class JsonLoader(DataSetLoader): |
|
|
class JsonLoader(DataSetLoader): |
|
|
""" |
|
|
""" |
|
@@ -417,7 +483,7 @@ class JsonLoader(DataSetLoader): |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
Default: ``False`` |
|
|
Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, fields=None, dropna=False): |
|
|
def __init__(self, fields=None, dropna=False): |
|
|
super(JsonLoader, self).__init__() |
|
|
super(JsonLoader, self).__init__() |
|
|
self.dropna = dropna |
|
|
self.dropna = dropna |
|
@@ -428,7 +494,7 @@ class JsonLoader(DataSetLoader): |
|
|
for k, v in fields.items(): |
|
|
for k, v in fields.items(): |
|
|
self.fields[k] = k if v is None else v |
|
|
self.fields[k] = k if v is None else v |
|
|
self.fields_list = list(self.fields.keys()) |
|
|
self.fields_list = list(self.fields.keys()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
def _load(self, path): |
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
|
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): |
|
|
for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna): |
|
@@ -452,7 +518,7 @@ class SNLILoader(JsonLoader): |
|
|
|
|
|
|
|
|
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip |
|
|
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
fields = { |
|
|
fields = { |
|
|
'sentence1_parse': 'words1', |
|
|
'sentence1_parse': 'words1', |
|
@@ -460,14 +526,14 @@ class SNLILoader(JsonLoader): |
|
|
'gold_label': 'target', |
|
|
'gold_label': 'target', |
|
|
} |
|
|
} |
|
|
super(SNLILoader, self).__init__(fields=fields) |
|
|
super(SNLILoader, self).__init__(fields=fields) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
def _load(self, path): |
|
|
ds = super(SNLILoader, self)._load(path) |
|
|
ds = super(SNLILoader, self)._load(path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_tree(x): |
|
|
def parse_tree(x): |
|
|
t = Tree.fromstring(x) |
|
|
t = Tree.fromstring(x) |
|
|
return t.leaves() |
|
|
return t.leaves() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ds.apply(lambda ins: parse_tree( |
|
|
ds.apply(lambda ins: parse_tree( |
|
|
ins['words1']), new_field_name='words1') |
|
|
ins['words1']), new_field_name='words1') |
|
|
ds.apply(lambda ins: parse_tree( |
|
|
ds.apply(lambda ins: parse_tree( |
|
@@ -488,12 +554,12 @@ class CSVLoader(DataSetLoader): |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
:param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . |
|
|
Default: ``False`` |
|
|
Default: ``False`` |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, headers=None, sep=",", dropna=False): |
|
|
def __init__(self, headers=None, sep=",", dropna=False): |
|
|
self.headers = headers |
|
|
self.headers = headers |
|
|
self.sep = sep |
|
|
self.sep = sep |
|
|
self.dropna = dropna |
|
|
self.dropna = dropna |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _load(self, path): |
|
|
def _load(self, path): |
|
|
ds = DataSet() |
|
|
ds = DataSet() |
|
|
for idx, data in _read_csv(path, headers=self.headers, |
|
|
for idx, data in _read_csv(path, headers=self.headers, |
|
@@ -508,7 +574,7 @@ def _add_seg_tag(data): |
|
|
:param data: list of ([word], [pos], [heads], [head_tags]) |
|
|
:param data: list of ([word], [pos], [heads], [head_tags]) |
|
|
:return: list of ([word], [pos]) |
|
|
:return: list of ([word], [pos]) |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_processed = [] |
|
|
_processed = [] |
|
|
for word_list, pos_list, _, _ in data: |
|
|
for word_list, pos_list, _, _ in data: |
|
|
new_sample = [] |
|
|
new_sample = [] |
|
|