Browse Source

Merge branch 'dev0.5.0' of https://github.com/fastnlp/fastNLP into dev0.5.0

tags/v0.4.10
陈怡然 5 years ago
parent
commit
a38a32755c
36 changed files with 280 additions and 593 deletions
  1. +1
    -1
      docs/Makefile
  2. +4
    -1
      docs/format.py
  3. +0
    -7
      docs/source/fastNLP.io.base_loader.rst
  4. +7
    -0
      docs/source/fastNLP.io.data_bundle.rst
  5. +1
    -1
      docs/source/fastNLP.io.rst
  6. +3
    -4
      fastNLP/io/__init__.py
  7. +0
    -313
      fastNLP/io/config_io.py
  8. +0
    -2
      fastNLP/io/data_bundle.py
  9. +2
    -2
      fastNLP/io/data_loader/conll.py
  10. +1
    -1
      fastNLP/io/data_loader/imdb.py
  11. +1
    -1
      fastNLP/io/data_loader/matching.py
  12. +1
    -1
      fastNLP/io/data_loader/mtl.py
  13. +1
    -1
      fastNLP/io/data_loader/people_daily.py
  14. +1
    -1
      fastNLP/io/data_loader/sst.py
  15. +1
    -1
      fastNLP/io/data_loader/yelp.py
  16. +1
    -1
      fastNLP/io/dataset_loader.py
  17. +1
    -1
      fastNLP/io/embed_loader.py
  18. +2
    -1
      fastNLP/io/loader/__init__.py
  19. +31
    -29
      fastNLP/io/loader/classification.py
  20. +17
    -21
      fastNLP/io/loader/loader.py
  21. +2
    -1
      fastNLP/io/loader/matching.py
  22. +1
    -1
      fastNLP/io/model_io.py
  23. +1
    -1
      fastNLP/io/pipe/classification.py
  24. +188
    -188
      reproduction/Summarization/Baseline/data/dataloader.py
  25. +1
    -1
      reproduction/Summarization/BertSum/dataloader.py
  26. +1
    -1
      reproduction/coreference_resolution/data_load/cr_loader.py
  27. +1
    -1
      reproduction/joint_cws_parse/data/data_loader.py
  28. +1
    -1
      reproduction/matching/data/MatchingDataLoader.py
  29. +1
    -1
      reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py
  30. +1
    -1
      reproduction/seqence_labelling/cws/data/CWSDataLoader.py
  31. +1
    -1
      reproduction/seqence_labelling/ner/data/Conll2003Loader.py
  32. +1
    -1
      reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
  33. +1
    -1
      reproduction/text_classification/data/IMDBLoader.py
  34. +1
    -1
      reproduction/text_classification/data/MTL16Loader.py
  35. +1
    -1
      reproduction/text_classification/data/sstloader.py
  36. +1
    -1
      reproduction/text_classification/data/yelpLoader.py

+ 1
- 1
docs/Makefile View File

@@ -20,7 +20,7 @@ server:
cd build/html && python -m http.server cd build/html && python -m http.server


dev: dev:
rm -rf build/html && make html && make server
rm -rf build && make html && make server


.PHONY: help Makefile .PHONY: help Makefile




+ 4
- 1
docs/format.py View File

@@ -59,7 +59,10 @@ def clear(path='./source/'):
else: else:
shorten(path + file, to_delete) shorten(path + file, to_delete)
for file in to_delete: for file in to_delete:
os.remove(path + file + ".rst")
try:
os.remove(path + file + ".rst")
except:
pass




clear() clear()

+ 0
- 7
docs/source/fastNLP.io.base_loader.rst View File

@@ -1,7 +0,0 @@
fastNLP.io.base\_loader
=======================

.. automodule:: fastNLP.io.base_loader
:members:
:undoc-members:
:show-inheritance:

+ 7
- 0
docs/source/fastNLP.io.data_bundle.rst View File

@@ -0,0 +1,7 @@
fastNLP.io.data\_bundle
=======================

.. automodule:: fastNLP.io.data_bundle
:members:
:undoc-members:
:show-inheritance:

+ 1
- 1
docs/source/fastNLP.io.rst View File

@@ -20,7 +20,7 @@ Submodules


.. toctree:: .. toctree::


fastNLP.io.base_loader
fastNLP.io.data_bundle
fastNLP.io.dataset_loader fastNLP.io.dataset_loader
fastNLP.io.embed_loader fastNLP.io.embed_loader
fastNLP.io.file_utils fastNLP.io.file_utils


+ 3
- 4
fastNLP/io/__init__.py View File

@@ -12,10 +12,9 @@
这些类的使用方法如下: 这些类的使用方法如下:
""" """
__all__ = [ __all__ = [
'EmbedLoader',

'DataBundle', 'DataBundle',
'DataSetLoader',
'EmbedLoader',


'YelpLoader', 'YelpLoader',
'YelpFullLoader', 'YelpFullLoader',
@@ -69,7 +68,7 @@ __all__ = [
] ]


from .embed_loader import EmbedLoader from .embed_loader import EmbedLoader
from .base_loader import DataBundle, DataSetLoader
from .data_bundle import DataBundle
from .dataset_loader import CSVLoader, JsonLoader from .dataset_loader import CSVLoader, JsonLoader
from .model_io import ModelLoader, ModelSaver from .model_io import ModelLoader, ModelSaver




+ 0
- 313
fastNLP/io/config_io.py View File

@@ -1,313 +0,0 @@
"""
用于读入和处理和保存 config 文件

.. todo::
这个模块中的类可能被抛弃?

"""
__all__ = [
"ConfigLoader",
"ConfigSection",
"ConfigSaver"
]

import configparser
import json
import os

from .base_loader import BaseLoader


class ConfigLoader(BaseLoader):
"""
别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader`

读取配置文件的Loader

:param str data_path: 配置文件的路径

"""
def __init__(self, data_path=None):
super(ConfigLoader, self).__init__()
if data_path is not None:
self.config = self.parse(super(ConfigLoader, self).load(data_path))
@staticmethod
def parse(string):
raise NotImplementedError
@staticmethod
def load_config(file_path, sections):
"""
把配置文件的section 存入提供的 ``sections`` 中

:param str file_path: 配置文件的路径
:param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection`
Example::

test_args = ConfigSection()
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})

"""
assert isinstance(sections, dict)
cfg = configparser.ConfigParser()
if not os.path.exists(file_path):
raise FileNotFoundError("config file {} not found. ".format(file_path))
cfg.read(file_path)
for s in sections:
attr_list = [i for i in sections[s].__dict__.keys() if
not callable(getattr(sections[s], i)) and not i.startswith("__")]
if s not in cfg:
print('section %s not found in config file' % (s))
continue
gen_sec = cfg[s]
for attr in gen_sec.keys():
try:
val = json.loads(gen_sec[attr])
# print(s, attr, val, type(val))
if attr in attr_list:
assert type(val) == type(getattr(sections[s], attr)), \
'type not match, except %s but got %s' % \
(type(getattr(sections[s], attr)), type(val))
"""
if attr in attr_list then check its type and
update its value.
else add a new attr in sections[s]
"""
setattr(sections[s], attr, val)
except Exception as e:
print("cannot load attribute %s in section %s"
% (attr, s))
pass


class ConfigSection(object):
"""
别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection`

ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用

"""
def __init__(self):
super(ConfigSection, self).__init__()
def __getitem__(self, key):
"""
:param key: str, the name of the attribute
:return attr: the value of this attribute
if key not in self.__dict__.keys():
return self[key]
else:
raise AttributeError
"""
if key in self.__dict__.keys():
return getattr(self, key)
raise AttributeError("do NOT have attribute %s" % key)
def __setitem__(self, key, value):
"""
:param key: str, the name of the attribute
:param value: the value of this attribute
if key not in self.__dict__.keys():
self[key] will be added
else:
self[key] will be updated
"""
if key in self.__dict__.keys():
if not isinstance(value, type(getattr(self, key))):
raise AttributeError("attr %s except %s but got %s" %
(key, str(type(getattr(self, key))), str(type(value))))
setattr(self, key, value)
def __contains__(self, item):
"""
:param item: The key of item.
:return: True if the key in self.__dict__.keys() else False.
"""
return item in self.__dict__.keys()
def __eq__(self, other):
"""Overwrite the == operator

:param other: Another ConfigSection() object which to be compared.
:return: True if value of each key in each ConfigSection() object are equal to the other, else False.
"""
for k in self.__dict__.keys():
if k not in other.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False
for k in other.__dict__.keys():
if k not in self.__dict__.keys():
return False
if getattr(self, k) != getattr(self, k):
return False
return True
def __ne__(self, other):
"""Overwrite the != operator

:param other:
:return:
"""
return not self.__eq__(other)
@property
def data(self):
return self.__dict__


class ConfigSaver(object):
"""
别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver`

ConfigSaver 是用来存储配置文件并解决相关冲突的类

:param str file_path: 配置文件的路径

"""
def __init__(self, file_path):
self.file_path = file_path
if not os.path.exists(self.file_path):
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path))
def _get_section(self, sect_name):
"""
This is the function to get the section with the section name.

:param sect_name: The name of section what wants to load.
:return: The section.
"""
sect = ConfigSection()
ConfigLoader().load_config(self.file_path, {sect_name: sect})
return sect
def _read_section(self):
"""
This is the function to read sections from the config file.

:return: sect_list, sect_key_list
sect_list: A list of ConfigSection().
sect_key_list: A list of names in sect_list.
"""
sect_name = None
sect_list = {}
sect_key_list = []
single_section = {}
single_section_key = []
with open(self.file_path, 'r') as f:
lines = f.readlines()
for line in lines:
if line.startswith('[') and line.endswith(']\n'):
if sect_name is None:
pass
else:
sect_list[sect_name] = single_section, single_section_key
single_section = {}
single_section_key = []
sect_key_list.append(sect_name)
sect_name = line[1: -2]
continue
if line.startswith('#'):
single_section[line] = '#'
single_section_key.append(line)
continue
if line.startswith('\n'):
single_section_key.append('\n')
continue
if '=' not in line:
raise RuntimeError("can NOT load config file {}".__format__(self.file_path))
key = line.split('=', maxsplit=1)[0].strip()
value = line.split('=', maxsplit=1)[1].strip() + '\n'
single_section[key] = value
single_section_key.append(key)
if sect_name is not None:
sect_list[sect_name] = single_section, single_section_key
sect_key_list.append(sect_name)
return sect_list, sect_key_list
def _write_section(self, sect_list, sect_key_list):
"""
This is the function to write config file with section list and name list.

:param sect_list: A list of ConfigSection() need to be writen into file.
:param sect_key_list: A list of name of sect_list.
:return:
"""
with open(self.file_path, 'w') as f:
for sect_key in sect_key_list:
single_section, single_section_key = sect_list[sect_key]
f.write('[' + sect_key + ']\n')
for key in single_section_key:
if key == '\n':
f.write('\n')
continue
if single_section[key] == '#':
f.write(key)
continue
f.write(key + ' = ' + single_section[key])
f.write('\n')
def save_config_file(self, section_name, section):
"""
这个方法可以用来修改并保存配置文件中单独的一个 section

:param str section_name: 需要保存的 section 的名字.
:param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型
"""
section_file = self._get_section(section_name)
if len(section_file.__dict__.keys()) == 0: # the section not in the file before
# append this section to config file
with open(self.file_path, 'a') as f:
f.write('[' + section_name + ']\n')
for k in section.__dict__.keys():
f.write(k + ' = ')
if isinstance(section[k], str):
f.write('\"' + str(section[k]) + '\"\n\n')
else:
f.write(str(section[k]) + '\n\n')
else:
# the section exists
change_file = False
for k in section.__dict__.keys():
if k not in section_file:
# find a new key in this section
change_file = True
break
if section_file[k] != section[k]:
change_file = True
break
if not change_file:
return
sect_list, sect_key_list = self._read_section()
if section_name not in sect_key_list:
raise AttributeError()
sect, sect_key = sect_list[section_name]
for k in section.__dict__.keys():
if k not in sect_key:
if sect_key[-1] != '\n':
sect_key.append('\n')
sect_key.append(k)
sect[k] = str(section[k])
if isinstance(section[k], str):
sect[k] = "\"" + sect[k] + "\""
sect[k] = sect[k] + "\n"
sect_list[section_name] = sect, sect_key
self._write_section(sect_list, sect_key_list)

fastNLP/io/base_loader.py → fastNLP/io/data_bundle.py View File

@@ -1,7 +1,5 @@
__all__ = [ __all__ = [
"BaseLoader",
'DataBundle', 'DataBundle',
'DataSetLoader',
] ]


import _pickle as pickle import _pickle as pickle

+ 2
- 2
fastNLP/io/data_loader/conll.py View File

@@ -1,11 +1,11 @@


from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ..base_loader import DataSetLoader
from ..data_bundle import DataSetLoader
from ..file_reader import _read_conll from ..file_reader import _read_conll
from typing import Union, Dict from typing import Union, Dict
from ..utils import check_loader_paths from ..utils import check_loader_paths
from ..base_loader import DataBundle
from ..data_bundle import DataBundle


class ConllLoader(DataSetLoader): class ConllLoader(DataSetLoader):
""" """


+ 1
- 1
fastNLP/io/data_loader/imdb.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict from typing import Union, Dict


from ..embed_loader import EmbeddingOption, EmbedLoader from ..embed_loader import EmbeddingOption, EmbedLoader
from ..base_loader import DataSetLoader, DataBundle
from ..data_bundle import DataSetLoader, DataBundle
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance


+ 1
- 1
fastNLP/io/data_loader/matching.py View File

@@ -4,7 +4,7 @@ from typing import Union, Dict, List


from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ..base_loader import DataBundle, DataSetLoader
from ..data_bundle import DataBundle, DataSetLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from ...modules.encoder.bert import BertTokenizer from ...modules.encoder.bert import BertTokenizer




+ 1
- 1
fastNLP/io/data_loader/mtl.py View File

@@ -1,7 +1,7 @@


from typing import Union, Dict from typing import Union, Dict


from ..base_loader import DataBundle
from ..data_bundle import DataBundle
from ..dataset_loader import CSVLoader from ..dataset_loader import CSVLoader
from ...core.vocabulary import Vocabulary, VocabularyOption from ...core.vocabulary import Vocabulary, VocabularyOption
from ...core.const import Const from ...core.const import Const


+ 1
- 1
fastNLP/io/data_loader/people_daily.py View File

@@ -1,5 +1,5 @@


from ..base_loader import DataSetLoader
from ..data_bundle import DataSetLoader
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ...core.const import Const from ...core.const import Const


+ 1
- 1
fastNLP/io/data_loader/sst.py View File

@@ -2,7 +2,7 @@
from typing import Union, Dict from typing import Union, Dict
from nltk import Tree from nltk import Tree


from ..base_loader import DataBundle, DataSetLoader
from ..data_bundle import DataBundle, DataSetLoader
from ..dataset_loader import CSVLoader from ..dataset_loader import CSVLoader
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet from ...core.dataset import DataSet


+ 1
- 1
fastNLP/io/data_loader/yelp.py View File

@@ -6,7 +6,7 @@ from ...core.const import Const
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ...core.vocabulary import VocabularyOption, Vocabulary from ...core.vocabulary import VocabularyOption, Vocabulary
from ..base_loader import DataBundle, DataSetLoader
from ..data_bundle import DataBundle, DataSetLoader
from typing import Union, Dict from typing import Union, Dict
from ..utils import check_loader_paths, get_tokenizer from ..utils import check_loader_paths, get_tokenizer




+ 1
- 1
fastNLP/io/dataset_loader.py View File

@@ -26,7 +26,7 @@ __all__ = [
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json from .file_reader import _read_csv, _read_json
from .base_loader import DataSetLoader
from .data_bundle import DataSetLoader




class JsonLoader(DataSetLoader): class JsonLoader(DataSetLoader):


+ 1
- 1
fastNLP/io/embed_loader.py View File

@@ -9,7 +9,7 @@ import warnings
import numpy as np import numpy as np


from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary
from .base_loader import BaseLoader
from .data_bundle import BaseLoader
from ..core.utils import Option from ..core.utils import Option






+ 2
- 1
fastNLP/io/loader/__init__.py View File

@@ -44,6 +44,8 @@ fastNLP 目前提供了如下的 Loader
""" """


__all__ = [ __all__ = [
'Loader',
'YelpLoader', 'YelpLoader',
'YelpFullLoader', 'YelpFullLoader',
'YelpPolarityLoader', 'YelpPolarityLoader',
@@ -57,7 +59,6 @@ __all__ = [
'OntoNotesNERLoader', 'OntoNotesNERLoader',
'CTBLoader', 'CTBLoader',


'Loader',
'CSVLoader', 'CSVLoader',
'JsonLoader', 'JsonLoader',




+ 31
- 29
fastNLP/io/loader/classification.py View File

@@ -7,6 +7,7 @@ import random
import shutil import shutil
import numpy as np import numpy as np



class YelpLoader(Loader): class YelpLoader(Loader):
""" """
别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader` 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader`
@@ -14,6 +15,7 @@ class YelpLoader(Loader):
原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。


Example:: Example::
"1","I got 'new' tires from the..." "1","I got 'new' tires from the..."
"1","Don't waste your time..." "1","Don't waste your time..."


@@ -28,11 +30,11 @@ class YelpLoader(Loader):
"...", "..." "...", "..."


""" """
def __init__(self): def __init__(self):
super(YelpLoader, self).__init__() super(YelpLoader, self).__init__()
def _load(self, path: str=None):
def _load(self, path: str = None):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
@@ -69,12 +71,12 @@ class YelpFullLoader(YelpLoader):
:param int seed: 划分dev时的随机数种子 :param int seed: 划分dev时的随机数种子
:return: str, 数据集的目录地址 :return: str, 数据集的目录地址
""" """
dataset_name = 'yelp-review-full' dataset_name = 'yelp-review-full'
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载 if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否需要重新下载
re_download = True re_download = True
if dev_ratio>0:
if dev_ratio > 0:
dev_line_count = 0 dev_line_count = 0
tr_line_count = 0 tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
@@ -83,14 +85,14 @@ class YelpFullLoader(YelpLoader):
tr_line_count += 1 tr_line_count += 1
for line in f2: for line in f2:
dev_line_count += 1 dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True re_download = True
else: else:
re_download = False re_download = False
if re_download: if re_download:
shutil.rmtree(data_dir) shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0: if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -109,7 +111,7 @@ class YelpFullLoader(YelpLoader):
finally: finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv')) os.remove(os.path.join(data_dir, 'middle_file.csv'))
return data_dir return data_dir




@@ -131,7 +133,7 @@ class YelpPolarityLoader(YelpLoader):
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求 if os.path.exists(os.path.join(data_dir, 'dev.csv')): # 存在dev的话,check是否符合比例要求
re_download = True re_download = True
if dev_ratio>0:
if dev_ratio > 0:
dev_line_count = 0 dev_line_count = 0
tr_line_count = 0 tr_line_count = 0
with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \ with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f1, \
@@ -140,14 +142,14 @@ class YelpPolarityLoader(YelpLoader):
tr_line_count += 1 tr_line_count += 1
for line in f2: for line in f2:
dev_line_count += 1 dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True re_download = True
else: else:
re_download = False re_download = False
if re_download: if re_download:
shutil.rmtree(data_dir) shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0: if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -166,7 +168,7 @@ class YelpPolarityLoader(YelpLoader):
finally: finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.csv')): if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
os.remove(os.path.join(data_dir, 'middle_file.csv')) os.remove(os.path.join(data_dir, 'middle_file.csv'))
return data_dir return data_dir




@@ -185,10 +187,10 @@ class IMDBLoader(Loader):
"...", "..." "...", "..."


""" """
def __init__(self): def __init__(self):
super(IMDBLoader, self).__init__() super(IMDBLoader, self).__init__()
def _load(self, path: str): def _load(self, path: str):
dataset = DataSet() dataset = DataSet()
with open(path, 'r', encoding="utf-8") as f: with open(path, 'r', encoding="utf-8") as f:
@@ -201,12 +203,12 @@ class IMDBLoader(Loader):
words = parts[1] words = parts[1]
if words: if words:
dataset.append(Instance(raw_words=words, target=target)) dataset.append(Instance(raw_words=words, target=target))
if len(dataset) == 0: if len(dataset) == 0:
raise RuntimeError(f"{path} has no valid data.") raise RuntimeError(f"{path} has no valid data.")
return dataset return dataset
def download(self, dev_ratio: float = 0.1, seed: int = 0): def download(self, dev_ratio: float = 0.1, seed: int = 0):
""" """
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 自动下载数据集,如果你使用了这个数据集,请引用以下的文章
@@ -221,9 +223,9 @@ class IMDBLoader(Loader):
""" """
dataset_name = 'aclImdb' dataset_name = 'aclImdb'
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求
if os.path.exists(os.path.join(data_dir, 'dev.txt')): # 存在dev的话,check是否符合比例要求
re_download = True re_download = True
if dev_ratio>0:
if dev_ratio > 0:
dev_line_count = 0 dev_line_count = 0
tr_line_count = 0 tr_line_count = 0
with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \ with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f1, \
@@ -232,14 +234,14 @@ class IMDBLoader(Loader):
tr_line_count += 1 tr_line_count += 1
for line in f2: for line in f2:
dev_line_count += 1 dev_line_count += 1
if not np.isclose(dev_line_count, dev_ratio*(tr_line_count + dev_line_count), rtol=0.005):
if not np.isclose(dev_line_count, dev_ratio * (tr_line_count + dev_line_count), rtol=0.005):
re_download = True re_download = True
else: else:
re_download = False re_download = False
if re_download: if re_download:
shutil.rmtree(data_dir) shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0: if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -258,7 +260,7 @@ class IMDBLoader(Loader):
finally: finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.txt')): if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
os.remove(os.path.join(data_dir, 'middle_file.txt')) os.remove(os.path.join(data_dir, 'middle_file.txt'))
return data_dir return data_dir




@@ -278,10 +280,10 @@ class SSTLoader(Loader):
raw_words列是str。 raw_words列是str。


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path: str): def _load(self, path: str):
""" """
从path读取SST文件 从path读取SST文件
@@ -296,7 +298,7 @@ class SSTLoader(Loader):
if line: if line:
ds.append(Instance(raw_words=line)) ds.append(Instance(raw_words=line))
return ds return ds
def download(self): def download(self):
""" """
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 自动下载数据集,如果你使用了这个数据集,请引用以下的文章
@@ -323,10 +325,10 @@ class SST2Loader(Loader):


test的DataSet没有target列。 test的DataSet没有target列。
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path: str): def _load(self, path: str):
""" """
从path读取SST2文件 从path读取SST2文件
@@ -335,7 +337,7 @@ class SST2Loader(Loader):
:return: DataSet :return: DataSet
""" """
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if 'test' in os.path.split(path)[1]: if 'test' in os.path.split(path)[1]:
@@ -356,7 +358,7 @@ class SST2Loader(Loader):
if raw_words: if raw_words:
ds.append(Instance(raw_words=raw_words, target=target)) ds.append(Instance(raw_words=raw_words, target=target))
return ds return ds
def download(self): def download(self):
""" """
自动下载数据集,如果你使用了该数据集,请引用以下的文章 自动下载数据集,如果你使用了该数据集,请引用以下的文章


+ 17
- 21
fastNLP/io/loader/loader.py View File

@@ -2,17 +2,21 @@ from ...core.dataset import DataSet
from .. import DataBundle from .. import DataBundle
from ..utils import check_loader_paths from ..utils import check_loader_paths
from typing import Union, Dict from typing import Union, Dict
import os
from ..file_utils import _get_dataset_url, get_cache_path, cached_path from ..file_utils import _get_dataset_url, get_cache_path, cached_path



class Loader: class Loader:
"""
各种数据 Loader 的基类,提供了 API 的参考.
"""
def __init__(self): def __init__(self):
pass pass

def _load(self, path:str) -> DataSet:
def _load(self, path: str) -> DataSet:
raise NotImplementedError raise NotImplementedError
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle:
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
""" """
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。


@@ -22,31 +26,25 @@ class Loader:
(0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 (0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。


(1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件 (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件
名包含'train'、 'dev'、 'test'则会报错

Example::
名包含'train'、 'dev'、 'test'则会报错::


data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、 data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、
# dev、 test等有所变化,可以通过以下的方式取出DataSet # dev、 test等有所变化,可以通过以下的方式取出DataSet
tr_data = data_bundle.datasets['train'] tr_data = data_bundle.datasets['train']
te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段 te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段


(2) 传入文件路径

Example::
(2) 传入文件路径::


data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train'
tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet


(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test

Example::
(3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test::


paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"}
data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test"
dev_data = data_bundle.datasets['dev'] dev_data = data_bundle.datasets['dev']


:return: 返回的:class:`~fastNLP.io.DataBundle`
:return: 返回的 :class:`~fastNLP.io.DataBundle`
""" """
if paths is None: if paths is None:
paths = self.download() paths = self.download()
@@ -54,10 +52,10 @@ class Loader:
datasets = {name: self._load(path) for name, path in paths.items()} datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets) data_bundle = DataBundle(datasets=datasets)
return data_bundle return data_bundle
def download(self): def download(self):
raise NotImplementedError(f"{self.__class__} cannot download data automatically.") raise NotImplementedError(f"{self.__class__} cannot download data automatically.")
def _get_dataset_path(self, dataset_name): def _get_dataset_path(self, dataset_name):
""" """
传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存
@@ -65,11 +63,9 @@ class Loader:
:param str dataset_name: 数据集的名称 :param str dataset_name: 数据集的名称
:return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。 :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。
""" """
default_cache_path = get_cache_path() default_cache_path = get_cache_path()
url = _get_dataset_url(dataset_name) url = _get_dataset_url(dataset_name)
output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset') output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset')
return output_dir return output_dir



+ 2
- 1
fastNLP/io/loader/matching.py View File

@@ -203,7 +203,8 @@ class QNLILoader(JsonLoader):
""" """
如果您的实验使用到了该数据,请引用 如果您的实验使用到了该数据,请引用


TODO 补充
.. todo::
补充


:return: :return:
""" """


+ 1
- 1
fastNLP/io/model_io.py View File

@@ -8,7 +8,7 @@ __all__ = [


import torch import torch


from .base_loader import BaseLoader
from .data_bundle import BaseLoader




class ModelLoader(BaseLoader): class ModelLoader(BaseLoader):


+ 1
- 1
fastNLP/io/pipe/classification.py View File

@@ -1,6 +1,6 @@
from nltk import Tree from nltk import Tree


from ..base_loader import DataBundle
from ..data_bundle import DataBundle
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ...core.const import Const from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader


+ 188
- 188
reproduction/Summarization/Baseline/data/dataloader.py View File

@@ -1,188 +1,188 @@
import pickle
import numpy as np
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.core.const import Const
from tools.logger import *
WORD_PAD = "[PAD]"
WORD_UNK = "[UNK]"
DOMAIN_UNK = "X"
TAG_UNK = "X"
class SummarizationLoader(JsonLoader):
"""
读取summarization数据集,读取的DataSet包含fields::
text: list(str),document
summary: list(str), summary
text_wd: list(list(str)),tokenized document
summary_wd: list(list(str)), tokenized summary
labels: list(int),
flatten_label: list(int), 0 or 1, flatten labels
domain: str, optional
tag: list(str), optional
数据来源: CNN_DailyMail Newsroom DUC
"""
def __init__(self):
super(SummarizationLoader, self).__init__()
def _load(self, path):
ds = super(SummarizationLoader, self)._load(path)
def _lower_text(text_list):
return [text.lower() for text in text_list]
def _split_list(text_list):
return [text.split() for text in text_list]
def _convert_label(label, sent_len):
np_label = np.zeros(sent_len, dtype=int)
if label != []:
np_label[np.array(label)] = 1
return np_label.tolist()
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
return ds
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
"""
:param paths: dict path for each dataset
:param vocab_size: int max_size for vocab
:param vocab_path: str vocab path
:param sent_max_len: int max token number of the sentence
:param doc_max_timesteps: int max sentence number of the document
:param domain: bool build vocab for publication, use 'X' for unknown
:param tag: bool build vocab for tag, use 'X' for unknown
:param load_vocab_file: bool build vocab (False) or load vocab (True)
:return: DataBundle
datasets: dict keys correspond to the paths dict
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
embeddings: optional
"""
def _pad_sent(text_wd):
pad_text_wd = []
for sent_wd in text_wd:
if len(sent_wd) < sent_max_len:
pad_num = sent_max_len - len(sent_wd)
sent_wd.extend([WORD_PAD] * pad_num)
else:
sent_wd = sent_wd[:sent_max_len]
pad_text_wd.append(sent_wd)
return pad_text_wd
def _token_mask(text_wd):
token_mask_list = []
for sent_wd in text_wd:
token_num = len(sent_wd)
if token_num < sent_max_len:
mask = [1] * token_num + [0] * (sent_max_len - token_num)
else:
mask = [1] * sent_max_len
token_mask_list.append(mask)
return token_mask_list
def _pad_label(label):
text_len = len(label)
if text_len < doc_max_timesteps:
pad_label = label + [0] * (doc_max_timesteps - text_len)
else:
pad_label = label[:doc_max_timesteps]
return pad_label
def _pad_doc(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
padding = [WORD_PAD] * sent_max_len
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
else:
pad_text = text_wd[:doc_max_timesteps]
return pad_text
def _sent_mask(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
else:
sent_mask = [1] * doc_max_timesteps
return sent_mask
datasets = {}
train_ds = None
for key, value in paths.items():
ds = self.load(value)
# pad sent
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
# pad document
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
# rename field
ds.rename_field("pad_text", Const.INPUT)
ds.rename_field("seq_len", Const.INPUT_LEN)
ds.rename_field("pad_label", Const.TARGET)
# set input and target
ds.set_input(Const.INPUT, Const.INPUT_LEN)
ds.set_target(Const.TARGET, Const.INPUT_LEN)
datasets[key] = ds
if "train" in key:
train_ds = datasets[key]
vocab_dict = {}
if load_vocab_file == False:
logger.info("[INFO] Build new vocab from training dataset!")
if train_ds == None:
raise ValueError("Lack train file to build vocabulary!")
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
vocab_dict["vocab"] = vocabs
else:
logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
word_list = []
with open(vocab_path, 'r', encoding='utf8') as vocab_f:
cnt = 2 # pad and unk
for line in vocab_f:
pieces = line.split("\t")
word_list.append(pieces[0])
cnt += 1
if cnt > vocab_size:
break
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.add_word_lst(word_list)
vocabs.build_vocab()
vocab_dict["vocab"] = vocabs
if domain == True:
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
domaindict.from_dataset(train_ds, field_name="publication")
vocab_dict["domain"] = domaindict
if tag == True:
tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
tagdict.from_dataset(train_ds, field_name="tag")
vocab_dict["tag"] = tagdict
for ds in datasets.values():
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return DataBundle(vocabs=vocab_dict, datasets=datasets)
import pickle
import numpy as np
from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.data_bundle import DataBundle
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.core.const import Const
from tools.logger import *
WORD_PAD = "[PAD]"
WORD_UNK = "[UNK]"
DOMAIN_UNK = "X"
TAG_UNK = "X"
class SummarizationLoader(JsonLoader):
"""
读取summarization数据集,读取的DataSet包含fields::
text: list(str),document
summary: list(str), summary
text_wd: list(list(str)),tokenized document
summary_wd: list(list(str)), tokenized summary
labels: list(int),
flatten_label: list(int), 0 or 1, flatten labels
domain: str, optional
tag: list(str), optional
数据来源: CNN_DailyMail Newsroom DUC
"""
def __init__(self):
super(SummarizationLoader, self).__init__()
def _load(self, path):
ds = super(SummarizationLoader, self)._load(path)
def _lower_text(text_list):
return [text.lower() for text in text_list]
def _split_list(text_list):
return [text.split() for text in text_list]
def _convert_label(label, sent_len):
np_label = np.zeros(sent_len, dtype=int)
if label != []:
np_label[np.array(label)] = 1
return np_label.tolist()
ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
return ds
def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
"""
:param paths: dict path for each dataset
:param vocab_size: int max_size for vocab
:param vocab_path: str vocab path
:param sent_max_len: int max token number of the sentence
:param doc_max_timesteps: int max sentence number of the document
:param domain: bool build vocab for publication, use 'X' for unknown
:param tag: bool build vocab for tag, use 'X' for unknown
:param load_vocab_file: bool build vocab (False) or load vocab (True)
:return: DataBundle
datasets: dict keys correspond to the paths dict
vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
embeddings: optional
"""
def _pad_sent(text_wd):
pad_text_wd = []
for sent_wd in text_wd:
if len(sent_wd) < sent_max_len:
pad_num = sent_max_len - len(sent_wd)
sent_wd.extend([WORD_PAD] * pad_num)
else:
sent_wd = sent_wd[:sent_max_len]
pad_text_wd.append(sent_wd)
return pad_text_wd
def _token_mask(text_wd):
token_mask_list = []
for sent_wd in text_wd:
token_num = len(sent_wd)
if token_num < sent_max_len:
mask = [1] * token_num + [0] * (sent_max_len - token_num)
else:
mask = [1] * sent_max_len
token_mask_list.append(mask)
return token_mask_list
def _pad_label(label):
text_len = len(label)
if text_len < doc_max_timesteps:
pad_label = label + [0] * (doc_max_timesteps - text_len)
else:
pad_label = label[:doc_max_timesteps]
return pad_label
def _pad_doc(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
padding = [WORD_PAD] * sent_max_len
pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
else:
pad_text = text_wd[:doc_max_timesteps]
return pad_text
def _sent_mask(text_wd):
text_len = len(text_wd)
if text_len < doc_max_timesteps:
sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
else:
sent_mask = [1] * doc_max_timesteps
return sent_mask
datasets = {}
train_ds = None
for key, value in paths.items():
ds = self.load(value)
# pad sent
ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
# pad document
ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
# rename field
ds.rename_field("pad_text", Const.INPUT)
ds.rename_field("seq_len", Const.INPUT_LEN)
ds.rename_field("pad_label", Const.TARGET)
# set input and target
ds.set_input(Const.INPUT, Const.INPUT_LEN)
ds.set_target(Const.TARGET, Const.INPUT_LEN)
datasets[key] = ds
if "train" in key:
train_ds = datasets[key]
vocab_dict = {}
if load_vocab_file == False:
logger.info("[INFO] Build new vocab from training dataset!")
if train_ds == None:
raise ValueError("Lack train file to build vocabulary!")
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
vocab_dict["vocab"] = vocabs
else:
logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
word_list = []
with open(vocab_path, 'r', encoding='utf8') as vocab_f:
cnt = 2 # pad and unk
for line in vocab_f:
pieces = line.split("\t")
word_list.append(pieces[0])
cnt += 1
if cnt > vocab_size:
break
vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
vocabs.add_word_lst(word_list)
vocabs.build_vocab()
vocab_dict["vocab"] = vocabs
if domain == True:
domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
domaindict.from_dataset(train_ds, field_name="publication")
vocab_dict["domain"] = domaindict
if tag == True:
tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
tagdict.from_dataset(train_ds, field_name="tag")
vocab_dict["tag"] = tagdict
for ds in datasets.values():
vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
return DataBundle(vocabs=vocab_dict, datasets=datasets)

+ 1
- 1
reproduction/Summarization/BertSum/dataloader.py View File

@@ -3,7 +3,7 @@ from datetime import timedelta


from fastNLP.io.dataset_loader import JsonLoader from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.modules.encoder._bert import BertTokenizer from fastNLP.modules.encoder._bert import BertTokenizer
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.data_bundle import DataBundle
from fastNLP.core.const import Const from fastNLP.core.const import Const


class BertData(JsonLoader): class BertData(JsonLoader):


+ 1
- 1
reproduction/coreference_resolution/data_load/cr_loader.py View File

@@ -1,7 +1,7 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataBundle
from fastNLP.io.data_bundle import DataBundle
from reproduction.coreference_resolution.model.config import Config from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess import reproduction.coreference_resolution.model.preprocess as preprocess




+ 1
- 1
reproduction/joint_cws_parse/data/data_loader.py View File

@@ -1,6 +1,6 @@




from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from fastNLP.io.data_loader import ConllLoader from fastNLP.io.data_loader import ConllLoader
import numpy as np import numpy as np




+ 1
- 1
reproduction/matching/data/MatchingDataLoader.py View File

@@ -9,7 +9,7 @@ from typing import Union, Dict


from fastNLP.core.const import Const from fastNLP.core.const import Const
from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.vocabulary import Vocabulary
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.io.data_bundle import DataBundle, DataSetLoader
from fastNLP.io.dataset_loader import JsonLoader, CSVLoader from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
from fastNLP.modules.encoder._bert import BertTokenizer from fastNLP.modules.encoder._bert import BertTokenizer


+ 1
- 1
reproduction/seqence_labelling/chinese_ner/data/ChineseNER.py View File

@@ -1,6 +1,6 @@




from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from fastNLP.io import ConllLoader from fastNLP.io import ConllLoader
from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2 from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
from fastNLP import Const from fastNLP import Const


+ 1
- 1
reproduction/seqence_labelling/cws/data/CWSDataLoader.py View File

@@ -1,7 +1,7 @@


from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator from typing import Union, Dict, List, Iterator
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance


+ 1
- 1
reproduction/seqence_labelling/ner/data/Conll2003Loader.py View File

@@ -1,6 +1,6 @@


from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict from typing import Union, Dict
from fastNLP import Vocabulary from fastNLP import Vocabulary
from fastNLP import Const from fastNLP import Const


+ 1
- 1
reproduction/seqence_labelling/ner/data/OntoNoteLoader.py View File

@@ -1,5 +1,5 @@
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict from typing import Union, Dict
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Vocabulary from fastNLP import Vocabulary


+ 1
- 1
reproduction/text_classification/data/IMDBLoader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator from typing import Union, Dict, List, Iterator
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance


+ 1
- 1
reproduction/text_classification/data/MTL16Loader.py View File

@@ -1,6 +1,6 @@
from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io.base_loader import DataSetLoader, DataBundle
from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from typing import Union, Dict, List, Iterator from typing import Union, Dict, List, Iterator
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance


+ 1
- 1
reproduction/text_classification/data/sstloader.py View File

@@ -1,6 +1,6 @@
from typing import Iterable from typing import Iterable
from nltk import Tree from nltk import Tree
from fastNLP.io.base_loader import DataBundle, DataSetLoader
from fastNLP.io.data_bundle import DataBundle, DataSetLoader
from fastNLP.core.vocabulary import VocabularyOption, Vocabulary from fastNLP.core.vocabulary import VocabularyOption, Vocabulary
from fastNLP import DataSet from fastNLP import DataSet
from fastNLP import Instance from fastNLP import Instance


+ 1
- 1
reproduction/text_classification/data/yelpLoader.py View File

@@ -4,7 +4,7 @@ from typing import Iterable
from fastNLP import DataSet, Instance, Vocabulary from fastNLP import DataSet, Instance, Vocabulary
from fastNLP.core.vocabulary import VocabularyOption from fastNLP.core.vocabulary import VocabularyOption
from fastNLP.io import JsonLoader from fastNLP.io import JsonLoader
from fastNLP.io.base_loader import DataBundle,DataSetLoader
from fastNLP.io.data_bundle import DataBundle,DataSetLoader
from fastNLP.io.embed_loader import EmbeddingOption from fastNLP.io.embed_loader import EmbeddingOption
from fastNLP.io.file_reader import _read_json from fastNLP.io.file_reader import _read_json
from typing import Union, Dict from typing import Union, Dict


Loading…
Cancel
Save