Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
陈怡然 5 years ago
parent
commit
a38a32755c
36 changed files with 280 additions and 593 deletions
  1. +1
    -1
      docs/Makefile
  2. +4
    -1
      docs/format.py
  3. +0
    -7
      docs/source/fastNLP.io.base_loader.rst
  4. +7
    -0
      docs/source/fastNLP.io.data_bundle.rst
  5. +1
    -1
      docs/source/fastNLP.io.rst
  6. +3
    -4
      fastNLP/io/__init__.py
  7. +0
    -313
      fastNLP/io/config_io.py
  8. +0
    -2
      fastNLP/io/data_bundle.py
  9. +2
    -2
      fastNLP/io/data_loader/conll.py
  10. +1
    -1
      fastNLP/io/data_loader/imdb.py
  11. +1
    -1
      fastNLP/io/data_loader/matching.py
  12. +1
    -1
      fastNLP/io/data_loader/mtl.py
  13. +1
    -1
      fastNLP/io/data_loader/people_daily.py
  14. +1
    -1
      fastNLP/io/data_loader/sst.py
  15. +1
    -1
      fastNLP/io/data_loader/yelp.py
  16. +1
    -1
      fastNLP/io/dataset_loader.py
  17. +1
    -1
      fastNLP/io/embed_loader.py
  18. +2
    -1
      fastNLP/io/loader/__init__.py
  19. +31
    -29
      fastNLP/io/loader/classification.py
  20. +17
    -21
      fastNLP/io/loader/loader.py
  21. +2
    -1
      fastNLP/io/loader/matching.py
  22. +1
    -1
      fastNLP/io/model_io.py
  23. +1
    -1
      fastNLP/io/pipe/classification.py
  24. +188
    -188
      reproduction/Summarization/Baseline/data/dataloader.py
  25. +1
    -1
      reproduction/Summarization/BertSum/dataloader.py
  26. +1
    -1
      reproduction/coreference_resolution/data_load/cr_loader.py
  27. +1
    -1
      reproduction/joint_cws_parse/data/data_loader.py
  28. +1
    -1
      reproduction/matching/data/MatchingDataLoader.py
  29. +1
    -1
      reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py
  30. +1
    -1
      reproduction/seqence_labelling/cws/data/CWSDataLoader.py
  31. +1
    -1
      reproduction/seqence_labelling/ner/data/Conll2003Loader.py
  32. +1
    -1
      reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
  33. +1
    -1
      reproduction/text_classification/data/IMDBLoader.py
  34. +1
    -1
      reproduction/text_classification/data/MTL16Loader.py
  35. +1
    -1
      reproduction/text_classification/data/sstloader.py
  36. +1
    -1
      reproduction/text_classification/data/yelpLoader.py

+ 1
- 1
docs/Makefile View File

@@ -20,7 +20,7 @@ server:
cd build/html && python -m http.server

dev:
rm -rf build/html && make html && make server
rm -rf build && make html && make server

.PHONY: help Makefile



+ 4
- 1
docs/format.py View File

@@ -59,7 +59,10 @@ def clear(path='./source/'):
else:
shorten(path + file, to_delete)
for file in to_delete:
os.remove(path + file + ".rst")
try:
os.remove(path + file + ".rst")
except:
pass


clear()

+ 0
- 7
docs/source/fastNLP.io.base_loader.rst View File

@@ -1,7 +0,0 @@
fastNLP.io.base\_loader
=======================

.. automodule:: fastNLP.io.base_loader
:members:
:undoc-members:
:show-inheritance:

+ 7
- 0
docs/source/fastNLP.io.data_bundle.rst View File

@@ -0,0 +1,7 @@
fastNLP.io.data\_bundle
=======================

.. automodule:: fastNLP.io.data_bundle
:members:
:undoc-members:
:show-inheritance:

+ 1
- 1
docs/source/fastNLP.io.rst View File

@@ -20,7 +20,7 @@ Submodules

.. toctree::

fastNLP.io.base_loader
fastNLP.io.data_bundle
fastNLP.io.dataset_loader
fastNLP.io.embed_loader
fastNLP.io.file_utils


+ 3
- 4
fastNLP/io/__init__.py View File

@@ -12,10 +12,9 @@
这些类的使用方法如下:
"""
__all__ = [
'EmbedLoader',

'DataBundle',
'DataSetLoader',
'EmbedLoader',

'YelpLoader',
'YelpFullLoader',
@@ -69,7 +68,7 @@ __all__ = [
]

from .embed_loader import EmbedLoader
from .base_loader import DataBundle, DataSetLoader
from .data_bundle import DataBundle
from .dataset_loader import CSVLoader, JsonLoader
from .model_io import ModelLoader, ModelSaver



+ 0
- 313
fastNLP/io/config_io.py View File

@@ -1,313 +0,0 @@
"""
用于读入和处理和保存 config 文件

.. todo::
这个模块中的类可能被抛弃?

"""
__all__ = [
"ConfigLoader",
"ConfigSection",
"ConfigSaver"
]

import configparser
import json
import os

from .base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""
别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader`

读取配置文件的Loader

:param str data_path: 配置文件的路径

"""
def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))
@staticmethod
def parse(string):
raise NotImplementedError
@staticmethod
def load_config(file_path, sections):
"""
把配置文件的section 存入提供的 ``sections`` 中

:param str file_path: 配置文件的路径
:param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection`
Example::

test_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})

"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
if not os.path.exists(file_path):
raise FileNotFoundError("config file {} not found. ".format(file_path))
cfg.read(file_path)
for s in sections:
attr_list = [i for i in sections[s].__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]
if s not in cfg:
print('section %s not found in config file' % (s))
continue
gen_sec = cfg[s]
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
# print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
"""
if attr in attr_list then check its type and
update its value.
else add a new attr in sections[s]
"""
setattr(sections[s], attr, val)
except Exception as e:
print("cannot load attribute %s in section %s"
% (attr, s))
pass


class ConfigSection(object):
"""
别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection`

ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用

"""
def __init__(self):
super(ConfigSection, self).__init__()
def __getitem__(self, key):
"""
:param key: str, the name of the attribute
:return attr: the value of this attribute
if key not in self.__dict__.keys():
return self[key]
else:
raise AttributeError
"""
if key in self.__dict__.keys():
return getattr(self, key)
raise AttributeError("do NOT have attribute %s" % key)
def __setitem__(self, key, value):
"""
:param key: str, the name of the attribute
:param value: the value of this attribute
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not isinstance(value, type(getattr(self, key))):
raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)
def __contains__(self, item):
"""
:param item: The key of item.
:return: True if the key in self.__dict__.keys() else False.
"""
return item in self.__dict__.keys()
def __eq__(self, other):
"""Overwrite the == operator

:param other: Another ConfigSection() object which to be compared.
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
"""
for k in self.__dict__.keys():
if k not in other.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False
for k in other.__dict__.keys():
if k not in self.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False
return True
def __ne__(self, other):
"""Overwrite the != operator

:param other:
:return:
"""
return not self.__eq__(other)
@property
def data(self):
return self.__dict__


class ConfigSaver(object):
"""
别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver`

ConfigSaver 是用来存储配置文件并解决相关冲突的类

:param str file_path: 配置文件的路径

"""
def __init__(self, file_path):
self.file_path = file_path
if not os.path.exists(self.file_path):
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path))
def _get_section(self, sect_name):
"""
This is the function to get the section with the section name.

:param sect_name: The name of section what wants to load.
:return: The section.
"""
sect = ConfigSection()
ConfigLoader().load_config(self.file_path, {sect_name: sect})
return sect
def _read_section(self):
"""
This is the function to read sections from the config file.

:return: sect_list, sect_key_list
sect_list: A list of ConfigSection().
sect_key_list: A list of names in sect_list.
"""
sect_name = None
sect_list = {}
sect_key_list = []
single_section = {}
single_section_key = []
with open(self.file_path, 'r') as f:
lines = f.readlines()
for line in lines:
if line.startswith('[') and line.endswith(']\n'):
if sect_name is None:
pass
else:
sect_list[sect_name] = single_section, single_section_key
single_section = {}
single_section_key = []
sect_key_list.append(sect_name)
sect_name = line[1: -2]
continue
if line.startswith('#'):
single_section[line] = '#'
single_section_key.append(line)
continue
if line.startswith('\n'):
single_section_key.append('\n')
continue
if '=' not in line:
raise RuntimeError("can NOT load config file {}".__format__(self.file_path))
key = line.split('=', maxsplit=1)[0].strip()
value = line.split('=', maxsplit=1)[1].strip() + '\n'
single_section[key] = value
single_section_key.append(key)
if sect_name is not None:
sect_list[sect_name] = single_section, single_section_key
sect_key_list.append(sect_name)
return sect_list, sect_key_list
def _write_section(self, sect_list, sect_key_list):
"""
This is the function to write config file with section list and name list.

:param sect_list: A list of ConfigSection() need to be writen into file.
:param sect_key_list: A list of name of sect_list.
:return:
"""
with open(self.file_path, 'w') as f:
for sect_key in sect_key_list:
single_section, single_section_key = sect_list[sect_key]
f.write('[' + sect_key + ']\n')
for key in single_section_key:
if key == '\n':
f.write('\n')
continue
if single_section[key] == '#':
f.write(key)
continue
f.write(key + ' = ' + single_section[key])
f.write('\n')
def save_config_file(self, section_name, section):
"""
这个方法可以用来修改并保存配置文件中单独的一个 section

:param str section_name: 需要保存的 section 的名字.
:param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型
"""
section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0: # the section not in the file before
# append this section to config file
with open(self.file_path, 'a') as f:
f.write('[' + section_name + ']\n')
for k in section.__dict__.keys():
f.write(k + ' = ')
if isinstance(section[k], str):
f.write('\"' + str(section[k]) + '\"\n\n')
else:
f.write(str(section[k]) + '\n\n')
else:
# the section exists
change_file = False
for k in section.__dict__.keys():
if k not in section_file:
# find a new key in this section
change_file = True
break
if section_file[k] != section[k]:
change_file = True
break
if not change_file:
return
sect_list, sect_key_list = self._read_section()
if section_name not in sect_key_list:
raise AttributeError()
sect, sect_key = sect_list[section_name]
for k in section.__dict__.keys():
if k not in sect_key:
if sect_key[-1] != '\n':
sect_key.append('\n')
sect_key.append(k)
sect[k] = str(section[k])
if isinstance(section[k], str):
sect[k] = "\"" + sect[k] + "\""
sect[k] = sect[k] + "\n"
sect_list[section_name] = sect, sect_key
self._write_section(sect_list, sect_key_list)

fastNLP/io/base_loader.py → fastNLP/io/data_bundle.py View File

@@ -1,7 +1,5 @@
__all__ = [
"BaseLoader",
'DataBundle',
'DataSetLoader',
]

import _pickle as pickle

+ 2
- 2
fastNLP/io/data_loader/conll.py View File

@@ -1,11 +1,11 @@

from ...core.dataset import DataSet
from ...core.instance import Instance
from ..base_loader import DataSetLoader
from ..data_bundle import DataSetLoader
from ..file_reader import _read_conll
from typing import Union, Dict
from ..utils import check_loader_paths
from ..base_loader import DataBundle
from ..data_bundle import DataBundle

class ConllLoader(DataSetLoader):
"""


+ 1
- 1
fastNLP/io/data_loader/imdb.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict

from ..embed_loader import EmbeddingOption, EmbedLoader
from ..base_loader import DataSetLoader, DataBundle
from ..data_bundle import DataSetLoader, DataBundle
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.instance import Instance


+ 1
- 1
fastNLP/io/data_loader/matching.py View File

@@ -4,7 +4,7 @@ from typing import Union, Dict, List

from ...core.const import Const
from ...core.vocabulary import Vocabulary
from ..base_loader import DataBundle, DataSetLoader
from ..data_bundle import DataBundle, DataSetLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...modules.encoder.bert import BertTokenizer



+ 1
- 1
fastNLP/io/data_loader/mtl.py View File

@@ -1,7 +1,7 @@

from typing import Union, Dict

from ..base_loader import DataBundle
from ..data_bundle import DataBundle
from ..dataset_loader import CSVLoader
from ...core.vocabulary import Vocabulary, VocabularyOption
from ...core.const import Const


+ 1
- 1
fastNLP/io/data_loader/people_daily.py View File

@@ -1,5 +1,5 @@

from ..base_loader import DataSetLoader
from ..data_bundle import DataSetLoader
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.const import Const


+ 1
- 1
fastNLP/io/data_loader/sst.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict
from nltk import Tree

from ..base_loader import DataBundle, DataSetLoader
from ..data_bundle import DataBundle, DataSetLoader
from ..dataset_loader import CSVLoader
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet


+ 1
- 1
fastNLP/io/data_loader/yelp.py View File

@@ -6,7 +6,7 @@ from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.vocabulary import VocabularyOption, Vocabulary
from ..base_loader import DataBundle, DataSetLoader
from ..data_bundle import DataBundle, DataSetLoader
from typing import Union, Dict
from ..utils import check_loader_paths, get_tokenizer



+ 1
- 1
fastNLP/io/dataset_loader.py View File

@@ -26,7 +26,7 @@ __all__ = [
from ..core.dataset import DataSet
from ..core.instance import Instance
from .file_reader import _read_csv, _read_json
from .base_loader import DataSetLoader
from .data_bundle import DataSetLoader


class JsonLoader(DataSetLoader):


+ 1
- 1
fastNLP/io/embed_loader.py View File

@@ -9,7 +9,7 @@ import warnings
import numpy as np

from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader
from .data_bundle import BaseLoader
from ..core.utils import Option




+ 2
- 1
fastNLP/io/loader/__init__.py View File

@@ -44,6 +44,8 @@ fastNLP 目前提供了如下的 Loader
"""

__all__ = [
'Loader',
'YelpLoader',
'YelpFullLoader',
'YelpPolarityLoader',
@@ -57,7 +59,6 @@ __all__ = [
'OntoNotesNERLoader',
'CTBLoader',

'Loader',
'CSVLoader',
'JsonLoader',



+ 31
- 29
fastNLP/io/loader/classification.py View File

@@ -7,6 +7,7 @@ import random
import shutil
import numpy as np


class YelpLoader(Loader):
"""
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader`
@@ -14,6 +15,7 @@ class YelpLoader(Loader):
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。

Example::
"1","I got 'new' tires from the..."
"1","Don't waste your time..."

@@ -28,11 +30,11 @@ class YelpLoader(Loader):
"...", "..."

"""
def __init__(self):
super(YelpLoader, self).__init__()
def _load(self, path: str=None):
def _load(self, path: str = None):
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
for line in f:
@@ -69,12 +71,12 @@ class YelpFullLoader(YelpLoader):
:param int seed: 划分dev时的随机数种子
:return: str, 数据集的目录地址
"""
dataset_name = 'yelp-review-full'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载
re_download = True
if dev_ratio>0:
if dev_ratio > 0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
@@ -83,14 +85,14 @@ class YelpFullLoader(YelpLoader):
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -109,7 +111,7 @@ class YelpFullLoader(YelpLoader):
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv'))
return data_dir


@@ -131,7 +133,7 @@ class YelpPolarityLoader(YelpLoader):
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求
re_download = True
if dev_ratio>0:
if dev_ratio > 0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
@@ -140,14 +142,14 @@ class YelpPolarityLoader(YelpLoader):
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -166,7 +168,7 @@ class YelpPolarityLoader(YelpLoader):
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv'))
return data_dir


@@ -185,10 +187,10 @@ class IMDBLoader(Loader):
"...", "..."

"""
def __init__(self):
super(IMDBLoader, self).__init__()
def _load(self, path: str):
dataset = DataSet()
with open(path, 'r', encoding="utf-8") as f:
@@ -201,12 +203,12 @@ class IMDBLoader(Loader):
words = parts[1]
if words:
dataset.append(Instance(raw_words=words, target=target))
if len(dataset) == 0:
raise RuntimeError(f"{path} has no valid data.")
return dataset
def download(self, dev_ratio: float = 0.1, seed: int = 0):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章
@@ -221,9 +223,9 @@ class IMDBLoader(Loader):
"""
dataset_name = 'aclImdb'
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求
re_download = True
if dev_ratio>0:
if dev_ratio > 0:
dev_line_count = 0
tr_line_count = 0
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \
@@ -232,14 +234,14 @@ class IMDBLoader(Loader):
tr_line_count += 1
for line in f2:
dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True
else:
re_download = False
if re_download:
shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -258,7 +260,7 @@ class IMDBLoader(Loader):
finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
os.remove(os.path.join(data_dir, 'middle_file.txt'))
return data_dir


@@ -278,10 +280,10 @@ class SSTLoader(Loader):
raw_words列是str。

"""
def __init__(self):
super().__init__()
def _load(self, path: str):
"""
从path读取SST文件
@@ -296,7 +298,7 @@ class SSTLoader(Loader):
if line:
ds.append(Instance(raw_words=line))
return ds
def download(self):
"""
自动下载数据集,如果你使用了这个数据集,请引用以下的文章
@@ -323,10 +325,10 @@ class SST2Loader(Loader):

test的DataSet没有target列。
"""
def __init__(self):
super().__init__()
def _load(self, path: str):
"""
从path读取SST2文件
@@ -335,7 +337,7 @@ class SST2Loader(Loader):
:return: DataSet
"""
ds = DataSet()
with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header
if 'test' in os.path.split(path)[1]:
@@ -356,7 +358,7 @@ class SST2Loader(Loader):
if raw_words:
ds.append(Instance(raw_words=raw_words, target=target))
return ds
def download(self):
"""
自动下载数据集,如果你使用了该数据集,请引用以下的文章


+ 17
- 21
fastNLP/io/loader/loader.py View File

@@ -2,17 +2,21 @@ from ...core.dataset import DataSet
from .. import DataBundle
from ..utils import check_loader_paths
from typing import Union, Dict
import os
from ..file_utils import _get_dataset_url, get_cache_path, cached_path


class Loader:
"""
各种数据 Loader 的基类,提供了 API 的参考.
"""
def __init__(self):
pass

def _load(self, path:str) -> DataSet:
def _load(self, path: str) -> DataSet:
raise NotImplementedError
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle:
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
"""
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。

@@ -22,31 +26,25 @@ class Loader:
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。

(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件
名包含'train'、 'dev'、 'test'则会报错

Example::
名包含'train'、 'dev'、 'test'则会报错::

data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、
# dev、 test等有所变化,可以通过以下的方式取出DataSet
tr_data = data_bundle.datasets['train']
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段

(2) 传入文件路径

Example::
(2) 传入文件路径::

data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train'
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet

(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test

Example::
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test::

paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"}
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test"
dev_data = data_bundle.datasets['dev']

:return: 返回的:class:`~fastNLP.io.DataBundle`
:return: 返回的 :class:`~fastNLP.io.DataBundle`
"""
if paths is None:
paths = self.download()
@@ -54,10 +52,10 @@ class Loader:
datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets)
return data_bundle
def download(self):
raise NotImplementedError(f"{self.__class__} cannot download data automatically.")
def _get_dataset_path(self, dataset_name):
"""
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存
@@ -65,11 +63,9 @@ class Loader:
:param str dataset_name: 数据集的名称
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。
"""
default_cache_path = get_cache_path()
url = _get_dataset_url(dataset_name)
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset')
return output_dir



+ 2
- 1
fastNLP/io/loader/matching.py View File

@@ -203,7 +203,8 @@ class QNLILoader(JsonLoader):
"""
如果您的实验使用到了该数据,请引用

TODO 补充
.. todo::
补充

:return:
"""


+ 1
- 1
fastNLP/io/model_io.py View File

@@ -8,7 +8,7 @@ __all__ = [

import torch

from .base_loader import BaseLoader
from .data_bundle import BaseLoader


class ModelLoader(BaseLoader):


+ 1
- 1
fastNLP/io/pipe/classification.py View File

@@ -1,6 +1,6 @@
from nltk import Tree

from ..base_loader import DataBundle
from ..data_bundle import DataBundle
from ...core.vocabulary import Vocabulary
from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader


+ 188
- 188
reproduction/Summarization/Baseline/data/dataloader.py View File

@@ -1,188 +1,188 @@
import pickle
import numpy as np
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.core.const import Const
from tools.logger import *
WORD_PAD = "[PAD]"
WORD_UNK = "[UNK]"
DOMAIN_UNK = "X"
TAG_UNK = "X"
class SummarizationLoader(JsonLoader):
"""
读取summarization数据集,读取的DataSet包含fields::
text: list(str),document
summary: list(str), summary
text_wd: list(list(str)),tokenized document
summary_wd: list(list(str)), tokenized summary
labels: list(int),
flatten_label: list(int), 0 or 1, flatten labels
domain: str, optional
tag: list(str), optional
数据来源: CNN_DailyMail Newsroom DUC
"""
def __init__(self):
super(SummarizationLoader, self).__init__()
def _load(self, path):
ds = super(SummarizationLoader, self)._load(path)
def _lower_text(text_list):
return [text.lower() for text in text_list]
def _split_list(text_list):
return [text.split() for text in text_list]
def _convert_label(label, sent_len):
np_label = np.zeros(sent_len, dtype=int)
if label != []:
np_label[np.array(label)] = 1
return np_label.tolist()
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
return ds
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
"""
:param paths: dict path for each dataset
:param vocab_size: int max_size for vocab
:param vocab_path: str vocab path
:param sent_max_len: int max token number of the sentence
:param doc_max_timesteps: int max sentence number of the document
:param domain: bool build vocab for publication, use 'X' for unknown
:param tag: bool build vocab for tag, use 'X' for unknown
:param load_vocab_file: bool build vocab (False) or load vocab (True)
:return: DataBundle
datasets: dict keys correspond to the paths dict
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
embeddings: optional
"""
def _pad_sent(text_wd):
pad_text_wd = []
for sent_wd in text_wd:
if len(sent_wd) < sent_max_len:
pad_num = sent_max_len - len(sent_wd)
sent_wd.extend([WORD_PAD] * pad_num)
else:
sent_wd = sent_wd[:sent_max_len]
pad_text_wd.append(sent_wd)
return pad_text_wd
def _token_mask(text_wd):
token_mask_list = []
for sent_wd in text_wd:
token_num = len(sent_wd)
if token_num < sent_max_len:
mask = [1] * token_num + [0] * (sent_max_len - token_num)
else:
mask = [1] * sent_max_len
token_mask_list.append(mask)
return token_mask_list
def _pad_label(label):
text_len = len(label)
if text_len < doc_max_timesteps:
pad_label = label + [0] * (doc_max_timesteps - text_len)
else:
pad_label = label[:doc_max_timesteps]
return pad_label
def _pad_doc(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
padding = [WORD_PAD] * sent_max_len
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
else:
pad_text = text_wd[:doc_max_timesteps]
return pad_text
def _sent_mask(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
else:
sent_mask = [1] * doc_max_timesteps
return sent_mask
datasets = {}
train_ds = None
for key, value in paths.items():
ds = self.load(value)
# pad sent
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
# pad document
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
# rename field
ds.rename_field("pad_text", Const.INPUT)
ds.rename_field("seq_len", Const.INPUT_LEN)
ds.rename_field("pad_label", Const.TARGET)
# set input and target
ds.set_input(Const.INPUT, Const.INPUT_LEN)
ds.set_target(Const.TARGET, Const.INPUT_LEN)
datasets[key] = ds
if "train" in key:
train_ds = datasets[key]
vocab_dict = {}
if load_vocab_file == False:
logger.info("[INFO] Build new vocab from training dataset!")
if train_ds == None:
raise ValueError("Lack train file to build vocabulary!")
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
vocab_dict["vocab"] = vocabs
else:
logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
word_list = []
with open(vocab_path, 'r', encoding='utf8') as vocab_f:
cnt = 2 # pad and unk
for line in vocab_f:
pieces = line.split("\t")
word_list.append(pieces[0])
cnt += 1
if cnt > vocab_size:
break
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.add_word_lst(word_list)
vocabs.build_vocab()
vocab_dict["vocab"] = vocabs
if domain == True:
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
domaindict.from_dataset(train_ds, field_name="publication")
vocab_dict["domain"] = domaindict
if tag == True:
tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
tagdict.from_dataset(train_ds, field_name="tag")
vocab_dict["tag"] = tagdict
for ds in datasets.values():
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return DataBundle(vocabs=vocab_dict, datasets=datasets)
import pickle
import numpy as np
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.data_bundle import DataBundle
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.core.const import Const
from tools.logger import *
WORD_PAD = "[PAD]"
WORD_UNK = "[UNK]"
DOMAIN_UNK = "X"
TAG_UNK = "X"
class SummarizationLoader(JsonLoader):
"""
读取summarization数据集,读取的DataSet包含fields::
text: list(str),document
summary: list(str), summary
text_wd: list(list(str)),tokenized document
summary_wd: list(list(str)), tokenized summary
labels: list(int),
flatten_label: list(int), 0 or 1, flatten labels
domain: str, optional
tag: list(str), optional
数据来源: CNN_DailyMail Newsroom DUC
"""
def __init__(self):
super(SummarizationLoader, self).__init__()
def _load(self, path):
ds = super(SummarizationLoader, self)._load(path)
def _lower_text(text_list):
return [text.lower() for text in text_list]
def _split_list(text_list):
return [text.split() for text in text_list]
def _convert_label(label, sent_len):
np_label = np.zeros(sent_len, dtype=int)
if label != []:
np_label[np.array(label)] = 1
return np_label.tolist()
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
return ds
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
"""
:param paths: dict path for each dataset
:param vocab_size: int max_size for vocab
:param vocab_path: str vocab path
:param sent_max_len: int max token number of the sentence
:param doc_max_timesteps: int max sentence number of the document
:param domain: bool build vocab for publication, use 'X' for unknown
:param tag: bool build vocab for tag, use 'X' for unknown
:param load_vocab_file: bool build vocab (False) or load vocab (True)
:return: DataBundle
datasets: dict keys correspond to the paths dict
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
embeddings: optional
"""
def _pad_sent(text_wd):
pad_text_wd = []
for sent_wd in text_wd:
if len(sent_wd) < sent_max_len:
pad_num = sent_max_len - len(sent_wd)
sent_wd.extend([WORD_PAD] * pad_num)
else:
sent_wd = sent_wd[:sent_max_len]
pad_text_wd.append(sent_wd)
return pad_text_wd
def _token_mask(text_wd):
token_mask_list = []
for sent_wd in text_wd:
token_num = len(sent_wd)
if token_num < sent_max_len:
mask = [1] * token_num + [0] * (sent_max_len - token_num)
else:
mask = [1] * sent_max_len
token_mask_list.append(mask)
return token_mask_list
def _pad_label(label):
text_len = len(label)
if text_len < doc_max_timesteps:
pad_label = label + [0] * (doc_max_timesteps - text_len)
else:
pad_label = label[:doc_max_timesteps]
return pad_label
def _pad_doc(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
padding = [WORD_PAD] * sent_max_len
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
else:
pad_text = text_wd[:doc_max_timesteps]
return pad_text
def _sent_mask(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
else:
sent_mask = [1] * doc_max_timesteps
return sent_mask
datasets = {}
train_ds = None
for key, value in paths.items():
ds = self.load(value)
# pad sent
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
# pad document
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
# rename field
ds.rename_field("pad_text", Const.INPUT)
ds.rename_field("seq_len", Const.INPUT_LEN)
ds.rename_field("pad_label", Const.TARGET)
# set input and target
ds.set_input(Const.INPUT, Const.INPUT_LEN)
ds.set_target(Const.TARGET, Const.INPUT_LEN)
datasets[key] = ds
if "train" in key:
train_ds = datasets[key]
vocab_dict = {}
if load_vocab_file == False:
logger.info("[INFO] Build new vocab from training dataset!")
if train_ds == None:
raise ValueError("Lack train file to build vocabulary!")
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
vocab_dict["vocab"] = vocabs
else:
logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
word_list = []
with open(vocab_path, 'r', encoding='utf8') as vocab_f:
cnt = 2 # pad and unk
for line in vocab_f:
pieces = line.split("\t")
word_list.append(pieces[0])
cnt += 1
if cnt > vocab_size:
break
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.add_word_lst(word_list)
vocabs.build_vocab()
vocab_dict["vocab"] = vocabs
if domain == True:
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
domaindict.from_dataset(train_ds, field_name="publication")
vocab_dict["domain"] = domaindict
if tag == True:
tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
tagdict.from_dataset(train_ds, field_name="tag")
vocab_dict["tag"] = tagdict
for ds in datasets.values():
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return DataBundle(vocabs=vocab_dict, datasets=datasets)

+ 1
- 1
reproduction/Summarization/BertSum/dataloader.py View File

@@ -3,7 +3,7 @@ from datetime import timedelta

from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.modules.encoder._bert import BertTokenizer
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.data_bundle import DataBundle
from fastNLP.core.const import Const

class BertData(JsonLoader):


+ 1
- 1
reproduction/coreference_resolution/data_load/cr_loader.py View File

@@ -1,7 +1,7 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.data_bundle import DataBundle
from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess



+ 1
- 1
reproduction/joint_cws_parse/data/data_loader.py View File

@@ -1,6 +1,6 @@


from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from fastNLP.io.data_loader import ConllLoader
import numpy as np



+ 1
- 1
reproduction/matching/data/MatchingDataLoader.py View File

@@ -9,7 +9,7 @@ from typing import Union, Dict

from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.io.data_bundle import DataBundle, DataSetLoader
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer


+ 1
- 1
reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py View File

@@ -1,6 +1,6 @@


from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from fastNLP.io import ConllLoader
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
from fastNLP import Const


+ 1
- 1
reproduction/seqence_labelling/cws/data/CWSDataLoader.py View File

@@ -1,7 +1,7 @@

from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance


+ 1
- 1
reproduction/seqence_labelling/ner/data/Conll2003Loader.py View File

@@ -1,6 +1,6 @@

from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict
from fastNLP import Vocabulary
from fastNLP import Const


+ 1
- 1
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py View File

@@ -1,5 +1,5 @@
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict
from fastNLP import DataSet
from fastNLP import Vocabulary


+ 1
- 1
reproduction/text_classification/data/IMDBLoader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance


+ 1
- 1
reproduction/text_classification/data/MTL16Loader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator
from fastNLP import DataSet
from fastNLP import Instance


+ 1
- 1
reproduction/text_classification/data/sstloader.py View File

@@ -1,6 +1,6 @@
from typing import Iterable
from nltk import Tree
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.io.data_bundle import DataBundle, DataSetLoader
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
from fastNLP import DataSet
from fastNLP import Instance


+ 1
- 1
reproduction/text_classification/data/yelpLoader.py View File

@@ -4,7 +4,7 @@ from typing import Iterable
from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io import JsonLoader
from fastNLP.io.base_loader import DataBundle,DataSetLoader
from fastNLP.io.data_bundle import DataBundle,DataSetLoader
from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json
from typing import Union, Dict


Loading…
Cancel
Save