Browse Source

merge conflict

tags/v0.4.10
yh 5 years ago
parent
commit
359a176748
100 changed files with 1418 additions and 1233 deletions
  1. +7
    -6
      README.md
  2. +92
    -83
      docs/count.py
  3. +0
    -68
      docs/format.py
  4. +5
    -3
      docs/source/conf.py
  5. +3
    -3
      docs/source/fastNLP.core.batch.rst
  6. +3
    -3
      docs/source/fastNLP.core.callback.rst
  7. +3
    -3
      docs/source/fastNLP.core.const.rst
  8. +3
    -3
      docs/source/fastNLP.core.dataset.rst
  9. +3
    -3
      docs/source/fastNLP.core.field.rst
  10. +3
    -3
      docs/source/fastNLP.core.instance.rst
  11. +3
    -3
      docs/source/fastNLP.core.losses.rst
  12. +3
    -3
      docs/source/fastNLP.core.metrics.rst
  13. +3
    -3
      docs/source/fastNLP.core.optimizer.rst
  14. +4
    -5
      docs/source/fastNLP.core.rst
  15. +3
    -3
      docs/source/fastNLP.core.sampler.rst
  16. +3
    -3
      docs/source/fastNLP.core.tester.rst
  17. +3
    -3
      docs/source/fastNLP.core.trainer.rst
  18. +3
    -3
      docs/source/fastNLP.core.utils.rst
  19. +3
    -3
      docs/source/fastNLP.core.vocabulary.rst
  20. +5
    -5
      docs/source/fastNLP.embeddings.bert_embedding.rst
  21. +5
    -5
      docs/source/fastNLP.embeddings.char_embedding.rst
  22. +7
    -0
      docs/source/fastNLP.embeddings.contextual_embedding.rst
  23. +5
    -5
      docs/source/fastNLP.embeddings.elmo_embedding.rst
  24. +3
    -3
      docs/source/fastNLP.embeddings.embedding.rst
  25. +5
    -5
      docs/source/fastNLP.embeddings.rst
  26. +5
    -5
      docs/source/fastNLP.embeddings.stack_embedding.rst
  27. +5
    -5
      docs/source/fastNLP.embeddings.static_embedding.rst
  28. +3
    -3
      docs/source/fastNLP.embeddings.utils.rst
  29. +5
    -5
      docs/source/fastNLP.io.data_bundle.rst
  30. +0
    -8
      docs/source/fastNLP.io.data_loader.rst
  31. +4
    -5
      docs/source/fastNLP.io.dataset_loader.rst
  32. +5
    -5
      docs/source/fastNLP.io.embed_loader.rst
  33. +5
    -5
      docs/source/fastNLP.io.file_utils.rst
  34. +2
    -3
      docs/source/fastNLP.io.loader.rst
  35. +5
    -5
      docs/source/fastNLP.io.model_io.rst
  36. +2
    -3
      docs/source/fastNLP.io.pipe.rst
  37. +6
    -15
      docs/source/fastNLP.io.rst
  38. +3
    -3
      docs/source/fastNLP.io.utils.rst
  39. +5
    -5
      docs/source/fastNLP.models.biaffine_parser.rst
  40. +5
    -5
      docs/source/fastNLP.models.cnn_text_classification.rst
  41. +4
    -5
      docs/source/fastNLP.models.rst
  42. +5
    -5
      docs/source/fastNLP.models.sequence_labeling.rst
  43. +3
    -3
      docs/source/fastNLP.models.snli.rst
  44. +5
    -5
      docs/source/fastNLP.models.star_transformer.rst
  45. +2
    -3
      docs/source/fastNLP.modules.decoder.rst
  46. +2
    -3
      docs/source/fastNLP.modules.encoder.rst
  47. +4
    -11
      docs/source/fastNLP.modules.rst
  48. +3
    -3
      docs/source/fastNLP.modules.utils.rst
  49. +4
    -5
      docs/source/fastNLP.rst
  50. +73
    -147
      docs/source/tutorials/tutorial_2_load_dataset.rst
  51. +21
    -68
      docs/source/tutorials/tutorial_3_embedding.rst
  52. +5
    -2
      docs/source/tutorials/tutorial_4_loss_optimizer.rst
  53. +4
    -1
      docs/source/tutorials/tutorial_5_datasetiter.rst
  54. +1
    -1
      docs/source/user/tutorials.rst
  55. +3
    -3
      fastNLP/__init__.py
  56. +65
    -2
      fastNLP/core/__init__.py
  57. +20
    -18
      fastNLP/core/_logger.py
  58. +13
    -8
      fastNLP/core/_parallel_utils.py
  59. +2
    -2
      fastNLP/core/batch.py
  60. +9
    -9
      fastNLP/core/callback.py
  61. +18
    -8
      fastNLP/core/const.py
  62. +6
    -5
      fastNLP/core/dataset.py
  63. +11
    -12
      fastNLP/core/dist_trainer.py
  64. +22
    -14
      fastNLP/core/field.py
  65. +15
    -13
      fastNLP/core/predictor.py
  66. +1
    -1
      fastNLP/core/tester.py
  67. +2
    -2
      fastNLP/core/utils.py
  68. +20
    -13
      fastNLP/core/vocabulary.py
  69. +0
    -1
      fastNLP/embeddings/__init__.py
  70. +98
    -76
      fastNLP/embeddings/bert_embedding.py
  71. +43
    -34
      fastNLP/embeddings/char_embedding.py
  72. +26
    -19
      fastNLP/embeddings/contextual_embedding.py
  73. +46
    -40
      fastNLP/embeddings/elmo_embedding.py
  74. +33
    -27
      fastNLP/embeddings/embedding.py
  75. +17
    -7
      fastNLP/embeddings/stack_embedding.py
  76. +37
    -29
      fastNLP/embeddings/static_embedding.py
  77. +11
    -5
      fastNLP/embeddings/utils.py
  78. +4
    -0
      fastNLP/io/__init__.py
  79. +6
    -1
      fastNLP/io/data_bundle.py
  80. +1
    -1
      fastNLP/io/data_loader/__init__.py
  81. +3
    -3
      fastNLP/io/dataset_loader.py
  82. +9
    -4
      fastNLP/io/embed_loader.py
  83. +16
    -9
      fastNLP/io/file_reader.py
  84. +27
    -11
      fastNLP/io/file_utils.py
  85. +2
    -2
      fastNLP/io/loader/__init__.py
  86. +19
    -8
      fastNLP/io/loader/classification.py
  87. +53
    -34
      fastNLP/io/loader/conll.py
  88. +9
    -3
      fastNLP/io/loader/csv.py
  89. +12
    -5
      fastNLP/io/loader/cws.py
  90. +8
    -2
      fastNLP/io/loader/json.py
  91. +10
    -3
      fastNLP/io/loader/loader.py
  92. +49
    -33
      fastNLP/io/loader/matching.py
  93. +3
    -0
      fastNLP/io/pipe/__init__.py
  94. +89
    -74
      fastNLP/io/pipe/classification.py
  95. +47
    -32
      fastNLP/io/pipe/conll.py
  96. +55
    -35
      fastNLP/io/pipe/cws.py
  97. +48
    -31
      fastNLP/io/pipe/matching.py
  98. +9
    -0
      fastNLP/io/pipe/pipe.py
  99. +24
    -14
      fastNLP/io/pipe/utils.py
  100. +21
    -8
      fastNLP/io/utils.py

+ 7
- 6
README.md View File

@@ -6,11 +6,12 @@
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)


fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner)、POS-Tagging等)、中文分词、[文本分类](reproduction/text_classification)、[Matching](reproduction/matching)、[指代消解](reproduction/coreference_resolution)、[摘要](reproduction/Summarization)等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性:
fastNLP 是一款轻量级的 NLP 工具包。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner)、POS-Tagging等)、中文分词、[文本分类](reproduction/text_classification)、[Matching](reproduction/matching)、[指代消解](reproduction/coreference_resolution)、[摘要](reproduction/Summarization)等任务; 也可以使用它快速构建许多复杂的网络模型,进行科研。它具有如下的特性:


- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码;
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的Loader和Pipe,省去预处理代码;
- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等; - 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等;
- 各种方便的NLP工具,例如预处理embedding加载(包括ELMo和BERT); 中间数据cache等; - 各种方便的NLP工具,例如预处理embedding加载(包括ELMo和BERT); 中间数据cache等;
- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载
- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)以供查阅; - 详尽的中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)以供查阅;
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; - 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等;
- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用,详细内容见 [reproduction](reproduction) 部分; - 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用,详细内容见 [reproduction](reproduction) 部分;
@@ -36,7 +37,7 @@ pip install fastNLP
python -m spacy download en python -m spacy download en
``` ```


目前使用pip安装fastNLP的版本是0.4.1,有较多功能仍未更新,最新内容以master分支为准。
目前使用pypi安装fastNLP的版本是0.4.1,有较多功能仍未更新,最新内容以master分支为准。
fastNLP0.5.0版本将在近期推出,请密切关注。 fastNLP0.5.0版本将在近期推出,请密切关注。




@@ -44,7 +45,7 @@ fastNLP0.5.0版本将在近期推出,请密切关注。


- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html) - [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html) - [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
- [2. 使用DataSetLoader加载数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_load_dataset.html)
- [2. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_load_dataset.html)
- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) - [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
- [4. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_loss_optimizer.html) - [4. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_loss_optimizer.html)
- [5. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_datasetiter.html) - [5. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_datasetiter.html)
@@ -91,7 +92,7 @@ fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedd


## 项目结构 ## 项目结构


![](./docs/source/figures/workflow.png)
<img src="./docs/source/figures/workflow.png" width="60%" height="60%">


fastNLP的大致工作流程如上图所示,而项目结构如下: fastNLP的大致工作流程如上图所示,而项目结构如下:


@@ -118,7 +119,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下:
</tr> </tr>
<tr> <tr>
<td><b> fastNLP.io </b></td> <td><b> fastNLP.io </b></td>
<td> 实现了读写功能,包括数据读入,模型读写等 </td>
<td> 实现了读写功能,包括数据读入与预处理,模型读写,自动下载等 </td>
</tr> </tr>
</table> </table>




+ 92
- 83
docs/count.py View File

@@ -1,98 +1,107 @@
import inspect
import os import os
import sys




def find_all(path='../fastNLP'):
head_list = []
alias_list = []
for path, dirs, files in os.walk(path):
def _colored_string(string: str, color: str or int) -> str:
"""在终端中显示一串有颜色的文字
:param string: 在终端中显示的文字
:param color: 文字的颜色
:return:
"""
if isinstance(color, str):
color = {
"black": 30, "Black": 30, "BLACK": 30,
"red": 31, "Red": 31, "RED": 31,
"green": 32, "Green": 32, "GREEN": 32,
"yellow": 33, "Yellow": 33, "YELLOW": 33,
"blue": 34, "Blue": 34, "BLUE": 34,
"purple": 35, "Purple": 35, "PURPLE": 35,
"cyan": 36, "Cyan": 36, "CYAN": 36,
"white": 37, "White": 37, "WHITE": 37
}[color]
return "\033[%dm%s\033[0m" % (color, string)


def find_all_modules():
modules = {}
children = {}
to_doc = set()
root = '../fastNLP'
for path, dirs, files in os.walk(root):
for file in files: for file in files:
if file.endswith('.py'): if file.endswith('.py'):
name = ".".join(path.split('/')[1:]) name = ".".join(path.split('/')[1:])
if file.split('.')[0] != "__init__": if file.split('.')[0] != "__init__":
name = name + '.' + file.split('.')[0] name = name + '.' + file.split('.')[0]
if len(name.split('.')) < 3 or name.startswith('fastNLP.core'):
heads, alias = find_one(path + '/' + file)
for h in heads:
head_list.append(name + "." + h)
for a in alias:
alias_list.append(a)
heads = {}
for h in head_list:
end = h.split('.')[-1]
file = h[:-len(end) - 1]
if end not in heads:
heads[end] = set()
heads[end].add(file)
alias = {}
for a in alias_list:
for each in a:
end = each.split('.')[-1]
file = each[:-len(end) - 1]
if end not in alias:
alias[end] = set()
alias[end].add(file)
print("IN alias NOT IN heads")
for item in alias:
if item not in heads:
print(item, alias[item])
elif len(heads[item]) != 2:
print(item, alias[item], heads[item])
print("\n\nIN heads NOT IN alias")
for item in heads:
if item not in alias:
print(item, heads[item])
__import__(name)
m = sys.modules[name]
modules[name] = m
try:
m.__all__
except:
print(name, "__all__ missing")
continue
if m.__doc__ is None:
print(name, "__doc__ missing")
continue
if "undocumented" not in m.__doc__:
to_doc.add(name)
for module in to_doc:
t = ".".join(module.split('.')[:-1])
if t in to_doc:
if t not in children:
children[t] = set()
children[t].add(module)
for m in children:
children[m] = sorted(children[m])
return modules, to_doc, children


def create_rst_file(modules, name, children):
m = modules[name]
with open("./source/" + name + ".rst", "w") as fout:
t = "=" * len(name)
fout.write(name + "\n")
fout.write(t + "\n")
fout.write("\n")
fout.write(".. automodule:: " + name + "\n")
if len(m.__all__) > 0:
fout.write(" :members: " + ", ".join(m.__all__) + "\n")
fout.write(" :inherited-members:\n")
fout.write("\n")
if name in children:
fout.write("子模块\n------\n\n.. toctree::\n\n")
for module in children[name]:
fout.write(" " + module + "\n")


def check_file(m, name):
for item, obj in inspect.getmembers(m):
if inspect.isclass(obj) and obj.__module__ == name:
print(obj)
if inspect.isfunction(obj) and obj.__module__ == name:
print("FUNC", obj)




def find_class(path):
with open(path, 'r') as fin:
lines = fin.readlines()
pars = {}
for i, line in enumerate(lines):
if line.strip().startswith('class'):
line = line.strip()[len('class'):-1].strip()
if line[-1] == ')':
line = line[:-1].split('(')
name = line[0].strip()
parents = line[1].split(',')
for i in range(len(parents)):
parents[i] = parents[i].strip()
if len(parents) == 1:
pars[name] = parents[0]
else:
pars[name] = tuple(parents)
return pars
def check_files(modules):
for name in sorted(modules.keys()):
if name == 'fastNLP.core.utils':
check_file(modules[name], name)




def find_one(path):
head_list = []
alias = []
with open(path, 'r') as fin:
lines = fin.readlines()
flag = False
for i, line in enumerate(lines):
if line.strip().startswith('__all__'):
line = line.strip()[len('__all__'):].strip()
if line[-1] == ']':
line = line[1:-1].strip()[1:].strip()
head_list.append(line.strip("\"").strip("\'").strip())
else:
flag = True
elif line.strip() == ']':
flag = False
elif flag:
line = line.strip()[:-1].strip("\"").strip("\'").strip()
if len(line) == 0 or line[0] == '#':
continue
head_list.append(line)
if line.startswith('def') or line.startswith('class'):
if lines[i + 2].strip().startswith("别名:"):
names = lines[i + 2].strip()[len("别名:"):].split()
names[0] = names[0][len(":class:`"):-1]
names[1] = names[1][len(":class:`"):-1]
alias.append((names[0], names[1]))
return head_list, alias
def main():
print(_colored_string('Getting modules...', "Blue"))
modules, to_doc, children = find_all_modules()
print(_colored_string('Done!', "Green"))
print(_colored_string('Creating rst files...', "Blue"))
for name in to_doc:
create_rst_file(modules, name, children)
print(_colored_string('Done!', "Green"))
print(_colored_string('Checking all files...', "Blue"))
check_files(modules)
print(_colored_string('Done!', "Green"))




if __name__ == "__main__": if __name__ == "__main__":
find_all() # use to check __all__
main()

+ 0
- 68
docs/format.py View File

@@ -1,68 +0,0 @@
import os


def shorten(file, to_delete, cut=False):
if file.endswith("index.rst") or file.endswith("conf.py"):
return
res = []
with open(file, "r") as fin:
lines = fin.readlines()
for line in lines:
if cut and line.rstrip() == "Submodules":
break
else:
res.append(line.rstrip())
for i, line in enumerate(res):
if line.endswith(" package"):
res[i] = res[i][:-len(" package")]
res[i + 1] = res[i + 1][:-len(" package")]
elif line.endswith(" module"):
res[i] = res[i][:-len(" module")]
res[i + 1] = res[i + 1][:-len(" module")]
else:
for name in to_delete:
if line.endswith(name):
res[i] = "del"

with open(file, "w") as fout:
for line in res:
if line != "del":
print(line, file=fout)


def clear(path='./source/'):
files = os.listdir(path)
to_delete = [
"fastNLP.core.dist_trainer",
"fastNLP.core.predictor",

"fastNLP.io.file_reader",
"fastNLP.io.config_io",

"fastNLP.embeddings.contextual_embedding",

"fastNLP.modules.dropout",
"fastNLP.models.base_model",
"fastNLP.models.bert",
"fastNLP.models.enas_utils",
"fastNLP.models.enas_controller",
"fastNLP.models.enas_model",
"fastNLP.models.enas_trainer",
]
for file in files:
if not os.path.isdir(path + file):
res = file.split('.')
if len(res) > 4:
to_delete.append(file[:-4])
elif len(res) == 4:
shorten(path + file, to_delete, True)
else:
shorten(path + file, to_delete)
for file in to_delete:
try:
os.remove(path + file + ".rst")
except:
pass


clear()

+ 5
- 3
docs/source/conf.py View File

@@ -48,12 +48,14 @@ extensions = [
autodoc_default_options = { autodoc_default_options = {
'member-order': 'bysource', 'member-order': 'bysource',
'special-members': '__init__', 'special-members': '__init__',
'undoc-members': True,
'undoc-members': False,
} }


autoclass_content = "class"

# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ['_templates']
# template_bridge
# The suffix(es) of source filenames. # The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string: # You can specify multiple suffix as a list of string:
# #
@@ -113,7 +115,7 @@ html_static_path = ['_static']
# -- Options for HTMLHelp output --------------------------------------------- # -- Options for HTMLHelp output ---------------------------------------------


# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'fastNLPdoc'
htmlhelp_basename = 'fastNLP doc'


# -- Options for LaTeX output ------------------------------------------------ # -- Options for LaTeX output ------------------------------------------------




+ 3
- 3
docs/source/fastNLP.core.batch.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.batch
================== ==================


.. automodule:: fastNLP.core.batch .. automodule:: fastNLP.core.batch
:members:
:undoc-members:
:show-inheritance:
:members: BatchIter, DataSetIter, TorchLoaderIter
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.callback.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.callback
===================== =====================


.. automodule:: fastNLP.core.callback .. automodule:: fastNLP.core.callback
:members:
:undoc-members:
:show-inheritance:
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.const.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.const
================== ==================


.. automodule:: fastNLP.core.const .. automodule:: fastNLP.core.const
:members:
:undoc-members:
:show-inheritance:
:members: Const
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.dataset.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.dataset
==================== ====================


.. automodule:: fastNLP.core.dataset .. automodule:: fastNLP.core.dataset
:members:
:undoc-members:
:show-inheritance:
:members: DataSet
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.field.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.field
================== ==================


.. automodule:: fastNLP.core.field .. automodule:: fastNLP.core.field
:members:
:undoc-members:
:show-inheritance:
:members: Padder, AutoPadder, EngChar2DPadder
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.instance.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.instance
===================== =====================


.. automodule:: fastNLP.core.instance .. automodule:: fastNLP.core.instance
:members:
:undoc-members:
:show-inheritance:
:members: Instance
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.losses.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.losses
=================== ===================


.. automodule:: fastNLP.core.losses .. automodule:: fastNLP.core.losses
:members:
:undoc-members:
:show-inheritance:
:members: LossBase, LossFunc, LossInForward, CrossEntropyLoss, BCELoss, L1Loss, NLLLoss
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.metrics.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.metrics
==================== ====================


.. automodule:: fastNLP.core.metrics .. automodule:: fastNLP.core.metrics
:members:
:undoc-members:
:show-inheritance:
:members: MetricBase, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.optimizer.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.optimizer
====================== ======================


.. automodule:: fastNLP.core.optimizer .. automodule:: fastNLP.core.optimizer
:members:
:undoc-members:
:show-inheritance:
:members: Optimizer, SGD, Adam, AdamW
:inherited-members:

+ 4
- 5
docs/source/fastNLP.core.rst View File

@@ -2,12 +2,11 @@ fastNLP.core
============ ============


.. automodule:: fastNLP.core .. automodule:: fastNLP.core
:members:
:undoc-members:
:show-inheritance:
:members: DataSet, Instance, FieldArray, Padder, AutoPadder, EngChar2DPadder, Vocabulary, DataSetIter, BatchIter, TorchLoaderIter, Const, Tester, Trainer, cache_results, seq_len_to_mask, get_seq_len, logger, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, SequentialSampler, BucketSampler, RandomSampler, Sampler
:inherited-members:


Submodules
----------
子模块
------


.. toctree:: .. toctree::




+ 3
- 3
docs/source/fastNLP.core.sampler.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.sampler
==================== ====================


.. automodule:: fastNLP.core.sampler .. automodule:: fastNLP.core.sampler
:members:
:undoc-members:
:show-inheritance:
:members: Sampler, BucketSampler, SequentialSampler, RandomSampler
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.tester.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.tester
=================== ===================


.. automodule:: fastNLP.core.tester .. automodule:: fastNLP.core.tester
:members:
:undoc-members:
:show-inheritance:
:members: Tester
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.trainer.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.trainer
==================== ====================


.. automodule:: fastNLP.core.trainer .. automodule:: fastNLP.core.trainer
:members:
:undoc-members:
:show-inheritance:
:members: Trainer
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.utils.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.utils
================== ==================


.. automodule:: fastNLP.core.utils .. automodule:: fastNLP.core.utils
:members:
:undoc-members:
:show-inheritance:
:members: cache_results, seq_len_to_mask, get_seq_len
:inherited-members:

+ 3
- 3
docs/source/fastNLP.core.vocabulary.rst View File

@@ -2,6 +2,6 @@ fastNLP.core.vocabulary
======================= =======================


.. automodule:: fastNLP.core.vocabulary .. automodule:: fastNLP.core.vocabulary
:members:
:undoc-members:
:show-inheritance:
:members: Vocabulary, VocabularyOption
:inherited-members:

+ 5
- 5
docs/source/fastNLP.embeddings.bert_embedding.rst View File

@@ -1,7 +1,7 @@
fastNLP.embeddings.bert\_embedding
==================================
fastNLP.embeddings.bert_embedding
=================================


.. automodule:: fastNLP.embeddings.bert_embedding .. automodule:: fastNLP.embeddings.bert_embedding
:members:
:undoc-members:
:show-inheritance:
:members: BertEmbedding, BertWordPieceEncoder
:inherited-members:

+ 5
- 5
docs/source/fastNLP.embeddings.char_embedding.rst View File

@@ -1,7 +1,7 @@
fastNLP.embeddings.char\_embedding
==================================
fastNLP.embeddings.char_embedding
=================================


.. automodule:: fastNLP.embeddings.char_embedding .. automodule:: fastNLP.embeddings.char_embedding
:members:
:undoc-members:
:show-inheritance:
:members: CNNCharEmbedding, LSTMCharEmbedding
:inherited-members:

+ 7
- 0
docs/source/fastNLP.embeddings.contextual_embedding.rst View File

@@ -0,0 +1,7 @@
fastNLP.embeddings.contextual_embedding
=======================================

.. automodule:: fastNLP.embeddings.contextual_embedding
:members: ContextualEmbedding
:inherited-members:


+ 5
- 5
docs/source/fastNLP.embeddings.elmo_embedding.rst View File

@@ -1,7 +1,7 @@
fastNLP.embeddings.elmo\_embedding
==================================
fastNLP.embeddings.elmo_embedding
=================================


.. automodule:: fastNLP.embeddings.elmo_embedding .. automodule:: fastNLP.embeddings.elmo_embedding
:members:
:undoc-members:
:show-inheritance:
:members: ElmoEmbedding
:inherited-members:

+ 3
- 3
docs/source/fastNLP.embeddings.embedding.rst View File

@@ -2,6 +2,6 @@ fastNLP.embeddings.embedding
============================ ============================


.. automodule:: fastNLP.embeddings.embedding .. automodule:: fastNLP.embeddings.embedding
:members:
:undoc-members:
:show-inheritance:
:members: Embedding, TokenEmbedding
:inherited-members:

+ 5
- 5
docs/source/fastNLP.embeddings.rst View File

@@ -2,17 +2,17 @@ fastNLP.embeddings
================== ==================


.. automodule:: fastNLP.embeddings .. automodule:: fastNLP.embeddings
:members:
:undoc-members:
:show-inheritance:
:members: Embedding, TokenEmbedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, BertWordPieceEncoder, StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding, get_embeddings
:inherited-members:


Submodules
----------
子模块
------


.. toctree:: .. toctree::


fastNLP.embeddings.bert_embedding fastNLP.embeddings.bert_embedding
fastNLP.embeddings.char_embedding fastNLP.embeddings.char_embedding
fastNLP.embeddings.contextual_embedding
fastNLP.embeddings.elmo_embedding fastNLP.embeddings.elmo_embedding
fastNLP.embeddings.embedding fastNLP.embeddings.embedding
fastNLP.embeddings.stack_embedding fastNLP.embeddings.stack_embedding


+ 5
- 5
docs/source/fastNLP.embeddings.stack_embedding.rst View File

@@ -1,7 +1,7 @@
fastNLP.embeddings.stack\_embedding
===================================
fastNLP.embeddings.stack_embedding
==================================


.. automodule:: fastNLP.embeddings.stack_embedding .. automodule:: fastNLP.embeddings.stack_embedding
:members:
:undoc-members:
:show-inheritance:
:members: StackEmbedding
:inherited-members:

+ 5
- 5
docs/source/fastNLP.embeddings.static_embedding.rst View File

@@ -1,7 +1,7 @@
fastNLP.embeddings.static\_embedding
====================================
fastNLP.embeddings.static_embedding
===================================


.. automodule:: fastNLP.embeddings.static_embedding .. automodule:: fastNLP.embeddings.static_embedding
:members:
:undoc-members:
:show-inheritance:
:members: StaticEmbedding
:inherited-members:

+ 3
- 3
docs/source/fastNLP.embeddings.utils.rst View File

@@ -2,6 +2,6 @@ fastNLP.embeddings.utils
======================== ========================


.. automodule:: fastNLP.embeddings.utils .. automodule:: fastNLP.embeddings.utils
:members:
:undoc-members:
:show-inheritance:
:members: get_embeddings
:inherited-members:

+ 5
- 5
docs/source/fastNLP.io.data_bundle.rst View File

@@ -1,7 +1,7 @@
fastNLP.io.data\_bundle
=======================
fastNLP.io.data_bundle
======================


.. automodule:: fastNLP.io.data_bundle .. automodule:: fastNLP.io.data_bundle
:members:
:undoc-members:
:show-inheritance:
:members: DataBundle
:inherited-members:

+ 0
- 8
docs/source/fastNLP.io.data_loader.rst View File

@@ -1,8 +0,0 @@
fastNLP.io.data\_loader
=======================

.. automodule:: fastNLP.io.data_loader
:members:
:undoc-members:
:show-inheritance:


+ 4
- 5
docs/source/fastNLP.io.dataset_loader.rst View File

@@ -1,7 +1,6 @@
fastNLP.io.dataset\_loader
==========================
fastNLP.io.dataset_loader
=========================


.. automodule:: fastNLP.io.dataset_loader .. automodule:: fastNLP.io.dataset_loader
:members:
:undoc-members:
:show-inheritance:
:members: CSVLoader, JsonLoader


+ 5
- 5
docs/source/fastNLP.io.embed_loader.rst View File

@@ -1,7 +1,7 @@
fastNLP.io.embed\_loader
========================
fastNLP.io.embed_loader
=======================


.. automodule:: fastNLP.io.embed_loader .. automodule:: fastNLP.io.embed_loader
:members:
:undoc-members:
:show-inheritance:
:members: EmbedLoader, EmbeddingOption
:inherited-members:

+ 5
- 5
docs/source/fastNLP.io.file_utils.rst View File

@@ -1,7 +1,7 @@
fastNLP.io.file\_utils
======================
fastNLP.io.file_utils
=====================


.. automodule:: fastNLP.io.file_utils .. automodule:: fastNLP.io.file_utils
:members:
:undoc-members:
:show-inheritance:
:members: cached_path, get_filepath, get_cache_path, split_filename_suffix, get_from_cache
:inherited-members:

+ 2
- 3
docs/source/fastNLP.io.loader.rst View File

@@ -2,7 +2,6 @@ fastNLP.io.loader
================= =================


.. automodule:: fastNLP.io.loader .. automodule:: fastNLP.io.loader
:members:
:undoc-members:
:show-inheritance:
:members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
:inherited-members:



+ 5
- 5
docs/source/fastNLP.io.model_io.rst View File

@@ -1,7 +1,7 @@
fastNLP.io.model\_io
====================
fastNLP.io.model_io
===================


.. automodule:: fastNLP.io.model_io .. automodule:: fastNLP.io.model_io
:members:
:undoc-members:
:show-inheritance:
:members: ModelLoader, ModelSaver
:inherited-members:

+ 2
- 3
docs/source/fastNLP.io.pipe.rst View File

@@ -2,7 +2,6 @@ fastNLP.io.pipe
=============== ===============


.. automodule:: fastNLP.io.pipe .. automodule:: fastNLP.io.pipe
:members:
:undoc-members:
:show-inheritance:
:members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe
:inherited-members:



+ 6
- 15
docs/source/fastNLP.io.rst View File

@@ -2,27 +2,18 @@ fastNLP.io
========== ==========


.. automodule:: fastNLP.io .. automodule:: fastNLP.io
:members:
:undoc-members:
:show-inheritance:
:members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver
:inherited-members:


Subpackages
-----------

.. toctree::

fastNLP.io.data_loader
fastNLP.io.loader
fastNLP.io.pipe

Submodules
----------
子模块
------


.. toctree:: .. toctree::


fastNLP.io.data_bundle fastNLP.io.data_bundle
fastNLP.io.dataset_loader
fastNLP.io.embed_loader fastNLP.io.embed_loader
fastNLP.io.file_utils fastNLP.io.file_utils
fastNLP.io.loader
fastNLP.io.model_io fastNLP.io.model_io
fastNLP.io.pipe
fastNLP.io.utils fastNLP.io.utils

+ 3
- 3
docs/source/fastNLP.io.utils.rst View File

@@ -2,6 +2,6 @@ fastNLP.io.utils
================ ================


.. automodule:: fastNLP.io.utils .. automodule:: fastNLP.io.utils
:members:
:undoc-members:
:show-inheritance:
:members: check_loader_paths
:inherited-members:

+ 5
- 5
docs/source/fastNLP.models.biaffine_parser.rst View File

@@ -1,7 +1,7 @@
fastNLP.models.biaffine\_parser
===============================
fastNLP.models.biaffine_parser
==============================


.. automodule:: fastNLP.models.biaffine_parser .. automodule:: fastNLP.models.biaffine_parser
:members:
:undoc-members:
:show-inheritance:
:members: BiaffineParser, GraphParser
:inherited-members:

+ 5
- 5
docs/source/fastNLP.models.cnn_text_classification.rst View File

@@ -1,7 +1,7 @@
fastNLP.models.cnn\_text\_classification
========================================
fastNLP.models.cnn_text_classification
======================================


.. automodule:: fastNLP.models.cnn_text_classification .. automodule:: fastNLP.models.cnn_text_classification
:members:
:undoc-members:
:show-inheritance:
:members: CNNText
:inherited-members:

+ 4
- 5
docs/source/fastNLP.models.rst View File

@@ -2,12 +2,11 @@ fastNLP.models
============== ==============


.. automodule:: fastNLP.models .. automodule:: fastNLP.models
:members:
:undoc-members:
:show-inheritance:
:members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser
:inherited-members:


Submodules
----------
子模块
------


.. toctree:: .. toctree::




+ 5
- 5
docs/source/fastNLP.models.sequence_labeling.rst View File

@@ -1,7 +1,7 @@
fastNLP.models.sequence\_labeling
=================================
fastNLP.models.sequence_labeling
================================


.. automodule:: fastNLP.models.sequence_labeling .. automodule:: fastNLP.models.sequence_labeling
:members:
:undoc-members:
:show-inheritance:
:members: SeqLabeling, AdvSeqLabel
:inherited-members:

+ 3
- 3
docs/source/fastNLP.models.snli.rst View File

@@ -2,6 +2,6 @@ fastNLP.models.snli
=================== ===================


.. automodule:: fastNLP.models.snli .. automodule:: fastNLP.models.snli
:members:
:undoc-members:
:show-inheritance:
:members: ESIM
:inherited-members:

+ 5
- 5
docs/source/fastNLP.models.star_transformer.rst View File

@@ -1,7 +1,7 @@
fastNLP.models.star\_transformer
================================
fastNLP.models.star_transformer
===============================


.. automodule:: fastNLP.models.star_transformer .. automodule:: fastNLP.models.star_transformer
:members:
:undoc-members:
:show-inheritance:
:members: StarTransEnc, STNLICls, STSeqCls, STSeqLabel
:inherited-members:

+ 2
- 3
docs/source/fastNLP.modules.decoder.rst View File

@@ -2,7 +2,6 @@ fastNLP.modules.decoder
======================= =======================


.. automodule:: fastNLP.modules.decoder .. automodule:: fastNLP.modules.decoder
:members:
:undoc-members:
:show-inheritance:
:members: MLP, ConditionalRandomField, viterbi_decode, allowed_transitions
:inherited-members:



+ 2
- 3
docs/source/fastNLP.modules.encoder.rst View File

@@ -2,7 +2,6 @@ fastNLP.modules.encoder
======================= =======================


.. automodule:: fastNLP.modules.encoder .. automodule:: fastNLP.modules.encoder
:members:
:undoc-members:
:show-inheritance:
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention
:inherited-members:



+ 4
- 11
docs/source/fastNLP.modules.rst View File

@@ -2,21 +2,14 @@ fastNLP.modules
=============== ===============


.. automodule:: fastNLP.modules .. automodule:: fastNLP.modules
:members:
:undoc-members:
:show-inheritance:
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout
:inherited-members:


Subpackages
-----------
子模块
------


.. toctree:: .. toctree::


fastNLP.modules.decoder fastNLP.modules.decoder
fastNLP.modules.encoder fastNLP.modules.encoder

Submodules
----------

.. toctree::

fastNLP.modules.utils fastNLP.modules.utils

+ 3
- 3
docs/source/fastNLP.modules.utils.rst View File

@@ -2,6 +2,6 @@ fastNLP.modules.utils
===================== =====================


.. automodule:: fastNLP.modules.utils .. automodule:: fastNLP.modules.utils
:members:
:undoc-members:
:show-inheritance:
:members: initial_parameter, summary
:inherited-members:

+ 4
- 5
docs/source/fastNLP.rst View File

@@ -2,12 +2,11 @@ fastNLP
======= =======


.. automodule:: fastNLP .. automodule:: fastNLP
:members:
:undoc-members:
:show-inheritance:
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC, LRFinder, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger
:inherited-members:


Subpackages
-----------
子模块
------


.. toctree:: .. toctree::




+ 73
- 147
docs/source/tutorials/tutorial_2_load_dataset.rst View File

@@ -1,57 +1,53 @@
=================================
使用DataSetLoader加载数据集
=================================
=======================================
使用Loader和Pipe加载并处理数据集
=======================================


这一部分是一个关于如何加载数据集的教程 这一部分是一个关于如何加载数据集的教程


教程目录: 教程目录:


- `Part I: 数据集容器`_
- `Part II: 数据集的使用方式`_
- `Part III: 不同数据类型的DataSetLoader`_
- `Part IV: DataSetLoader举例`_
- `Part V: fastNLP封装好的数据集加载器`_
- `Part I: 数据集容器DataBundle`_
- `Part II: 加载数据集的基类Loader`_
- `Part III: 不同格式类型的基础Loader`_
- `Part IV: 使用Pipe对数据集进行预处理`_
- `Part V: fastNLP封装好的Loader和Pipe`_




----------------------------
Part I: 数据集容器
----------------------------
------------------------------------
Part I: 数据集容器DataBundle
------------------------------------


在fastNLP中,我们使用 :class:`~fastNLP.io.base_loader.DataBundle` 来存储数据集信息。
:class:`~fastNLP.io.base_loader.DataBundle` 类包含了两个重要内容: `datasets` 和 `vocabs` 。
在fastNLP中,我们使用 :class:`~fastNLP.io.data_bundle.DataBundle` 来存储数据集信息。
:class:`~fastNLP.io.data_bundle.DataBundle` 类包含了两个重要内容: `datasets` 和 `vocabs` 。


`datasets` 是一个 `key` 为数据集名称(如 `train` , `dev` ,和 `test` 等), `value` 为 :class:`~fastNLP.DataSet` 的字典。 `datasets` 是一个 `key` 为数据集名称(如 `train` , `dev` ,和 `test` 等), `value` 为 :class:`~fastNLP.DataSet` 的字典。


`vocabs` 是一个 `key` 为词表名称(如 :attr:`fastNLP.Const.INPUT` 表示输入文本的词表名称, :attr:`fastNLP.Const.TARGET` 表示目标 `vocabs` 是一个 `key` 为词表名称(如 :attr:`fastNLP.Const.INPUT` 表示输入文本的词表名称, :attr:`fastNLP.Const.TARGET` 表示目标
的真实标签词表的名称,等等), `value` 为词表内容( :class:`~fastNLP.Vocabulary` )的字典。 的真实标签词表的名称,等等), `value` 为词表内容( :class:`~fastNLP.Vocabulary` )的字典。


----------------------------
Part II: 数据集的使用方式
----------------------------
-------------------------------------
Part II: 加载数据集的基类Loader
-------------------------------------


在fastNLP中,我们采用 :class:`~fastNLP.io.base_loader.DataSetLoader` 来作为加载数据集的基类。
:class:`~fastNLP.io.base_loader.DataSetLoader` 定义了各种DataSetLoader所需的API接口,开发者应该继承它实现各种的DataSetLoader。
在各种数据集的DataSetLoader当中,至少应该编写如下内容:
在fastNLP中,我们采用 :class:`~fastNLP.io.loader.Loader` 来作为加载数据集的基类。
:class:`~fastNLP.io.loader.Loader` 定义了各种Loader所需的API接口,开发者应该继承它实现各种的Loader。
在各种数据集的Loader当中,至少应该编写如下内容:


- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的 :class:`~fastNLP.io.DataBundle`
- _load 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet`
- load 函数:从文件或者文件夹中读取数据并组装成 :class:`~fastNLP.io.data_bundle.DataBundle`


**\*process函数中可以调用load函数或_load函数**

DataSetLoader的_load或者load函数返回的 :class:`~fastNLP.DataSet` 当中,内容为数据集的文本信息,process函数返回的
:class:`~fastNLP.io.DataBundle` 当中, `datasets` 的内容为已经index好的、可以直接被 :class:`~fastNLP.Trainer`
接受的内容。
Loader的load函数返回的 :class:`~fastNLP.io.data_bundle.DataBundle` 里面包含了数据集的原始数据。


-------------------------------------------------------- --------------------------------------------------------
Part III: 不同数据类型的DataSetLoader
Part III: 不同格式类型的基础Loader
-------------------------------------------------------- --------------------------------------------------------


:class:`~fastNLP.io.dataset_loader.CSVLoader`
:class:`~fastNLP.io.loader.CSVLoader`
读取CSV类型的数据集文件。例子如下: 读取CSV类型的数据集文件。例子如下:


.. code-block:: python .. code-block:: python


from fastNLP.io.loader import CSVLoader
data_set_loader = CSVLoader( data_set_loader = CSVLoader(
headers=('words', 'target'), sep='\t' headers=('words', 'target'), sep='\t'
) )
@@ -67,17 +63,18 @@ Part III: 不同数据类型的DataSetLoader
The performances are an absolute joy . 4 The performances are an absolute joy . 4




:class:`~fastNLP.io.dataset_loader.JsonLoader`
:class:`~fastNLP.io.loader.JsonLoader`
读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下: 读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下:


.. code-block:: python .. code-block:: python


data_set_loader = JsonLoader(
from fastNLP.io.loader import JsonLoader
oader = JsonLoader(
fields={'sentence1': 'words1', 'sentence2': 'words2', 'gold_label': 'target'} fields={'sentence1': 'words1', 'sentence2': 'words2', 'gold_label': 'target'}
) )
# 表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'words1'、'words2'、'target'这三个fields # 表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'words1'、'words2'、'target'这三个fields


data_set = data_set_loader._load('path/to/your/file')
data_set = loader._load('path/to/your/file')


数据集内容样例如下 :: 数据集内容样例如下 ::


@@ -86,139 +83,68 @@ Part III: 不同数据类型的DataSetLoader
{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} {"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"}


------------------------------------------ ------------------------------------------
Part IV: DataSetLoader举例
Part IV: 使用Pipe对数据集进行预处理
------------------------------------------ ------------------------------------------


以Matching任务为例子:

:class:`~fastNLP.io.data_loader.MatchingLoader`
我们在fastNLP当中封装了一个Matching任务数据集的数据加载类: :class:`~fastNLP.io.data_loader.MatchingLoader` .

在MatchingLoader类当中我们封装了一个对数据集中的文本内容进行进一步的预处理的函数:
:meth:`~fastNLP.io.data_loader.MatchingLoader.process`
这个函数具有各种预处理option,如:
- 是否将文本转成全小写
- 是否需要序列长度信息,需要什么类型的序列长度信息
- 是否需要用BertTokenizer来获取序列的WordPiece信息
- 等等
在fastNLP中,我们采用 :class:`~fastNLP.io.pipe.Pipe` 来作为加载数据集的基类。
:class:`~fastNLP.io.pipe.Pipe` 定义了各种Pipe所需的API接口,开发者应该继承它实现各种的Pipe。
在各种数据集的Pipe当中,至少应该编写如下内容:


具体内容参见 :meth:`fastNLP.io.MatchingLoader.process` 。
- process 函数:对输入的 :class:`~fastNLP.io.data_bundle.DataBundle` 进行处理(如构建词表、
将dataset的文本内容转成index等等),然后返回该 :class:`~fastNLP.io.data_bundle.DataBundle`
- process_from_file 函数:输入数据集所在文件夹,读取内容并组装成 :class:`~fastNLP.io.data_bundle.DataBundle` ,
然后调用相对应的process函数对数据进行预处理


:class:`~fastNLP.io.data_loader.SNLILoader`
一个关于SNLI数据集的DataSetLoader。SNLI数据集来自
`SNLI Data Set <https://nlp.stanford.edu/projects/snli/snli_1.0.zip>`_ .
以SNLI数据集为例,写一个自定义Pipe的例子如下:


在 :class:`~fastNLP.io.data_loader.SNLILoader` 的 :meth:`~fastNLP.io.data_loader.SNLILoader._load`
函数中,我们用以下代码将数据集内容从文本文件读入内存:
.. code-block:: python


.. code-block:: python
from fastNLP.io.loader import SNLILoader
from fastNLP.io.pipe import MatchingPipe


data = SNLILoader().process(
paths='path/to/snli/data', to_lower=False, seq_len_type='seq_len',
get_index=True, concat=False,
)
print(data)
class MySNLIPipe(MatchingPipe):


输出的内容是::
def process(self, data_bundle):
data_bundle = super(MySNLIPipe, self).process(data_bundle)
# MatchingPipe类里封装了一个关于matching任务的process函数,可以直接继承使用
# 如果有需要进行额外的预处理操作可以在这里加入您的代码
return data_bundle


In total 3 datasets:
train has 549367 instances.
dev has 9842 instances.
test has 9824 instances.
In total 2 vocabs:
words has 43154 entries.
target has 3 entries.
def process_from_file(self, paths=None):
data_bundle = SNLILoader().load(paths) # 使用SNLILoader读取原始数据集
# SNLILoader的load函数中,paths如果为None则会自动下载
return self.process(data_bundle) # 调用相对应的process函数对data_bundle进行处理


调用Pipe示例:


这里的data是一个 :class:`~fastNLP.io.base_loader.DataBundle` ,取 ``datasets`` 字典里的内容即可直接传入
:class:`~fastNLP.Trainer` 或者 :class:`~fastNLP.Tester` 进行训练或者测试。
.. code-block:: python


:class:`~fastNLP.io.data_loader.IMDBLoader`
以IMDB数据集为例,在 :class:`~fastNLP.io.data_loader.IMDBLoader` 的 :meth:`~fastNLP.io.data_loader.IMDBLoader._load`
函数中,我们用以下代码将数据集内容从文本文件读入内存:
from fastNLP.io.pipe import SNLIBertPipe
data_bundle = SNLIBertPipe(lower=True, tokenizer=arg.tokenizer).process_from_file()
print(data_bundle)


.. code-block:: python
输出的内容是::


data = IMDBLoader().process(
paths={'train': 'path/to/train/file', 'test': 'path/to/test/file'}
)
print(data)
In total 3 datasets:
train has 549367 instances.
dev has 9842 instances.
test has 9824 instances.
In total 2 vocabs:
words has 34184 entries.
target has 3 entries.


输出的内容是::

In total 3 datasets:
train has 22500 instances.
test has 25000 instances.
dev has 2500 instances.
In total 2 vocabs:
words has 82846 entries.
target has 2 entries.


这里的将原来的train集按9:1的比例分成了训练集和验证集。
这里表示一共有3个数据集和2个词表。其中:


- 3个数据集分别为train、dev、test数据集,分别有549367、9842、9824个instance
- 2个词表分别为words词表与target词表。其中words词表为句子文本所构建的词表,一共有34184个单词;
target词表为目标标签所构建的词表,一共有3种标签。(注:如果有多个输入,则句子文本所构建的词表将
会被命名为words1以对应相对应的列名)


------------------------------------------ ------------------------------------------
Part V: fastNLP封装好的数据集加载器
Part V: fastNLP封装好的Loader和Pipe
------------------------------------------ ------------------------------------------


fastNLP封装好的数据集加载器可以适用于多种类型的任务:

- `文本分类任务`_
- `序列标注任务`_
- `Matching任务`_


文本分类任务
-------------------

========================== ==================================================================
数据集名称 数据集加载器
-------------------------- ------------------------------------------------------------------
IMDb :class:`~fastNLP.io.data_loader.IMDBLoader`
-------------------------- ------------------------------------------------------------------
SST :class:`~fastNLP.io.data_loader.SSTLoader`
-------------------------- ------------------------------------------------------------------
SST-2 :class:`~fastNLP.io.data_loader.SST2Loader`
-------------------------- ------------------------------------------------------------------
Yelp Polarity :class:`~fastNLP.io.data_loader.YelpLoader`
-------------------------- ------------------------------------------------------------------
Yelp Full :class:`~fastNLP.io.data_loader.YelpLoader`
-------------------------- ------------------------------------------------------------------
MTL16 :class:`~fastNLP.io.data_loader.MTL16Loader`
========================== ==================================================================



序列标注任务
-------------------

========================== ==================================================================
数据集名称 数据集加载器
-------------------------- ------------------------------------------------------------------
Conll :class:`~fastNLP.io.data_loader.ConllLoader`
-------------------------- ------------------------------------------------------------------
Conll2003 :class:`~fastNLP.io.data_loader.Conll2003Loader`
-------------------------- ------------------------------------------------------------------
人民日报数据集 :class:`~fastNLP.io.data_loader.PeopleDailyCorpusLoader`
========================== ==================================================================



Matching任务
-------------------

========================== ==================================================================
数据集名称 数据集加载器
-------------------------- ------------------------------------------------------------------
SNLI :class:`~fastNLP.io.data_loader.SNLILoader`
-------------------------- ------------------------------------------------------------------
MultiNLI :class:`~fastNLP.io.data_loader.MNLILoader`
-------------------------- ------------------------------------------------------------------
QNLI :class:`~fastNLP.io.data_loader.QNLILoader`
-------------------------- ------------------------------------------------------------------
RTE :class:`~fastNLP.io.data_loader.RTELoader`
-------------------------- ------------------------------------------------------------------
Quora Pair Dataset :class:`~fastNLP.io.data_loader.QuoraLoader`
========================== ==================================================================
fastNLP封装了多种任务/数据集的Loader和Pipe并提供自动下载功能,具体参见文档

`fastNLP可加载的embedding与数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_



+ 21
- 68
docs/source/tutorials/tutorial_3_embedding.rst View File

@@ -12,6 +12,7 @@
- `Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)`_ - `Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)`_
- `Part V: 使用character-level的embedding`_ - `Part V: 使用character-level的embedding`_
- `Part VI: 叠加使用多个embedding`_ - `Part VI: 叠加使用多个embedding`_
- `Part VII: fastNLP支持的预训练Embedding`_






@@ -35,12 +36,14 @@ Part II: 使用随机初始化的embedding


.. code-block:: python .. code-block:: python


from fastNLP import Embedding
embed = Embedding(10000, 50) embed = Embedding(10000, 50)


也可以传入一个初始化的参数矩阵: 也可以传入一个初始化的参数矩阵:


.. code-block:: python .. code-block:: python


from fastNLP import Embedding
embed = Embedding(init_embed) embed = Embedding(init_embed)


其中的init_embed可以是torch.FloatTensor、torch.nn.Embedding或者numpy.ndarray。 其中的init_embed可以是torch.FloatTensor、torch.nn.Embedding或者numpy.ndarray。
@@ -59,6 +62,7 @@ Embedding,例子如下:


.. code-block:: python .. code-block:: python


from fastNLP import StaticEmbedding
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)


vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径,也可以是embedding模型的名称: vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径,也可以是embedding模型的名称:
@@ -67,34 +71,13 @@ vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径
和word2vec类型的权重文件都支持) 和word2vec类型的权重文件都支持)


2 如果传入的是模型名称,那么fastNLP将会根据名称查找embedding模型,如果在cache目录下找到模型则会 2 如果传入的是模型名称,那么fastNLP将会根据名称查找embedding模型,如果在cache目录下找到模型则会
自动加载;如果找不到则会自动下载。可以通过环境变量 ``FASTNLP_CACHE_DIR`` 来自定义cache目录,如::
自动加载;如果找不到则会自动下载到cache目录。默认的cache目录为 `~/.fastNLP` 文件夹。可以通过环境
变量 ``FASTNLP_CACHE_DIR`` 来自定义cache目录,如::


$ FASTNLP_CACHE_DIR=~/fastnlp_cache_dir python your_python_file.py $ FASTNLP_CACHE_DIR=~/fastnlp_cache_dir python your_python_file.py


这个命令表示fastNLP将会在 `~/fastnlp_cache_dir` 这个目录下寻找模型,找不到则会自动将模型下载到这个目录 这个命令表示fastNLP将会在 `~/fastnlp_cache_dir` 这个目录下寻找模型,找不到则会自动将模型下载到这个目录


目前支持的静态embedding模型有:

========================== ================================
模型名称 模型
-------------------------- --------------------------------
en glove.840B.300d
-------------------------- --------------------------------
en-glove-840d-300 glove.840B.300d
-------------------------- --------------------------------
en-glove-6b-50 glove.6B.50d
-------------------------- --------------------------------
en-word2vec-300 谷歌word2vec 300维
-------------------------- --------------------------------
en-fasttext 英文fasttext 300维
-------------------------- --------------------------------
cn 腾讯中文词向量 200维
-------------------------- --------------------------------
cn-fasttext 中文fasttext 300维
========================== ================================



----------------------------------------------------------- -----------------------------------------------------------
Part IV: 使用预训练的Contextual Embedding(ELMo & BERT) Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)
----------------------------------------------------------- -----------------------------------------------------------
@@ -106,62 +89,20 @@ Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)


.. code-block:: python .. code-block:: python


from fastNLP import ElmoEmbedding
embed = ElmoEmbedding(vocab, model_dir_or_name='small', requires_grad=False) embed = ElmoEmbedding(vocab, model_dir_or_name='small', requires_grad=False)


目前支持的ElmoEmbedding模型有:

========================== ================================
模型名称 模型
-------------------------- --------------------------------
small allennlp ELMo的small
-------------------------- --------------------------------
medium allennlp ELMo的medium
-------------------------- --------------------------------
original allennlp ELMo的original
-------------------------- --------------------------------
5.5b-original allennlp ELMo的5.5B original
========================== ================================

BERT-embedding的使用方法如下: BERT-embedding的使用方法如下:


.. code-block:: python .. code-block:: python


from fastNLP import BertEmbedding
embed = BertEmbedding( embed = BertEmbedding(
vocab, model_dir_or_name='en-base-cased', requires_grad=False, layers='4,-2,-1' vocab, model_dir_or_name='en-base-cased', requires_grad=False, layers='4,-2,-1'
) )


其中layers变量表示需要取哪几层的encode结果。 其中layers变量表示需要取哪几层的encode结果。


目前支持的BertEmbedding模型有:

========================== ====================================
模型名称 模型
-------------------------- ------------------------------------
en bert-base-cased
-------------------------- ------------------------------------
en-base-uncased bert-base-uncased
-------------------------- ------------------------------------
en-base-cased bert-base-cased
-------------------------- ------------------------------------
en-large-uncased bert-large-uncased
-------------------------- ------------------------------------
en-large-cased bert-large-cased
-------------------------- ------------------------------------
-------------------------- ------------------------------------
en-large-cased-wwm bert-large-cased-whole-word-mask
-------------------------- ------------------------------------
en-large-uncased-wwm bert-large-uncased-whole-word-mask
-------------------------- ------------------------------------
en-base-cased-mrpc bert-base-cased-finetuned-mrpc
-------------------------- ------------------------------------
-------------------------- ------------------------------------
multilingual bert-base-multilingual-cased
-------------------------- ------------------------------------
multilingual-base-uncased bert-base-multilingual-uncased
-------------------------- ------------------------------------
multilingual-base-cased bert-base-multilingual-cased
========================== ====================================

----------------------------------------------------- -----------------------------------------------------
Part V: 使用character-level的embedding Part V: 使用character-level的embedding
----------------------------------------------------- -----------------------------------------------------
@@ -173,6 +114,7 @@ CNNCharEmbedding的使用例子如下:


.. code-block:: python .. code-block:: python


from fastNLP import CNNCharEmbedding
embed = CNNCharEmbedding(vocab, embed_size=100, char_emb_size=50) embed = CNNCharEmbedding(vocab, embed_size=100, char_emb_size=50)


这表示这个CNNCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。 这表示这个CNNCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。
@@ -181,12 +123,12 @@ CNNCharEmbedding的使用例子如下:


.. code-block:: python .. code-block:: python


from fastNLP import LSTMCharEmbedding
embed = LSTMCharEmbedding(vocab, embed_size=100, char_emb_size=50) embed = LSTMCharEmbedding(vocab, embed_size=100, char_emb_size=50)


这表示这个LSTMCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。 这表示这个LSTMCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。





----------------------------------------------------- -----------------------------------------------------
Part VI: 叠加使用多个embedding Part VI: 叠加使用多个embedding
----------------------------------------------------- -----------------------------------------------------
@@ -197,6 +139,7 @@ Part VI: 叠加使用多个embedding


.. code-block:: python .. code-block:: python


from fastNLP import StaticEmbedding, StackEmbedding
embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)


@@ -208,7 +151,17 @@ StackEmbedding会把多个embedding的结果拼接起来,如上面例子的sta


.. code-block:: python .. code-block:: python


from fastNLP import StaticEmbedding, StackEmbedding, ElmoEmbedding
elmo_embedding = ElmoEmbedding(vocab, model_dir_or_name='medium', layers='0,1,2', requires_grad=False) elmo_embedding = ElmoEmbedding(vocab, model_dir_or_name='medium', layers='0,1,2', requires_grad=False)
glove_embedding = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) glove_embedding = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)


stack_embed = StackEmbedding([elmo_embedding, glove_embedding]) stack_embed = StackEmbedding([elmo_embedding, glove_embedding])

------------------------------------------
Part VII: fastNLP支持的预训练Embedding
------------------------------------------

fastNLP支持多种预训练Embedding并提供自动下载功能,具体参见文档

`fastNLP可加载的embedding与数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_


+ 5
- 2
docs/source/tutorials/tutorial_4_loss_optimizer.rst View File

@@ -1,4 +1,4 @@
==============================================================================
==============================================================================
动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试
============================================================================== ==============================================================================


@@ -19,7 +19,9 @@


loader = SSTLoader() loader = SSTLoader()
#这里的all.txt是下载好数据后train.txt、dev.txt、test.txt的组合 #这里的all.txt是下载好数据后train.txt、dev.txt、test.txt的组合
dataset = loader.load("./trainDevTestTrees_PTB/trees/all.txt")
#loader.load(path)会首先判断path是否为none,若是则自动从网站下载数据,若不是则读入数据并返回databundle
databundle_ = loader.load("./trainDevTestTrees_PTB/trees/all.txt")
dataset = databundle_.datasets['train']
print(dataset[0]) print(dataset[0])


输出数据如下:: 输出数据如下::
@@ -31,6 +33,7 @@


数据处理 数据处理
可以使用事先定义的 :class:`~fastNLP.io.SSTPipe` 类对数据进行基本预处理,这里我们手动进行处理。
我们使用 :class:`~fastNLP.DataSet` 类的 :meth:`~fastNLP.DataSet.apply` 方法将 ``target`` :mod:`~fastNLP.core.field` 转化为整数。 我们使用 :class:`~fastNLP.DataSet` 类的 :meth:`~fastNLP.DataSet.apply` 方法将 ``target`` :mod:`~fastNLP.core.field` 转化为整数。
.. code-block:: python .. code-block:: python


+ 4
- 1
docs/source/tutorials/tutorial_5_datasetiter.rst View File

@@ -20,7 +20,9 @@


loader = SSTLoader() loader = SSTLoader()
#这里的all.txt是下载好数据后train.txt、dev.txt、test.txt的组合 #这里的all.txt是下载好数据后train.txt、dev.txt、test.txt的组合
dataset = loader.load("./trainDevTestTrees_PTB/trees/all.txt")
#loader.load(path)会首先判断path是否为none,若是则自动从网站下载数据,若不是则读入数据并返回databundle
databundle_ = loader.load("./trainDevTestTrees_PTB/trees/all.txt")
dataset = databundle_.datasets['train']
print(dataset[0]) print(dataset[0])


输出数据如下:: 输出数据如下::
@@ -32,6 +34,7 @@


数据处理 数据处理
可以使用事先定义的 :class:`~fastNLP.io.SSTPipe` 类对数据进行基本预处理,这里我们手动进行处理。
我们使用 :class:`~fastNLP.DataSet` 类的 :meth:`~fastNLP.DataSet.apply` 方法将 ``target`` :mod:`~fastNLP.core.field` 转化为整数。 我们使用 :class:`~fastNLP.DataSet` 类的 :meth:`~fastNLP.DataSet.apply` 方法将 ``target`` :mod:`~fastNLP.core.field` 转化为整数。
.. code-block:: python .. code-block:: python


+ 1
- 1
docs/source/user/tutorials.rst View File

@@ -8,7 +8,7 @@ fastNLP 详细使用教程
:maxdepth: 1 :maxdepth: 1


使用DataSet预处理文本 </tutorials/tutorial_1_data_preprocess> 使用DataSet预处理文本 </tutorials/tutorial_1_data_preprocess>
使用DataSetLoader加载数据集 </tutorials/tutorial_2_load_dataset>
使用Loader和Pipe加载并处理数据集 </tutorials/tutorial_2_load_dataset>
使用Embedding模块将文本转成向量 </tutorials/tutorial_3_embedding> 使用Embedding模块将文本转成向量 </tutorials/tutorial_3_embedding>
动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试 </tutorials/tutorial_4_loss_optimizer> 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试 </tutorials/tutorial_4_loss_optimizer>
动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程 </tutorials/tutorial_5_datasetiter> 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程 </tutorials/tutorial_5_datasetiter>


+ 3
- 3
fastNLP/__init__.py View File

@@ -65,8 +65,8 @@ __all__ = [
] ]
__version__ = '0.4.5' __version__ = '0.4.5'


from .core import *
from . import embeddings
from . import models from . import models
from . import modules from . import modules
from . import embeddings
from .io import data_loader
from .core import *
from .io import loader, pipe

+ 65
- 2
fastNLP/core/__init__.py View File

@@ -10,8 +10,72 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa


对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。


""" """
__all__ = [
"DataSet",
"Instance",
"FieldArray",
"Padder",
"AutoPadder",
"EngChar2DPadder",
"Vocabulary",
"DataSetIter",
"BatchIter",
"TorchLoaderIter",
"Const",
"Tester",
"Trainer",
"cache_results",
"seq_len_to_mask",
"get_seq_len",
"logger",
"Callback",
"GradientClipCallback",
"EarlyStopCallback",
"FitlogCallback",
"EvaluateCallback",
"LRScheduler",
"ControlC",
"LRFinder",
"TensorboardCallback",
"WarmupCallback",
'SaveModelCallback',
"EchoCallback",
"TesterCallback",
"CallbackException",
"EarlyStopError",
"LossFunc",
"CrossEntropyLoss",
"L1Loss",
"BCELoss",
"NLLLoss",
"LossInForward",
"AccuracyMetric",
"SpanFPreRecMetric",
"ExtractiveQAMetric",
"Optimizer",
"SGD",
"Adam",
"AdamW",
"SequentialSampler",
"BucketSampler",
"RandomSampler",
"Sampler",
]

from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter from .batch import DataSetIter, BatchIter, TorchLoaderIter
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
@@ -28,4 +92,3 @@ from .tester import Tester
from .trainer import Trainer from .trainer import Trainer
from .utils import cache_results, seq_len_to_mask, get_seq_len from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary from .vocabulary import Vocabulary
from ._logger import logger

+ 20
- 18
fastNLP/core/_logger.py View File

@@ -1,15 +1,15 @@
"""undocumented"""

__all__ = [
'logger',
]

import logging import logging
import logging.config import logging.config
import torch
import _pickle as pickle
import os import os
import sys import sys
import warnings import warnings


__all__ = [
'logger',
]

ROOT_NAME = 'fastNLP' ROOT_NAME = 'fastNLP'


try: try:
@@ -25,7 +25,7 @@ if tqdm is not None:
class TqdmLoggingHandler(logging.Handler): class TqdmLoggingHandler(logging.Handler):
def __init__(self, level=logging.INFO): def __init__(self, level=logging.INFO):
super().__init__(level) super().__init__(level)
def emit(self, record): def emit(self, record):
try: try:
msg = self.format(record) msg = self.format(record)
@@ -59,14 +59,14 @@ def _add_file_handler(logger, path, level='INFO'):
if os.path.abspath(path) == h.baseFilename: if os.path.abspath(path) == h.baseFilename:
# file path already added # file path already added
return return
# File Handler # File Handler
if os.path.exists(path): if os.path.exists(path):
assert os.path.isfile(path) assert os.path.isfile(path)
warnings.warn('log already exists in {}'.format(path)) warnings.warn('log already exists in {}'.format(path))
dirname = os.path.abspath(os.path.dirname(path)) dirname = os.path.abspath(os.path.dirname(path))
os.makedirs(dirname, exist_ok=True) os.makedirs(dirname, exist_ok=True)
file_handler = logging.FileHandler(path, mode='a') file_handler = logging.FileHandler(path, mode='a')
file_handler.setLevel(_get_level(level)) file_handler.setLevel(_get_level(level))
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
@@ -87,7 +87,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
break break
if stream_handler is not None: if stream_handler is not None:
logger.removeHandler(stream_handler) logger.removeHandler(stream_handler)
# Stream Handler # Stream Handler
if stdout == 'plain': if stdout == 'plain':
stream_handler = logging.StreamHandler(sys.stdout) stream_handler = logging.StreamHandler(sys.stdout)
@@ -95,7 +95,7 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
stream_handler = TqdmLoggingHandler(level) stream_handler = TqdmLoggingHandler(level)
else: else:
stream_handler = None stream_handler = None
if stream_handler is not None: if stream_handler is not None:
stream_formatter = logging.Formatter('%(message)s') stream_formatter = logging.Formatter('%(message)s')
stream_handler.setLevel(level) stream_handler.setLevel(level)
@@ -103,38 +103,40 @@ def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
logger.addHandler(stream_handler) logger.addHandler(stream_handler)





class FastNLPLogger(logging.getLoggerClass()): class FastNLPLogger(logging.getLoggerClass()):
def __init__(self, name): def __init__(self, name):
super().__init__(name) super().__init__(name)
def add_file(self, path='./log.txt', level='INFO'): def add_file(self, path='./log.txt', level='INFO'):
"""add log output file and level""" """add log output file and level"""
_add_file_handler(self, path, level) _add_file_handler(self, path, level)
def set_stdout(self, stdout='tqdm', level='INFO'): def set_stdout(self, stdout='tqdm', level='INFO'):
"""set stdout format and level""" """set stdout format and level"""
_set_stdout_handler(self, stdout, level) _set_stdout_handler(self, stdout, level)



logging.setLoggerClass(FastNLPLogger) logging.setLoggerClass(FastNLPLogger)


# print(logging.getLoggerClass()) # print(logging.getLoggerClass())
# print(logging.getLogger()) # print(logging.getLogger())


def _init_logger(path=None, stdout='tqdm', level='INFO'): def _init_logger(path=None, stdout='tqdm', level='INFO'):
"""initialize logger""" """initialize logger"""
level = _get_level(level) level = _get_level(level)
# logger = logging.getLogger() # logger = logging.getLogger()
logger = logging.getLogger(ROOT_NAME) logger = logging.getLogger(ROOT_NAME)
logger.propagate = False logger.propagate = False
logger.setLevel(level) logger.setLevel(level)
_set_stdout_handler(logger, stdout, level) _set_stdout_handler(logger, stdout, level)
# File Handler # File Handler
if path is not None: if path is not None:
_add_file_handler(logger, path, level) _add_file_handler(logger, path, level)
return logger return logger






+ 13
- 8
fastNLP/core/_parallel_utils.py View File

@@ -1,11 +1,14 @@
"""undocumented"""

__all__ = []


import threading import threading

import torch import torch
from torch import nn from torch import nn
from torch.nn.parallel.parallel_apply import get_a_var from torch.nn.parallel.parallel_apply import get_a_var

from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather




def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
@@ -27,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(devices) assert len(modules) == len(devices)
else: else:
devices = [None] * len(modules) devices = [None] * len(modules)
lock = threading.Lock() lock = threading.Lock()
results = {} results = {}
grad_enabled = torch.is_grad_enabled() grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None): def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled) torch.set_grad_enabled(grad_enabled)
if device is None: if device is None:
@@ -47,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
except Exception as e: except Exception as e:
with lock: with lock:
results[i] = e results[i] = e
if len(modules) > 1: if len(modules) > 1:
threads = [threading.Thread(target=_worker, threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device)) args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))] enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads: for thread in threads:
thread.start() thread.start()
for thread in threads: for thread in threads:
thread.join() thread.join()
else: else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = [] outputs = []
for i in range(len(inputs)): for i in range(len(inputs)):
output = results[i] output = results[i]
@@ -79,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
:param output_device: nn.DataParallel中的output_device :param output_device: nn.DataParallel中的output_device
:return: :return:
""" """
def wrapper(network, *inputs, **kwargs): def wrapper(network, *inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1: if len(device_ids) == 1:
@@ -86,6 +90,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
replicas = replicate(network, device_ids[:len(inputs)]) replicas = replicate(network, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device) return gather(outputs, output_device)
return wrapper return wrapper




@@ -99,4 +104,4 @@ def _model_contains_inner_module(model):
if isinstance(model, nn.Module): if isinstance(model, nn.Module):
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
return True return True
return False
return False

+ 2
- 2
fastNLP/core/batch.py View File

@@ -17,7 +17,7 @@ from numbers import Number


from .sampler import SequentialSampler from .sampler import SequentialSampler
from .dataset import DataSet from .dataset import DataSet
from ._logger import logger
_python_is_exit = False _python_is_exit = False




@@ -75,7 +75,7 @@ class DataSetGetter:
try: try:
data, flag = _to_tensor(data, f.dtype) data, flag = _to_tensor(data, f.dtype)
except TypeError as e: except TypeError as e:
print(f"Field {n} cannot be converted to torch.tensor.")
logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e raise e
batch_dict[n] = data batch_dict[n] = data
return batch_dict return batch_dict


+ 9
- 9
fastNLP/core/callback.py View File

@@ -83,7 +83,6 @@ try:
except: except:
tensorboardX_flag = False tensorboardX_flag = False


from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet from .dataset import DataSet
from .tester import Tester from .tester import Tester
from ._logger import logger from ._logger import logger
@@ -505,7 +504,7 @@ class EarlyStopCallback(Callback):
def on_exception(self, exception): def on_exception(self, exception):
if isinstance(exception, EarlyStopError): if isinstance(exception, EarlyStopError):
print("Early Stopping triggered in epoch {}!".format(self.epoch))
logger.info("Early Stopping triggered in epoch {}!".format(self.epoch))
else: else:
raise exception # 抛出陌生Error raise exception # 抛出陌生Error


@@ -752,8 +751,7 @@ class LRFinder(Callback):
self.smooth_value = SmoothValue(0.8) self.smooth_value = SmoothValue(0.8)
self.opt = None self.opt = None
self.find = None self.find = None
self.loader = ModelLoader()

@property @property
def lr_gen(self): def lr_gen(self):
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch scale = (self.end_lr - self.start_lr) / self.batch_per_epoch
@@ -768,7 +766,7 @@ class LRFinder(Callback):
self.opt = self.trainer.optimizer # pytorch optimizer self.opt = self.trainer.optimizer # pytorch optimizer
self.opt.param_groups[0]["lr"] = self.start_lr self.opt.param_groups[0]["lr"] = self.start_lr
# save model # save model
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
torch.save(self.model.state_dict(), 'tmp')
self.find = True self.find = True
def on_backward_begin(self, loss): def on_backward_begin(self, loss):
@@ -797,7 +795,9 @@ class LRFinder(Callback):
self.opt.param_groups[0]["lr"] = self.best_lr self.opt.param_groups[0]["lr"] = self.best_lr
self.find = False self.find = False
# reset model # reset model
ModelLoader().load_pytorch(self.trainer.model, "tmp")
states = torch.load('tmp')
self.model.load_state_dict(states)
os.remove('tmp')
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr))




@@ -988,14 +988,14 @@ class SaveModelCallback(Callback):
try: try:
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) _save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
except Exception as e: except Exception as e:
print(f"The following exception:{e} happens when save model to {self.save_dir}.")
logger.error(f"The following exception:{e} happens when save model to {self.save_dir}.")
if delete_pair: if delete_pair:
try: try:
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) delete_model_path = os.path.join(self.save_dir, delete_pair[1])
if os.path.exists(delete_model_path): if os.path.exists(delete_model_path):
os.remove(delete_model_path) os.remove(delete_model_path)
except Exception as e: except Exception as e:
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
logger.error(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")


def on_exception(self, exception): def on_exception(self, exception):
if self.save_on_exception: if self.save_on_exception:
@@ -1032,7 +1032,7 @@ class EchoCallback(Callback):


def __getattribute__(self, item): def __getattribute__(self, item):
if item.startswith('on_'): if item.startswith('on_'):
print('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()),
logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()),
file=self.out) file=self.out)
return super(EchoCallback, self).__getattribute__(item) return super(EchoCallback, self).__getattribute__(item)




+ 18
- 8
fastNLP/core/const.py View File

@@ -1,3 +1,13 @@
"""
.. todo::
doc
"""

__all__ = [
"Const"
]


class Const: class Const:
""" """
fastNLP中field命名常量。 fastNLP中field命名常量。
@@ -25,47 +35,47 @@ class Const:
LOSS = 'loss' LOSS = 'loss'
RAW_WORD = 'raw_words' RAW_WORD = 'raw_words'
RAW_CHAR = 'raw_chars' RAW_CHAR = 'raw_chars'
@staticmethod @staticmethod
def INPUTS(i): def INPUTS(i):
"""得到第 i 个 ``INPUT`` 的命名""" """得到第 i 个 ``INPUT`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.INPUT + str(i) return Const.INPUT + str(i)
@staticmethod @staticmethod
def CHAR_INPUTS(i): def CHAR_INPUTS(i):
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" """得到第 i 个 ``CHAR_INPUT`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.CHAR_INPUT + str(i) return Const.CHAR_INPUT + str(i)
@staticmethod @staticmethod
def RAW_WORDS(i): def RAW_WORDS(i):
i = int(i) + 1 i = int(i) + 1
return Const.RAW_WORD + str(i) return Const.RAW_WORD + str(i)
@staticmethod @staticmethod
def RAW_CHARS(i): def RAW_CHARS(i):
i = int(i) + 1 i = int(i) + 1
return Const.RAW_CHAR + str(i) return Const.RAW_CHAR + str(i)
@staticmethod @staticmethod
def INPUT_LENS(i): def INPUT_LENS(i):
"""得到第 i 个 ``INPUT_LEN`` 的命名""" """得到第 i 个 ``INPUT_LEN`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.INPUT_LEN + str(i) return Const.INPUT_LEN + str(i)
@staticmethod @staticmethod
def OUTPUTS(i): def OUTPUTS(i):
"""得到第 i 个 ``OUTPUT`` 的命名""" """得到第 i 个 ``OUTPUT`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.OUTPUT + str(i) return Const.OUTPUT + str(i)
@staticmethod @staticmethod
def TARGETS(i): def TARGETS(i):
"""得到第 i 个 ``TARGET`` 的命名""" """得到第 i 个 ``TARGET`` 的命名"""
i = int(i) + 1 i = int(i) + 1
return Const.TARGET + str(i) return Const.TARGET + str(i)
@staticmethod @staticmethod
def LOSSES(i): def LOSSES(i):
"""得到第 i 个 ``LOSS`` 的命名""" """得到第 i 个 ``LOSS`` 的命名"""


+ 6
- 5
fastNLP/core/dataset.py View File

@@ -300,6 +300,7 @@ from .utils import _get_func_signature
from .field import AppendToTargetOrInputException from .field import AppendToTargetOrInputException
from .field import SetInputOrTargetException from .field import SetInputOrTargetException
from .const import Const from .const import Const
from ._logger import logger


class DataSet(object): class DataSet(object):
""" """
@@ -452,7 +453,7 @@ class DataSet(object):
try: try:
self.field_arrays[name].append(field) self.field_arrays[name].append(field)
except AppendToTargetOrInputException as e: except AppendToTargetOrInputException as e:
print(f"Cannot append to field:{name}.")
logger.error(f"Cannot append to field:{name}.")
raise e raise e
def add_fieldarray(self, field_name, fieldarray): def add_fieldarray(self, field_name, fieldarray):
@@ -609,7 +610,7 @@ class DataSet(object):
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self.field_arrays[name].is_target = flag self.field_arrays[name].is_target = flag
except SetInputOrTargetException as e: except SetInputOrTargetException as e:
print(f"Cannot set field:{name} as target.")
logger.error(f"Cannot set field:{name} as target.")
raise e raise e
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
@@ -633,7 +634,7 @@ class DataSet(object):
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self.field_arrays[name].is_input = flag self.field_arrays[name].is_input = flag
except SetInputOrTargetException as e: except SetInputOrTargetException as e:
print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.")
logger.error(f"Cannot set field:{name} as input, exception happens at the {e.index} value.")
raise e raise e
else: else:
raise KeyError("{} is not a valid field name.".format(name)) raise KeyError("{} is not a valid field name.".format(name))
@@ -728,7 +729,7 @@ class DataSet(object):
results.append(func(ins[field_name])) results.append(func(ins[field_name]))
except Exception as e: except Exception as e:
if idx != -1: if idx != -1:
print("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
raise e raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
@@ -795,7 +796,7 @@ class DataSet(object):
results.append(func(ins)) results.append(func(ins))
except BaseException as e: except BaseException as e:
if idx != -1: if idx != -1:
print("Exception happens at the `{}`th instance.".format(idx))
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e raise e


# results = [func(ins) for ins in self._inner_iter()] # results = [func(ins) for ins in self._inner_iter()]


+ 11
- 12
fastNLP/core/dist_trainer.py View File

@@ -1,29 +1,29 @@
"""
"""undocumented
正在开发中的分布式训练代码 正在开发中的分布式训练代码
""" """
import logging
import os
import time
from datetime import datetime

import torch import torch
import torch.cuda import torch.cuda
import torch.optim
import torch.distributed as dist import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
import torch.optim
from pkg_resources import parse_version
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
import os
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm from tqdm import tqdm
import time
from datetime import datetime, timedelta
from functools import partial


from ._logger import logger
from .batch import DataSetIter, BatchIter from .batch import DataSetIter, BatchIter
from .callback import DistCallbackManager, CallbackException, TesterCallback from .callback import DistCallbackManager, CallbackException, TesterCallback
from .dataset import DataSet from .dataset import DataSet
from .losses import _prepare_losser from .losses import _prepare_losser
from .optimizer import Optimizer from .optimizer import Optimizer
from .utils import _build_args from .utils import _build_args
from .utils import _move_dict_value_to_device
from .utils import _get_func_signature from .utils import _get_func_signature
from ._logger import logger
import logging
from pkg_resources import parse_version
from .utils import _move_dict_value_to_device


__all__ = [ __all__ = [
'get_local_rank', 'get_local_rank',
@@ -54,7 +54,6 @@ class DistTrainer():
num_workers=1, drop_last=False, num_workers=1, drop_last=False,
dev_data=None, metrics=None, metric_key=None, dev_data=None, metrics=None, metric_key=None,
update_every=1, print_every=10, validate_every=-1, update_every=1, print_every=10, validate_every=-1,
log_path=None,
save_every=-1, save_path=None, device='auto', save_every=-1, save_path=None, device='auto',
fp16='', backend=None, init_method=None): fp16='', backend=None, init_method=None):




+ 22
- 14
fastNLP/core/field.py View File

@@ -1,16 +1,24 @@
"""
.. todo::
doc
"""

__all__ = [ __all__ = [
"Padder", "Padder",
"AutoPadder", "AutoPadder",
"EngChar2DPadder", "EngChar2DPadder",
] ]


from numbers import Number
import torch
import numpy as np
from typing import Any
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy
from collections import Counter from collections import Counter
from copy import deepcopy
from numbers import Number
from typing import Any

import numpy as np
import torch

from ._logger import logger
from .utils import _is_iterable from .utils import _is_iterable




@@ -39,7 +47,7 @@ class FieldArray:
try: try:
_content = list(_content) _content = list(_content)
except BaseException as e: except BaseException as e:
print(f"Cannot convert content(of type:{type(content)}) into list.")
logger.error(f"Cannot convert content(of type:{type(content)}) into list.")
raise e raise e
self.name = name self.name = name
self.content = _content self.content = _content
@@ -263,7 +271,7 @@ class FieldArray:
try: try:
new_contents.append(cell.split(sep)) new_contents.append(cell.split(sep))
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
@@ -283,8 +291,8 @@ class FieldArray:
else: else:
new_contents.append(int(cell)) new_contents.append(int(cell))
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.")
print(e)
logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
def float(self, inplace=True): def float(self, inplace=True):
@@ -303,7 +311,7 @@ class FieldArray:
else: else:
new_contents.append(float(cell)) new_contents.append(float(cell))
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
@@ -323,7 +331,7 @@ class FieldArray:
else: else:
new_contents.append(bool(cell)) new_contents.append(bool(cell))
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
@@ -344,7 +352,7 @@ class FieldArray:
else: else:
new_contents.append(cell.lower()) new_contents.append(cell.lower())
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
@@ -364,7 +372,7 @@ class FieldArray:
else: else:
new_contents.append(cell.upper()) new_contents.append(cell.upper())
except Exception as e: except Exception as e:
print(f"Exception happens when process value in index {index}.")
logger.error(f"Exception happens when process value in index {index}.")
raise e raise e
return self._after_process(new_contents, inplace=inplace) return self._after_process(new_contents, inplace=inplace)
@@ -401,7 +409,7 @@ class FieldArray:
self.is_input = self.is_input self.is_input = self.is_input
self.is_target = self.is_input self.is_target = self.is_input
except SetInputOrTargetException as e: except SetInputOrTargetException as e:
print("The newly generated field cannot be set as input or target.")
logger.error("The newly generated field cannot be set as input or target.")
raise e raise e
return self return self
else: else:


+ 15
- 13
fastNLP/core/predictor.py View File

@@ -1,13 +1,15 @@
"""
..todo::
检查这个类是否需要
"""
"""undocumented"""

__all__ = [
"Predictor"
]

from collections import defaultdict from collections import defaultdict


import torch import torch


from . import DataSetIter
from . import DataSet from . import DataSet
from . import DataSetIter
from . import SequentialSampler from . import SequentialSampler
from .utils import _build_args, _move_dict_value_to_device, _get_model_device from .utils import _build_args, _move_dict_value_to_device, _get_model_device


@@ -21,7 +23,7 @@ class Predictor(object):


:param torch.nn.Module network: 用来完成预测任务的模型 :param torch.nn.Module network: 用来完成预测任务的模型
""" """
def __init__(self, network): def __init__(self, network):
if not isinstance(network, torch.nn.Module): if not isinstance(network, torch.nn.Module):
raise ValueError( raise ValueError(
@@ -29,7 +31,7 @@ class Predictor(object):
self.network = network self.network = network
self.batch_size = 1 self.batch_size = 1
self.batch_output = [] self.batch_output = []
def predict(self, data: DataSet, seq_len_field_name=None): def predict(self, data: DataSet, seq_len_field_name=None):
"""用已经训练好的模型进行inference. """用已经训练好的模型进行inference.


@@ -41,27 +43,27 @@ class Predictor(object):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
prev_training = self.network.training prev_training = self.network.training
self.network.eval() self.network.eval()
network_device = _get_model_device(self.network) network_device = _get_model_device(self.network)
batch_output = defaultdict(list) batch_output = defaultdict(list)
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
if hasattr(self.network, "predict"): if hasattr(self.network, "predict"):
predict_func = self.network.predict predict_func = self.network.predict
else: else:
predict_func = self.network.forward predict_func = self.network.forward
with torch.no_grad(): with torch.no_grad():
for batch_x, _ in data_iterator: for batch_x, _ in data_iterator:
_move_dict_value_to_device(batch_x, _, device=network_device) _move_dict_value_to_device(batch_x, _, device=network_device)
refined_batch_x = _build_args(predict_func, **batch_x) refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x) prediction = predict_func(**refined_batch_x)
if seq_len_field_name is not None: if seq_len_field_name is not None:
seq_lens = batch_x[seq_len_field_name].tolist() seq_lens = batch_x[seq_len_field_name].tolist()
for key, value in prediction.items(): for key, value in prediction.items():
value = value.cpu().numpy() value = value.cpu().numpy()
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
@@ -74,6 +76,6 @@ class Predictor(object):
batch_output[key].extend(tmp_batch) batch_output[key].extend(tmp_batch)
else: else:
batch_output[key].append(value) batch_output[key].append(value)
self.network.train(prev_training) self.network.train(prev_training)
return batch_output return batch_output

+ 1
- 1
fastNLP/core/tester.py View File

@@ -192,7 +192,7 @@ class Tester(object):
dataset=self.data, check_level=0) dataset=self.data, check_level=0)
if self.verbose >= 1: if self.verbose >= 1:
print("[tester] \n{}".format(self._format_eval_results(eval_results)))
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False) self._mode(network, is_test=False)
return eval_results return eval_results


+ 2
- 2
fastNLP/core/utils.py View File

@@ -145,7 +145,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
with open(cache_filepath, 'rb') as f: with open(cache_filepath, 'rb') as f:
results = _pickle.load(f) results = _pickle.load(f)
if verbose == 1: if verbose == 1:
print("Read cache from {}.".format(cache_filepath))
logger.info("Read cache from {}.".format(cache_filepath))
refresh_flag = False refresh_flag = False
if refresh_flag: if refresh_flag:
@@ -156,7 +156,7 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1):
_prepare_cache_filepath(cache_filepath) _prepare_cache_filepath(cache_filepath)
with open(cache_filepath, 'wb') as f: with open(cache_filepath, 'wb') as f:
_pickle.dump(results, f) _pickle.dump(results, f)
print("Save cache to {}.".format(cache_filepath))
logger.info("Save cache to {}.".format(cache_filepath))
return results return results


+ 20
- 13
fastNLP/core/vocabulary.py View File

@@ -1,16 +1,23 @@
"""
.. todo::
doc
"""

__all__ = [ __all__ = [
"Vocabulary", "Vocabulary",
"VocabularyOption", "VocabularyOption",
] ]


from functools import wraps
from collections import Counter from collections import Counter
from functools import partial
from functools import wraps

from ._logger import logger
from .dataset import DataSet from .dataset import DataSet
from .utils import Option from .utils import Option
from functools import partial
import numpy as np
from .utils import _is_iterable from .utils import _is_iterable



class VocabularyOption(Option): class VocabularyOption(Option):
def __init__(self, def __init__(self,
max_size=None, max_size=None,
@@ -49,8 +56,8 @@ def _check_build_status(func):
if self.rebuild is False: if self.rebuild is False:
self.rebuild = True self.rebuild = True
if self.max_size is not None and len(self.word_count) >= self.max_size: if self.max_size is not None and len(self.word_count) >= self.max_size:
print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. "
"Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__)) self.max_size, func.__name__))
return func(self, *args, **kwargs) return func(self, *args, **kwargs)
@@ -198,7 +205,7 @@ class Vocabulary(object):
self.build_reverse_vocab() self.build_reverse_vocab()
self.rebuild = False self.rebuild = False
return self return self
def build_reverse_vocab(self): def build_reverse_vocab(self):
""" """
基于 `word to index` dict, 构建 `index to word` dict. 基于 `word to index` dict, 构建 `index to word` dict.
@@ -278,26 +285,26 @@ class Vocabulary(object):
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
raise RuntimeError("Only support field with 2 dimensions.") raise RuntimeError("Only support field with 2 dimensions.")
return [[self.to_index(c) for c in w] for w in field] return [[self.to_index(c) for c in w] for w in field]
new_field_name = new_field_name or field_name new_field_name = new_field_name or field_name
if type(new_field_name) == type(field_name): if type(new_field_name) == type(field_name):
if isinstance(new_field_name, list): if isinstance(new_field_name, list):
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \
"field_name."
"field_name."
elif isinstance(new_field_name, str): elif isinstance(new_field_name, str):
field_name = [field_name] field_name = [field_name]
new_field_name = [new_field_name] new_field_name = [new_field_name]
else: else:
raise TypeError("field_name and new_field_name can only be str or List[str].") raise TypeError("field_name and new_field_name can only be str or List[str].")
for idx, dataset in enumerate(datasets): for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet): if isinstance(dataset, DataSet):
try: try:
for f_n, n_f_n in zip(field_name, new_field_name): for f_n, n_f_n in zip(field_name, new_field_name):
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n)
except Exception as e: except Exception as e:
print("When processing the `{}` dataset, the following error occurred.".format(idx))
logger.info("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e raise e
else: else:
raise RuntimeError("Only DataSet type is allowed.") raise RuntimeError("Only DataSet type is allowed.")
@@ -353,7 +360,7 @@ class Vocabulary(object):
try: try:
dataset.apply(construct_vocab) dataset.apply(construct_vocab)
except BaseException as e: except BaseException as e:
print("When processing the `{}` dataset, the following error occurred:".format(idx))
log("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e raise e
else: else:
raise TypeError("Only DataSet type is allowed.") raise TypeError("Only DataSet type is allowed.")
@@ -376,7 +383,7 @@ class Vocabulary(object):
:return: bool :return: bool
""" """
return word in self._no_create_word return word in self._no_create_word
def to_index(self, w): def to_index(self, w):
""" """
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``::


+ 0
- 1
fastNLP/embeddings/__init__.py View File

@@ -18,7 +18,6 @@ __all__ = [
"get_embeddings", "get_embeddings",
] ]



from .embedding import Embedding, TokenEmbedding from .embedding import Embedding, TokenEmbedding
from .static_embedding import StaticEmbedding from .static_embedding import StaticEmbedding
from .elmo_embedding import ElmoEmbedding from .elmo_embedding import ElmoEmbedding


+ 98
- 76
fastNLP/embeddings/bert_embedding.py View File

@@ -1,3 +1,12 @@
"""
.. todo::
doc
"""

__all__ = [
"BertEmbedding",
"BertWordPieceEncoder"
]


import os import os
import collections import collections
@@ -12,6 +21,8 @@ from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MOD
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
from .contextual_embedding import ContextualEmbedding from .contextual_embedding import ContextualEmbedding
import warnings import warnings
from ..core import logger



class BertEmbedding(ContextualEmbedding): class BertEmbedding(ContextualEmbedding):
""" """
@@ -54,11 +65,12 @@ class BertEmbedding(ContextualEmbedding):
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
来进行分类的任务将auto_truncate置为True。 来进行分类的任务将auto_truncate置为True。
""" """
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1',
pool_method: str='first', word_dropout=0, dropout=0, include_cls_sep: bool=False,
pooled_cls=True, requires_grad: bool=False, auto_truncate:bool=False):
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1',
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False):
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
# 根据model_dir_or_name检查是否存在并下载 # 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'):
@@ -71,21 +83,21 @@ class BertEmbedding(ContextualEmbedding):
model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name)) model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
else: else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.") raise ValueError(f"Cannot recognize {model_dir_or_name}.")
self._word_sep_index = None self._word_sep_index = None
if '[SEP]' in vocab: if '[SEP]' in vocab:
self._word_sep_index = vocab['[SEP]'] self._word_sep_index = vocab['[SEP]']
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
pool_method=pool_method, include_cls_sep=include_cls_sep, pool_method=pool_method, include_cls_sep=include_cls_sep,
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2)
self.requires_grad = requires_grad self.requires_grad = requires_grad
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
def _delete_model_weights(self): def _delete_model_weights(self):
del self.model del self.model
def forward(self, words): def forward(self, words):
""" """
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
@@ -100,9 +112,9 @@ class BertEmbedding(ContextualEmbedding):
return self.dropout(outputs) return self.dropout(outputs)
outputs = self.model(words) outputs = self.model(words)
outputs = torch.cat([*outputs], dim=-1) outputs = torch.cat([*outputs], dim=-1)
return self.dropout(outputs) return self.dropout(outputs)
def drop_word(self, words): def drop_word(self, words):
""" """
按照设定随机将words设置为unknown_index。 按照设定随机将words设置为unknown_index。
@@ -114,13 +126,15 @@ class BertEmbedding(ContextualEmbedding):
with torch.no_grad(): with torch.no_grad():
if self._word_sep_index: # 不能drop sep if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._word_sep_index) sep_mask = words.eq(self._word_sep_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(0)
mask = pad_mask.__and__(mask) # pad的位置不为unk
words = words.masked_fill(mask, self._word_unk_index) words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index: if self._word_sep_index:
words.masked_fill_(sep_mask, self._word_sep_index) words.masked_fill_(sep_mask, self._word_sep_index)
return words return words
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -129,12 +143,12 @@ class BertEmbedding(ContextualEmbedding):
:return: :return:
""" """
requires_grads = set([param.requires_grad for name, param in self.named_parameters() requires_grads = set([param.requires_grad for name, param in self.named_parameters()
if 'word_pieces_lengths' not in name])
if 'word_pieces_lengths' not in name])
if len(requires_grads) == 1: if len(requires_grads) == 1:
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
@@ -155,10 +169,11 @@ class BertWordPieceEncoder(nn.Module):
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
:param bool requires_grad: 是否需要gradient。 :param bool requires_grad: 是否需要gradient。
""" """
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool=False):
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
word_dropout=0, dropout=0, requires_grad: bool = False):
super().__init__() super().__init__()
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
model_url = _get_embedding_url('bert', model_dir_or_name.lower()) model_url = _get_embedding_url('bert', model_dir_or_name.lower())
model_dir = cached_path(model_url, name='embedding') model_dir = cached_path(model_url, name='embedding')
@@ -167,15 +182,16 @@ class BertWordPieceEncoder(nn.Module):
model_dir = model_dir_or_name model_dir = model_dir_or_name
else: else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.") raise ValueError(f"Cannot recognize {model_dir_or_name}.")
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls) self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls)
self._sep_index = self.model._sep_index self._sep_index = self.model._sep_index
self._wordpiece_pad_index = self.model._wordpiece_pad_index
self._wordpiece_unk_index = self.model._wordpiece_unknown_index self._wordpiece_unk_index = self.model._wordpiece_unknown_index
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
self.requires_grad = requires_grad self.requires_grad = requires_grad
self.word_dropout = word_dropout self.word_dropout = word_dropout
self.dropout_layer = nn.Dropout(dropout) self.dropout_layer = nn.Dropout(dropout)
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -187,24 +203,24 @@ class BertWordPieceEncoder(nn.Module):
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
param.requires_grad = value param.requires_grad = value
@property @property
def embed_size(self): def embed_size(self):
return self._embed_size return self._embed_size
@property @property
def embedding_dim(self): def embedding_dim(self):
return self._embed_size return self._embed_size
@property @property
def num_embedding(self): def num_embedding(self):
return self.model.encoder.config.vocab_size return self.model.encoder.config.vocab_size
def index_datasets(self, *datasets, field_name, add_cls_sep=True): def index_datasets(self, *datasets, field_name, add_cls_sep=True):
""" """
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
@@ -216,7 +232,7 @@ class BertWordPieceEncoder(nn.Module):
:return: :return:
""" """
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep)
def forward(self, word_pieces, token_type_ids=None): def forward(self, word_pieces, token_type_ids=None):
""" """
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
@@ -233,13 +249,13 @@ class BertWordPieceEncoder(nn.Module):
token_type_ids = sep_mask_cumsum.fmod(2) token_type_ids = sep_mask_cumsum.fmod(2)
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
token_type_ids = token_type_ids.eq(0).long() token_type_ids = token_type_ids.eq(0).long()
word_pieces = self.drop_word(word_pieces) word_pieces = self.drop_word(word_pieces)
outputs = self.model(word_pieces, token_type_ids) outputs = self.model(word_pieces, token_type_ids)
outputs = torch.cat([*outputs], dim=-1) outputs = torch.cat([*outputs], dim=-1)
return self.dropout_layer(outputs) return self.dropout_layer(outputs)
def drop_word(self, words): def drop_word(self, words):
""" """
按照设定随机将words设置为unknown_index。 按照设定随机将words设置为unknown_index。
@@ -251,8 +267,10 @@ class BertWordPieceEncoder(nn.Module):
with torch.no_grad(): with torch.no_grad():
if self._word_sep_index: # 不能drop sep if self._word_sep_index: # 不能drop sep
sep_mask = words.eq(self._wordpiece_unk_index) sep_mask = words.eq(self._wordpiece_unk_index)
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(self._wordpiece_pad_index)
mask = pad_mask.__and__(mask) # pad的位置不为unk
words = words.masked_fill(mask, self._word_unk_index) words = words.masked_fill(mask, self._word_unk_index)
if self._word_sep_index: if self._word_sep_index:
words.masked_fill_(sep_mask, self._wordpiece_unk_index) words.masked_fill_(sep_mask, self._wordpiece_unk_index)
@@ -260,10 +278,10 @@ class BertWordPieceEncoder(nn.Module):




class _WordBertModel(nn.Module): class _WordBertModel(nn.Module):
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first',
include_cls_sep:bool=False, pooled_cls:bool=False, auto_truncate:bool=False, min_freq=2):
def __init__(self, model_dir: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
super().__init__() super().__init__()
self.tokenzier = BertTokenizer.from_pretrained(model_dir) self.tokenzier = BertTokenizer.from_pretrained(model_dir)
self.encoder = BertModel.from_pretrained(model_dir) self.encoder = BertModel.from_pretrained(model_dir)
self._max_position_embeddings = self.encoder.config.max_position_embeddings self._max_position_embeddings = self.encoder.config.max_position_embeddings
@@ -271,23 +289,23 @@ class _WordBertModel(nn.Module):
encoder_layer_number = len(self.encoder.encoder.layer) encoder_layer_number = len(self.encoder.encoder.layer)
self.layers = list(map(int, layers.split(','))) self.layers = list(map(int, layers.split(',')))
for layer in self.layers: for layer in self.layers:
if layer<0:
assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
if layer < 0:
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
else: else:
assert layer<encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
f"a bert model with {encoder_layer_number} layers."
assert pool_method in ('avg', 'max', 'first', 'last') assert pool_method in ('avg', 'max', 'first', 'last')
self.pool_method = pool_method self.pool_method = pool_method
self.include_cls_sep = include_cls_sep self.include_cls_sep = include_cls_sep
self.pooled_cls = pooled_cls self.pooled_cls = pooled_cls
self.auto_truncate = auto_truncate self.auto_truncate = auto_truncate
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
print("Start to generating word pieces for word.")
logger.info("Start to generating word pieces for word.")
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的
found_count = 0 found_count = 0
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
if '[sep]' in vocab: if '[sep]' in vocab:
@@ -302,10 +320,11 @@ class _WordBertModel(nn.Module):
elif index == vocab.unknown_idx: elif index == vocab.unknown_idx:
word = '[UNK]' word = '[UNK]'
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
if len(word_pieces)==1:
if len(word_pieces) == 1:
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面
if vocab.word_count[word]>=min_freq and not vocab._is_word_no_create_entry(word): #出现次数大于这个次数才新增
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
word): # 出现次数大于这个次数才新增
word_piece_dict[word] = 1 # 新增一个值 word_piece_dict[word] = 1 # 新增一个值
continue continue
for word_piece in word_pieces: for word_piece in word_pieces:
@@ -327,7 +346,7 @@ class _WordBertModel(nn.Module):
new_word_piece_vocab[token] = len(new_word_piece_vocab) new_word_piece_vocab[token] = len(new_word_piece_vocab)
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab)
self.encoder.embeddings.word_embeddings = embed self.encoder.embeddings.word_embeddings = embed
word_to_wordpieces = [] word_to_wordpieces = []
word_pieces_lengths = [] word_pieces_lengths = []
for word, index in vocab: for word, index in vocab:
@@ -343,11 +362,11 @@ class _WordBertModel(nn.Module):
self._sep_index = self.tokenzier.vocab['[SEP]'] self._sep_index = self.tokenzier.vocab['[SEP]']
self._word_pad_index = vocab.padding_idx self._word_pad_index = vocab.padding_idx
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
print("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
self.word_to_wordpieces = np.array(word_to_wordpieces) self.word_to_wordpieces = np.array(word_to_wordpieces)
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths))
print("Successfully generate word pieces.")
logger.debug("Successfully generate word pieces.")
def forward(self, words): def forward(self, words):
""" """


@@ -358,34 +377,37 @@ class _WordBertModel(nn.Module):
batch_size, max_word_len = words.size() batch_size, max_word_len = words.size()
word_mask = words.ne(self._word_pad_index) # 为1的地方有word word_mask = words.ne(self._word_pad_index) # 为1的地方有word
seq_len = word_mask.sum(dim=-1) seq_len = word_mask.sum(dim=-1)
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0), 0) # batch_size x max_len
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0),
0) # batch_size x max_len
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
if word_piece_length+2>self._max_position_embeddings:
if word_piece_length + 2 > self._max_position_embeddings:
if self.auto_truncate: if self.auto_truncate:
word_pieces_lengths = word_pieces_lengths.masked_fill(word_pieces_lengths+2>self._max_position_embeddings,
self._max_position_embeddings-2)
word_pieces_lengths = word_pieces_lengths.masked_fill(
word_pieces_lengths + 2 > self._max_position_embeddings,
self._max_position_embeddings - 2)
else: else:
raise RuntimeError("After split words into word pieces, the lengths of word pieces are longer than the "
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.")

raise RuntimeError(
"After split words into word pieces, the lengths of word pieces are longer than the "
f"maximum allowed sequence length:{self._max_position_embeddings} of bert.")
# +2是由于需要加入[CLS]与[SEP] # +2是由于需要加入[CLS]与[SEP]
word_pieces = words.new_full((batch_size, min(word_piece_length+2, self._max_position_embeddings)),
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
fill_value=self._wordpiece_pad_index) fill_value=self._wordpiece_pad_index)
attn_masks = torch.zeros_like(word_pieces) attn_masks = torch.zeros_like(word_pieces)
# 1. 获取words的word_pieces的id,以及对应的span范围 # 1. 获取words的word_pieces的id,以及对应的span范围
word_indexes = words.cpu().numpy() word_indexes = words.cpu().numpy()
for i in range(batch_size): for i in range(batch_size):
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
if self.auto_truncate and len(word_pieces_i)>self._max_position_embeddings-2:
word_pieces_i = word_pieces_i[:self._max_position_embeddings-2]
word_pieces[i, 1:word_pieces_lengths[i]+1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :word_pieces_lengths[i]+2].fill_(1)
if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2:
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
# 添加[cls]和[sep] # 添加[cls]和[sep]
word_pieces[:, 0].fill_(self._cls_index) word_pieces[:, 0].fill_(self._cls_index)
batch_indexes = torch.arange(batch_size).to(words) batch_indexes = torch.arange(batch_size).to(words)
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index
if self._has_sep_in_vocab: #但[SEP]在vocab中出现应该才会需要token_ids
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
if self._has_sep_in_vocab: # 但[SEP]在vocab中出现应该才会需要token_ids
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
token_type_ids = sep_mask_cumsum.fmod(2) token_type_ids = sep_mask_cumsum.fmod(2)
@@ -396,9 +418,9 @@ class _WordBertModel(nn.Module):
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
output_all_encoded_layers=True)
output_all_encoded_layers=True)
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
if self.include_cls_sep: if self.include_cls_sep:
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
bert_outputs[-1].size(-1)) bert_outputs[-1].size(-1))
@@ -414,7 +436,7 @@ class _WordBertModel(nn.Module):
real_word_piece_length = output_layer.size(1) - 2 real_word_piece_length = output_layer.size(1) - 2
if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
paddings = output_layer.new_zeros(batch_size, paddings = output_layer.new_zeros(batch_size,
word_piece_length-real_word_piece_length,
word_piece_length - real_word_piece_length,
output_layer.size(2)) output_layer.size(2))
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
# 从word_piece collapse到word的表示 # 从word_piece collapse到word的表示
@@ -423,27 +445,27 @@ class _WordBertModel(nn.Module):
if self.pool_method == 'first': if self.pool_method == 'first':
for i in range(batch_size): for i in range(batch_size):
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置 i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[
i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size
elif self.pool_method == 'last': elif self.pool_method == 'last':
for i in range(batch_size): for i in range(batch_size):
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i]+1] - 1 # 每个word的end
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length]
elif self.pool_method == 'max': elif self.pool_method == 'max':
for i in range(batch_size): for i in range(batch_size):
for j in range(seq_len[i]): for j in range(seq_len[i]):
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1]
outputs[l_index, i, j+s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2)
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2)
else: else:
for i in range(batch_size): for i in range(batch_size):
for j in range(seq_len[i]): for j in range(seq_len[i]):
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1]
outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
if self.include_cls_sep: if self.include_cls_sep:
if l in (len(bert_outputs)-1, -1) and self.pooled_cls:
if l in (len(bert_outputs) - 1, -1) and self.pooled_cls:
outputs[l_index, :, 0] = pooled_cls outputs[l_index, :, 0] = pooled_cls
else: else:
outputs[l_index, :, 0] = output_layer[:, 0] outputs[l_index, :, 0] = output_layer[:, 0]
outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift]
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
# 3. 最终的embedding结果 # 3. 最终的embedding结果
return outputs return outputs


+ 43
- 34
fastNLP/embeddings/char_embedding.py View File

@@ -3,6 +3,10 @@
词的index而不需要使用词语中的char的index来获取表达。 词的index而不需要使用词语中的char的index来获取表达。
""" """


__all__ = [
"CNNCharEmbedding",
"LSTMCharEmbedding"
]


import torch import torch
import torch.nn as nn import torch.nn as nn
@@ -15,6 +19,8 @@ from ..core.vocabulary import Vocabulary
from .embedding import TokenEmbedding from .embedding import TokenEmbedding
from .utils import _construct_char_vocab_from_vocab from .utils import _construct_char_vocab_from_vocab
from .utils import get_embeddings from .utils import get_embeddings
from ..core import logger



class CNNCharEmbedding(TokenEmbedding): class CNNCharEmbedding(TokenEmbedding):
""" """
@@ -49,14 +55,15 @@ class CNNCharEmbedding(TokenEmbedding):
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, (文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
""" """
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
pool_method: str='max', activation='relu', min_char_freq: int=2, pre_train_char_embed: str=None):
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1),
pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None):
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
for kernel in kernel_sizes: for kernel in kernel_sizes:
assert kernel % 2 == 1, "Only odd kernel is allowed." assert kernel % 2 == 1, "Only odd kernel is allowed."
assert pool_method in ('max', 'avg') assert pool_method in ('max', 'avg')
self.pool_method = pool_method self.pool_method = pool_method
# activation function # activation function
@@ -74,12 +81,12 @@ class CNNCharEmbedding(TokenEmbedding):
else: else:
raise Exception( raise Exception(
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
print("Start constructing character vocabulary.")
logger.info("Start constructing character vocabulary.")
# 建立char的词表 # 建立char的词表
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
self.char_pad_index = self.char_vocab.padding_idx self.char_pad_index = self.char_vocab.padding_idx
print(f"In total, there are {len(self.char_vocab)} distinct characters.")
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index # 对vocab进行index
max_word_len = max(map(lambda x: len(x[0]), vocab)) max_word_len = max(map(lambda x: len(x[0]), vocab))
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len),
@@ -95,14 +102,14 @@ class CNNCharEmbedding(TokenEmbedding):
self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed) self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed)
else: else:
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
self.convs = nn.ModuleList([nn.Conv1d( self.convs = nn.ModuleList([nn.Conv1d(
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
for i in range(len(kernel_sizes))]) for i in range(len(kernel_sizes))])
self._embed_size = embed_size self._embed_size = embed_size
self.fc = nn.Linear(sum(filter_nums), embed_size) self.fc = nn.Linear(sum(filter_nums), embed_size)
self.reset_parameters() self.reset_parameters()
def forward(self, words): def forward(self, words):
""" """
输入words的index后,生成对应的words的表示。 输入words的index后,生成对应的words的表示。
@@ -113,14 +120,14 @@ class CNNCharEmbedding(TokenEmbedding):
words = self.drop_word(words) words = self.drop_word(words)
batch_size, max_len = words.size() batch_size, max_len = words.size()
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
word_lengths = self.word_lengths[words] # batch_size x max_len
word_lengths = self.word_lengths[words] # batch_size x max_len
max_word_len = word_lengths.max() max_word_len = word_lengths.max()
chars = chars[:, :, :max_word_len] chars = chars[:, :, :max_word_len]
# 为1的地方为mask # 为1的地方为mask
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
chars = self.dropout(chars) chars = self.dropout(chars)
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1)
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)
for conv in self.convs] for conv in self.convs]
@@ -128,13 +135,13 @@ class CNNCharEmbedding(TokenEmbedding):
conv_chars = self.activation(conv_chars) conv_chars = self.activation(conv_chars)
if self.pool_method == 'max': if self.pool_method == 'max':
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters)
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters)
else: else:
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
chars = self.fc(chars) chars = self.fc(chars)
return self.dropout(chars) return self.dropout(chars)
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -150,21 +157,21 @@ class CNNCharEmbedding(TokenEmbedding):
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
continue continue
param.requires_grad = value param.requires_grad = value
def reset_parameters(self): def reset_parameters(self):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset
continue continue
if 'char_embedding' in name: if 'char_embedding' in name:
continue continue
if param.data.dim()>1:
if param.data.dim() > 1:
nn.init.xavier_uniform_(param, 1) nn.init.xavier_uniform_(param, 1)
else: else:
nn.init.uniform_(param, -1, 1) nn.init.uniform_(param, -1, 1)
@@ -202,13 +209,15 @@ class LSTMCharEmbedding(TokenEmbedding):
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, (文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
""" """
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
dropout:float=0, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
bidirectional=True, pre_train_char_embed: str=None):
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu',
min_char_freq: int = 2,
bidirectional=True, pre_train_char_embed: str = None):
super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)

assert hidden_size % 2 == 0, "Only even kernel is allowed." assert hidden_size % 2 == 0, "Only even kernel is allowed."
assert pool_method in ('max', 'avg') assert pool_method in ('max', 'avg')
self.pool_method = pool_method self.pool_method = pool_method
# activation function # activation function
@@ -226,12 +235,12 @@ class LSTMCharEmbedding(TokenEmbedding):
else: else:
raise Exception( raise Exception(
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
print("Start constructing character vocabulary.")
logger.info("Start constructing character vocabulary.")
# 建立char的词表 # 建立char的词表
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
self.char_pad_index = self.char_vocab.padding_idx self.char_pad_index = self.char_vocab.padding_idx
print(f"In total, there are {len(self.char_vocab)} distinct characters.")
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
# 对vocab进行index # 对vocab进行index
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) self.max_word_len = max(map(lambda x: len(x[0]), vocab))
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), self.max_word_len), self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), self.max_word_len),
@@ -247,14 +256,14 @@ class LSTMCharEmbedding(TokenEmbedding):
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
else: else:
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
self.fc = nn.Linear(hidden_size, embed_size) self.fc = nn.Linear(hidden_size, embed_size)
hidden_size = hidden_size // 2 if bidirectional else hidden_size hidden_size = hidden_size // 2 if bidirectional else hidden_size
self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True) self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True)
self._embed_size = embed_size self._embed_size = embed_size
self.bidirectional = bidirectional self.bidirectional = bidirectional
def forward(self, words): def forward(self, words):
""" """
输入words的index后,生成对应的words的表示。 输入words的index后,生成对应的words的表示。
@@ -276,7 +285,7 @@ class LSTMCharEmbedding(TokenEmbedding):
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len) char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len)
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
# B x M x M x H # B x M x M x H
lstm_chars = self.activation(lstm_chars) lstm_chars = self.activation(lstm_chars)
if self.pool_method == 'max': if self.pool_method == 'max':
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
@@ -284,11 +293,11 @@ class LSTMCharEmbedding(TokenEmbedding):
else: else:
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
chars = self.fc(chars) chars = self.fc(chars)
return self.dropout(chars) return self.dropout(chars)
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -305,7 +314,7 @@ class LSTMCharEmbedding(TokenEmbedding):
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():


+ 26
- 19
fastNLP/embeddings/contextual_embedding.py View File

@@ -1,23 +1,30 @@
"""
.. todo::
doc
"""

__all__ = [
"ContextualEmbedding"
]

from abc import abstractmethod from abc import abstractmethod

import torch import torch


from ..core.vocabulary import Vocabulary
from ..core.dataset import DataSet
from .embedding import TokenEmbedding
from ..core import logger
from ..core.batch import DataSetIter from ..core.batch import DataSetIter
from ..core.dataset import DataSet
from ..core.sampler import SequentialSampler from ..core.sampler import SequentialSampler
from ..core.utils import _move_model_to_device, _get_model_device from ..core.utils import _move_model_to_device, _get_model_device
from .embedding import TokenEmbedding

__all__ = [
"ContextualEmbedding"
]
from ..core.vocabulary import Vocabulary




class ContextualEmbedding(TokenEmbedding): class ContextualEmbedding(TokenEmbedding):
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0):
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True):
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True):
""" """
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。


@@ -32,14 +39,14 @@ class ContextualEmbedding(TokenEmbedding):
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed."
assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." assert 'words' in dataset.get_input_name(), "`words` field has to be set as input."
except Exception as e: except Exception as e:
print(f"Exception happens at {index} dataset.")
logger.error(f"Exception happens at {index} dataset.")
raise e raise e
sent_embeds = {} sent_embeds = {}
_move_model_to_device(self, device=device) _move_model_to_device(self, device=device)
device = _get_model_device(self) device = _get_model_device(self)
pad_index = self._word_vocab.padding_idx pad_index = self._word_vocab.padding_idx
print("Start to calculate sentence representations.")
logger.info("Start to calculate sentence representations.")
with torch.no_grad(): with torch.no_grad():
for index, dataset in enumerate(datasets): for index, dataset in enumerate(datasets):
try: try:
@@ -54,18 +61,18 @@ class ContextualEmbedding(TokenEmbedding):
word_embeds = self(words).detach().cpu().numpy() word_embeds = self(words).detach().cpu().numpy()
for b in range(words.size(0)): for b in range(words.size(0)):
length = seq_len_from_behind[b] length = seq_len_from_behind[b]
if length==0:
if length == 0:
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b] sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b]
else: else:
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length]
except Exception as e: except Exception as e:
print(f"Exception happens at {index} dataset.")
logger.error(f"Exception happens at {index} dataset.")
raise e raise e
print("Finish calculating sentence representations.")
logger.info("Finish calculating sentence representations.")
self.sent_embeds = sent_embeds self.sent_embeds = sent_embeds
if delete_weights: if delete_weights:
self._delete_model_weights() self._delete_model_weights()
def _get_sent_reprs(self, words): def _get_sent_reprs(self, words):
""" """
获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None
@@ -88,12 +95,12 @@ class ContextualEmbedding(TokenEmbedding):
embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device) embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device)
return embeds return embeds
return None return None
@abstractmethod @abstractmethod
def _delete_model_weights(self): def _delete_model_weights(self):
"""删除计算表示的模型以节省资源""" """删除计算表示的模型以节省资源"""
raise NotImplementedError raise NotImplementedError
def remove_sentence_cache(self): def remove_sentence_cache(self):
""" """
删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。


+ 46
- 40
fastNLP/embeddings/elmo_embedding.py View File

@@ -1,6 +1,13 @@
"""
.. todo::
doc
"""


import os
__all__ = [
"ElmoEmbedding"
]


import os
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
@@ -11,7 +18,7 @@ from ..core.vocabulary import Vocabulary
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder
from .contextual_embedding import ContextualEmbedding from .contextual_embedding import ContextualEmbedding
from ..core import logger


class ElmoEmbedding(ContextualEmbedding): class ElmoEmbedding(ContextualEmbedding):
""" """
@@ -49,11 +56,11 @@ class ElmoEmbedding(ContextualEmbedding):
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
并删除character encoder,之后将直接使用cache的embedding。默认为False。 并删除character encoder,之后将直接使用cache的embedding。默认为False。
""" """
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = False, def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = False,
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False): word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False):
super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
# 根据model_dir_or_name检查是否存在并下载 # 根据model_dir_or_name检查是否存在并下载
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) model_url = _get_embedding_url('elmo', model_dir_or_name.lower())
@@ -64,7 +71,7 @@ class ElmoEmbedding(ContextualEmbedding):
else: else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.") raise ValueError(f"Cannot recognize {model_dir_or_name}.")
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs)
if layers == 'mix': if layers == 'mix':
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1),
requires_grad=requires_grad) requires_grad=requires_grad)
@@ -79,16 +86,16 @@ class ElmoEmbedding(ContextualEmbedding):
self.layers = layers self.layers = layers
self._get_outputs = self._get_layer_outputs self._get_outputs = self._get_layer_outputs
self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2
self.requires_grad = requires_grad self.requires_grad = requires_grad
def _get_mixed_outputs(self, outputs): def _get_mixed_outputs(self, outputs):
# outputs: num_layers x batch_size x max_len x hidden_size # outputs: num_layers x batch_size x max_len x hidden_size
# return: batch_size x max_len x hidden_size # return: batch_size x max_len x hidden_size
weights = F.softmax(self.layer_weights + 1 / len(outputs), dim=0).to(outputs) weights = F.softmax(self.layer_weights + 1 / len(outputs), dim=0).to(outputs)
outputs = torch.einsum('l,lbij->bij', weights, outputs) outputs = torch.einsum('l,lbij->bij', weights, outputs)
return self.gamma.to(outputs) * outputs return self.gamma.to(outputs) * outputs
def set_mix_weights_requires_grad(self, flag=True): def set_mix_weights_requires_grad(self, flag=True):
""" """
当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用 当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用
@@ -100,15 +107,15 @@ class ElmoEmbedding(ContextualEmbedding):
if hasattr(self, 'layer_weights'): if hasattr(self, 'layer_weights'):
self.layer_weights.requires_grad = flag self.layer_weights.requires_grad = flag
self.gamma.requires_grad = flag self.gamma.requires_grad = flag
def _get_layer_outputs(self, outputs): def _get_layer_outputs(self, outputs):
if len(self.layers) == 1: if len(self.layers) == 1:
outputs = outputs[self.layers[0]] outputs = outputs[self.layers[0]]
else: else:
outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1) outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1)
return outputs return outputs
def forward(self, words: torch.LongTensor): def forward(self, words: torch.LongTensor):
""" """
计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的
@@ -125,12 +132,12 @@ class ElmoEmbedding(ContextualEmbedding):
outputs = self.model(words) outputs = self.model(words)
outputs = self._get_outputs(outputs) outputs = self._get_outputs(outputs)
return self.dropout(outputs) return self.dropout(outputs)
def _delete_model_weights(self): def _delete_model_weights(self):
for name in ['layers', 'model', 'layer_weights', 'gamma']: for name in ['layers', 'model', 'layer_weights', 'gamma']:
if hasattr(self, name): if hasattr(self, name):
delattr(self, name) delattr(self, name)
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -144,7 +151,7 @@ class ElmoEmbedding(ContextualEmbedding):
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
@@ -162,7 +169,7 @@ class _ElmoModel(nn.Module):
(4) 设计一个保存token的embedding,允许缓存word的表示。 (4) 设计一个保存token的embedding,允许缓存word的表示。


""" """
def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False):
super(_ElmoModel, self).__init__() super(_ElmoModel, self).__init__()
self.model_dir = model_dir self.model_dir = model_dir
@@ -187,14 +194,14 @@ class _ElmoModel(nn.Module):
config = json.load(config_f) config = json.load(config_f)
self.weight_file = os.path.join(model_dir, weight_file) self.weight_file = os.path.join(model_dir, weight_file)
self.config = config self.config = config
OOV_TAG = '<oov>' OOV_TAG = '<oov>'
PAD_TAG = '<pad>' PAD_TAG = '<pad>'
BOS_TAG = '<bos>' BOS_TAG = '<bos>'
EOS_TAG = '<eos>' EOS_TAG = '<eos>'
BOW_TAG = '<bow>' BOW_TAG = '<bow>'
EOW_TAG = '<eow>' EOW_TAG = '<eow>'
# For the model trained with character-based word encoder. # For the model trained with character-based word encoder.
char_lexicon = {} char_lexicon = {}
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
@@ -204,29 +211,29 @@ class _ElmoModel(nn.Module):
tokens.insert(0, '\u3000') tokens.insert(0, '\u3000')
token, i = tokens token, i = tokens
char_lexicon[token] = int(i) char_lexicon[token] = int(i)
# 做一些sanity check # 做一些sanity check
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]:
assert special_word in char_lexicon, f"{special_word} not found in char.dic." assert special_word in char_lexicon, f"{special_word} not found in char.dic."
# 从vocab中构建char_vocab # 从vocab中构建char_vocab
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG)
# 需要保证<bow>与<eow>在里面 # 需要保证<bow>与<eow>在里面
char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG]) char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG])
for word, index in vocab: for word, index in vocab:
char_vocab.add_word_lst(list(word)) char_vocab.add_word_lst(list(word))
self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx
# 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示) # 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示)
char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']),
padding_idx=len(char_vocab)) padding_idx=len(char_vocab))
# 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict # 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict
elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu') elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu')
char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight'] char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight']
found_char_count = 0 found_char_count = 0
for char, index in char_vocab: # 调整character embedding for char, index in char_vocab: # 调整character embedding
if char in char_lexicon: if char in char_lexicon:
@@ -235,11 +242,10 @@ class _ElmoModel(nn.Module):
else: else:
index_in_pre = char_lexicon[OOV_TAG] index_in_pre = char_lexicon[OOV_TAG]
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
# 生成words到chars的映射 # 生成words到chars的映射
max_chars = config['char_cnn']['max_characters_per_token'] max_chars = config['char_cnn']['max_characters_per_token']

self.register_buffer('words_to_chars_embedding', torch.full((len(vocab) + 2, max_chars), self.register_buffer('words_to_chars_embedding', torch.full((len(vocab) + 2, max_chars),
fill_value=len(char_vocab), fill_value=len(char_vocab),
dtype=torch.long)) dtype=torch.long))
@@ -257,29 +263,29 @@ class _ElmoModel(nn.Module):
char_vocab.to_index(EOW_TAG)] char_vocab.to_index(EOW_TAG)]
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) self.words_to_chars_embedding[index] = torch.LongTensor(char_ids)
self.char_vocab = char_vocab self.char_vocab = char_vocab
self.token_embedder = ConvTokenEmbedder( self.token_embedder = ConvTokenEmbedder(
config, self.weight_file, None, char_emb_layer) config, self.weight_file, None, char_emb_layer)
elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight
self.token_embedder.load_state_dict(elmo_model["char_cnn"]) self.token_embedder.load_state_dict(elmo_model["char_cnn"])
self.output_dim = config['lstm']['projection_dim'] self.output_dim = config['lstm']['projection_dim']
# lstm encoder # lstm encoder
self.encoder = ElmobiLm(config) self.encoder = ElmobiLm(config)
self.encoder.load_state_dict(elmo_model["lstm"]) self.encoder.load_state_dict(elmo_model["lstm"])
if cache_word_reprs: if cache_word_reprs:
if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
print("Start to generate cache word representations.")
logger.info("Start to generate cache word representations.")
batch_size = 320 batch_size = 320
# bos eos # bos eos
word_size = self.words_to_chars_embedding.size(0) word_size = self.words_to_chars_embedding.size(0)
num_batches = word_size // batch_size + \ num_batches = word_size // batch_size + \
int(word_size % batch_size != 0) int(word_size % batch_size != 0)
self.cached_word_embedding = nn.Embedding(word_size, self.cached_word_embedding = nn.Embedding(word_size,
config['lstm']['projection_dim']) config['lstm']['projection_dim'])
with torch.no_grad(): with torch.no_grad():
@@ -290,12 +296,12 @@ class _ElmoModel(nn.Module):
word_reprs = self.token_embedder(words.unsqueeze(1), word_reprs = self.token_embedder(words.unsqueeze(1),
chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] chars).detach() # batch_size x 1 x config['encoder']['projection_dim']
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1)
print("Finish generating cached word representations. Going to delete the character encoder.")
logger.info("Finish generating cached word representations. Going to delete the character encoder.")
del self.token_embedder, self.words_to_chars_embedding del self.token_embedder, self.words_to_chars_embedding
else: else:
print("There is no need to cache word representations, since no character information is used.")
logger.info("There is no need to cache word representations, since no character information is used.")
def forward(self, words): def forward(self, words):
""" """


@@ -320,7 +326,7 @@ class _ElmoModel(nn.Module):
else: else:
chars = None chars = None
token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim
encoder_output = self.encoder(token_embedding, seq_len) encoder_output = self.encoder(token_embedding, seq_len)
if encoder_output.size(2) < max_len + 2: if encoder_output.size(2) < max_len + 2:
num_layers, _, output_len, hidden_size = encoder_output.size() num_layers, _, output_len, hidden_size = encoder_output.size()
@@ -331,7 +337,7 @@ class _ElmoModel(nn.Module):
token_embedding = token_embedding.masked_fill(mask, 0) token_embedding = token_embedding.masked_fill(mask, 0)
token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3]) token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3])
encoder_output = torch.cat((token_embedding, encoder_output), dim=0) encoder_output = torch.cat((token_embedding, encoder_output), dim=0)
# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 # 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。
encoder_output = encoder_output[:, :, 1:-1] encoder_output = encoder_output[:, :, 1:-1]
return encoder_output return encoder_output

+ 33
- 27
fastNLP/embeddings/embedding.py View File

@@ -3,6 +3,10 @@


""" """


__all__ = [
"Embedding",
"TokenEmbedding"
]


import torch.nn as nn import torch.nn as nn
from abc import abstractmethod from abc import abstractmethod
@@ -33,11 +37,11 @@ class Embedding(nn.Module):
:param float dropout: 对Embedding的输出的dropout。 :param float dropout: 对Embedding的输出的dropout。
:param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。 :param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。
""" """
def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None): def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None):
super(Embedding, self).__init__() super(Embedding, self).__init__()
self.embed = get_embeddings(init_embed) self.embed = get_embeddings(init_embed)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
@@ -48,44 +52,44 @@ class Embedding(nn.Module):
self._embed_size = self.embed.embedding_dim self._embed_size = self.embed.embedding_dim
else: else:
self._embed_size = self.embed.weight.size(1) self._embed_size = self.embed.weight.size(1)
if word_dropout>0 and not isinstance(unk_index, int):
if word_dropout > 0 and not isinstance(unk_index, int):
raise ValueError("When drop word is set, you need to pass in the unk_index.") raise ValueError("When drop word is set, you need to pass in the unk_index.")
else: else:
self._embed_size = self.embed.embed_size self._embed_size = self.embed.embed_size
unk_index = self.embed.get_word_vocab().unknown_idx unk_index = self.embed.get_word_vocab().unknown_idx
self.unk_index = unk_index self.unk_index = unk_index
self.word_dropout = word_dropout self.word_dropout = word_dropout
def forward(self, words): def forward(self, words):
""" """
:param torch.LongTensor words: [batch, seq_len] :param torch.LongTensor words: [batch, seq_len]
:return: torch.Tensor : [batch, seq_len, embed_dim] :return: torch.Tensor : [batch, seq_len, embed_dim]
""" """
if self.word_dropout>0 and self.training:
if self.word_dropout > 0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
words = words.masked_fill(mask, self.unk_index) words = words.masked_fill(mask, self.unk_index)
words = self.embed(words) words = self.embed(words)
return self.dropout(words) return self.dropout(words)
@property @property
def num_embedding(self)->int:
def num_embedding(self) -> int:
if isinstance(self.embed, nn.Embedding): if isinstance(self.embed, nn.Embedding):
return self.embed.weight.size(0) return self.embed.weight.size(0)
else: else:
return self.embed.num_embedding return self.embed.num_embedding
def __len__(self): def __len__(self):
return len(self.embed) return len(self.embed)
@property @property
def embed_size(self) -> int: def embed_size(self) -> int:
return self._embed_size return self._embed_size
@property @property
def embedding_dim(self) -> int: def embedding_dim(self) -> int:
return self._embed_size return self._embed_size
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -96,14 +100,14 @@ class Embedding(nn.Module):
return self.embed.weight.requires_grad return self.embed.weight.requires_grad
else: else:
return self.embed.requires_grad return self.embed.requires_grad
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
if not isinstance(self.embed, TokenEmbedding): if not isinstance(self.embed, TokenEmbedding):
self.embed.weight.requires_grad = value self.embed.weight.requires_grad = value
else: else:
self.embed.requires_grad = value self.embed.requires_grad = value
@property @property
def size(self): def size(self):
if isinstance(self.embed, TokenEmbedding): if isinstance(self.embed, TokenEmbedding):
@@ -120,12 +124,12 @@ class TokenEmbedding(nn.Module):
assert vocab.padding is not None, "Vocabulary must have a padding entry." assert vocab.padding is not None, "Vocabulary must have a padding entry."
self._word_vocab = vocab self._word_vocab = vocab
self._word_pad_index = vocab.padding_idx self._word_pad_index = vocab.padding_idx
if word_dropout>0:
if word_dropout > 0:
assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word." assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word."
self.word_dropout = word_dropout self.word_dropout = word_dropout
self._word_unk_index = vocab.unknown_idx self._word_unk_index = vocab.unknown_idx
self.dropout_layer = nn.Dropout(dropout) self.dropout_layer = nn.Dropout(dropout)
def drop_word(self, words): def drop_word(self, words):
""" """
按照设定随机将words设置为unknown_index。 按照设定随机将words设置为unknown_index。
@@ -134,11 +138,13 @@ class TokenEmbedding(nn.Module):
:return: :return:
""" """
if self.word_dropout > 0 and self.training: if self.word_dropout > 0 and self.training:
mask = torch.ones_like(words).float() * self.word_dropout
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
pad_mask = words.ne(self._word_pad_index)
mask = mask.__and__(pad_mask)
words = words.masked_fill(mask, self._word_unk_index) words = words.masked_fill(mask, self._word_unk_index)
return words return words
def dropout(self, words): def dropout(self, words):
""" """
对embedding后的word表示进行drop。 对embedding后的word表示进行drop。
@@ -147,7 +153,7 @@ class TokenEmbedding(nn.Module):
:return: :return:
""" """
return self.dropout_layer(words) return self.dropout_layer(words)
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -159,23 +165,23 @@ class TokenEmbedding(nn.Module):
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for param in self.parameters(): for param in self.parameters():
param.requires_grad = value param.requires_grad = value
def __len__(self): def __len__(self):
return len(self._word_vocab) return len(self._word_vocab)
@property @property
def embed_size(self) -> int: def embed_size(self) -> int:
return self._embed_size return self._embed_size
@property @property
def embedding_dim(self) -> int: def embedding_dim(self) -> int:
return self._embed_size return self._embed_size
@property @property
def num_embedding(self) -> int: def num_embedding(self) -> int:
""" """
@@ -183,7 +189,7 @@ class TokenEmbedding(nn.Module):
:return: :return:
""" """
return len(self._word_vocab) return len(self._word_vocab)
def get_word_vocab(self): def get_word_vocab(self):
""" """
返回embedding的词典。 返回embedding的词典。
@@ -191,11 +197,11 @@ class TokenEmbedding(nn.Module):
:return: Vocabulary :return: Vocabulary
""" """
return self._word_vocab return self._word_vocab
@property @property
def size(self): def size(self):
return torch.Size(self.num_embedding, self._embed_size) return torch.Size(self.num_embedding, self._embed_size)
@abstractmethod @abstractmethod
def forward(self, words): def forward(self, words):
raise NotImplementedError raise NotImplementedError

+ 17
- 7
fastNLP/embeddings/stack_embedding.py View File

@@ -1,3 +1,12 @@
"""
.. todo::
doc
"""

__all__ = [
"StackEmbedding",
]

from typing import List from typing import List


import torch import torch
@@ -26,6 +35,7 @@ class StackEmbedding(TokenEmbedding):
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。


""" """
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
vocabs = [] vocabs = []
for embed in embeds: for embed in embeds:
@@ -34,14 +44,14 @@ class StackEmbedding(TokenEmbedding):
_vocab = vocabs[0] _vocab = vocabs[0]
for vocab in vocabs[1:]: for vocab in vocabs[1:]:
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
assert isinstance(embeds, list) assert isinstance(embeds, list)
for embed in embeds: for embed in embeds:
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
self.embeds = nn.ModuleList(embeds) self.embeds = nn.ModuleList(embeds)
self._embed_size = sum([embed.embed_size for embed in self.embeds]) self._embed_size = sum([embed.embed_size for embed in self.embeds])
def append(self, embed: TokenEmbedding): def append(self, embed: TokenEmbedding):
""" """
添加一个embedding到结尾。 添加一个embedding到结尾。
@@ -50,18 +60,18 @@ class StackEmbedding(TokenEmbedding):
""" """
assert isinstance(embed, TokenEmbedding) assert isinstance(embed, TokenEmbedding)
self.embeds.append(embed) self.embeds.append(embed)
def pop(self): def pop(self):
""" """
弹出最后一个embed 弹出最后一个embed
:return: :return:
""" """
return self.embeds.pop() return self.embeds.pop()
@property @property
def embed_size(self): def embed_size(self):
return self._embed_size return self._embed_size
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -73,12 +83,12 @@ class StackEmbedding(TokenEmbedding):
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for embed in self.embeds(): for embed in self.embeds():
embed.requires_grad = value embed.requires_grad = value
def forward(self, words): def forward(self, words):
""" """
得到多个embedding的结果,并把结果按照顺序concat起来。 得到多个embedding的结果,并把结果按照顺序concat起来。


+ 37
- 29
fastNLP/embeddings/static_embedding.py View File

@@ -1,4 +1,11 @@
"""
.. todo::
doc
"""


__all__ = [
"StaticEmbedding"
]
import os import os


import torch import torch
@@ -12,6 +19,8 @@ from .embedding import TokenEmbedding
from ..modules.utils import _get_file_name_base_on_postfix from ..modules.utils import _get_file_name_base_on_postfix
from copy import deepcopy from copy import deepcopy
from collections import defaultdict from collections import defaultdict
from ..core import logger



class StaticEmbedding(TokenEmbedding): class StaticEmbedding(TokenEmbedding):
""" """
@@ -55,15 +64,16 @@ class StaticEmbedding(TokenEmbedding):
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
""" """
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=-1, requires_grad: bool=True,
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True,
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
if embedding_dim>0:
if embedding_dim > 0:
model_dir_or_name = None model_dir_or_name = None
# 得到cache_path # 得到cache_path
if model_dir_or_name is None: if model_dir_or_name is None:
assert embedding_dim>=1, "The dimension of embedding should be larger than 1."
assert embedding_dim >= 1, "The dimension of embedding should be larger than 1."
embedding_dim = int(embedding_dim) embedding_dim = int(embedding_dim)
model_path = None model_path = None
elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
@@ -76,9 +86,9 @@ class StaticEmbedding(TokenEmbedding):
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt')
else: else:
raise ValueError(f"Cannot recognize {model_dir_or_name}.") raise ValueError(f"Cannot recognize {model_dir_or_name}.")
# 根据min_freq缩小vocab # 根据min_freq缩小vocab
truncate_vocab = (vocab.min_freq is None and min_freq>1) or (vocab.min_freq and vocab.min_freq<min_freq)
truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq)
if truncate_vocab: if truncate_vocab:
truncated_vocab = deepcopy(vocab) truncated_vocab = deepcopy(vocab)
truncated_vocab.min_freq = min_freq truncated_vocab.min_freq = min_freq
@@ -89,23 +99,23 @@ class StaticEmbedding(TokenEmbedding):
lowered_word_count[word.lower()] += count lowered_word_count[word.lower()] += count
for word in truncated_vocab.word_count.keys(): for word in truncated_vocab.word_count.keys():
word_count = truncated_vocab.word_count[word] word_count = truncated_vocab.word_count[word]
if lowered_word_count[word.lower()]>=min_freq and word_count<min_freq:
truncated_vocab.add_word_lst([word]*(min_freq-word_count),
if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq:
truncated_vocab.add_word_lst([word] * (min_freq - word_count),
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) no_create_entry=truncated_vocab._is_word_no_create_entry(word))
# 只限制在train里面的词语使用min_freq筛选 # 只限制在train里面的词语使用min_freq筛选
if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None: if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None:
for word in truncated_vocab.word_count.keys(): for word in truncated_vocab.word_count.keys():
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word]<min_freq:
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq:
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]),
no_create_entry=True) no_create_entry=True)
truncated_vocab.build_vocab() truncated_vocab.build_vocab()
truncated_words_to_words = torch.arange(len(vocab)).long() truncated_words_to_words = torch.arange(len(vocab)).long()
for word, index in vocab: for word, index in vocab:
truncated_words_to_words[index] = truncated_vocab.to_index(word) truncated_words_to_words[index] = truncated_vocab.to_index(word)
print(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
vocab = truncated_vocab vocab = truncated_vocab
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False)
# 读取embedding # 读取embedding
if lower: if lower:
@@ -115,7 +125,7 @@ class StaticEmbedding(TokenEmbedding):
lowered_vocab.add_word(word.lower(), no_create_entry=True) lowered_vocab.add_word(word.lower(), no_create_entry=True)
else: else:
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
print(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
f"unique lowered words.") f"unique lowered words.")
if model_path: if model_path:
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
@@ -144,21 +154,20 @@ class StaticEmbedding(TokenEmbedding):
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) self.register_buffer('words_to_words', torch.arange(len(vocab)).long())
if not self.only_norm_found_vector and normalize: if not self.only_norm_found_vector and normalize:
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)
if truncate_vocab: if truncate_vocab:
for i in range(len(truncated_words_to_words)): for i in range(len(truncated_words_to_words)):
index_in_truncated_vocab = truncated_words_to_words[i] index_in_truncated_vocab = truncated_words_to_words[i]
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab]
del self.words_to_words del self.words_to_words
self.register_buffer('words_to_words', truncated_words_to_words) self.register_buffer('words_to_words', truncated_words_to_words)

self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
padding_idx=vocab.padding_idx, padding_idx=vocab.padding_idx,
max_norm=None, norm_type=2, scale_grad_by_freq=False, max_norm=None, norm_type=2, scale_grad_by_freq=False,
sparse=False, _weight=embedding) sparse=False, _weight=embedding)
self._embed_size = self.embedding.weight.size(1) self._embed_size = self.embedding.weight.size(1)
self.requires_grad = requires_grad self.requires_grad = requires_grad
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None):
""" """


@@ -168,14 +177,14 @@ class StaticEmbedding(TokenEmbedding):
:return: torch.FloatTensor :return: torch.FloatTensor
""" """
embed = torch.zeros(num_embedding, embedding_dim) embed = torch.zeros(num_embedding, embedding_dim)
if init_embed is None: if init_embed is None:
nn.init.uniform_(embed, -np.sqrt(3/embedding_dim), np.sqrt(3/embedding_dim))
nn.init.uniform_(embed, -np.sqrt(3 / embedding_dim), np.sqrt(3 / embedding_dim))
else: else:
init_embed(embed) init_embed(embed)
return embed return embed
@property @property
def requires_grad(self): def requires_grad(self):
""" """
@@ -189,14 +198,14 @@ class StaticEmbedding(TokenEmbedding):
return requires_grads.pop() return requires_grads.pop()
else: else:
return None return None
@requires_grad.setter @requires_grad.setter
def requires_grad(self, value): def requires_grad(self, value):
for name, param in self.named_parameters(): for name, param in self.named_parameters():
if 'words_to_words' in name: if 'words_to_words' in name:
continue continue
param.requires_grad = value param.requires_grad = value
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>',
error='ignore', init_method=None): error='ignore', init_method=None):
""" """
@@ -249,15 +258,15 @@ class StaticEmbedding(TokenEmbedding):
index = vocab.to_index(word) index = vocab.to_index(word)
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
if self.only_norm_found_vector: if self.only_norm_found_vector:
matrix[index] = matrix[index]/np.linalg.norm(matrix[index])
matrix[index] = matrix[index] / np.linalg.norm(matrix[index])
found_count += 1 found_count += 1
except Exception as e: except Exception as e:
if error == 'ignore': if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx)) warnings.warn("Error occurred at the {} line.".format(idx))
else: else:
print("Error occurred at the {} line.".format(idx))
logger.error("Error occurred at the {} line.".format(idx))
raise e raise e
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
for word, index in vocab: for word, index in vocab:
if index not in matrix and not vocab._is_word_no_create_entry(word): if index not in matrix and not vocab._is_word_no_create_entry(word):
if found_unknown: # 如果有unkonwn,用unknown初始化 if found_unknown: # 如果有unkonwn,用unknown初始化
@@ -266,21 +275,20 @@ class StaticEmbedding(TokenEmbedding):
matrix[index] = None matrix[index] = None
# matrix中代表是需要建立entry的词 # matrix中代表是需要建立entry的词
vectors = self._randomly_init_embed(len(matrix), dim, init_method) vectors = self._randomly_init_embed(len(matrix), dim, init_method)
if vocab.unknown is None: # 创建一个专门的unknown if vocab.unknown is None: # 创建一个专门的unknown
unknown_idx = len(matrix) unknown_idx = len(matrix)
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
else: else:
unknown_idx = vocab.unknown_idx unknown_idx = vocab.unknown_idx
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long())

for index, (index_in_vocab, vec) in enumerate(matrix.items()): for index, (index_in_vocab, vec) in enumerate(matrix.items()):
if vec is not None: if vec is not None:
vectors[index] = vec vectors[index] = vec
self.words_to_words[index_in_vocab] = index self.words_to_words[index_in_vocab] = index
return vectors return vectors
def forward(self, words): def forward(self, words):
""" """
传入words的index 传入words的index


+ 11
- 5
fastNLP/embeddings/utils.py View File

@@ -1,13 +1,19 @@
"""
.. todo::
doc
"""
import numpy as np import numpy as np
import torch import torch
from torch import nn as nn from torch import nn as nn


from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary


__all__ = ['get_embeddings']
__all__ = [
'get_embeddings'
]




def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1):
def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1):
""" """
给定一个word的vocabulary生成character的vocabulary. 给定一个word的vocabulary生成character的vocabulary.


@@ -36,8 +42,8 @@ def get_embeddings(init_embed):
if isinstance(init_embed, tuple): if isinstance(init_embed, tuple):
res = nn.Embedding( res = nn.Embedding(
num_embeddings=init_embed[0], embedding_dim=init_embed[1]) num_embeddings=init_embed[0], embedding_dim=init_embed[1])
nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)),
b=np.sqrt(3/res.weight.data.size(1)))
nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)),
b=np.sqrt(3 / res.weight.data.size(1)))
elif isinstance(init_embed, nn.Module): elif isinstance(init_embed, nn.Module):
res = init_embed res = init_embed
elif isinstance(init_embed, torch.Tensor): elif isinstance(init_embed, torch.Tensor):
@@ -48,4 +54,4 @@ def get_embeddings(init_embed):
else: else:
raise TypeError( raise TypeError(
'invalid init_embed type: {}'.format((type(init_embed)))) 'invalid init_embed type: {}'.format((type(init_embed))))
return res
return res

+ 4
- 0
fastNLP/io/__init__.py View File

@@ -45,6 +45,8 @@ __all__ = [
"QNLILoader", "QNLILoader",
"RTELoader", "RTELoader",


"Pipe",

"YelpFullPipe", "YelpFullPipe",
"YelpPolarityPipe", "YelpPolarityPipe",
"SSTPipe", "SSTPipe",
@@ -58,6 +60,8 @@ __all__ = [
"PeopleDailyPipe", "PeopleDailyPipe",
"WeiboNERPipe", "WeiboNERPipe",


"CWSPipe",

"MatchingBertPipe", "MatchingBertPipe",
"RTEBertPipe", "RTEBertPipe",
"SNLIBertPipe", "SNLIBertPipe",


+ 6
- 1
fastNLP/io/data_bundle.py View File

@@ -1,10 +1,15 @@
"""
.. todo::
doc
"""
__all__ = [ __all__ = [
'DataBundle', 'DataBundle',
] ]


import _pickle as pickle import _pickle as pickle
from typing import Union, Dict
import os import os
from typing import Union, Dict

from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.vocabulary import Vocabulary from ..core.vocabulary import Vocabulary




+ 1
- 1
fastNLP/io/data_loader/__init__.py View File

@@ -1,4 +1,4 @@
"""
"""undocumented
.. warning:: .. warning::


本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。


+ 3
- 3
fastNLP/io/dataset_loader.py View File

@@ -1,4 +1,4 @@
"""
"""undocumented
.. warning:: .. warning::


本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。 本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。
@@ -23,10 +23,10 @@ __all__ = [
] ]




from .data_bundle import DataSetLoader
from .file_reader import _read_csv, _read_json
from ..core.dataset import DataSet from ..core.dataset import DataSet
from ..core.instance import Instance from ..core.instance import Instance
from .file_reader import _read_csv, _read_json
from .data_bundle import DataSetLoader




class JsonLoader(DataSetLoader): class JsonLoader(DataSetLoader):


+ 9
- 4
fastNLP/io/embed_loader.py View File

@@ -1,16 +1,21 @@
"""
.. todo::
doc
"""
__all__ = [ __all__ = [
"EmbedLoader", "EmbedLoader",
"EmbeddingOption", "EmbeddingOption",
] ]


import logging
import os import os
import warnings import warnings


import numpy as np import numpy as np


from ..core.vocabulary import Vocabulary
from .data_bundle import BaseLoader from .data_bundle import BaseLoader
from ..core.utils import Option from ..core.utils import Option
from ..core.vocabulary import Vocabulary




class EmbeddingOption(Option): class EmbeddingOption(Option):
@@ -91,10 +96,10 @@ class EmbedLoader(BaseLoader):
if error == 'ignore': if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx)) warnings.warn("Error occurred at the {} line.".format(idx))
else: else:
print("Error occurred at the {} line.".format(idx))
logging.error("Error occurred at the {} line.".format(idx))
raise e raise e
total_hits = sum(hit_flags) total_hits = sum(hit_flags)
print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
logging.info("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
if init_method is None: if init_method is None:
found_vectors = matrix[hit_flags] found_vectors = matrix[hit_flags]
if len(found_vectors) != 0: if len(found_vectors) != 0:
@@ -157,7 +162,7 @@ class EmbedLoader(BaseLoader):
warnings.warn("Error occurred at the {} line.".format(idx)) warnings.warn("Error occurred at the {} line.".format(idx))
pass pass
else: else:
print("Error occurred at the {} line.".format(idx))
logging.error("Error occurred at the {} line.".format(idx))
raise e raise e
if dim == -1: if dim == -1:
raise RuntimeError("{} is an empty file.".format(embed_filepath)) raise RuntimeError("{} is an empty file.".format(embed_filepath))


+ 16
- 9
fastNLP/io/file_reader.py View File

@@ -1,8 +1,13 @@
"""
"""undocumented
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API 此模块用于给其它模块提供读取文件的函数,没有为用户提供 API
""" """

__all__ = []

import json import json
import warnings

from ..core import logger



def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
""" """
@@ -23,8 +28,8 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
headers = headers.split(sep) headers = headers.split(sep)
start_idx += 1 start_idx += 1
elif not isinstance(headers, (list, tuple)): elif not isinstance(headers, (list, tuple)):
raise TypeError("headers should be list or tuple, not {}." \
.format(type(headers)))
raise TypeError("headers should be list or tuple, not {}." \
.format(type(headers)))
for line_idx, line in enumerate(f, start_idx): for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep) contents = line.rstrip('\r\n').split(sep)
if len(contents) != len(headers): if len(contents) != len(headers):
@@ -81,6 +86,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
:if False, raise ValueError when reading invalid data. default: True :if False, raise ValueError when reading invalid data. default: True
:return: generator, every time yield (line number, conll item) :return: generator, every time yield (line number, conll item)
""" """
def parse_conll(sample): def parse_conll(sample):
sample = list(map(list, zip(*sample))) sample = list(map(list, zip(*sample)))
sample = [sample[i] for i in indexes] sample = [sample[i] for i in indexes]
@@ -88,14 +94,15 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
if len(f) <= 0: if len(f) <= 0:
raise ValueError('empty field') raise ValueError('empty field')
return sample return sample
with open(path, 'r', encoding=encoding) as f: with open(path, 'r', encoding=encoding) as f:
sample = [] sample = []
start = next(f).strip() start = next(f).strip()
if start!='':
if start != '':
sample.append(start.split()) sample.append(start.split())
for line_idx, line in enumerate(f, 1): for line_idx, line in enumerate(f, 1):
line = line.strip() line = line.strip()
if line=='':
if line == '':
if len(sample): if len(sample):
try: try:
res = parse_conll(sample) res = parse_conll(sample)
@@ -103,9 +110,9 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
yield line_idx, res yield line_idx, res
except Exception as e: except Exception as e:
if dropna: if dropna:
warnings.warn('Invalid instance ends at line: {} has been dropped.'.format(line_idx))
logger.warn('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
continue continue
raise ValueError('Invalid instance ends at line: {}'.format(line_idx))
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
elif line.startswith('#'): elif line.startswith('#'):
continue continue
else: else:
@@ -117,5 +124,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e: except Exception as e:
if dropna: if dropna:
return return
print('invalid instance ends at line: {}'.format(line_idx))
logger.error('invalid instance ends at line: {}'.format(line_idx))
raise e raise e

+ 27
- 11
fastNLP/io/file_utils.py View File

@@ -1,12 +1,28 @@
"""
.. todo::
doc
"""

__all__ = [
"cached_path",
"get_filepath",
"get_cache_path",
"split_filename_suffix",
"get_from_cache",
]

import os import os
import re
import shutil
import tempfile
from pathlib import Path from pathlib import Path
from urllib.parse import urlparse from urllib.parse import urlparse
import re
import requests import requests
import tempfile
from tqdm import tqdm
import shutil
from requests import HTTPError from requests import HTTPError
from tqdm import tqdm

from ..core import logger


PRETRAINED_BERT_MODEL_DIR = { PRETRAINED_BERT_MODEL_DIR = {
'en': 'bert-base-cased.zip', 'en': 'bert-base-cased.zip',
@@ -58,7 +74,7 @@ PRETRAIN_STATIC_FILES = {
'en-fasttext-crawl': "crawl-300d-2M.vec.zip", 'en-fasttext-crawl': "crawl-300d-2M.vec.zip",


'cn': "tencent_cn.zip", 'cn': "tencent_cn.zip",
'cn-tencent': "tencent_cn.txt.zip",
'cn-tencent': "tencent_cn.zip",
'cn-fasttext': "cc.zh.300.vec.gz", 'cn-fasttext': "cc.zh.300.vec.gz",
'cn-sgns-literature-word': 'sgns.literature.word.txt.zip', 'cn-sgns-literature-word': 'sgns.literature.word.txt.zip',
} }
@@ -336,7 +352,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
content_length = req.headers.get("Content-Length") content_length = req.headers.get("Content-Length")
total = int(content_length) if content_length is not None else None total = int(content_length) if content_length is not None else None
progress = tqdm(unit="B", total=total, unit_scale=1) progress = tqdm(unit="B", total=total, unit_scale=1)
print("%s not found in cache, downloading to %s" % (url, temp_filename))
logger.info("%s not found in cache, downloading to %s" % (url, temp_filename))


with open(temp_filename, "wb") as temp_file: with open(temp_filename, "wb") as temp_file:
for chunk in req.iter_content(chunk_size=1024 * 16): for chunk in req.iter_content(chunk_size=1024 * 16):
@@ -344,12 +360,12 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
progress.update(len(chunk)) progress.update(len(chunk))
temp_file.write(chunk) temp_file.write(chunk)
progress.close() progress.close()
print(f"Finish download from {url}")
logger.info(f"Finish download from {url}")


# 开始解压 # 开始解压
if suffix in ('.zip', '.tar.gz', '.gz'): if suffix in ('.zip', '.tar.gz', '.gz'):
uncompress_temp_dir = tempfile.mkdtemp() uncompress_temp_dir = tempfile.mkdtemp()
print(f"Start to uncompress file to {uncompress_temp_dir}")
logger.debug(f"Start to uncompress file to {uncompress_temp_dir}")
if suffix == '.zip': if suffix == '.zip':
unzip_file(Path(temp_filename), Path(uncompress_temp_dir)) unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
elif suffix == '.gz': elif suffix == '.gz':
@@ -362,13 +378,13 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0]) uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])


cache_path.mkdir(parents=True, exist_ok=True) cache_path.mkdir(parents=True, exist_ok=True)
print("Finish un-compressing file.")
logger.debug("Finish un-compressing file.")
else: else:
uncompress_temp_dir = temp_filename uncompress_temp_dir = temp_filename
cache_path = str(cache_path) + suffix cache_path = str(cache_path) + suffix


# 复制到指定的位置 # 复制到指定的位置
print(f"Copy file to {cache_path}")
logger.info(f"Copy file to {cache_path}")
if os.path.isdir(uncompress_temp_dir): if os.path.isdir(uncompress_temp_dir):
for filename in os.listdir(uncompress_temp_dir): for filename in os.listdir(uncompress_temp_dir):
if os.path.isdir(os.path.join(uncompress_temp_dir, filename)): if os.path.isdir(os.path.join(uncompress_temp_dir, filename)):
@@ -379,7 +395,7 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
shutil.copyfile(uncompress_temp_dir, cache_path) shutil.copyfile(uncompress_temp_dir, cache_path)
success = True success = True
except Exception as e: except Exception as e:
print(e)
logger.error(e)
raise e raise e
finally: finally:
if not success: if not success:


+ 2
- 2
fastNLP/io/loader/__init__.py View File

@@ -62,8 +62,8 @@ __all__ = [
"PeopleDailyNERLoader", "PeopleDailyNERLoader",
"WeiboNERLoader", "WeiboNERLoader",


# 'CSVLoader',
# 'JsonLoader',
'CSVLoader',
'JsonLoader',


'CWSLoader', 'CWSLoader',




+ 19
- 8
fastNLP/io/loader/classification.py View File

@@ -1,13 +1,24 @@
from ...core.dataset import DataSet
from ...core.instance import Instance
from .loader import Loader
import warnings
"""undocumented"""

__all__ = [
"YelpLoader",
"YelpFullLoader",
"YelpPolarityLoader",
"IMDBLoader",
"SSTLoader",
"SST2Loader",
]

import glob
import os import os
import random import random
import shutil import shutil
import numpy as np
import glob
import time import time
import warnings

from .loader import Loader
from ...core.dataset import DataSet
from ...core.instance import Instance




class YelpLoader(Loader): class YelpLoader(Loader):
@@ -59,7 +70,7 @@ class YelpLoader(Loader):




class YelpFullLoader(YelpLoader): class YelpFullLoader(YelpLoader):
def download(self, dev_ratio: float = 0.1, re_download:bool=False):
def download(self, dev_ratio: float = 0.1, re_download: bool = False):
""" """
自动下载数据集,如果你使用了这个数据集,请引用以下的文章 自动下载数据集,如果你使用了这个数据集,请引用以下的文章


@@ -128,7 +139,7 @@ class YelpPolarityLoader(YelpLoader):
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir) shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.csv')): if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
if dev_ratio > 0: if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."


+ 53
- 34
fastNLP/io/loader/conll.py View File

@@ -1,19 +1,33 @@
from typing import Dict, Union
"""undocumented"""

__all__ = [
"ConllLoader",
"Conll2003Loader",
"Conll2003NERLoader",
"OntoNotesNERLoader",
"CTBLoader",
"CNNERLoader",
"MsraNERLoader",
"WeiboNERLoader",
"PeopleDailyNERLoader"
]


from .loader import Loader
from ...core.dataset import DataSet
from ..file_reader import _read_conll
from ...core.instance import Instance
from ...core.const import Const
import glob import glob
import os import os
import random
import shutil import shutil
import time import time
import random

from .loader import Loader
from ..file_reader import _read_conll
from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance



class ConllLoader(Loader): class ConllLoader(Loader):
""" """
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader`
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.loader.ConllLoader`


ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示:


@@ -46,6 +60,7 @@ class ConllLoader(Loader):
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``


""" """
def __init__(self, headers, indexes=None, dropna=True): def __init__(self, headers, indexes=None, dropna=True):
super(ConllLoader, self).__init__() super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)): if not isinstance(headers, (list, tuple)):
@@ -59,7 +74,7 @@ class ConllLoader(Loader):
if len(indexes) != len(headers): if len(indexes) != len(headers):
raise ValueError raise ValueError
self.indexes = indexes self.indexes = indexes
def _load(self, path): def _load(self, path):
""" """
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -100,12 +115,13 @@ class Conll2003Loader(ConllLoader):
"[...]", "[...]", "[...]", "[...]" "[...]", "[...]", "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
headers = [ headers = [
'raw_words', 'pos', 'chunk', 'ner', 'raw_words', 'pos', 'chunk', 'ner',
] ]
super(Conll2003Loader, self).__init__(headers=headers) super(Conll2003Loader, self).__init__(headers=headers)
def _load(self, path): def _load(self, path):
""" """
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -126,7 +142,7 @@ class Conll2003Loader(ConllLoader):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds
def download(self, output_dir=None): def download(self, output_dir=None):
raise RuntimeError("conll2003 cannot be downloaded automatically.") raise RuntimeError("conll2003 cannot be downloaded automatically.")


@@ -157,12 +173,13 @@ class Conll2003NERLoader(ConllLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
headers = [ headers = [
'raw_words', 'target', 'raw_words', 'target',
] ]
super().__init__(headers=headers, indexes=[0, 3]) super().__init__(headers=headers, indexes=[0, 3])
def _load(self, path): def _load(self, path):
""" """
传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
@@ -183,7 +200,7 @@ class Conll2003NERLoader(ConllLoader):
ins = {h: data[i] for i, h in enumerate(self.headers)} ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins)) ds.append(Instance(**ins))
return ds return ds
def download(self): def download(self):
raise RuntimeError("conll2003 cannot be downloaded automatically.") raise RuntimeError("conll2003 cannot be downloaded automatically.")


@@ -203,13 +220,13 @@ class OntoNotesNERLoader(ConllLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10]) super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10])
def _load(self, path:str):
def _load(self, path: str):
dataset = super()._load(path) dataset = super()._load(path)
def convert_to_bio(tags): def convert_to_bio(tags):
bio_tags = [] bio_tags = []
flag = None flag = None
@@ -226,7 +243,7 @@ class OntoNotesNERLoader(ConllLoader):
flag = None flag = None
bio_tags.append(bio_label) bio_tags.append(bio_label)
return bio_tags return bio_tags
def convert_word(words): def convert_word(words):
converted_words = [] converted_words = []
for word in words: for word in words:
@@ -235,7 +252,7 @@ class OntoNotesNERLoader(ConllLoader):
converted_words.append(word) converted_words.append(word)
continue continue
# 以下是由于这些符号被转义了,再转回来 # 以下是由于这些符号被转义了,再转回来
tfrs = {'-LRB-':'(',
tfrs = {'-LRB-': '(',
'-RRB-': ')', '-RRB-': ')',
'-LSB-': '[', '-LSB-': '[',
'-RSB-': ']', '-RSB-': ']',
@@ -247,12 +264,12 @@ class OntoNotesNERLoader(ConllLoader):
else: else:
converted_words.append(word) converted_words.append(word)
return converted_words return converted_words
dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET) dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET)
return dataset return dataset
def download(self): def download(self):
raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer " raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer "
"https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.") "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.")
@@ -261,13 +278,13 @@ class OntoNotesNERLoader(ConllLoader):
class CTBLoader(Loader): class CTBLoader(Loader):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
pass pass




class CNNERLoader(Loader): class CNNERLoader(Loader):
def _load(self, path:str):
def _load(self, path: str):
""" """
支持加载形如以下格式的内容,一行两列,以空格隔开两个sample 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample


@@ -330,10 +347,11 @@ class MsraNERLoader(CNNERLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def download(self, dev_ratio:float=0.1, re_download:bool=False)->str:
def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str:
""" """
自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language
Processing Bakeoff: Word Segmentation and Named Entity Recognition. Processing Bakeoff: Word Segmentation and Named Entity Recognition.
@@ -355,7 +373,7 @@ class MsraNERLoader(CNNERLoader):
if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的 if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
shutil.rmtree(data_dir) shutil.rmtree(data_dir)
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
if not os.path.exists(os.path.join(data_dir, 'dev.conll')): if not os.path.exists(os.path.join(data_dir, 'dev.conll')):
if dev_ratio > 0: if dev_ratio > 0:
assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)." assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
@@ -379,15 +397,15 @@ class MsraNERLoader(CNNERLoader):
finally: finally:
if os.path.exists(os.path.join(data_dir, 'middle_file.conll')): if os.path.exists(os.path.join(data_dir, 'middle_file.conll')):
os.remove(os.path.join(data_dir, 'middle_file.conll')) os.remove(os.path.join(data_dir, 'middle_file.conll'))
return data_dir return data_dir




class WeiboNERLoader(CNNERLoader): class WeiboNERLoader(CNNERLoader):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def download(self)->str:
def download(self) -> str:
""" """
自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for
Chinese Social Media with Jointly Trained Embeddings. Chinese Social Media with Jointly Trained Embeddings.
@@ -396,7 +414,7 @@ class WeiboNERLoader(CNNERLoader):
""" """
dataset_name = 'weibo-ner' dataset_name = 'weibo-ner'
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
return data_dir return data_dir




@@ -426,11 +444,12 @@ class PeopleDailyNERLoader(CNNERLoader):
"[...]", "[...]" "[...]", "[...]"


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def download(self) -> str: def download(self) -> str:
dataset_name = 'peopledaily' dataset_name = 'peopledaily'
data_dir = self._get_dataset_path(dataset_name=dataset_name) data_dir = self._get_dataset_path(dataset_name=dataset_name)
return data_dir return data_dir

+ 9
- 3
fastNLP/io/loader/csv.py View File

@@ -1,12 +1,18 @@
"""undocumented"""

__all__ = [
"CSVLoader",
]

from .loader import Loader
from ..file_reader import _read_csv
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ..file_reader import _read_csv
from .loader import Loader




class CSVLoader(Loader): class CSVLoader(Loader):
""" """
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.dataset_loader.CSVLoader`
别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.loader.CSVLoader`


读取CSV格式的数据集, 返回 ``DataSet`` 。 读取CSV格式的数据集, 返回 ``DataSet`` 。




+ 12
- 5
fastNLP/io/loader/cws.py View File

@@ -1,11 +1,18 @@
from .loader import Loader
from ...core.dataset import DataSet
from ...core.instance import Instance
"""undocumented"""

__all__ = [
"CWSLoader"
]

import glob import glob
import os import os
import time
import shutil
import random import random
import shutil
import time

from .loader import Loader
from ...core.dataset import DataSet
from ...core.instance import Instance




class CWSLoader(Loader): class CWSLoader(Loader):


+ 8
- 2
fastNLP/io/loader/json.py View File

@@ -1,7 +1,13 @@
"""undocumented"""

__all__ = [
"JsonLoader"
]

from .loader import Loader
from ..file_reader import _read_json
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ..file_reader import _read_json
from .loader import Loader




class JsonLoader(Loader): class JsonLoader(Loader):


+ 10
- 3
fastNLP/io/loader/loader.py View File

@@ -1,8 +1,15 @@
from ...core.dataset import DataSet
from .. import DataBundle
from ..utils import check_loader_paths
"""undocumented"""

__all__ = [
"Loader"
]

from typing import Union, Dict from typing import Union, Dict

from .. import DataBundle
from ..file_utils import _get_dataset_url, get_cache_path, cached_path from ..file_utils import _get_dataset_url, get_cache_path, cached_path
from ..utils import check_loader_paths
from ...core.dataset import DataSet




class Loader: class Loader:


+ 49
- 33
fastNLP/io/loader/matching.py View File

@@ -1,10 +1,21 @@
"""undocumented"""

__all__ = [
"MNLILoader",
"SNLILoader",
"QNLILoader",
"RTELoader",
"QuoraLoader",
]

import os
import warnings import warnings
from .loader import Loader
from typing import Union, Dict

from .json import JsonLoader from .json import JsonLoader
from ...core.const import Const
from .loader import Loader
from .. import DataBundle from .. import DataBundle
import os
from typing import Union, Dict
from ...core.const import Const
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance


@@ -22,10 +33,11 @@ class MNLILoader(Loader):
"...", "...","." "...", "...","."


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
@@ -50,8 +62,8 @@ class MNLILoader(Loader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def load(self, paths:str=None):
def load(self, paths: str = None):
""" """


:param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv,
@@ -64,13 +76,13 @@ class MNLILoader(Loader):
paths = self.download() paths = self.download()
if not os.path.isdir(paths): if not os.path.isdir(paths):
raise NotADirectoryError(f"{paths} is not a valid directory.") raise NotADirectoryError(f"{paths} is not a valid directory.")
files = {'dev_matched':"dev_matched.tsv",
"dev_mismatched":"dev_mismatched.tsv",
"test_matched":"test_matched.tsv",
"test_mismatched":"test_mismatched.tsv",
"train":'train.tsv'}
files = {'dev_matched': "dev_matched.tsv",
"dev_mismatched": "dev_mismatched.tsv",
"test_matched": "test_matched.tsv",
"test_mismatched": "test_mismatched.tsv",
"train": 'train.tsv'}
datasets = {} datasets = {}
for name, filename in files.items(): for name, filename in files.items():
filepath = os.path.join(paths, filename) filepath = os.path.join(paths, filename)
@@ -78,11 +90,11 @@ class MNLILoader(Loader):
if 'test' not in name: if 'test' not in name:
raise FileNotFoundError(f"{name} not found in directory {filepath}.") raise FileNotFoundError(f"{name} not found in directory {filepath}.")
datasets[name] = self._load(filepath) datasets[name] = self._load(filepath)
data_bundle = DataBundle(datasets=datasets) data_bundle = DataBundle(datasets=datasets)
return data_bundle return data_bundle
def download(self): def download(self):
""" """
如果你使用了这个数据,请引用 如果你使用了这个数据,请引用
@@ -106,14 +118,15 @@ class SNLILoader(JsonLoader):
"...", "...", "." "...", "...", "."


""" """
def __init__(self): def __init__(self):
super().__init__(fields={ super().__init__(fields={
'sentence1': Const.RAW_WORDS(0), 'sentence1': Const.RAW_WORDS(0),
'sentence2': Const.RAW_WORDS(1), 'sentence2': Const.RAW_WORDS(1),
'gold_label': Const.TARGET, 'gold_label': Const.TARGET,
}) })
def load(self, paths: Union[str, Dict[str, str]]=None) -> DataBundle:
def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
""" """
从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。


@@ -138,11 +151,11 @@ class SNLILoader(JsonLoader):
paths = _paths paths = _paths
else: else:
raise NotADirectoryError(f"{paths} is not a valid directory.") raise NotADirectoryError(f"{paths} is not a valid directory.")
datasets = {name: self._load(path) for name, path in paths.items()} datasets = {name: self._load(path) for name, path in paths.items()}
data_bundle = DataBundle(datasets=datasets) data_bundle = DataBundle(datasets=datasets)
return data_bundle return data_bundle
def download(self): def download(self):
""" """
如果您的文章使用了这份数据,请引用 如果您的文章使用了这份数据,请引用
@@ -169,12 +182,13 @@ class QNLILoader(JsonLoader):
test数据集没有target列 test数据集没有target列


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path): def _load(self, path):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
@@ -198,7 +212,7 @@ class QNLILoader(JsonLoader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def download(self): def download(self):
""" """
如果您的实验使用到了该数据,请引用 如果您的实验使用到了该数据,请引用
@@ -225,12 +239,13 @@ class RTELoader(Loader):


test数据集没有target列 test数据集没有target列
""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
f.readline() # 跳过header f.readline() # 跳过header
if path.endswith("test.tsv"): if path.endswith("test.tsv"):
@@ -254,7 +269,7 @@ class RTELoader(Loader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def download(self): def download(self):
return self._get_dataset_path('rte') return self._get_dataset_path('rte')


@@ -281,12 +296,13 @@ class QuoraLoader(Loader):
"...","." "...","."


""" """
def __init__(self): def __init__(self):
super().__init__() super().__init__()
def _load(self, path:str):
def _load(self, path: str):
ds = DataSet() ds = DataSet()
with open(path, 'r', encoding='utf-8') as f: with open(path, 'r', encoding='utf-8') as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
@@ -298,6 +314,6 @@ class QuoraLoader(Loader):
if raw_words1 and raw_words2 and target: if raw_words1 and raw_words2 and target:
ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target)) ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
return ds return ds
def download(self): def download(self):
raise RuntimeError("Quora cannot be downloaded automatically.") raise RuntimeError("Quora cannot be downloaded automatically.")

+ 3
- 0
fastNLP/io/pipe/__init__.py View File

@@ -10,6 +10,8 @@ Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``proce
__all__ = [ __all__ = [
"Pipe", "Pipe",


"CWSPipe",

"YelpFullPipe", "YelpFullPipe",
"YelpPolarityPipe", "YelpPolarityPipe",
"SSTPipe", "SSTPipe",
@@ -43,3 +45,4 @@ from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe
MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe
from .pipe import Pipe from .pipe import Pipe
from .conll import Conll2003Pipe from .conll import Conll2003Pipe
from .cws import CWSPipe

+ 89
- 74
fastNLP/io/pipe/classification.py View File

@@ -1,26 +1,39 @@
"""undocumented"""

__all__ = [
"YelpFullPipe",
"YelpPolarityPipe",
"SSTPipe",
"SST2Pipe",
'IMDBPipe'
]

import re

from nltk import Tree from nltk import Tree


from .pipe import Pipe
from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from ..data_bundle import DataBundle from ..data_bundle import DataBundle
from ...core.vocabulary import Vocabulary
from ...core.const import Const
from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
from ...core.const import Const
from ...core.dataset import DataSet from ...core.dataset import DataSet
from ...core.instance import Instance from ...core.instance import Instance
from ...core.vocabulary import Vocabulary


from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
from .pipe import Pipe
import re
nonalpnum = re.compile('[^0-9a-zA-Z?!\']+') nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
from ...core.utils import cache_results




class _CLSPipe(Pipe): class _CLSPipe(Pipe):
""" """
分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列


""" """
def __init__(self, tokenizer:str='spacy', lang='en'):
def __init__(self, tokenizer: str = 'spacy', lang='en'):
self.tokenizer = get_tokenizer(tokenizer, lang=lang) self.tokenizer = get_tokenizer(tokenizer, lang=lang)
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
""" """
将DataBundle中的数据进行tokenize 将DataBundle中的数据进行tokenize
@@ -33,9 +46,9 @@ class _CLSPipe(Pipe):
new_field_name = new_field_name or field_name new_field_name = new_field_name or field_name
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
return data_bundle return data_bundle
def _granularize(self, data_bundle, tag_map): def _granularize(self, data_bundle, tag_map):
""" """
该函数对data_bundle中'target'列中的内容进行转换。 该函数对data_bundle中'target'列中的内容进行转换。
@@ -47,9 +60,9 @@ class _CLSPipe(Pipe):
""" """
for name in list(data_bundle.datasets.keys()): for name in list(data_bundle.datasets.keys()):
dataset = data_bundle.get_dataset(name) dataset = data_bundle.get_dataset(name)
dataset.apply_field(lambda target:tag_map.get(target, -100), field_name=Const.TARGET,
dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET,
new_field_name=Const.TARGET) new_field_name=Const.TARGET)
dataset.drop(lambda ins:ins[Const.TARGET] == -100)
dataset.drop(lambda ins: ins[Const.TARGET] == -100)
data_bundle.set_dataset(dataset, name) data_bundle.set_dataset(dataset, name)
return data_bundle return data_bundle


@@ -69,7 +82,7 @@ def _clean_str(words):
t = ''.join(tt) t = ''.join(tt)
if t != '': if t != '':
words_collection.append(t) words_collection.append(t)
return words_collection return words_collection




@@ -89,19 +102,20 @@ class YelpFullPipe(_CLSPipe):
1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, lower:bool=False, granularity=5, tokenizer:str='spacy'):
def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity self.granularity = granularity
if granularity==2:
if granularity == 2:
self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1} self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1}
elif granularity==3:
self.tag_map = {"1": 0, "2": 0, "3":1, "4": 2, "5": 2}
elif granularity == 3:
self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2}
else: else:
self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4} self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}
def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None): def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
""" """
将DataBundle中的数据进行tokenize 将DataBundle中的数据进行tokenize
@@ -116,7 +130,7 @@ class YelpFullPipe(_CLSPipe):
dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name) dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name) dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name)
return data_bundle return data_bundle
def process(self, data_bundle): def process(self, data_bundle):
""" """
传入的DataSet应该具备如下的结构 传入的DataSet应该具备如下的结构
@@ -131,30 +145,30 @@ class YelpFullPipe(_CLSPipe):
:param data_bundle: :param data_bundle:
:return: :return:
""" """
# 复制一列words # 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower) data_bundle = _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize # 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# 根据granularity设置tag # 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
# 删除空行 # 删除空行
data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT) data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT)
# index # index
data_bundle = _indexize(data_bundle=data_bundle) data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -179,27 +193,28 @@ class YelpPolarityPipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。 :param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, lower:bool=False, tokenizer:str='spacy'):
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
def process(self, data_bundle): def process(self, data_bundle):
# 复制一列words # 复制一列words
data_bundle = _add_words_field(data_bundle, lower=self.lower) data_bundle = _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize # 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# index # index
data_bundle = _indexize(data_bundle=data_bundle) data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -230,7 +245,7 @@ class SSTPipe(_CLSPipe):
0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'): def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.subtree = subtree self.subtree = subtree
@@ -238,15 +253,15 @@ class SSTPipe(_CLSPipe):
self.lower = lower self.lower = lower
assert granularity in (2, 3, 5), "granularity can only be 2,3,5." assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
self.granularity = granularity self.granularity = granularity
if granularity==2:
if granularity == 2:
self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1} self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1}
elif granularity==3:
self.tag_map = {"0": 0, "1": 0, "2":1, "3": 2, "4": 2}
elif granularity == 3:
self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2}
else: else:
self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4} self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
def process(self, data_bundle:DataBundle):
def process(self, data_bundle: DataBundle):
""" """
对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与


@@ -277,26 +292,26 @@ class SSTPipe(_CLSPipe):
instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label()) instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label())
ds.append(instance) ds.append(instance)
data_bundle.set_dataset(ds, name) data_bundle.set_dataset(ds, name)
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
# 进行tokenize # 进行tokenize
data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT) data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
# 根据granularity设置tag # 根据granularity设置tag
data_bundle = self._granularize(data_bundle, tag_map=self.tag_map) data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
# index # index
data_bundle = _indexize(data_bundle=data_bundle) data_bundle = _indexize(data_bundle=data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
data_bundle = SSTLoader().load(paths) data_bundle = SSTLoader().load(paths)
return self.process(data_bundle=data_bundle) return self.process(data_bundle=data_bundle)
@@ -316,11 +331,12 @@ class SST2Pipe(_CLSPipe):
:param bool lower: 是否对输入进行小写化。 :param bool lower: 是否对输入进行小写化。
:param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
""" """
def __init__(self, lower=False, tokenizer='spacy'): def __init__(self, lower=False, tokenizer='spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
def process(self, data_bundle:DataBundle):
def process(self, data_bundle: DataBundle):
""" """
可以处理的DataSet应该具备如下的结构 可以处理的DataSet应该具备如下的结构


@@ -335,15 +351,15 @@ class SST2Pipe(_CLSPipe):
:return: :return:
""" """
_add_words_field(data_bundle, self.lower) _add_words_field(data_bundle, self.lower)
data_bundle = self._tokenize(data_bundle=data_bundle) data_bundle = self._tokenize(data_bundle=data_bundle)
src_vocab = Vocabulary() src_vocab = Vocabulary()
src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT, src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name,dataset in data_bundle.datasets.items() if
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
name != 'train']) name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
tgt_vocab = Vocabulary(unknown=None, padding=None) tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
datasets = [] datasets = []
@@ -351,18 +367,18 @@ class SST2Pipe(_CLSPipe):
if dataset.has_field(Const.TARGET): if dataset.has_field(Const.TARGET):
datasets.append(dataset) datasets.append(dataset)
tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET) tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET)
data_bundle.set_vocab(src_vocab, Const.INPUT) data_bundle.set_vocab(src_vocab, Const.INPUT)
data_bundle.set_vocab(tgt_vocab, Const.TARGET) data_bundle.set_vocab(tgt_vocab, Const.TARGET)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(Const.INPUT, Const.INPUT_LEN) data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
data_bundle.set_target(Const.TARGET) data_bundle.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -390,11 +406,12 @@ class IMDBPipe(_CLSPipe):
:param bool lower: 是否将words列的数据小写。 :param bool lower: 是否将words列的数据小写。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
""" """
def __init__(self, lower:bool=False, tokenizer:str='spacy'):
def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
super().__init__(tokenizer=tokenizer, lang='en') super().__init__(tokenizer=tokenizer, lang='en')
self.lower = lower self.lower = lower
def process(self, data_bundle:DataBundle):
def process(self, data_bundle: DataBundle):
""" """
期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型


@@ -409,25 +426,26 @@ class IMDBPipe(_CLSPipe):
target列应该为str。 target列应该为str。
:return: DataBundle :return: DataBundle
""" """
# 替换<br /> # 替换<br />
def replace_br(raw_words): def replace_br(raw_words):
raw_words = raw_words.replace("<br />", ' ') raw_words = raw_words.replace("<br />", ' ')
return raw_words return raw_words
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD) dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT) self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT)
_indexize(data_bundle) _indexize(data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
dataset.set_input(Const.INPUT, Const.INPUT_LEN) dataset.set_input(Const.INPUT, Const.INPUT_LEN)
dataset.set_target(Const.TARGET) dataset.set_target(Const.TARGET)
return data_bundle return data_bundle
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
""" """


@@ -437,8 +455,5 @@ class IMDBPipe(_CLSPipe):
# 读取数据 # 读取数据
data_bundle = IMDBLoader().load(paths) data_bundle = IMDBLoader().load(paths)
data_bundle = self.process(data_bundle) data_bundle = self.process(data_bundle)
return data_bundle return data_bundle




+ 47
- 32
fastNLP/io/pipe/conll.py View File

@@ -1,13 +1,25 @@
"""undocumented"""

__all__ = [
"Conll2003NERPipe",
"Conll2003Pipe",
"OntoNotesNERPipe",
"MsraNERPipe",
"PeopleDailyPipe",
"WeiboNERPipe"
]

from .pipe import Pipe from .pipe import Pipe
from .. import DataBundle
from .utils import _add_chars_field
from .utils import _indexize, _add_words_field
from .utils import iob2, iob2bioes from .utils import iob2, iob2bioes
from ...core.const import Const
from .. import DataBundle
from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
from .utils import _indexize, _add_words_field
from .utils import _add_chars_field
from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader
from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary



class _NERPipe(Pipe): class _NERPipe(Pipe):
""" """
NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表
@@ -20,14 +32,14 @@ class _NERPipe(Pipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
""" """
def __init__(self, encoding_type: str = 'bio', lower: bool = False): def __init__(self, encoding_type: str = 'bio', lower: bool = False):
if encoding_type == 'bio': if encoding_type == 'bio':
self.convert_tag = iob2 self.convert_tag = iob2
else: else:
self.convert_tag = lambda words: iob2bioes(iob2(words)) self.convert_tag = lambda words: iob2bioes(iob2(words))
self.lower = lower self.lower = lower
def process(self, data_bundle: DataBundle) -> DataBundle: def process(self, data_bundle: DataBundle) -> DataBundle:
""" """
支持的DataSet的field为 支持的DataSet的field为
@@ -46,21 +58,21 @@ class _NERPipe(Pipe):
# 转换tag # 转换tag
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
# index # index
_indexize(data_bundle) _indexize(data_bundle)
input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN] input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(*input_fields) data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields) data_bundle.set_target(*target_fields)
return data_bundle return data_bundle




@@ -84,7 +96,7 @@ class Conll2003NERPipe(_NERPipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
""" """
def process_from_file(self, paths) -> DataBundle: def process_from_file(self, paths) -> DataBundle:
""" """


@@ -94,7 +106,7 @@ class Conll2003NERPipe(_NERPipe):
# 读取数据 # 读取数据
data_bundle = Conll2003NERLoader().load(paths) data_bundle = Conll2003NERLoader().load(paths)
data_bundle = self.process(data_bundle) data_bundle = self.process(data_bundle)
return data_bundle return data_bundle




@@ -125,8 +137,8 @@ class Conll2003Pipe(Pipe):
else: else:
self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags)) self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags))
self.lower = lower self.lower = lower
def process(self, data_bundle)->DataBundle:
def process(self, data_bundle) -> DataBundle:
""" """
输入的DataSet应该类似于如下的形式 输入的DataSet应该类似于如下的形式


@@ -145,9 +157,9 @@ class Conll2003Pipe(Pipe):
dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD]) dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD])
dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk') dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk')
dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner') dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner')
_add_words_field(data_bundle, lower=self.lower) _add_words_field(data_bundle, lower=self.lower)
# index # index
_indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner']) _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner'])
# chunk中存在一些tag只在dev中出现,没在train中 # chunk中存在一些tag只在dev中出现,没在train中
@@ -155,18 +167,18 @@ class Conll2003Pipe(Pipe):
tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk') tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk')
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk') tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk')
data_bundle.set_vocab(tgt_vocab, 'chunk') data_bundle.set_vocab(tgt_vocab, 'chunk')
input_fields = [Const.INPUT, Const.INPUT_LEN] input_fields = [Const.INPUT, Const.INPUT_LEN]
target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN] target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
data_bundle.set_input(*input_fields) data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields) data_bundle.set_target(*target_fields)
return data_bundle return data_bundle
def process_from_file(self, paths): def process_from_file(self, paths):
""" """


@@ -194,7 +206,7 @@ class OntoNotesNERPipe(_NERPipe):
:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
:param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。 :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
""" """
def process_from_file(self, paths): def process_from_file(self, paths):
data_bundle = OntoNotesNERLoader().load(paths) data_bundle = OntoNotesNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)
@@ -211,13 +223,13 @@ class _CNNERPipe(Pipe):


:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
""" """
def __init__(self, encoding_type: str = 'bio'): def __init__(self, encoding_type: str = 'bio'):
if encoding_type == 'bio': if encoding_type == 'bio':
self.convert_tag = iob2 self.convert_tag = iob2
else: else:
self.convert_tag = lambda words: iob2bioes(iob2(words)) self.convert_tag = lambda words: iob2bioes(iob2(words))
def process(self, data_bundle: DataBundle) -> DataBundle: def process(self, data_bundle: DataBundle) -> DataBundle:
""" """
支持的DataSet的field为 支持的DataSet的field为
@@ -239,21 +251,21 @@ class _CNNERPipe(Pipe):
# 转换tag # 转换tag
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET) dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
_add_chars_field(data_bundle, lower=False) _add_chars_field(data_bundle, lower=False)
# index # index
_indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET) _indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET)
input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN] input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT) dataset.add_seq_len(Const.CHAR_INPUT)
data_bundle.set_input(*input_fields) data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields) data_bundle.set_target(*target_fields)
return data_bundle return data_bundle




@@ -272,6 +284,7 @@ class MsraNERPipe(_CNNERPipe):
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。


""" """
def process_from_file(self, paths=None) -> DataBundle: def process_from_file(self, paths=None) -> DataBundle:
data_bundle = MsraNERLoader().load(paths) data_bundle = MsraNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)
@@ -291,6 +304,7 @@ class PeopleDailyPipe(_CNNERPipe):
raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
""" """
def process_from_file(self, paths=None) -> DataBundle: def process_from_file(self, paths=None) -> DataBundle:
data_bundle = PeopleDailyNERLoader().load(paths) data_bundle = PeopleDailyNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)
@@ -312,6 +326,7 @@ class WeiboNERPipe(_CNNERPipe):


:param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。 :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
""" """
def process_from_file(self, paths=None) -> DataBundle: def process_from_file(self, paths=None) -> DataBundle:
data_bundle = WeiboNERLoader().load(paths) data_bundle = WeiboNERLoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)

+ 55
- 35
fastNLP/io/pipe/cws.py View File

@@ -1,10 +1,19 @@
"""undocumented"""

__all__ = [
"CWSPipe"
]

import re
from itertools import chain

from .pipe import Pipe from .pipe import Pipe
from .utils import _indexize
from .. import DataBundle from .. import DataBundle
from ..loader import CWSLoader from ..loader import CWSLoader
from ... import Const
from itertools import chain
from .utils import _indexize
import re
from ...core.const import Const


def _word_lens_to_bmes(word_lens): def _word_lens_to_bmes(word_lens):
""" """


@@ -13,11 +22,11 @@ def _word_lens_to_bmes(word_lens):
""" """
tags = [] tags = []
for word_len in word_lens: for word_len in word_lens:
if word_len==1:
if word_len == 1:
tags.append('S') tags.append('S')
else: else:
tags.append('B') tags.append('B')
tags.extend(['M']*(word_len-2))
tags.extend(['M'] * (word_len - 2))
tags.append('E') tags.append('E')
return tags return tags


@@ -30,10 +39,10 @@ def _word_lens_to_segapp(word_lens):
""" """
tags = [] tags = []
for word_len in word_lens: for word_len in word_lens:
if word_len==1:
if word_len == 1:
tags.append('SEG') tags.append('SEG')
else: else:
tags.extend(['APP']*(word_len-1))
tags.extend(['APP'] * (word_len - 1))
tags.append('SEG') tags.append('SEG')
return tags return tags


@@ -97,13 +106,21 @@ def _digit_span_to_special_tag(span):
else: else:
return '<NUM>' return '<NUM>'



def _find_and_replace_digit_spans(line): def _find_and_replace_digit_spans(line):
# only consider words start with number, contains '.', characters.
# If ends with space, will be processed
# If ends with Chinese character, will be processed
# If ends with or contains english char, not handled.
# floats are replaced by <DEC>
# otherwise unkdgt
"""
only consider words start with number, contains '.', characters.
If ends with space, will be processed
If ends with Chinese character, will be processed
If ends with or contains english char, not handled.
floats are replaced by <DEC>
otherwise unkdgt
"""
new_line = '' new_line = ''
pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])' pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])'
prev_end = 0 prev_end = 0
@@ -136,17 +153,18 @@ class CWSPipe(Pipe):
:param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]
:param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
""" """
def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False): def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False):
if encoding_type=='bmes':
if encoding_type == 'bmes':
self.word_lens_to_tags = _word_lens_to_bmes self.word_lens_to_tags = _word_lens_to_bmes
else: else:
self.word_lens_to_tags = _word_lens_to_segapp self.word_lens_to_tags = _word_lens_to_segapp
self.dataset_name = dataset_name self.dataset_name = dataset_name
self.bigrams = bigrams self.bigrams = bigrams
self.trigrams = trigrams self.trigrams = trigrams
self.replace_num_alpha = replace_num_alpha self.replace_num_alpha = replace_num_alpha
def _tokenize(self, data_bundle): def _tokenize(self, data_bundle):
""" """
将data_bundle中的'chars'列切分成一个一个的word. 将data_bundle中的'chars'列切分成一个一个的word.
@@ -162,10 +180,10 @@ class CWSPipe(Pipe):
char = [] char = []
subchar = [] subchar = []
for c in word: for c in word:
if c=='<':
if c == '<':
subchar.append(c) subchar.append(c)
continue continue
if c=='>' and subchar[0]=='<':
if c == '>' and subchar[0] == '<':
char.append(''.join(subchar)) char.append(''.join(subchar))
subchar = [] subchar = []
if subchar: if subchar:
@@ -175,12 +193,12 @@ class CWSPipe(Pipe):
char.extend(subchar) char.extend(subchar)
chars.append(char) chars.append(char)
return chars return chars
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT, dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT,
new_field_name=Const.CHAR_INPUT) new_field_name=Const.CHAR_INPUT)
return data_bundle return data_bundle
def process(self, data_bundle: DataBundle) -> DataBundle: def process(self, data_bundle: DataBundle) -> DataBundle:
""" """
可以处理的DataSet需要包含raw_words列 可以处理的DataSet需要包含raw_words列
@@ -196,42 +214,43 @@ class CWSPipe(Pipe):
:return: :return:
""" """
data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT) data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT)
if self.replace_num_alpha: if self.replace_num_alpha:
data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT) data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
self._tokenize(data_bundle) self._tokenize(data_bundle)
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars:self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT,
dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT,
new_field_name=Const.TARGET) new_field_name=Const.TARGET)
dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT,
dataset.apply_field(lambda chars: list(chain(*chars)), field_name=Const.CHAR_INPUT,
new_field_name=Const.CHAR_INPUT) new_field_name=Const.CHAR_INPUT)
input_field_names = [Const.CHAR_INPUT] input_field_names = [Const.CHAR_INPUT]
if self.bigrams: if self.bigrams:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+['<eos>'])],
dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + ['<eos>'])],
field_name=Const.CHAR_INPUT, new_field_name='bigrams') field_name=Const.CHAR_INPUT, new_field_name='bigrams')
input_field_names.append('bigrams') input_field_names.append('bigrams')
if self.trigrams: if self.trigrams:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply_field(lambda chars: [c1+c2+c3 for c1, c2, c3 in zip(chars, chars[1:]+['<eos>'], chars[2:]+['<eos>']*2)],
dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
zip(chars, chars[1:] + ['<eos>'], chars[2:] + ['<eos>'] * 2)],
field_name=Const.CHAR_INPUT, new_field_name='trigrams') field_name=Const.CHAR_INPUT, new_field_name='trigrams')
input_field_names.append('trigrams') input_field_names.append('trigrams')
_indexize(data_bundle, input_field_names, Const.TARGET) _indexize(data_bundle, input_field_names, Const.TARGET)
input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
target_fields = [Const.TARGET, Const.INPUT_LEN] target_fields = [Const.TARGET, Const.INPUT_LEN]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.CHAR_INPUT) dataset.add_seq_len(Const.CHAR_INPUT)
data_bundle.set_input(*input_fields) data_bundle.set_input(*input_fields)
data_bundle.set_target(*target_fields) data_bundle.set_target(*target_fields)
return data_bundle return data_bundle
def process_from_file(self, paths=None) -> DataBundle: def process_from_file(self, paths=None) -> DataBundle:
""" """


@@ -239,8 +258,9 @@ class CWSPipe(Pipe):
:return: :return:
""" """
if self.dataset_name is None and paths is None: if self.dataset_name is None and paths is None:
raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.")
raise RuntimeError(
"You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.")
if self.dataset_name is not None and paths is not None: if self.dataset_name is not None and paths is not None:
raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously") raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously")
data_bundle = CWSLoader(self.dataset_name).load(paths) data_bundle = CWSLoader(self.dataset_name).load(paths)
return self.process(data_bundle)
return self.process(data_bundle)

+ 48
- 31
fastNLP/io/pipe/matching.py View File

@@ -1,9 +1,25 @@
"""undocumented"""

__all__ = [
"MatchingBertPipe",
"RTEBertPipe",
"SNLIBertPipe",
"QuoraBertPipe",
"QNLIBertPipe",
"MNLIBertPipe",
"MatchingPipe",
"RTEPipe",
"SNLIPipe",
"QuoraPipe",
"QNLIPipe",
"MNLIPipe",
]


from .pipe import Pipe from .pipe import Pipe
from .utils import get_tokenizer from .utils import get_tokenizer
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader
from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary from ...core.vocabulary import Vocabulary
from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader




class MatchingBertPipe(Pipe): class MatchingBertPipe(Pipe):
@@ -24,12 +40,13 @@ class MatchingBertPipe(Pipe):
:param bool lower: 是否将word小写化。 :param bool lower: 是否将word小写化。
:param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
""" """
def __init__(self, lower=False, tokenizer: str='raw'):
def __init__(self, lower=False, tokenizer: str = 'raw'):
super().__init__() super().__init__()
self.lower = bool(lower) self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer) self.tokenizer = get_tokenizer(tokenizer=tokenizer)
def _tokenize(self, data_bundle, field_names, new_field_names): def _tokenize(self, data_bundle, field_names, new_field_names):
""" """


@@ -43,62 +60,62 @@ class MatchingBertPipe(Pipe):
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name) new_field_name=new_field_name)
return data_bundle return data_bundle
def process(self, data_bundle): def process(self, data_bundle):
for dataset in data_bundle.datasets.values(): for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET): if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-') dataset.drop(lambda x: x[Const.TARGET] == '-')
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), ) dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), )
dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), ) dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), )
if self.lower: if self.lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower() dataset[Const.INPUTS(1)].lower()
data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)], data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)],
[Const.INPUTS(0), Const.INPUTS(1)]) [Const.INPUTS(0), Const.INPUTS(1)])
# concat两个words # concat两个words
def concat(ins): def concat(ins):
words0 = ins[Const.INPUTS(0)] words0 = ins[Const.INPUTS(0)]
words1 = ins[Const.INPUTS(1)] words1 = ins[Const.INPUTS(1)]
words = words0 + ['[SEP]'] + words1 words = words0 + ['[SEP]'] + words1
return words return words
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.apply(concat, new_field_name=Const.INPUT) dataset.apply(concat, new_field_name=Const.INPUT)
dataset.delete_field(Const.INPUTS(0)) dataset.delete_field(Const.INPUTS(0))
dataset.delete_field(Const.INPUTS(1)) dataset.delete_field(Const.INPUTS(1))
word_vocab = Vocabulary() word_vocab = Vocabulary()
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=Const.INPUT, field_name=Const.INPUT,
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
'train' not in name]) 'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)] dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
data_bundle.set_vocab(word_vocab, Const.INPUT) data_bundle.set_vocab(word_vocab, Const.INPUT)
data_bundle.set_vocab(target_vocab, Const.TARGET) data_bundle.set_vocab(target_vocab, Const.TARGET)
input_fields = [Const.INPUT, Const.INPUT_LEN] input_fields = [Const.INPUT, Const.INPUT_LEN]
target_fields = [Const.TARGET] target_fields = [Const.TARGET]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUT) dataset.add_seq_len(Const.INPUT)
dataset.set_input(*input_fields, flag=True) dataset.set_input(*input_fields, flag=True)
for fields in target_fields: for fields in target_fields:
if dataset.has_field(fields): if dataset.has_field(fields):
dataset.set_target(fields, flag=True) dataset.set_target(fields, flag=True)
return data_bundle return data_bundle




@@ -150,12 +167,13 @@ class MatchingPipe(Pipe):
:param bool lower: 是否将所有raw_words转为小写。 :param bool lower: 是否将所有raw_words转为小写。
:param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。 :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。
""" """
def __init__(self, lower=False, tokenizer: str='raw'):
def __init__(self, lower=False, tokenizer: str = 'raw'):
super().__init__() super().__init__()
self.lower = bool(lower) self.lower = bool(lower)
self.tokenizer = get_tokenizer(tokenizer=tokenizer) self.tokenizer = get_tokenizer(tokenizer=tokenizer)
def _tokenize(self, data_bundle, field_names, new_field_names): def _tokenize(self, data_bundle, field_names, new_field_names):
""" """


@@ -169,7 +187,7 @@ class MatchingPipe(Pipe):
dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name, dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
new_field_name=new_field_name) new_field_name=new_field_name)
return data_bundle return data_bundle
def process(self, data_bundle): def process(self, data_bundle):
""" """
接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有
@@ -181,40 +199,40 @@ class MatchingPipe(Pipe):
"This site includes a...", "The Government Executive...", "not_entailment" "This site includes a...", "The Government Executive...", "not_entailment"
"...", "..." "...", "..."


:param data_bundle:
:return:
:param data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容
:return: data_bundle
""" """
data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)], data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)],
[Const.INPUTS(0), Const.INPUTS(1)]) [Const.INPUTS(0), Const.INPUTS(1)])
for dataset in data_bundle.datasets.values(): for dataset in data_bundle.datasets.values():
if dataset.has_field(Const.TARGET): if dataset.has_field(Const.TARGET):
dataset.drop(lambda x: x[Const.TARGET] == '-') dataset.drop(lambda x: x[Const.TARGET] == '-')
if self.lower: if self.lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUTS(0)].lower() dataset[Const.INPUTS(0)].lower()
dataset[Const.INPUTS(1)].lower() dataset[Const.INPUTS(1)].lower()
word_vocab = Vocabulary() word_vocab = Vocabulary()
word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name], word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
field_name=[Const.INPUTS(0), Const.INPUTS(1)], field_name=[Const.INPUTS(0), Const.INPUTS(1)],
no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
'train' not in name]) 'train' not in name])
word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)]) word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)])
target_vocab = Vocabulary(padding=None, unknown=None) target_vocab = Vocabulary(padding=None, unknown=None)
target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET) target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
dataset.has_field(Const.TARGET)] dataset.has_field(Const.TARGET)]
target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET) target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
data_bundle.set_vocab(word_vocab, Const.INPUTS(0)) data_bundle.set_vocab(word_vocab, Const.INPUTS(0))
data_bundle.set_vocab(target_vocab, Const.TARGET) data_bundle.set_vocab(target_vocab, Const.TARGET)
input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)] input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)]
target_fields = [Const.TARGET] target_fields = [Const.TARGET]
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0)) dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0))
dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1)) dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1))
@@ -222,7 +240,7 @@ class MatchingPipe(Pipe):
for fields in target_fields: for fields in target_fields:
if dataset.has_field(fields): if dataset.has_field(fields):
dataset.set_target(fields, flag=True) dataset.set_target(fields, flag=True)
return data_bundle return data_bundle




@@ -254,4 +272,3 @@ class MNLIPipe(MatchingPipe):
def process_from_file(self, paths=None): def process_from_file(self, paths=None):
data_bundle = MNLILoader().load(paths) data_bundle = MNLILoader().load(paths)
return self.process(data_bundle) return self.process(data_bundle)


+ 9
- 0
fastNLP/io/pipe/pipe.py View File

@@ -1,7 +1,16 @@
"""undocumented"""

__all__ = [
"Pipe",
]

from .. import DataBundle from .. import DataBundle




class Pipe: class Pipe:
"""
别名::class:`fastNLP.io.Pipe` :class:`fastNLP.io.pipe.Pipe`
"""
def process(self, data_bundle: DataBundle) -> DataBundle: def process(self, data_bundle: DataBundle) -> DataBundle:
""" """
对输入的DataBundle进行处理,然后返回该DataBundle。 对输入的DataBundle进行处理,然后返回该DataBundle。


+ 24
- 14
fastNLP/io/pipe/utils.py View File

@@ -1,8 +1,18 @@
"""undocumented"""

__all__ = [
"iob2",
"iob2bioes",
"get_tokenizer",
]

from typing import List from typing import List
from ...core.vocabulary import Vocabulary
from ...core.const import Const from ...core.const import Const
from ...core.vocabulary import Vocabulary



def iob2(tags:List[str])->List[str]:
def iob2(tags: List[str]) -> List[str]:
""" """
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format
@@ -25,7 +35,8 @@ def iob2(tags:List[str])->List[str]:
tags[i] = "B" + tag[1:] tags[i] = "B" + tag[1:]
return tags return tags


def iob2bioes(tags:List[str])->List[str]:

def iob2bioes(tags: List[str]) -> List[str]:
""" """
将iob的tag转换为bioes编码 将iob的tag转换为bioes编码
:param tags: :param tags:
@@ -38,12 +49,12 @@ def iob2bioes(tags:List[str])->List[str]:
else: else:
split = tag.split('-')[0] split = tag.split('-')[0]
if split == 'B': if split == 'B':
if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag) new_tags.append(tag)
else: else:
new_tags.append(tag.replace('B-', 'S-')) new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I': elif split == 'I':
if i + 1<len(tags) and tags[i+1].split('-')[0] == 'I':
if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag) new_tags.append(tag)
else: else:
new_tags.append(tag.replace('I-', 'E-')) new_tags.append(tag.replace('I-', 'E-'))
@@ -52,7 +63,7 @@ def iob2bioes(tags:List[str])->List[str]:
return new_tags return new_tags




def get_tokenizer(tokenizer:str, lang='en'):
def get_tokenizer(tokenizer: str, lang='en'):
""" """


:param str tokenizer: 获取tokenzier方法 :param str tokenizer: 获取tokenzier方法
@@ -97,13 +108,13 @@ def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Con
name != 'train']) name != 'train'])
src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name) src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name)
data_bundle.set_vocab(src_vocab, input_field_name) data_bundle.set_vocab(src_vocab, input_field_name)
for target_field_name in target_field_names: for target_field_name in target_field_names:
tgt_vocab = Vocabulary(unknown=None, padding=None) tgt_vocab = Vocabulary(unknown=None, padding=None)
tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name) tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name)
tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name) tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name)
data_bundle.set_vocab(tgt_vocab, target_field_name) data_bundle.set_vocab(tgt_vocab, target_field_name)
return data_bundle return data_bundle




@@ -116,7 +127,7 @@ def _add_words_field(data_bundle, lower=False):
:return: 传入的DataBundle :return: 传入的DataBundle
""" """
data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True) data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True)
if lower: if lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.INPUT].lower() dataset[Const.INPUT].lower()
@@ -132,7 +143,7 @@ def _add_chars_field(data_bundle, lower=False):
:return: 传入的DataBundle :return: 传入的DataBundle
""" """
data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True) data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True)
if lower: if lower:
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset[Const.CHAR_INPUT].lower() dataset[Const.CHAR_INPUT].lower()
@@ -147,6 +158,7 @@ def _drop_empty_instance(data_bundle, field_name):
:param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉 :param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉
:return: 传入的DataBundle :return: 传入的DataBundle
""" """
def empty_instance(ins): def empty_instance(ins):
if field_name: if field_name:
field_value = ins[field_name] field_value = ins[field_name]
@@ -157,10 +169,8 @@ def _drop_empty_instance(data_bundle, field_name):
if field_value in ((), {}, [], ''): if field_value in ((), {}, [], ''):
return True return True
return False return False
for name, dataset in data_bundle.datasets.items(): for name, dataset in data_bundle.datasets.items():
dataset.drop(empty_instance) dataset.drop(empty_instance)
return data_bundle return data_bundle



+ 21
- 8
fastNLP/io/utils.py View File

@@ -1,10 +1,20 @@
import os
"""
.. todo::
doc
"""


from typing import Union, Dict
__all__ = [
"check_loader_paths"
]

import os
from pathlib import Path from pathlib import Path
from typing import Union, Dict

from ..core import logger




def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
""" """
检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果::


@@ -33,11 +43,13 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
path_pair = ('train', filename) path_pair = ('train', filename)
if 'dev' in filename: if 'dev' in filename:
if path_pair: if path_pair:
raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
raise Exception(
"File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename) path_pair = ('dev', filename)
if 'test' in filename: if 'test' in filename:
if path_pair: if path_pair:
raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
raise Exception(
"File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename) path_pair = ('test', filename)
if path_pair: if path_pair:
files[path_pair[0]] = os.path.join(paths, path_pair[1]) files[path_pair[0]] = os.path.join(paths, path_pair[1])
@@ -46,7 +58,7 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
return files return files
else: else:
raise FileNotFoundError(f"{paths} is not a valid file path.") raise FileNotFoundError(f"{paths} is not a valid file path.")
elif isinstance(paths, dict): elif isinstance(paths, dict):
if paths: if paths:
if 'train' not in paths: if 'train' not in paths:
@@ -65,13 +77,14 @@ def check_loader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
else: else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.") raise TypeError(f"paths only supports str and dict. not {type(paths)}.")



def get_tokenizer(): def get_tokenizer():
try: try:
import spacy import spacy
spacy.prefer_gpu() spacy.prefer_gpu()
en = spacy.load('en') en = spacy.load('en')
print('use spacy tokenizer')
logger.info('use spacy tokenizer')
return lambda x: [w.text for w in en.tokenizer(x)] return lambda x: [w.text for w in en.tokenizer(x)]
except Exception as e: except Exception as e:
print('use raw tokenizer')
logger.error('use raw tokenizer')
return lambda x: x.split() return lambda x: x.split()

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save