diff --git a/.travis.yml b/.travis.yml
index 210d158a..0d63417a 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,6 +1,9 @@
language: python
python:
- "3.6"
+
+env:
+ - TRAVIS=1
# command to install dependencies
install:
- pip install --quiet -r requirements.txt
diff --git a/README.md b/README.md
index b294e54b..531fbc83 100644
--- a/README.md
+++ b/README.md
@@ -6,11 +6,12 @@

[](http://fastnlp.readthedocs.io/?badge=latest)
-fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner)、POS-Tagging等)、中文分词、[文本分类](reproduction/text_classification)、[Matching](reproduction/matching)、[指代消解](reproduction/coreference_resolution)、[摘要](reproduction/Summarization)等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性:
+fastNLP 是一款轻量级的 NLP 工具包。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner)、POS-Tagging等)、中文分词、[文本分类](reproduction/text_classification)、[Matching](reproduction/matching)、[指代消解](reproduction/coreference_resolution)、[摘要](reproduction/Summarization)等任务; 也可以使用它快速构建许多复杂的网络模型,进行科研。它具有如下的特性:
-- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码;
+- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的Loader和Pipe,省去预处理代码;
- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等;
- 各种方便的NLP工具,例如预处理embedding加载(包括ELMo和BERT); 中间数据cache等;
+- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载
- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)以供查阅;
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等;
- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用,详细内容见 [reproduction](reproduction) 部分;
@@ -36,11 +37,15 @@ pip install fastNLP
python -m spacy download en
```
+目前使用pypi安装fastNLP的版本是0.4.1,有较多功能仍未更新,最新内容以master分支为准。
+fastNLP0.5.0版本将在近期推出,请密切关注。
+
## fastNLP教程
+- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
-- [2. 使用DataSetLoader加载数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_load_dataset.html)
+- [2. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_load_dataset.html)
- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
- [4. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_loss_optimizer.html)
- [5. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_datasetiter.html)
@@ -48,17 +53,23 @@ python -m spacy download en
- [7. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_modules_models.html)
- [8. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_metrics.html)
- [9. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_callback.html)
+- [10. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_fitlog.html)
## 内置组件
-大部分用于的 NLP 任务神经网络都可以看做由编码器(encoder)、解码器(decoder)两种模块组成。
+大部分用于的 NLP 任务神经网络都可以看做由词嵌入(embeddings)和两种模块:编码器(encoder)、解码器(decoder)组成。
+
+以文本分类任务为例,下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图:

-fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
+fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedding(GloVe、word2vec)、上下文相关embedding
+(ELMo、BERT)、字符embedding(基于CNN或者LSTM的CharEmbedding)
+
+与此同时,fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
@@ -81,7 +92,7 @@ fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助
## 项目结构
-
+
fastNLP的大致工作流程如上图所示,而项目结构如下:
@@ -102,9 +113,13 @@ fastNLP的大致工作流程如上图所示,而项目结构如下:
fastNLP.modules |
实现了用于搭建神经网络模型的诸多组件 |
+
+ fastNLP.embeddings |
+ 实现了将序列index转为向量序列的功能,包括读取预训练embedding等 |
+
fastNLP.io |
- 实现了读写功能,包括数据读入,模型读写等 |
+ 实现了读写功能,包括数据读入与预处理,模型读写,自动下载等 |
diff --git a/docs/Makefile b/docs/Makefile
index 2b4de2d8..b41beb44 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -14,13 +14,13 @@ help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
apidoc:
- $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ)
+ $(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) && python3 format.py
server:
cd build/html && python -m http.server
dev:
- rm -rf build/html && make html && make server
+ rm -rf build && make html && make server
.PHONY: help Makefile
diff --git a/docs/count.py b/docs/count.py
new file mode 100644
index 00000000..e1aad115
--- /dev/null
+++ b/docs/count.py
@@ -0,0 +1,65 @@
+import os
+import sys
+
+
+def find_all_modules():
+ modules = {}
+ children = {}
+ to_doc = set()
+ root = '../fastNLP'
+ for path, dirs, files in os.walk(root):
+ for file in files:
+ if file.endswith('.py'):
+ name = ".".join(path.split('/')[1:])
+ if file.split('.')[0] != "__init__":
+ name = name + '.' + file.split('.')[0]
+ __import__(name)
+ m = sys.modules[name]
+ modules[name] = m
+ try:
+ m.__all__
+ except:
+ print(name, "__all__ missing")
+ continue
+ if m.__doc__ is None:
+ print(name, "__doc__ missing")
+ continue
+ if "undocumented" not in m.__doc__:
+ to_doc.add(name)
+ for module in to_doc:
+ t = ".".join(module.split('.')[:-1])
+ if t in to_doc:
+ if t not in children:
+ children[t] = set()
+ children[t].add(module)
+ for m in children:
+ children[m] = sorted(children[m])
+ return modules, to_doc, children
+
+
+def create_rst_file(modules, name, children):
+ m = modules[name]
+ with open("./source/" + name + ".rst", "w") as fout:
+ t = "=" * len(name)
+ fout.write(name + "\n")
+ fout.write(t + "\n")
+ fout.write("\n")
+ fout.write(".. automodule:: " + name + "\n")
+ if len(m.__all__) > 0:
+ fout.write(" :members: " + ", ".join(m.__all__) + "\n")
+ fout.write(" :inherited-members:\n")
+ fout.write("\n")
+ if name in children:
+ fout.write("子模块\n------\n\n.. toctree::\n\n")
+ for module in children[name]:
+ fout.write(" " + module + "\n")
+
+
+def main():
+ modules, to_doc, children = find_all_modules()
+ for name in to_doc:
+ create_rst_file(modules, name, children)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2e10bc89..83cb7185 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -48,12 +48,14 @@ extensions = [
autodoc_default_options = {
'member-order': 'bysource',
'special-members': '__init__',
- 'undoc-members': True,
+ 'undoc-members': False,
}
+autoclass_content = "class"
+
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
-
+# template_bridge
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
@@ -113,7 +115,7 @@ html_static_path = ['_static']
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
-htmlhelp_basename = 'fastNLPdoc'
+htmlhelp_basename = 'fastNLP doc'
# -- Options for LaTeX output ------------------------------------------------
diff --git a/docs/source/fastNLP.core.batch.rst b/docs/source/fastNLP.core.batch.rst
index 33a5b730..50ad6fed 100644
--- a/docs/source/fastNLP.core.batch.rst
+++ b/docs/source/fastNLP.core.batch.rst
@@ -2,6 +2,6 @@ fastNLP.core.batch
==================
.. automodule:: fastNLP.core.batch
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: BatchIter, DataSetIter, TorchLoaderIter
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.callback.rst b/docs/source/fastNLP.core.callback.rst
index 31ec627b..d37ddb11 100644
--- a/docs/source/fastNLP.core.callback.rst
+++ b/docs/source/fastNLP.core.callback.rst
@@ -2,6 +2,6 @@ fastNLP.core.callback
=====================
.. automodule:: fastNLP.core.callback
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.const.rst b/docs/source/fastNLP.core.const.rst
index c9e3bd97..82a1992e 100644
--- a/docs/source/fastNLP.core.const.rst
+++ b/docs/source/fastNLP.core.const.rst
@@ -2,6 +2,6 @@ fastNLP.core.const
==================
.. automodule:: fastNLP.core.const
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Const
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.dataset.rst b/docs/source/fastNLP.core.dataset.rst
index b377cb0f..e13d7f1c 100644
--- a/docs/source/fastNLP.core.dataset.rst
+++ b/docs/source/fastNLP.core.dataset.rst
@@ -2,6 +2,6 @@ fastNLP.core.dataset
====================
.. automodule:: fastNLP.core.dataset
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: DataSet
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.field.rst b/docs/source/fastNLP.core.field.rst
index 7686e79a..73dad8af 100644
--- a/docs/source/fastNLP.core.field.rst
+++ b/docs/source/fastNLP.core.field.rst
@@ -2,6 +2,6 @@ fastNLP.core.field
==================
.. automodule:: fastNLP.core.field
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Padder, AutoPadder, EngChar2DPadder
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.instance.rst b/docs/source/fastNLP.core.instance.rst
index 14393a91..010567b9 100644
--- a/docs/source/fastNLP.core.instance.rst
+++ b/docs/source/fastNLP.core.instance.rst
@@ -2,6 +2,6 @@ fastNLP.core.instance
=====================
.. automodule:: fastNLP.core.instance
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Instance
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.losses.rst b/docs/source/fastNLP.core.losses.rst
index d2dd492b..daf246f8 100644
--- a/docs/source/fastNLP.core.losses.rst
+++ b/docs/source/fastNLP.core.losses.rst
@@ -2,6 +2,6 @@ fastNLP.core.losses
===================
.. automodule:: fastNLP.core.losses
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: LossBase, LossFunc, LossInForward, CrossEntropyLoss, BCELoss, L1Loss, NLLLoss
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.metrics.rst b/docs/source/fastNLP.core.metrics.rst
index 69afff36..96748a78 100644
--- a/docs/source/fastNLP.core.metrics.rst
+++ b/docs/source/fastNLP.core.metrics.rst
@@ -2,6 +2,6 @@ fastNLP.core.metrics
====================
.. automodule:: fastNLP.core.metrics
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: MetricBase, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.optimizer.rst b/docs/source/fastNLP.core.optimizer.rst
index e2100d2e..44e45c4f 100644
--- a/docs/source/fastNLP.core.optimizer.rst
+++ b/docs/source/fastNLP.core.optimizer.rst
@@ -2,6 +2,6 @@ fastNLP.core.optimizer
======================
.. automodule:: fastNLP.core.optimizer
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Optimizer, SGD, Adam, AdamW
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst
index 82c13e46..56de46e9 100644
--- a/docs/source/fastNLP.core.rst
+++ b/docs/source/fastNLP.core.rst
@@ -2,15 +2,13 @@ fastNLP.core
============
.. automodule:: fastNLP.core
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: DataSet, Instance, FieldArray, Padder, AutoPadder, EngChar2DPadder, Vocabulary, DataSetIter, BatchIter, TorchLoaderIter, Const, Tester, Trainer, cache_results, seq_len_to_mask, get_seq_len, logger, Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, SequentialSampler, BucketSampler, RandomSampler, Sampler
+ :inherited-members:
子模块
-----------
+------
.. toctree::
- :titlesonly:
fastNLP.core.batch
fastNLP.core.callback
@@ -26,4 +24,3 @@ fastNLP.core
fastNLP.core.trainer
fastNLP.core.utils
fastNLP.core.vocabulary
-
diff --git a/docs/source/fastNLP.core.sampler.rst b/docs/source/fastNLP.core.sampler.rst
index 1810d59c..56291894 100644
--- a/docs/source/fastNLP.core.sampler.rst
+++ b/docs/source/fastNLP.core.sampler.rst
@@ -2,6 +2,6 @@ fastNLP.core.sampler
====================
.. automodule:: fastNLP.core.sampler
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Sampler, BucketSampler, SequentialSampler, RandomSampler
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.tester.rst b/docs/source/fastNLP.core.tester.rst
index a9e7e09f..90ec2a88 100644
--- a/docs/source/fastNLP.core.tester.rst
+++ b/docs/source/fastNLP.core.tester.rst
@@ -2,6 +2,6 @@ fastNLP.core.tester
===================
.. automodule:: fastNLP.core.tester
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Tester
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.trainer.rst b/docs/source/fastNLP.core.trainer.rst
index 9e518d4b..92c08718 100644
--- a/docs/source/fastNLP.core.trainer.rst
+++ b/docs/source/fastNLP.core.trainer.rst
@@ -2,6 +2,6 @@ fastNLP.core.trainer
====================
.. automodule:: fastNLP.core.trainer
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Trainer
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.utils.rst b/docs/source/fastNLP.core.utils.rst
index fcd3f50c..027a43e9 100644
--- a/docs/source/fastNLP.core.utils.rst
+++ b/docs/source/fastNLP.core.utils.rst
@@ -2,6 +2,6 @@ fastNLP.core.utils
==================
.. automodule:: fastNLP.core.utils
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: cache_results, seq_len_to_mask, get_seq_len
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.core.vocabulary.rst b/docs/source/fastNLP.core.vocabulary.rst
index b3bf4bac..ac07a8c6 100644
--- a/docs/source/fastNLP.core.vocabulary.rst
+++ b/docs/source/fastNLP.core.vocabulary.rst
@@ -2,6 +2,6 @@ fastNLP.core.vocabulary
=======================
.. automodule:: fastNLP.core.vocabulary
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Vocabulary, VocabularyOption
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.bert_embedding.rst b/docs/source/fastNLP.embeddings.bert_embedding.rst
new file mode 100644
index 00000000..51828cb0
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.bert_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.bert_embedding
+=================================
+
+.. automodule:: fastNLP.embeddings.bert_embedding
+ :members: BertEmbedding, BertWordPieceEncoder
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.char_embedding.rst b/docs/source/fastNLP.embeddings.char_embedding.rst
new file mode 100644
index 00000000..a9b129d8
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.char_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.char_embedding
+=================================
+
+.. automodule:: fastNLP.embeddings.char_embedding
+ :members: CNNCharEmbedding, LSTMCharEmbedding
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.contextual_embedding.rst b/docs/source/fastNLP.embeddings.contextual_embedding.rst
new file mode 100644
index 00000000..ee64c7a0
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.contextual_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.contextual_embedding
+=======================================
+
+.. automodule:: fastNLP.embeddings.contextual_embedding
+ :members: ContextualEmbedding
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.elmo_embedding.rst b/docs/source/fastNLP.embeddings.elmo_embedding.rst
new file mode 100644
index 00000000..06cc13af
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.elmo_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.elmo_embedding
+=================================
+
+.. automodule:: fastNLP.embeddings.elmo_embedding
+ :members: ElmoEmbedding
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.embedding.rst b/docs/source/fastNLP.embeddings.embedding.rst
new file mode 100644
index 00000000..4d5fcf46
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.embedding
+============================
+
+.. automodule:: fastNLP.embeddings.embedding
+ :members: Embedding, TokenEmbedding
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.rst b/docs/source/fastNLP.embeddings.rst
new file mode 100644
index 00000000..8376408c
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.rst
@@ -0,0 +1,20 @@
+fastNLP.embeddings
+==================
+
+.. automodule:: fastNLP.embeddings
+ :members: Embedding, TokenEmbedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, BertWordPieceEncoder, StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding, get_embeddings
+ :inherited-members:
+
+子模块
+------
+
+.. toctree::
+
+ fastNLP.embeddings.bert_embedding
+ fastNLP.embeddings.char_embedding
+ fastNLP.embeddings.contextual_embedding
+ fastNLP.embeddings.elmo_embedding
+ fastNLP.embeddings.embedding
+ fastNLP.embeddings.stack_embedding
+ fastNLP.embeddings.static_embedding
+ fastNLP.embeddings.utils
diff --git a/docs/source/fastNLP.embeddings.stack_embedding.rst b/docs/source/fastNLP.embeddings.stack_embedding.rst
new file mode 100644
index 00000000..6af91623
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.stack_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.stack_embedding
+==================================
+
+.. automodule:: fastNLP.embeddings.stack_embedding
+ :members: StackEmbedding
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.static_embedding.rst b/docs/source/fastNLP.embeddings.static_embedding.rst
new file mode 100644
index 00000000..2df1c329
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.static_embedding.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.static_embedding
+===================================
+
+.. automodule:: fastNLP.embeddings.static_embedding
+ :members: StaticEmbedding
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.embeddings.utils.rst b/docs/source/fastNLP.embeddings.utils.rst
new file mode 100644
index 00000000..13e5936b
--- /dev/null
+++ b/docs/source/fastNLP.embeddings.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.embeddings.utils
+========================
+
+.. automodule:: fastNLP.embeddings.utils
+ :members: get_embeddings
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.io.base_loader.rst b/docs/source/fastNLP.io.base_loader.rst
deleted file mode 100644
index c1f9ac14..00000000
--- a/docs/source/fastNLP.io.base_loader.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.io.base\_loader
-=======================
-
-.. automodule:: fastNLP.io.base_loader
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.io.data_bundle.rst b/docs/source/fastNLP.io.data_bundle.rst
new file mode 100644
index 00000000..71a921f1
--- /dev/null
+++ b/docs/source/fastNLP.io.data_bundle.rst
@@ -0,0 +1,7 @@
+fastNLP.io.data_bundle
+======================
+
+.. automodule:: fastNLP.io.data_bundle
+ :members: DataBundle
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.io.dataset_loader.rst b/docs/source/fastNLP.io.dataset_loader.rst
index d6663e59..c211ecf9 100644
--- a/docs/source/fastNLP.io.dataset_loader.rst
+++ b/docs/source/fastNLP.io.dataset_loader.rst
@@ -1,7 +1,6 @@
-fastNLP.io.dataset\_loader
-==========================
+fastNLP.io.dataset_loader
+=========================
.. automodule:: fastNLP.io.dataset_loader
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: CSVLoader, JsonLoader
+
diff --git a/docs/source/fastNLP.io.embed_loader.rst b/docs/source/fastNLP.io.embed_loader.rst
index 7a8e730c..581f5c1b 100644
--- a/docs/source/fastNLP.io.embed_loader.rst
+++ b/docs/source/fastNLP.io.embed_loader.rst
@@ -1,7 +1,7 @@
-fastNLP.io.embed\_loader
-========================
+fastNLP.io.embed_loader
+=======================
.. automodule:: fastNLP.io.embed_loader
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: EmbedLoader, EmbeddingOption
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.io.file_utils.rst b/docs/source/fastNLP.io.file_utils.rst
new file mode 100644
index 00000000..0815e068
--- /dev/null
+++ b/docs/source/fastNLP.io.file_utils.rst
@@ -0,0 +1,7 @@
+fastNLP.io.file_utils
+=====================
+
+.. automodule:: fastNLP.io.file_utils
+ :members: cached_path, get_filepath, get_cache_path, split_filename_suffix, get_from_cache
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst
new file mode 100644
index 00000000..060b5450
--- /dev/null
+++ b/docs/source/fastNLP.io.loader.rst
@@ -0,0 +1,7 @@
+fastNLP.io.loader
+=================
+
+.. automodule:: fastNLP.io.loader
+ :members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.io.model_io.rst b/docs/source/fastNLP.io.model_io.rst
index 50d4c25a..183122b1 100644
--- a/docs/source/fastNLP.io.model_io.rst
+++ b/docs/source/fastNLP.io.model_io.rst
@@ -1,7 +1,7 @@
-fastNLP.io.model\_io
-====================
+fastNLP.io.model_io
+===================
.. automodule:: fastNLP.io.model_io
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: ModelLoader, ModelSaver
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst
new file mode 100644
index 00000000..d35d2ddc
--- /dev/null
+++ b/docs/source/fastNLP.io.pipe.rst
@@ -0,0 +1,7 @@
+fastNLP.io.pipe
+===============
+
+.. automodule:: fastNLP.io.pipe
+ :members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst
index fad05a21..2aacb883 100644
--- a/docs/source/fastNLP.io.rst
+++ b/docs/source/fastNLP.io.rst
@@ -2,18 +2,18 @@ fastNLP.io
==========
.. automodule:: fastNLP.io
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver
+ :inherited-members:
子模块
-----------
+------
.. toctree::
- :titlesonly:
- fastNLP.io.base_loader
- fastNLP.io.dataset_loader
+ fastNLP.io.data_bundle
fastNLP.io.embed_loader
+ fastNLP.io.file_utils
+ fastNLP.io.loader
fastNLP.io.model_io
-
+ fastNLP.io.pipe
+ fastNLP.io.utils
diff --git a/docs/source/fastNLP.io.utils.rst b/docs/source/fastNLP.io.utils.rst
new file mode 100644
index 00000000..3bff3c45
--- /dev/null
+++ b/docs/source/fastNLP.io.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.io.utils
+================
+
+.. automodule:: fastNLP.io.utils
+ :members: check_loader_paths
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.models.biaffine_parser.rst b/docs/source/fastNLP.models.biaffine_parser.rst
index a3dd1836..c3dbb0a5 100644
--- a/docs/source/fastNLP.models.biaffine_parser.rst
+++ b/docs/source/fastNLP.models.biaffine_parser.rst
@@ -1,7 +1,7 @@
-fastNLP.models.biaffine\_parser
-===============================
+fastNLP.models.biaffine_parser
+==============================
.. automodule:: fastNLP.models.biaffine_parser
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: BiaffineParser, GraphParser
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.models.cnn_text_classification.rst b/docs/source/fastNLP.models.cnn_text_classification.rst
index a935d0bf..fe4bb157 100644
--- a/docs/source/fastNLP.models.cnn_text_classification.rst
+++ b/docs/source/fastNLP.models.cnn_text_classification.rst
@@ -1,7 +1,7 @@
-fastNLP.models.cnn\_text\_classification
-========================================
+fastNLP.models.cnn_text_classification
+======================================
.. automodule:: fastNLP.models.cnn_text_classification
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: CNNText
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst
index 5858ebcd..88854a79 100644
--- a/docs/source/fastNLP.models.rst
+++ b/docs/source/fastNLP.models.rst
@@ -2,19 +2,16 @@ fastNLP.models
==============
.. automodule:: fastNLP.models
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser
+ :inherited-members:
子模块
-----------
+------
.. toctree::
- :titlesonly:
fastNLP.models.biaffine_parser
fastNLP.models.cnn_text_classification
fastNLP.models.sequence_labeling
fastNLP.models.snli
fastNLP.models.star_transformer
-
diff --git a/docs/source/fastNLP.models.sequence_labeling.rst b/docs/source/fastNLP.models.sequence_labeling.rst
index 6d569fe1..b66e637e 100644
--- a/docs/source/fastNLP.models.sequence_labeling.rst
+++ b/docs/source/fastNLP.models.sequence_labeling.rst
@@ -1,7 +1,7 @@
-fastNLP.models.sequence\_labeling
-=================================
+fastNLP.models.sequence_labeling
+================================
.. automodule:: fastNLP.models.sequence_labeling
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: SeqLabeling, AdvSeqLabel
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.models.snli.rst b/docs/source/fastNLP.models.snli.rst
index 24c2cc53..8551051a 100644
--- a/docs/source/fastNLP.models.snli.rst
+++ b/docs/source/fastNLP.models.snli.rst
@@ -2,6 +2,6 @@ fastNLP.models.snli
===================
.. automodule:: fastNLP.models.snli
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: ESIM
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.models.star_transformer.rst b/docs/source/fastNLP.models.star_transformer.rst
index c93fb8cd..f4b5989e 100644
--- a/docs/source/fastNLP.models.star_transformer.rst
+++ b/docs/source/fastNLP.models.star_transformer.rst
@@ -1,7 +1,7 @@
-fastNLP.models.star\_transformer
-================================
+fastNLP.models.star_transformer
+===============================
.. automodule:: fastNLP.models.star_transformer
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: StarTransEnc, STNLICls, STSeqCls, STSeqLabel
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.modules.decoder.crf.rst b/docs/source/fastNLP.modules.decoder.crf.rst
deleted file mode 100644
index 6d5b0d5b..00000000
--- a/docs/source/fastNLP.modules.decoder.crf.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.decoder.CRF
-===========================
-
-.. automodule:: fastNLP.modules.decoder.crf
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.decoder.mlp.rst b/docs/source/fastNLP.modules.decoder.mlp.rst
deleted file mode 100644
index 7d661ebf..00000000
--- a/docs/source/fastNLP.modules.decoder.mlp.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.decoder.MLP
-===========================
-
-.. automodule:: fastNLP.modules.decoder.mlp
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.decoder.rst b/docs/source/fastNLP.modules.decoder.rst
index e42a9f39..b121f9e9 100644
--- a/docs/source/fastNLP.modules.decoder.rst
+++ b/docs/source/fastNLP.modules.decoder.rst
@@ -2,17 +2,6 @@ fastNLP.modules.decoder
=======================
.. automodule:: fastNLP.modules.decoder
- :members:
- :undoc-members:
- :show-inheritance:
-
-子模块
-----------
-
-.. toctree::
- :titlesonly:
-
- fastNLP.modules.decoder.crf
- fastNLP.modules.decoder.mlp
- fastNLP.modules.decoder.utils
+ :members: MLP, ConditionalRandomField, viterbi_decode, allowed_transitions
+ :inherited-members:
diff --git a/docs/source/fastNLP.modules.decoder.utils.rst b/docs/source/fastNLP.modules.decoder.utils.rst
deleted file mode 100644
index da979d99..00000000
--- a/docs/source/fastNLP.modules.decoder.utils.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.decoder.utils
-=============================
-
-.. automodule:: fastNLP.modules.decoder.utils
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.bert.rst b/docs/source/fastNLP.modules.encoder.bert.rst
deleted file mode 100644
index 66bd0bbd..00000000
--- a/docs/source/fastNLP.modules.encoder.bert.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.bert
-============================
-
-.. automodule:: fastNLP.modules.encoder.bert
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.char_encoder.rst b/docs/source/fastNLP.modules.encoder.char_encoder.rst
deleted file mode 100644
index 61ea3340..00000000
--- a/docs/source/fastNLP.modules.encoder.char_encoder.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.char\_encoder
-=====================================
-
-.. automodule:: fastNLP.modules.encoder.char_encoder
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.conv_maxpool.rst b/docs/source/fastNLP.modules.encoder.conv_maxpool.rst
deleted file mode 100644
index 7058a723..00000000
--- a/docs/source/fastNLP.modules.encoder.conv_maxpool.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.conv\_maxpool
-=====================================
-
-.. automodule:: fastNLP.modules.encoder.conv_maxpool
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.embedding.rst b/docs/source/fastNLP.modules.encoder.embedding.rst
deleted file mode 100644
index 4427b3bf..00000000
--- a/docs/source/fastNLP.modules.encoder.embedding.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.embedding
-=================================
-
-.. automodule:: fastNLP.modules.encoder.embedding
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.lstm.rst b/docs/source/fastNLP.modules.encoder.lstm.rst
deleted file mode 100644
index f9cbea88..00000000
--- a/docs/source/fastNLP.modules.encoder.lstm.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.lstm
-============================
-
-.. automodule:: fastNLP.modules.encoder.lstm
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst
index b15232fa..6b44a192 100644
--- a/docs/source/fastNLP.modules.encoder.rst
+++ b/docs/source/fastNLP.modules.encoder.rst
@@ -2,22 +2,6 @@ fastNLP.modules.encoder
=======================
.. automodule:: fastNLP.modules.encoder
- :members:
- :undoc-members:
- :show-inheritance:
-
-子模块
-----------
-
-.. toctree::
- :titlesonly:
-
- fastNLP.modules.encoder.bert
- fastNLP.modules.encoder.char_encoder
- fastNLP.modules.encoder.conv_maxpool
- fastNLP.modules.encoder.embedding
- fastNLP.modules.encoder.lstm
- fastNLP.modules.encoder.star_transformer
- fastNLP.modules.encoder.transformer
- fastNLP.modules.encoder.variational_rnn
+ :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention
+ :inherited-members:
diff --git a/docs/source/fastNLP.modules.encoder.star_transformer.rst b/docs/source/fastNLP.modules.encoder.star_transformer.rst
deleted file mode 100644
index 0c406782..00000000
--- a/docs/source/fastNLP.modules.encoder.star_transformer.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.star\_transformer
-=========================================
-
-.. automodule:: fastNLP.modules.encoder.star_transformer
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.transformer.rst b/docs/source/fastNLP.modules.encoder.transformer.rst
deleted file mode 100644
index 6a40c597..00000000
--- a/docs/source/fastNLP.modules.encoder.transformer.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.transformer
-===================================
-
-.. automodule:: fastNLP.modules.encoder.transformer
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.encoder.variational_rnn.rst b/docs/source/fastNLP.modules.encoder.variational_rnn.rst
deleted file mode 100644
index 348fb3d8..00000000
--- a/docs/source/fastNLP.modules.encoder.variational_rnn.rst
+++ /dev/null
@@ -1,7 +0,0 @@
-fastNLP.modules.encoder.variational\_rnn
-========================================
-
-.. automodule:: fastNLP.modules.encoder.variational_rnn
- :members:
- :undoc-members:
- :show-inheritance:
diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst
index 7f75cfdc..6134d0dd 100644
--- a/docs/source/fastNLP.modules.rst
+++ b/docs/source/fastNLP.modules.rst
@@ -2,15 +2,14 @@ fastNLP.modules
===============
.. automodule:: fastNLP.modules
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout
+ :inherited-members:
子模块
------------
+------
.. toctree::
- :titlesonly:
- fastNLP.modules.decoder
- fastNLP.modules.encoder
\ No newline at end of file
+ fastNLP.modules.decoder
+ fastNLP.modules.encoder
+ fastNLP.modules.utils
diff --git a/docs/source/fastNLP.modules.utils.rst b/docs/source/fastNLP.modules.utils.rst
new file mode 100644
index 00000000..e28ca35a
--- /dev/null
+++ b/docs/source/fastNLP.modules.utils.rst
@@ -0,0 +1,7 @@
+fastNLP.modules.utils
+=====================
+
+.. automodule:: fastNLP.modules.utils
+ :members: initial_parameter, summary
+ :inherited-members:
+
diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst
index f0c3d41c..f22ea936 100644
--- a/docs/source/fastNLP.rst
+++ b/docs/source/fastNLP.rst
@@ -1,20 +1,17 @@
-API 文档
-===============
+fastNLP
+=======
.. automodule:: fastNLP
- :members:
- :undoc-members:
- :show-inheritance:
+ :members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC, LRFinder, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger
+ :inherited-members:
-内部模块
------------
+子模块
+------
.. toctree::
- :titlesonly:
- :maxdepth: 3
-
- fastNLP.core
- fastNLP.io
- fastNLP.modules
- fastNLP.models
+ fastNLP.core
+ fastNLP.embeddings
+ fastNLP.io
+ fastNLP.models
+ fastNLP.modules
diff --git a/docs/source/figures/text_classification.png b/docs/source/figures/text_classification.png
index 0d36a2a1..21502708 100644
Binary files a/docs/source/figures/text_classification.png and b/docs/source/figures/text_classification.png differ
diff --git a/docs/source/figures/workflow.png b/docs/source/figures/workflow.png
index d2f22df8..d8e4e455 100644
Binary files a/docs/source/figures/workflow.png and b/docs/source/figures/workflow.png differ
diff --git a/docs/source/index.rst b/docs/source/index.rst
index ca000859..d48af986 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,60 +1,28 @@
fastNLP 中文文档
=====================
-fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个命名实体识别(NER)、中文分词或文本分类任务;
-也可以使用他构建许多复杂的网络模型,进行科研。它具有如下的特性:
+`fastNLP `_ 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注
+(NER、POS-Tagging等)、中文分词、文本分类、Matching、指代消解、摘要等任务
+(详见 `reproduction `_ );
+也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性:
-- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码。
-- 各种方便的NLP工具,例如预处理embedding加载; 中间数据cache等;
-- 详尽的中文文档以供查阅;
-- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等;
-- 封装CNNText,Biaffine等模型可供直接使用;
-- 便捷且具有扩展性的训练器; 提供多种内置callback函数,方便实验记录、异常捕获等。
+- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的 :mod:`~fastNLP.io.data_loader` ,省去预处理代码;
+- 多种训练、测试组件,例如训练器 :class:`~fastNLP.Trainer` ;测试器 :class:`~fastNLP.Tester` ;以及各种评测 :mod:`~fastNLP.core.metrics` 等等;
+- 各种方便的NLP工具,例如预处理 :mod:`embedding` 加载(包括ELMo和BERT); 中间数据存储 :func:`cache ` 等;
+- 提供诸多高级模块 :mod:`~fastNLP.modules`,例如 :class:`~fastNLP.modules.VarLSTM` , :class:`Transformer` , :class:`CRF` 等;
+- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种 :mod:`~fastNLP.models` 可供直接使用;
+- 训练器便捷且具有扩展性,提供多种内置 :mod:`~fastNLP.core.callback` 函数,方便实验记录、异常捕获等。
-内置组件
-------------
-
-大部分用于的 NLP 任务神经网络都可以看做由编码(encoder)、聚合(aggregator)、解码(decoder)三种模块组成。
-
-.. image:: figures/text_classification.png
-
-fastNLP 在 :mod:`~fastNLP.modules` 模块中内置了三种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。
-三种模块的功能和常见组件如下:
-
-+-----------------------+-----------------------+-----------------------+
-| module type | functionality | example |
-+=======================+=======================+=======================+
-| encoder | 将输入编码为具有具 | embedding, RNN, CNN, |
-| | 有表示能力的向量 | transformer |
-+-----------------------+-----------------------+-----------------------+
-| aggregator | 从多个向量中聚合信息 | self-attention, |
-| | | max-pooling |
-+-----------------------+-----------------------+-----------------------+
-| decoder | 将具有某种表示意义的 | MLP, CRF |
-| | 向量解码为需要的输出 | |
-| | 形式 | |
-+-----------------------+-----------------------+-----------------------+
-
-
-内置模型
-----------------
-
-fastNLP 在 :mod:`~fastNLP.models` 模块中内置了如 :class:`~fastNLP.models.CNNText` 、
-:class:`~fastNLP.models.SeqLabeling` 等完整的模型,以供用户直接使用。
-
-.. todo::
- 这些模型的介绍如下表所示:(模型名称 + 介绍 + 任务上的结果)
-
用户手册
----------------
.. toctree::
- :maxdepth: 1
+ :maxdepth: 2
安装指南
快速入门
- 详细指南
+ 详细教程
API 文档
-------------
@@ -67,11 +35,11 @@ API 文档
fastNLP
-fitlog
-------
+fitlog文档
+----------
-用户可以 `点此 `_ 查看fitlog的文档。
-fitlog 是由我们团队开发,用于帮助用户记录日志并管理代码的工具
+您可以 `点此 `_ 查看fitlog的文档。
+fitlog 是由我们团队开发的日志记录+代码管理的工具。
索引与搜索
==================
diff --git a/docs/source/modules.rst b/docs/source/modules.rst
index 9ca3c7f3..e9a92cb7 100644
--- a/docs/source/modules.rst
+++ b/docs/source/modules.rst
@@ -2,7 +2,6 @@ fastNLP
=======
.. toctree::
- :titlesonly:
:maxdepth: 4
fastNLP
diff --git a/docs/source/tutorials/tutorial_1_data_preprocess.rst b/docs/source/tutorials/tutorial_1_data_preprocess.rst
index cd97ca75..0ec63f87 100644
--- a/docs/source/tutorials/tutorial_1_data_preprocess.rst
+++ b/docs/source/tutorials/tutorial_1_data_preprocess.rst
@@ -1,5 +1,5 @@
==============================
-数据格式及预处理教程
+使用DataSet预处理文本
==============================
:class:`~fastNLP.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,
@@ -60,7 +60,7 @@
seq_len=3)
])
-在初步构建完数据集之后,我们可可以通过 `for` 循环遍历 :class:`~fastNLP.DataSet` 中的内容。
+在初步构建完数据集之后,我们可以通过 `for` 循环遍历 :class:`~fastNLP.DataSet` 中的内容。
.. code-block:: python
diff --git a/docs/source/tutorials/tutorial_2_load_dataset.rst b/docs/source/tutorials/tutorial_2_load_dataset.rst
index 2576992d..17ad6baf 100644
--- a/docs/source/tutorials/tutorial_2_load_dataset.rst
+++ b/docs/source/tutorials/tutorial_2_load_dataset.rst
@@ -1,57 +1,53 @@
-=========================
-数据集加载教程
-=========================
+=======================================
+使用Loader和Pipe加载并处理数据集
+=======================================
这一部分是一个关于如何加载数据集的教程
教程目录:
- - `Part I: 数据集信息`_
- - `Part II: 数据集的使用方式`_
- - `Part III: 不同数据类型的DataSetLoader`_
- - `Part IV: DataSetLoader举例`_
- - `Part V: fastNLP封装好的数据集加载器`_
+ - `Part I: 数据集容器DataBundle`_
+ - `Part II: 加载数据集的基类Loader`_
+ - `Part III: 不同格式类型的基础Loader`_
+ - `Part IV: 使用Pipe对数据集进行预处理`_
+ - `Part V: fastNLP封装好的Loader和Pipe`_
-----------------------------
-Part I: 数据集信息
-----------------------------
+------------------------------------
+Part I: 数据集容器DataBundle
+------------------------------------
-在fastNLP中,我们使用 :class:`~fastNLP.io.base_loader.DataInfo` 来存储数据集信息。 :class:`~fastNLP.io.base_loader.DataInfo`
-类包含了两个重要内容: `datasets` 和 `vocabs` 。
+在fastNLP中,我们使用 :class:`~fastNLP.io.data_bundle.DataBundle` 来存储数据集信息。
+:class:`~fastNLP.io.data_bundle.DataBundle` 类包含了两个重要内容: `datasets` 和 `vocabs` 。
`datasets` 是一个 `key` 为数据集名称(如 `train` , `dev` ,和 `test` 等), `value` 为 :class:`~fastNLP.DataSet` 的字典。
`vocabs` 是一个 `key` 为词表名称(如 :attr:`fastNLP.Const.INPUT` 表示输入文本的词表名称, :attr:`fastNLP.Const.TARGET` 表示目标
的真实标签词表的名称,等等), `value` 为词表内容( :class:`~fastNLP.Vocabulary` )的字典。
-----------------------------
-Part II: 数据集的使用方式
-----------------------------
+-------------------------------------
+Part II: 加载数据集的基类Loader
+-------------------------------------
-在fastNLP中,我们采用 :class:`~fastNLP.io.base_loader.DataSetLoader` 来作为加载数据集的基类。
-:class:`~fastNLP.io.base_loader.DataSetLoader` 定义了各种DataSetLoader所需的API接口,开发者应该继承它实现各种的DataSetLoader。
-在各种数据集的DataSetLoader当中,至少应该编写如下内容:
+在fastNLP中,我们采用 :class:`~fastNLP.io.loader.Loader` 来作为加载数据集的基类。
+:class:`~fastNLP.io.loader.Loader` 定义了各种Loader所需的API接口,开发者应该继承它实现各种的Loader。
+在各种数据集的Loader当中,至少应该编写如下内容:
- - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的 :class:`~fastNLP.io.DataInfo`
+ - _load 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet`
+ - load 函数:从文件或者文件夹中读取数据并组装成 :class:`~fastNLP.io.data_bundle.DataBundle`
- **\*process函数中可以调用load函数或_load函数**
-
-DataSetLoader的_load或者load函数返回的 :class:`~fastNLP.DataSet` 当中,内容为数据集的文本信息,process函数返回的
-:class:`~fastNLP.io.DataInfo` 当中, `datasets` 的内容为已经index好的、可以直接被 :class:`~fastNLP.Trainer`
-接受的内容。
+Loader的load函数返回的 :class:`~fastNLP.io.data_bundle.DataBundle` 里面包含了数据集的原始数据。
--------------------------------------------------------
-Part III: 不同数据类型的DataSetLoader
+Part III: 不同格式类型的基础Loader
--------------------------------------------------------
-:class:`~fastNLP.io.dataset_loader.CSVLoader`
+:class:`~fastNLP.io.loader.CSVLoader`
读取CSV类型的数据集文件。例子如下:
.. code-block:: python
+ from fastNLP.io.loader import CSVLoader
data_set_loader = CSVLoader(
headers=('words', 'target'), sep='\t'
)
@@ -67,17 +63,18 @@ Part III: 不同数据类型的DataSetLoader
The performances are an absolute joy . 4
-:class:`~fastNLP.io.dataset_loader.JsonLoader`
+:class:`~fastNLP.io.loader.JsonLoader`
读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下:
.. code-block:: python
- data_set_loader = JsonLoader(
+ from fastNLP.io.loader import JsonLoader
+ oader = JsonLoader(
fields={'sentence1': 'words1', 'sentence2': 'words2', 'gold_label': 'target'}
)
# 表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'words1'、'words2'、'target'这三个fields
- data_set = data_set_loader._load('path/to/your/file')
+ data_set = loader._load('path/to/your/file')
数据集内容样例如下 ::
@@ -86,108 +83,68 @@ Part III: 不同数据类型的DataSetLoader
{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"}
------------------------------------------
-Part IV: DataSetLoader举例
+Part IV: 使用Pipe对数据集进行预处理
------------------------------------------
-以Matching任务为例子:
-
- :class:`~fastNLP.io.data_loader.matching.MatchingLoader`
- 我们在fastNLP当中封装了一个Matching任务数据集的数据加载类: :class:`~fastNLP.io.data_loader.matching.MatchingLoader` .
-
- 在MatchingLoader类当中我们封装了一个对数据集中的文本内容进行进一步的预处理的函数:
- :meth:`~fastNLP.io.data_loader.matching.MatchingLoader.process`
- 这个函数具有各种预处理option,如:
- - 是否将文本转成全小写
- - 是否需要序列长度信息,需要什么类型的序列长度信息
- - 是否需要用BertTokenizer来获取序列的WordPiece信息
- - 等等
-
- 具体内容参见 :meth:`fastNLP.io.MatchingLoader.process` 。
-
- :class:`~fastNLP.io.data_loader.matching.SNLILoader`
- 一个关于SNLI数据集的DataSetLoader。SNLI数据集来自
- `SNLI Data Set `_ .
-
- 在 :class:`~fastNLP.io.data_loader.matching.SNLILoader` 的 :meth:`~fastNLP.io.data_loader.matching.SNLILoader._load`
- 函数中,我们用以下代码将数据集内容从文本文件读入内存
+在fastNLP中,我们采用 :class:`~fastNLP.io.pipe.Pipe` 来作为加载数据集的基类。
+:class:`~fastNLP.io.pipe.Pipe` 定义了各种Pipe所需的API接口,开发者应该继承它实现各种的Pipe。
+在各种数据集的Pipe当中,至少应该编写如下内容:
- .. code-block:: python
-
- def _load(self, path):
- ds = JsonLoader._load(self, path) # SNLI数据集原始文件为Json格式,可以采用JsonLoader来读取数据集文件
-
- parentheses_table = str.maketrans({'(': None, ')': None})
- # 字符串匹配格式:SNLI数据集的文本中由括号分割开的,组成树结构,因此
- # 我们将这些括号去除。
-
- ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
- new_field_name=Const.INPUTS(0))
- # 把第一句话的内容用上面的字符串匹配格式进行替换,并将句子分割为一个由单词组成的list
- ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
- new_field_name=Const.INPUTS(1))
- # 对第二句话的内容进行同样的预处理
- ds.drop(lambda x: x[Const.TARGET] == '-') # 将标签为'-'的样本丢掉
- return ds
-
-------------------------------------------
-Part V: fastNLP封装好的数据集加载器
-------------------------------------------
+ - process 函数:对输入的 :class:`~fastNLP.io.data_bundle.DataBundle` 进行处理(如构建词表、
+ 将dataset的文本内容转成index等等),然后返回该 :class:`~fastNLP.io.data_bundle.DataBundle`
+ - process_from_file 函数:输入数据集所在文件夹,读取内容并组装成 :class:`~fastNLP.io.data_bundle.DataBundle` ,
+ 然后调用相对应的process函数对数据进行预处理
-fastNLP封装好的数据集加载器可以适用于多种类型的任务:
+以SNLI数据集为例,写一个自定义Pipe的例子如下:
- - `文本分类任务`_
- - `序列标注任务`_
- - `Matching任务`_
- - `指代消解任务`_
- - `摘要任务`_
+.. code-block:: python
+ from fastNLP.io.loader import SNLILoader
+ from fastNLP.io.pipe import MatchingPipe
-文本分类任务
--------------------
+ class MySNLIPipe(MatchingPipe):
-文本分类任务
+ def process(self, data_bundle):
+ data_bundle = super(MySNLIPipe, self).process(data_bundle)
+ # MatchingPipe类里封装了一个关于matching任务的process函数,可以直接继承使用
+ # 如果有需要进行额外的预处理操作可以在这里加入您的代码
+ return data_bundle
+ def process_from_file(self, paths=None):
+ data_bundle = SNLILoader().load(paths) # 使用SNLILoader读取原始数据集
+ # SNLILoader的load函数中,paths如果为None则会自动下载
+ return self.process(data_bundle) # 调用相对应的process函数对data_bundle进行处理
+调用Pipe示例:
-序列标注任务
--------------------
+.. code-block:: python
-序列标注任务
+ from fastNLP.io.pipe import SNLIBertPipe
+ data_bundle = SNLIBertPipe(lower=True, tokenizer=arg.tokenizer).process_from_file()
+ print(data_bundle)
+输出的内容是::
-Matching任务
--------------------
+ In total 3 datasets:
+ train has 549367 instances.
+ dev has 9842 instances.
+ test has 9824 instances.
+ In total 2 vocabs:
+ words has 34184 entries.
+ target has 3 entries.
-:class:`~fastNLP.io.data_loader.matching.SNLILoader`
- 一个关于SNLI数据集的DataSetLoader。SNLI数据集来自
- `SNLI Data Set `_ .
+这里表示一共有3个数据集和2个词表。其中:
-:class:`~fastNLP.io.data_loader.matching.MNLILoader`
- 一个关于MultiNLI数据集的DataSetLoader。MultiNLI数据集来自 `GLUE benchmark `_
+ - 3个数据集分别为train、dev、test数据集,分别有549367、9842、9824个instance
+ - 2个词表分别为words词表与target词表。其中words词表为句子文本所构建的词表,一共有34184个单词;
+ target词表为目标标签所构建的词表,一共有3种标签。(注:如果有多个输入,则句子文本所构建的词表将
+ 会被命名为words1以对应相对应的列名)
-:class:`~fastNLP.io.data_loader.matching.QNLILoader`
- 一个关于QNLI数据集的DataSetLoader。QNLI数据集来自 `GLUE benchmark `_
-
-:class:`~fastNLP.io.data_loader.matching.RTELoader`
- 一个关于Recognizing Textual Entailment数据集(RTE)的DataSetLoader。RTE数据集来自
- `GLUE benchmark `_
-
-:class:`~fastNLP.io.data_loader.matching.QuoraLoader`
- 一个关于Quora数据集的DataSetLoader。
-
-
-
-
-指代消解任务
--------------------
-
-指代消解任务
-
-
-
-摘要任务
--------------------
+------------------------------------------
+Part V: fastNLP封装好的Loader和Pipe
+------------------------------------------
-摘要任务
+fastNLP封装了多种任务/数据集的Loader和Pipe并提供自动下载功能,具体参见文档
+`fastNLP可加载的embedding与数据集 `_
diff --git a/docs/source/tutorials/tutorial_3_embedding.rst b/docs/source/tutorials/tutorial_3_embedding.rst
index 5e0a9107..07dc30bc 100644
--- a/docs/source/tutorials/tutorial_3_embedding.rst
+++ b/docs/source/tutorials/tutorial_3_embedding.rst
@@ -12,6 +12,7 @@
- `Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)`_
- `Part V: 使用character-level的embedding`_
- `Part VI: 叠加使用多个embedding`_
+ - `Part VII: fastNLP支持的预训练Embedding`_
@@ -29,18 +30,20 @@ fastNLP的embedding包括了预训练embedding和随机初始化embedding。
Part II: 使用随机初始化的embedding
---------------------------------------
-使用随机初始化的embedding参见 :class:`~fastNLP.modules.encoder.embedding.Embedding` 。
+使用随机初始化的embedding参见 :class:`~fastNLP.embeddings.embedding.Embedding` 。
可以传入词表大小和embedding维度:
.. code-block:: python
+ from fastNLP import Embedding
embed = Embedding(10000, 50)
也可以传入一个初始化的参数矩阵:
.. code-block:: python
+ from fastNLP import Embedding
embed = Embedding(init_embed)
其中的init_embed可以是torch.FloatTensor、torch.nn.Embedding或者numpy.ndarray。
@@ -53,12 +56,13 @@ Part III: 使用预训练的静态embedding
在使用预训练的embedding之前,需要根据数据集的内容构建一个词表 :class:`~fastNLP.core.vocabulary.Vocabulary` ,在
预训练embedding类初始化的时候需要将这个词表作为参数传入。
-在fastNLP中,我们提供了 :class:`~fastNLP.modules.encoder.embedding.StaticEmbedding` 这一个类。
-通过 :class:`~fastNLP.modules.encoder.embedding.StaticEmbedding` 可以加载预训练好的静态
+在fastNLP中,我们提供了 :class:`~fastNLP.embeddings.StaticEmbedding` 这一个类。
+通过 :class:`~fastNLP.embeddings.StaticEmbedding` 可以加载预训练好的静态
Embedding,例子如下:
.. code-block:: python
+ from fastNLP import StaticEmbedding
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径,也可以是embedding模型的名称:
@@ -67,112 +71,50 @@ vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径
和word2vec类型的权重文件都支持)
2 如果传入的是模型名称,那么fastNLP将会根据名称查找embedding模型,如果在cache目录下找到模型则会
- 自动加载;如果找不到则会自动下载。可以通过环境变量 ``FASTNLP_CACHE_DIR`` 来自定义cache目录,如::
+ 自动加载;如果找不到则会自动下载到cache目录。默认的cache目录为 `~/.fastNLP` 文件夹。可以通过环境
+ 变量 ``FASTNLP_CACHE_DIR`` 来自定义cache目录,如::
$ FASTNLP_CACHE_DIR=~/fastnlp_cache_dir python your_python_file.py
这个命令表示fastNLP将会在 `~/fastnlp_cache_dir` 这个目录下寻找模型,找不到则会自动将模型下载到这个目录
-目前支持的静态embedding模型有:
-
- ========================== ================================
- 模型名称 模型
- -------------------------- --------------------------------
- en glove.840B.300d
- -------------------------- --------------------------------
- en-glove-840d-300 glove.840B.300d
- -------------------------- --------------------------------
- en-glove-6b-50 glove.6B.50d
- -------------------------- --------------------------------
- en-word2vec-300 谷歌word2vec 300维
- -------------------------- --------------------------------
- en-fasttext 英文fasttext 300维
- -------------------------- --------------------------------
- cn 腾讯中文词向量 200维
- -------------------------- --------------------------------
- cn-fasttext 中文fasttext 300维
- ========================== ================================
-
-
-
-----------------------------------------------------------
Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)
-----------------------------------------------------------
-在fastNLP中,我们提供了ELMo和BERT的embedding: :class:`~fastNLP.modules.encoder.embedding.ElmoEmbedding`
-和 :class:`~fastNLP.modules.encoder.embedding.BertEmbedding` 。
+在fastNLP中,我们提供了ELMo和BERT的embedding: :class:`~fastNLP.embeddings.ElmoEmbedding`
+和 :class:`~fastNLP.embeddings.BertEmbedding` 。
与静态embedding类似,ELMo的使用方法如下:
.. code-block:: python
+ from fastNLP import ElmoEmbedding
embed = ElmoEmbedding(vocab, model_dir_or_name='small', requires_grad=False)
-目前支持的ElmoEmbedding模型有:
-
- ========================== ================================
- 模型名称 模型
- -------------------------- --------------------------------
- small allennlp ELMo的small
- -------------------------- --------------------------------
- medium allennlp ELMo的medium
- -------------------------- --------------------------------
- original allennlp ELMo的original
- -------------------------- --------------------------------
- 5.5b-original allennlp ELMo的5.5B original
- ========================== ================================
-
BERT-embedding的使用方法如下:
.. code-block:: python
+ from fastNLP import BertEmbedding
embed = BertEmbedding(
vocab, model_dir_or_name='en-base-cased', requires_grad=False, layers='4,-2,-1'
)
其中layers变量表示需要取哪几层的encode结果。
-目前支持的BertEmbedding模型有:
-
- ========================== ====================================
- 模型名称 模型
- -------------------------- ------------------------------------
- en bert-base-cased
- -------------------------- ------------------------------------
- en-base-uncased bert-base-uncased
- -------------------------- ------------------------------------
- en-base-cased bert-base-cased
- -------------------------- ------------------------------------
- en-large-uncased bert-large-uncased
- -------------------------- ------------------------------------
- en-large-cased bert-large-cased
- -------------------------- ------------------------------------
- -------------------------- ------------------------------------
- en-large-cased-wwm bert-large-cased-whole-word-mask
- -------------------------- ------------------------------------
- en-large-uncased-wwm bert-large-uncased-whole-word-mask
- -------------------------- ------------------------------------
- en-base-cased-mrpc bert-base-cased-finetuned-mrpc
- -------------------------- ------------------------------------
- -------------------------- ------------------------------------
- multilingual bert-base-multilingual-cased
- -------------------------- ------------------------------------
- multilingual-base-uncased bert-base-multilingual-uncased
- -------------------------- ------------------------------------
- multilingual-base-cased bert-base-multilingual-cased
- ========================== ====================================
-
-----------------------------------------------------
Part V: 使用character-level的embedding
-----------------------------------------------------
-除了预训练的embedding以外,fastNLP还提供了CharEmbedding: :class:`~fastNLP.modules.encoder.embedding.CNNCharEmbedding` 和
-:class:`~fastNLP.modules.encoder.embedding.LSTMCharEmbedding` 。
+除了预训练的embedding以外,fastNLP还提供了CharEmbedding: :class:`~fastNLP.embeddings.CNNCharEmbedding` 和
+:class:`~fastNLP.embeddings.LSTMCharEmbedding` 。
CNNCharEmbedding的使用例子如下:
.. code-block:: python
+ from fastNLP import CNNCharEmbedding
embed = CNNCharEmbedding(vocab, embed_size=100, char_emb_size=50)
这表示这个CNNCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。
@@ -181,22 +123,23 @@ CNNCharEmbedding的使用例子如下:
.. code-block:: python
+ from fastNLP import LSTMCharEmbedding
embed = LSTMCharEmbedding(vocab, embed_size=100, char_emb_size=50)
这表示这个LSTMCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。
-
-----------------------------------------------------
Part VI: 叠加使用多个embedding
-----------------------------------------------------
-在fastNLP中,我们使用 :class:`~fastNLP.modules.encoder.embedding.StackEmbedding` 来叠加多个embedding
+在fastNLP中,我们使用 :class:`~fastNLP.embeddings.StackEmbedding` 来叠加多个embedding
例子如下:
.. code-block:: python
+ from fastNLP import StaticEmbedding, StackEmbedding
embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
@@ -208,7 +151,17 @@ StackEmbedding会把多个embedding的结果拼接起来,如上面例子的sta
.. code-block:: python
+ from fastNLP import StaticEmbedding, StackEmbedding, ElmoEmbedding
elmo_embedding = ElmoEmbedding(vocab, model_dir_or_name='medium', layers='0,1,2', requires_grad=False)
glove_embedding = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
stack_embed = StackEmbedding([elmo_embedding, glove_embedding])
+
+------------------------------------------
+Part VII: fastNLP支持的预训练Embedding
+------------------------------------------
+
+fastNLP支持多种预训练Embedding并提供自动下载功能,具体参见文档
+
+`fastNLP可加载的embedding与数据集 `_
+
diff --git a/docs/source/tutorials/tutorial_4_loss_optimizer.rst b/docs/source/tutorials/tutorial_4_loss_optimizer.rst
index 2a4d159a..f863a7a8 100644
--- a/docs/source/tutorials/tutorial_4_loss_optimizer.rst
+++ b/docs/source/tutorials/tutorial_4_loss_optimizer.rst
@@ -1,8 +1,9 @@
==============================================================================
-Loss 和 optimizer 教程 ———— 以文本分类为例
+动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试
==============================================================================
-我们使用和 :doc:`/user/quickstart` 中一样的任务来进行详细的介绍。给出一段评价性文字,预测其情感倾向是积极(label=1)、消极(label=0)还是中性(label=2),使用 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester` 来进行快速训练和测试,损失函数之前的内容与 :doc:`/tutorials/tutorial_5_datasetiter` 中的完全一样,如已经阅读过可以跳过。
+我们使用和 :doc:`/user/quickstart` 中一样的任务来进行详细的介绍。给出一段评价性文字,预测其情感倾向是积极(label=1)、
+消极(label=0)还是中性(label=2),使用 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester` 来进行快速训练和测试。
--------------
数据处理
@@ -157,6 +158,7 @@ Vocabulary 的使用
损失函数
训练模型需要提供一个损失函数
,fastNLP中提供了直接可以导入使用的四种loss,分别为:
+
* :class:`~fastNLP.CrossEntropyLoss`:包装了torch.nn.functional.cross_entropy()函数,返回交叉熵损失(可以运用于多分类场景)
* :class:`~fastNLP.BCELoss`:包装了torch.nn.functional.binary_cross_entropy()函数,返回二分类的交叉熵
* :class:`~fastNLP.L1Loss`:包装了torch.nn.functional.l1_loss()函数,返回L1 损失
@@ -208,7 +210,7 @@ Vocabulary 的使用
#使用CNNText的时候第一个参数输入一个tuple,作为模型定义embedding的参数
#还可以传入 kernel_nums, kernel_sizes, padding, dropout的自定义值
- model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=3, padding=2, dropout=0.1)
+ model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=3, dropout=0.1)
#如果在定义trainer的时候没有传入optimizer参数,模型默认的优化器为torch.optim.Adam且learning rate为lr=4e-3
#这里只使用了optimizer_1作为优化器输入,感兴趣可以尝试optimizer_2或者其他优化器作为输入
diff --git a/docs/source/tutorials/tutorial_5_datasetiter.rst b/docs/source/tutorials/tutorial_5_datasetiter.rst
index b57bd5c8..e81b18dd 100644
--- a/docs/source/tutorials/tutorial_5_datasetiter.rst
+++ b/docs/source/tutorials/tutorial_5_datasetiter.rst
@@ -1,8 +1,10 @@
==============================================================================
-DataSetIter 教程 ———— 以文本分类为例
+动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程
==============================================================================
-我们使用和 :doc:`/user/quickstart` 中一样的任务来进行详细的介绍。给出一段评价性文字,预测其情感倾向是积极(label=1)、消极(label=0)还是中性(label=2),使用:class:`~fastNLP.DataSetIter` 类来编写自己的训练过程。自己编写训练过程之前的内容与 :doc:`/tutorials/tutorial_4_loss_optimizer` 中的完全一样,如已经阅读过可以跳过。
+我们使用和 :doc:`/user/quickstart` 中一样的任务来进行详细的介绍。给出一段评价性文字,预测其情感倾向是积极(label=1)、
+消极(label=0)还是中性(label=2),使用 :class:`~fastNLP.DataSetIter` 类来编写自己的训练过程。
+自己编写训练过程之前的内容与 :doc:`/tutorials/tutorial_4_loss_optimizer` 中的完全一样,如已经阅读过可以跳过。
--------------
数据处理
@@ -190,7 +192,7 @@ sampler
import time
embed_dim = 100
- model = CNNText((len(vocab),embed_dim), num_classes=3, padding=2, dropout=0.1)
+ model = CNNText((len(vocab),embed_dim), num_classes=3, dropout=0.1)
def train(epoch, data, devdata):
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
diff --git a/docs/source/tutorials/tutorial_6_seq_labeling.rst b/docs/source/tutorials/tutorial_6_seq_labeling.rst
index 490db6f5..09a53cdc 100644
--- a/docs/source/tutorials/tutorial_6_seq_labeling.rst
+++ b/docs/source/tutorials/tutorial_6_seq_labeling.rst
@@ -1,5 +1,5 @@
=====================
-序列标注教程
+快速实现序列标注模型
=====================
这一部分的内容主要展示如何使用fastNLP 实现序列标注任务。你可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。
@@ -45,7 +45,7 @@ fastNLP可以方便地载入各种类型的数据。同时,针对常见的数
数据处理
----------------------------
-我们进一步处理数据。将数据和词表封装在 :class:`~fastNLP.DataInfo` 类中。data是DataInfo的实例。
+我们进一步处理数据。将数据和词表封装在 :class:`~fastNLP.DataBundle` 类中。data是DataBundle的实例。
我们输入模型的数据包括char embedding,以及word embedding。在数据处理部分,我们尝试完成词表的构建。
使用fastNLP中的Vocabulary类来构建词表。
diff --git a/docs/source/tutorials/tutorial_7_modules_models.rst b/docs/source/tutorials/tutorial_7_modules_models.rst
index d69d9d2e..680d75fd 100644
--- a/docs/source/tutorials/tutorial_7_modules_models.rst
+++ b/docs/source/tutorials/tutorial_7_modules_models.rst
@@ -1,5 +1,5 @@
======================================
-Modules 和 models 的教程
+使用Modules和Models快速搭建自定义模型
======================================
:mod:`~fastNLP.modules` 和 :mod:`~fastNLP.models` 用于构建 fastNLP 所需的神经网络模型,它可以和 torch.nn 中的模型一起使用。
@@ -181,7 +181,7 @@ FastNLP 完全支持使用 pyTorch 编写的模型,但与 pyTorch 中编写模
)
)
-FastNLP 中包含的各种模块如下表,您可以点击具体的名称查看详细的 API:
+FastNLP 中包含的各种模块如下表,您可以点击具体的名称查看详细的 API,也可以通过 :doc:`/fastNLP.modules` 进行了解。
.. csv-table::
:header: 名称, 介绍
@@ -189,7 +189,6 @@ FastNLP 中包含的各种模块如下表,您可以点击具体的名称查看
:class:`~fastNLP.modules.ConvolutionCharEncoder` , char级别的卷积 encoder
:class:`~fastNLP.modules.LSTMCharEncoder` , char级别基于LSTM的 encoder
:class:`~fastNLP.modules.ConvMaxpool` , 结合了Convolution和Max-Pooling于一体的模块
- :class:`~fastNLP.modules.Embedding` , 基础的Embedding模块
:class:`~fastNLP.modules.LSTM` , LSTM模块, 轻量封装了PyTorch的LSTM
:class:`~fastNLP.modules.StarTransformer` , Star-Transformer 的encoder部分
:class:`~fastNLP.modules.TransformerEncoder` , Transformer的encoder模块,不包含embedding层
@@ -198,8 +197,11 @@ FastNLP 中包含的各种模块如下表,您可以点击具体的名称查看
:class:`~fastNLP.modules.VarGRU` , Variational Dropout GRU 模块
:class:`~fastNLP.modules.MaxPool` , Max-pooling模块
:class:`~fastNLP.modules.MaxPoolWithMask` , 带mask矩阵的max pooling。在做 max-pooling的时候不会考虑mask值为0的位置。
+ :class:`~fastNLP.modules.AvgPool` , Average-pooling模块
+ :class:`~fastNLP.modules.AvgPoolWithMask` , 带mask矩阵的average pooling。在做 average-pooling的时候不会考虑mask值为0的位置。
:class:`~fastNLP.modules.MultiHeadAttention` , MultiHead Attention 模块
:class:`~fastNLP.modules.MLP` , 简单的多层感知器模块
:class:`~fastNLP.modules.ConditionalRandomField` , 条件随机场模块
:class:`~fastNLP.modules.viterbi_decode` , 给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数 (与 :class:`~fastNLP.modules.ConditionalRandomField` 配合使用)
:class:`~fastNLP.modules.allowed_transitions` , 给定一个id到label的映射表,返回所有可以跳转的列表(与 :class:`~fastNLP.modules.ConditionalRandomField` 配合使用)
+ :class:`~fastNLP.modules.TimestepDropout` , 简单包装过的Dropout 组件
diff --git a/docs/source/tutorials/tutorial_8_metrics.rst b/docs/source/tutorials/tutorial_8_metrics.rst
index a3c6770e..0b4f86c8 100644
--- a/docs/source/tutorials/tutorial_8_metrics.rst
+++ b/docs/source/tutorials/tutorial_8_metrics.rst
@@ -1,6 +1,6 @@
-=====================
-Metric 教程
-=====================
+===============================
+使用Metric快速评测你的模型
+===============================
在进行训练时,fastNLP提供了各种各样的 :mod:`~fastNLP.core.metrics` 。
如 :doc:`/user/quickstart` 中所介绍的,:class:`~fastNLP.AccuracyMetric` 类的对象被直接传到 :class:`~fastNLP.Trainer` 中用于训练
diff --git a/docs/source/tutorials/tutorial_9_callback.rst b/docs/source/tutorials/tutorial_9_callback.rst
index 01fbb6c3..8e2742bb 100644
--- a/docs/source/tutorials/tutorial_9_callback.rst
+++ b/docs/source/tutorials/tutorial_9_callback.rst
@@ -1,6 +1,6 @@
-==============================================================================
-Callback 教程
-==============================================================================
+===================================================
+使用Callback自定义你的训练过程
+===================================================
在训练时,我们常常要使用trick来提高模型的性能(如调节学习率),或者要打印训练中的信息。
这里我们提供Callback类,在Trainer中插入代码,完成一些自定义的操作。
@@ -44,10 +44,10 @@ Callback的构建和使用
这里,:class:`~fastNLP.Callback` 中所有以 ``on_`` 开头的类方法会在 :class:`~fastNLP.Trainer` 的训练中在特定时间调用。
如 on_train_begin() 会在训练开始时被调用,on_epoch_end() 会在每个 epoch 结束时调用。
- 具体有哪些类方法,参见文档。
+ 具体有哪些类方法,参见文档 :class:`~fastNLP.Callback` 。
另外,为了使用方便,可以在 :class:`~fastNLP.Callback` 内部访问 :class:`~fastNLP.Trainer` 中的属性,如 optimizer, epoch, step,分别对应训练时的优化器,当前epoch数,和当前的总step数。
- 具体可访问的属性,参见文档。
+ 具体可访问的属性,参见文档 :class:`~fastNLP.Callback` 。
使用Callback
在定义好 :class:`~fastNLP.Callback` 之后,就能将它传入Trainer的 ``callbacks`` 参数,在实际训练时使用。
diff --git a/docs/source/user/tutorials.rst b/docs/source/user/tutorials.rst
index cd1fba05..3e9e1b54 100644
--- a/docs/source/user/tutorials.rst
+++ b/docs/source/user/tutorials.rst
@@ -1,18 +1,20 @@
-===================
-fastNLP详细使用教程
-===================
+========================
+fastNLP 详细使用教程
+========================
+
+这里是更详细的使用教程。对于大部分的用户,我们建议你从第一篇开始顺序阅读;如果你只想了解其中的一部分,也可以进行选读。
.. toctree::
:maxdepth: 1
- 1. 使用DataSet预处理文本
- 2. 使用DataSetLoader加载数据集
- 3. 使用Embedding模块将文本转成向量
- 4. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试
- 5. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程
- 6. 快速实现序列标注模型
- 7. 使用Modules和Models快速搭建自定义模型
- 8. 使用Metric快速评测你的模型
- 9. 使用Callback自定义你的训练过程
- 10. 使用fitlog 辅助 fastNLP 进行科研
+ 使用DataSet预处理文本
+ 使用Loader和Pipe加载并处理数据集
+ 使用Embedding模块将文本转成向量
+ 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试
+ 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程
+ 快速实现序列标注模型
+ 使用Modules和Models快速搭建自定义模型
+ 使用Metric快速评测你的模型
+ 使用Callback自定义你的训练过程
+ 使用fitlog 辅助 fastNLP 进行科研
diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py
index 6b43da13..19efac31 100644
--- a/fastNLP/__init__.py
+++ b/fastNLP/__init__.py
@@ -1,22 +1,24 @@
"""
-fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.modules`、:mod:`~fastNLP.models`
-等子模块组成,你可以点进去查看每个模块的文档。
+fastNLP 由 :mod:`~fastNLP.core` 、 :mod:`~fastNLP.io` 、:mod:`~fastNLP.embeddings` 、 :mod:`~fastNLP.modules`、
+:mod:`~fastNLP.models` 等子模块组成,你可以查看每个模块的文档。
- :mod:`~fastNLP.core` 是fastNLP 的核心模块,包括 DataSet、 Trainer、 Tester 等组件。详见文档 :doc:`/fastNLP.core`
- :mod:`~fastNLP.io` 是实现输入输出的模块,包括了数据集的读取,模型的存取等功能。详见文档 :doc:`/fastNLP.io`
+- :mod:`~fastNLP.embeddings` 提供用于构建复杂网络模型所需的各种embedding。详见文档 :doc:`/fastNLP.embeddings`
- :mod:`~fastNLP.modules` 包含了用于搭建神经网络模型的诸多组件,可以帮助用户快速搭建自己所需的网络。详见文档 :doc:`/fastNLP.modules`
-- :mod:`~fastNLP.models` 包含了一些使用 fastNLP 实现的完整网络模型,包括CNNText、SeqLabeling等常见模型。详见文档 :doc:`/fastNLP.models`
+- :mod:`~fastNLP.models` 包含了一些使用 fastNLP 实现的完整网络模型,包括 :class:`~fastNLP.models.CNNText` 、 :class:`~fastNLP.models.SeqLabeling` 等常见模型。详见文档 :doc:`fastNLP.models`
fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的文档如下:
"""
__all__ = [
"Instance",
"FieldArray",
-
+
+
"DataSetIter",
"BatchIter",
"TorchLoaderIter",
-
+
"Vocabulary",
"DataSet",
"Const",
@@ -30,6 +32,7 @@ __all__ = [
"TensorboardCallback",
"LRScheduler",
"ControlC",
+ "LRFinder",
"Padder",
"AutoPadder",
@@ -42,7 +45,8 @@ __all__ = [
"Optimizer",
"SGD",
"Adam",
-
+ "AdamW",
+
"Sampler",
"SequentialSampler",
"BucketSampler",
@@ -50,15 +54,19 @@ __all__ = [
"LossFunc",
"CrossEntropyLoss",
- "L1Loss", "BCELoss",
+ "L1Loss",
+ "BCELoss",
"NLLLoss",
"LossInForward",
- "cache_results"
+ "cache_results",
+
+ 'logger'
]
__version__ = '0.4.5'
-from .core import *
+from . import embeddings
from . import models
from . import modules
-from .io import data_loader
+from .core import *
+from .io import loader, pipe
diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py
index efc83017..efee08b5 100644
--- a/fastNLP/core/__init__.py
+++ b/fastNLP/core/__init__.py
@@ -1,30 +1,94 @@
"""
core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fastNLP 包中直接 import。当然你也同样可以从 core 模块的子模块中 import,
-例如 Batch 组件有两种 import 的方式::
+例如 :class:`~fastNLP.DataSetIter` 组件有两种 import 的方式::
# 直接从 fastNLP 中 import
- from fastNLP import Batch
+ from fastNLP import DataSetIter
- # 从 core 模块的子模块 batch 中 import
- from fastNLP.core.batch import Batch
+ # 从 core 模块的子模块 batch 中 import DataSetIter
+ from fastNLP.core.batch import DataSetIter
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。
-.. todo::
- 介绍core 的子模块的分工,好像必要性不大
-
"""
+__all__ = [
+ "DataSet",
+
+ "Instance",
+
+ "FieldArray",
+ "Padder",
+ "AutoPadder",
+ "EngChar2DPadder",
+
+ "Vocabulary",
+
+ "DataSetIter",
+ "BatchIter",
+ "TorchLoaderIter",
+
+ "Const",
+
+ "Tester",
+ "Trainer",
+
+ "cache_results",
+ "seq_len_to_mask",
+ "get_seq_len",
+ "logger",
+
+ "Callback",
+ "GradientClipCallback",
+ "EarlyStopCallback",
+ "FitlogCallback",
+ "EvaluateCallback",
+ "LRScheduler",
+ "ControlC",
+ "LRFinder",
+ "TensorboardCallback",
+ "WarmupCallback",
+ 'SaveModelCallback',
+ "EchoCallback",
+ "TesterCallback",
+ "CallbackException",
+ "EarlyStopError",
+
+ "LossFunc",
+ "CrossEntropyLoss",
+ "L1Loss",
+ "BCELoss",
+ "NLLLoss",
+ "LossInForward",
+
+ "AccuracyMetric",
+ "SpanFPreRecMetric",
+ "ExtractiveQAMetric",
+
+ "Optimizer",
+ "SGD",
+ "Adam",
+ "AdamW",
+
+ "SequentialSampler",
+ "BucketSampler",
+ "RandomSampler",
+ "Sampler",
+]
+
+from ._logger import logger
from .batch import DataSetIter, BatchIter, TorchLoaderIter
-from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC
+from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \
+ LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \
+ TesterCallback, CallbackException, EarlyStopError
from .const import Const
from .dataset import DataSet
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder
from .instance import Instance
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric
-from .optimizer import Optimizer, SGD, Adam
+from .optimizer import Optimizer, SGD, Adam, AdamW
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler
from .tester import Tester
from .trainer import Trainer
-from .utils import cache_results, seq_len_to_mask
+from .utils import cache_results, seq_len_to_mask, get_seq_len
from .vocabulary import Vocabulary
diff --git a/fastNLP/core/_logger.py b/fastNLP/core/_logger.py
new file mode 100644
index 00000000..7198cfbd
--- /dev/null
+++ b/fastNLP/core/_logger.py
@@ -0,0 +1,155 @@
+"""undocumented"""
+
+__all__ = [
+ 'logger',
+]
+
+import logging
+import logging.config
+import os
+import sys
+import warnings
+
+ROOT_NAME = 'fastNLP'
+
+try:
+ import fitlog
+except ImportError:
+ fitlog = None
+try:
+ from tqdm.auto import tqdm
+except ImportError:
+ tqdm = None
+
+if tqdm is not None:
+ class TqdmLoggingHandler(logging.Handler):
+ def __init__(self, level=logging.INFO):
+ super().__init__(level)
+
+ def emit(self, record):
+ try:
+ msg = self.format(record)
+ tqdm.write(msg)
+ self.flush()
+ except (KeyboardInterrupt, SystemExit):
+ raise
+ except:
+ self.handleError(record)
+else:
+ class TqdmLoggingHandler(logging.StreamHandler):
+ def __init__(self, level=logging.INFO):
+ super().__init__(sys.stdout)
+ self.setLevel(level)
+
+
+def _get_level(level):
+ if isinstance(level, int):
+ pass
+ else:
+ level = level.lower()
+ level = {'info': logging.INFO, 'debug': logging.DEBUG,
+ 'warn': logging.WARN, 'warning': logging.WARN,
+ 'error': logging.ERROR}[level]
+ return level
+
+
+def _add_file_handler(logger, path, level='INFO'):
+ for h in logger.handlers:
+ if isinstance(h, logging.FileHandler):
+ if os.path.abspath(path) == h.baseFilename:
+ # file path already added
+ return
+
+ # File Handler
+ if os.path.exists(path):
+ assert os.path.isfile(path)
+ warnings.warn('log already exists in {}'.format(path))
+ dirname = os.path.abspath(os.path.dirname(path))
+ os.makedirs(dirname, exist_ok=True)
+
+ file_handler = logging.FileHandler(path, mode='a')
+ file_handler.setLevel(_get_level(level))
+ file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s',
+ datefmt='%Y/%m/%d %H:%M:%S')
+ file_handler.setFormatter(file_formatter)
+ logger.addHandler(file_handler)
+
+
+def _set_stdout_handler(logger, stdout='tqdm', level='INFO'):
+ level = _get_level(level)
+ if stdout not in ['none', 'plain', 'tqdm']:
+ raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm']))
+ # make sure to initialize logger only once
+ stream_handler = None
+ for i, h in enumerate(logger.handlers):
+ if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)):
+ stream_handler = h
+ break
+ if stream_handler is not None:
+ logger.removeHandler(stream_handler)
+
+ # Stream Handler
+ if stdout == 'plain':
+ stream_handler = logging.StreamHandler(sys.stdout)
+ elif stdout == 'tqdm':
+ stream_handler = TqdmLoggingHandler(level)
+ else:
+ stream_handler = None
+
+ if stream_handler is not None:
+ stream_formatter = logging.Formatter('%(message)s')
+ stream_handler.setLevel(level)
+ stream_handler.setFormatter(stream_formatter)
+ logger.addHandler(stream_handler)
+
+
+class FastNLPLogger(logging.getLoggerClass()):
+ def __init__(self, name):
+ super().__init__(name)
+
+ def add_file(self, path='./log.txt', level='INFO'):
+ """add log output file and level"""
+ _add_file_handler(self, path, level)
+
+ def set_stdout(self, stdout='tqdm', level='INFO'):
+ """set stdout format and level"""
+ _set_stdout_handler(self, stdout, level)
+
+
+logging.setLoggerClass(FastNLPLogger)
+
+
+# print(logging.getLoggerClass())
+# print(logging.getLogger())
+
+def _init_logger(path=None, stdout='tqdm', level='INFO'):
+ """initialize logger"""
+ level = _get_level(level)
+
+ # logger = logging.getLogger()
+ logger = logging.getLogger(ROOT_NAME)
+ logger.propagate = False
+ logger.setLevel(level)
+
+ _set_stdout_handler(logger, stdout, level)
+
+ # File Handler
+ if path is not None:
+ _add_file_handler(logger, path, level)
+
+ return logger
+
+
+def _get_logger(name=None, level='INFO'):
+ level = _get_level(level)
+ if name is None:
+ name = ROOT_NAME
+ assert isinstance(name, str)
+ if not name.startswith(ROOT_NAME):
+ name = '{}.{}'.format(ROOT_NAME, name)
+ logger = logging.getLogger(name)
+ logger.setLevel(level)
+ return logger
+
+
+logger = _init_logger(path=None)
diff --git a/fastNLP/core/_parallel_utils.py b/fastNLP/core/_parallel_utils.py
index 4a7757d3..ce745820 100644
--- a/fastNLP/core/_parallel_utils.py
+++ b/fastNLP/core/_parallel_utils.py
@@ -1,10 +1,14 @@
+"""undocumented"""
+
+__all__ = []
import threading
+
import torch
+from torch import nn
from torch.nn.parallel.parallel_apply import get_a_var
-
-from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
from torch.nn.parallel.replicate import replicate
+from torch.nn.parallel.scatter_gather import scatter_kwargs, gather
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
@@ -26,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
-
+
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
-
+
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
@@ -46,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None):
except Exception as e:
with lock:
results[i] = e
-
+
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
-
+
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
-
+
outputs = []
for i in range(len(inputs)):
output = results[i]
@@ -78,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
:param output_device: nn.DataParallel中的output_device
:return:
"""
+
def wrapper(network, *inputs, **kwargs):
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0)
if len(device_ids) == 1:
@@ -85,4 +90,18 @@ def _data_parallel_wrapper(func_name, device_ids, output_device):
replicas = replicate(network, device_ids[:len(inputs)])
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)])
return gather(outputs, output_device)
+
return wrapper
+
+
+def _model_contains_inner_module(model):
+ """
+
+ :param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel,
+ nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。
+ :return: bool
+ """
+ if isinstance(model, nn.Module):
+ if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)):
+ return True
+ return False
diff --git a/fastNLP/core/batch.py b/fastNLP/core/batch.py
index 2d8c1a80..ff710b30 100644
--- a/fastNLP/core/batch.py
+++ b/fastNLP/core/batch.py
@@ -1,24 +1,23 @@
"""
-batch 模块实现了 fastNLP 所需的 Batch 类。
+batch 模块实现了 fastNLP 所需的 :class:`~fastNLP.core.batch.DataSetIter` 类。
"""
__all__ = [
+ "BatchIter",
"DataSetIter",
"TorchLoaderIter",
]
import atexit
-from queue import Empty, Full
import numpy as np
import torch
-import torch.multiprocessing as mp
import torch.utils.data
from numbers import Number
from .sampler import SequentialSampler
from .dataset import DataSet
-
+from ._logger import logger
_python_is_exit = False
@@ -49,6 +48,11 @@ class DataSetGetter:
return len(self.dataset)
def collate_fn(self, batch: list):
+ """
+
+ :param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]]
+ :return:
+ """
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景
batch_x = {n:[] for n in self.inputs.keys()}
batch_y = {n:[] for n in self.targets.keys()}
@@ -71,7 +75,7 @@ class DataSetGetter:
try:
data, flag = _to_tensor(data, f.dtype)
except TypeError as e:
- print(f"Field {n} cannot be converted to torch.tensor.")
+ logger.error(f"Field {n} cannot be converted to torch.tensor.")
raise e
batch_dict[n] = data
return batch_dict
@@ -94,9 +98,13 @@ class DataSetGetter:
class SamplerAdapter(torch.utils.data.Sampler):
def __init__(self, sampler, dataset):
+ super().__init__(dataset)
self.sampler = sampler
self.dataset = dataset
+ def __len__(self):
+ return len(self.dataset)
+
def __iter__(self):
return iter(self.sampler(self.dataset))
@@ -166,15 +174,19 @@ class DataSetIter(BatchIter):
timeout=0, worker_init_fn=None):
super().__init__()
assert isinstance(dataset, DataSet)
- sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
+ if not isinstance(sampler, torch.utils.data.Sampler):
+ self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset)
+ else:
+ self.sampler = sampler
dataset = DataSetGetter(dataset, as_numpy)
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None
self.dataiter = torch.utils.data.DataLoader(
- dataset=dataset, batch_size=batch_size, sampler=sampler,
+ dataset=dataset, batch_size=batch_size, sampler=self.sampler,
collate_fn=collate_fn, num_workers=num_workers,
pin_memory=pin_memory, drop_last=drop_last,
timeout=timeout, worker_init_fn=worker_init_fn)
- self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last)
+ # 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了
+ self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last)
self.batch_size = batch_size
@@ -183,7 +195,7 @@ class TorchLoaderIter(BatchIter):
super().__init__()
assert isinstance(dataset, torch.utils.data.DataLoader)
self.dataiter = dataset
- self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last)
+ self.num_batches = self.get_num_batches(len(dataset.sampler), dataset.batch_size, dataset.drop_last)
self.batch_size = dataset.batch_size
@@ -201,6 +213,13 @@ class OnlineDataIter(BatchIter):
def _to_tensor(batch, field_dtype):
+ """
+
+ :param batch: np.array()
+ :param field_dtype: 数据类型
+ :return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor,
+ 返回的batch就是原来的数据,且flag为False
+ """
try:
if field_dtype is not None and isinstance(field_dtype, type)\
and issubclass(field_dtype, Number) \
diff --git a/fastNLP/core/callback.py b/fastNLP/core/callback.py
index 8a202795..2c130061 100644
--- a/fastNLP/core/callback.py
+++ b/fastNLP/core/callback.py
@@ -2,11 +2,11 @@ r"""
callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:`~fastNLP.Trainer` 类。
虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,
-比如负采样,learning rate decay, Early Stop等。
-为了解决这个问题fastNLP引入了callback的机制,Callback 是一种在Trainer训练过程中特定阶段会运行的函数集合。
-关于Trainer的详细文档,请参见 :doc:`trainer 模块`
+比如负采样,learning rate decay 和 early stop等。
+为了解决这个问题,fastNLP引入了callback的机制,:class:`~fastNLP.Callback` 是一种在Trainer训练过程中特定阶段会运行的函数集合。
+关于 :class:`~fastNLP.Trainer` 的详细文档,请参见 :doc:`trainer 模块`
-我们将 :meth:`~fastNLP.Train.train` 这个函数内部分为以下的阶段,在对应阶段会触发相应的调用::
+我们将 :meth:`~fastNLP.Trainer.train` 这个函数内部分为以下的阶段,在对应阶段会触发相应的调用::
callback.on_train_begin() # 开始进行训练
for i in range(1, n_epochs+1):
@@ -31,8 +31,8 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:
callback.on_train_end() # 训练结束
callback.on_exception() # 这是一个特殊的步骤,在训练过程中遭遇exception会跳转到这里。
-如下面的例子所示,我们可以使用内置的 callback 类,或者继承 :class:`~fastNLP.core.callback.Callback`
-定义自己的 callback 类::
+如下面的例子所示,我们可以使用内置的 callback 组件,或者继承 :class:`~fastNLP.core.callback.Callback`
+定义自己的 callback 组件::
from fastNLP import Callback, EarlyStopCallback, Trainer, CrossEntropyLoss, AccuracyMetric
from fastNLP.models import CNNText
@@ -51,12 +51,19 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class:
"""
__all__ = [
"Callback",
+
"GradientClipCallback",
"EarlyStopCallback",
- "TensorboardCallback",
"FitlogCallback",
+ "EvaluateCallback",
"LRScheduler",
"ControlC",
+ "LRFinder",
+ "TensorboardCallback",
+ "WarmupCallback",
+ "SaveModelCallback",
+ "EchoCallback",
+ "TesterCallback",
"CallbackException",
"EarlyStopError"
@@ -76,9 +83,9 @@ try:
except:
tensorboardX_flag = False
-from ..io.model_io import ModelSaver, ModelLoader
from .dataset import DataSet
from .tester import Tester
+from ._logger import logger
try:
import fitlog
@@ -100,7 +107,8 @@ class Callback(object):
def __init__(self):
super(Callback, self).__init__()
self._trainer = None # 在Trainer内部被重新赋值
-
+ self._disabled = False
+
@property
def trainer(self):
"""
@@ -158,7 +166,19 @@ class Callback(object):
def batch_per_epoch(self):
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。"""
return self._trainer.batch_per_epoch
-
+
+ @property
+ def is_master(self):
+ return self._trainer.is_master
+
+ @property
+ def disabled(self):
+ return self._disabled
+
+ @property
+ def logger(self):
+ return getattr(self._trainer, 'logger', logger)
+
def on_train_begin(self):
"""
在Train过程开始之前调用。
@@ -250,6 +270,14 @@ class Callback(object):
:return:
"""
pass
+
+ def on_validation(self):
+ """
+ 如果Trainer中设置了验证,则会在每次需要验证时调用该函数
+
+ :return:
+ """
+ pass
def on_epoch_end(self):
"""
@@ -281,6 +309,8 @@ def _transfer(func):
def wrapper(manager, *arg):
returns = []
for callback in manager.callbacks:
+ if callback.disabled:
+ continue
returns.append(getattr(callback, func.__name__)(*arg))
return returns
@@ -297,22 +327,28 @@ class CallbackManager(Callback):
"""
super(CallbackManager, self).__init__()
# set attribute of trainer environment
-
+ self._env = env
self.callbacks = []
- if callbacks is not None:
- if isinstance(callbacks, list):
- if all([isinstance(cb, Callback) for cb in callbacks]) is True:
- self.callbacks.extend(callbacks)
- else:
- obj = [not isinstance(cb, Callback) for cb in callbacks][0]
- raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
+ if callbacks:
+ self.callbacks = self.prepare_callbacks(callbacks)
+
+ def prepare_callbacks(self, callbacks):
+ if not callbacks:
+ return []
+ if isinstance(callbacks, list):
+ if all([isinstance(cb, Callback) for cb in callbacks]) is True:
+ pass
else:
- raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
-
- for env_name, env_val in env.items():
- for callback in self.callbacks:
+ obj = [not isinstance(cb, Callback) for cb in callbacks][0]
+ raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}")
+ else:
+ raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.")
+
+ for env_name, env_val in self._env.items():
+ for callback in callbacks:
setattr(callback, '_' + env_name, env_val) # Callback.trainer
-
+ return callbacks
+
@_transfer
def on_train_begin(self):
pass
@@ -352,6 +388,10 @@ class CallbackManager(Callback):
@_transfer
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval):
pass
+
+ @_transfer
+ def on_validation(self):
+ pass
@_transfer
def on_epoch_end(self):
@@ -366,6 +406,33 @@ class CallbackManager(Callback):
pass
+class DistCallbackManager(CallbackManager):
+ def __init__(self, env, callbacks_all=None, callbacks_master=None):
+ super(DistCallbackManager, self).__init__(env)
+ assert 'trainer' in env
+ self._trainer = env['trainer']
+ self.callbacks_master = []
+ self.callbacks_all = []
+ self.add_callback(callbacks_all, master=False)
+ self.add_callback(callbacks_master, master=True)
+
+ def patch_callback(self, callbacks, disabled):
+ if not callbacks:
+ return
+ if not isinstance(callbacks, (list, tuple)):
+ callbacks = [callbacks]
+ for cb in callbacks:
+ cb._disabled = disabled
+
+ def add_callback(self, cb, master=False):
+ if master:
+ self.patch_callback(cb, not self.is_master)
+ self.callbacks_master += self.prepare_callbacks(cb)
+ else:
+ self.callbacks_all += self.prepare_callbacks(cb)
+ self.callbacks = self.callbacks_all + self.callbacks_master
+
+
class GradientClipCallback(Callback):
"""
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback`
@@ -403,6 +470,9 @@ class GradientClipCallback(Callback):
def on_backward_end(self):
if self.step%self.update_every==0:
if self.parameters is None:
+ if getattr(self.trainer, 'fp16', ''):
+ from apex import amp
+ self.clip_fun(amp.master_params(self.optimizer), self.clip_value)
self.clip_fun(self.model.parameters(), self.clip_value)
else:
self.clip_fun(self.parameters, self.clip_value)
@@ -434,7 +504,7 @@ class EarlyStopCallback(Callback):
def on_exception(self, exception):
if isinstance(exception, EarlyStopError):
- print("Early Stopping triggered in epoch {}!".format(self.epoch))
+ logger.info("Early Stopping triggered in epoch {}!".format(self.epoch))
else:
raise exception # 抛出陌生Error
@@ -448,10 +518,9 @@ class FitlogCallback(Callback):
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。
- :param ~fastNLP.DataSet,dict(~fastNLP.DataSet) data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个
- DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过
- dict的方式传入。如果仅传入DataSet, 则被命名为test
- :param ~fastNLP.Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test`
+ :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要
+ 传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。
+ :param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。
:param int verbose: 是否在终端打印evaluation的结果,0不打印。
@@ -465,21 +534,24 @@ class FitlogCallback(Callback):
self._log_exception = log_exception
assert isinstance(log_loss_every, int) and log_loss_every>=0
if tester is not None:
- assert isinstance(tester, Tester), "Only fastNLP.Tester allowed."
- assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data."
- if data is not None:
- assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed."
- setattr(tester, 'verbose', 0)
- self.testers['test'] = tester
-
+ if isinstance(tester, dict):
+ for name, test in tester.items():
+ if not isinstance(test, Tester):
+ raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.")
+ self.testers['tester-' + name] = test
+ if isinstance(tester, Tester):
+ self.testers['tester-test'] = tester
+ for tester in self.testers.values():
+ setattr(tester, 'verbose', 0)
+
if isinstance(data, dict):
for key, value in data.items():
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}."
for key, value in data.items():
- self.datasets[key] = value
+ self.datasets['data-' + key] = value
elif isinstance(data, DataSet):
- self.datasets['test'] = data
- else:
+ self.datasets['data-test'] = data
+ elif data is not None:
raise TypeError("data receives dict[DataSet] or DataSet object.")
self.verbose = verbose
@@ -492,8 +564,11 @@ class FitlogCallback(Callback):
if len(self.datasets) > 0:
for key, data in self.datasets.items():
- tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics,
- verbose=0)
+ tester = Tester(data=data, model=self.model,
+ batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
+ metrics=self.trainer.metrics,
+ verbose=0,
+ use_tqdm=self.trainer.test_use_tqdm)
self.testers[key] = tester
fitlog.add_progress(total_steps=self.n_steps)
@@ -533,6 +608,68 @@ class FitlogCallback(Callback):
fitlog.add_other(repr(exception), name='except_info')
+class EvaluateCallback(Callback):
+ """
+ 别名: :class:`fastNLP.EvaluateCallback` :class:`fastNLP.core.callback.EvaluateCallback`
+
+ 该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。
+
+ :param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个
+ DataSet请通过dict的方式传入。
+ :param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。
+ """
+
+ def __init__(self, data=None, tester=None):
+ super().__init__()
+ self.datasets = {}
+ self.testers = {}
+ if tester is not None:
+ if isinstance(tester, dict):
+ for name, test in tester.items():
+ if not isinstance(test, Tester):
+ raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.")
+ self.testers['tester-' + name] = test
+ if isinstance(tester, Tester):
+ self.testers['tester-test'] = tester
+ for tester in self.testers.values():
+ setattr(tester, 'verbose', 0)
+
+ if isinstance(data, dict):
+ for key, value in data.items():
+ assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}."
+ for key, value in data.items():
+ self.datasets['data-' + key] = value
+ elif isinstance(data, DataSet):
+ self.datasets['data-test'] = data
+ elif data is not None:
+ raise TypeError("data receives dict[DataSet] or DataSet object.")
+
+ def on_train_begin(self):
+ if len(self.datasets) > 0 and self.trainer.dev_data is None:
+ raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.")
+
+ if len(self.datasets) > 0:
+ for key, data in self.datasets.items():
+ tester = Tester(data=data, model=self.model,
+ batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size),
+ metrics=self.trainer.metrics, verbose=0,
+ use_tqdm=self.trainer.test_use_tqdm)
+ self.testers[key] = tester
+
+ def on_valid_end(self, eval_result, metric_key, optimizer, better_result):
+ if len(self.testers) > 0:
+ for key, tester in self.testers.items():
+ try:
+ eval_result = tester.test()
+ # self.pbar.write("Evaluation on {}:".format(key))
+ self.logger.info("Evaluation on {}:".format(key))
+ # self.pbar.write(tester._format_eval_results(eval_result))
+ self.logger.info(tester._format_eval_results(eval_result))
+ except Exception:
+ # self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key))
+ self.logger.info("Exception happens when evaluate on DataSet named `{}`.".format(key))
+
+
class LRScheduler(Callback):
"""
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler`
@@ -586,7 +723,7 @@ class SmoothValue(object):
self.smooth = None
def add_value(self, val: float) -> None:
- "Add `val` to calculate updated smoothed value."
+ """Add `val` to calculate updated smoothed value."""
self.n += 1
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val
self.smooth = self.mov_avg / (1 - self.beta ** self.n)
@@ -614,8 +751,7 @@ class LRFinder(Callback):
self.smooth_value = SmoothValue(0.8)
self.opt = None
self.find = None
- self.loader = ModelLoader()
-
+
@property
def lr_gen(self):
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch
@@ -630,7 +766,7 @@ class LRFinder(Callback):
self.opt = self.trainer.optimizer # pytorch optimizer
self.opt.param_groups[0]["lr"] = self.start_lr
# save model
- ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True)
+ torch.save(self.model.state_dict(), 'tmp')
self.find = True
def on_backward_begin(self, loss):
@@ -659,7 +795,9 @@ class LRFinder(Callback):
self.opt.param_groups[0]["lr"] = self.best_lr
self.find = False
# reset model
- ModelLoader().load_pytorch(self.trainer.model, "tmp")
+ states = torch.load('tmp')
+ self.model.load_state_dict(states)
+ os.remove('tmp')
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr))
@@ -850,14 +988,14 @@ class SaveModelCallback(Callback):
try:
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param)
except Exception as e:
- print(f"The following exception:{e} happens when save model to {self.save_dir}.")
+ logger.error(f"The following exception:{e} happens when save model to {self.save_dir}.")
if delete_pair:
try:
delete_model_path = os.path.join(self.save_dir, delete_pair[1])
if os.path.exists(delete_model_path):
os.remove(delete_model_path)
except Exception as e:
- print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
+ logger.error(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.")
def on_exception(self, exception):
if self.save_on_exception:
@@ -884,3 +1022,70 @@ class EarlyStopError(CallbackException):
def __init__(self, msg):
super(EarlyStopError, self).__init__(msg)
+
+
+class EchoCallback(Callback):
+ def __init__(self, name, out=sys.stdout):
+ super(EchoCallback, self).__init__()
+ self.name = name
+ self.out = out
+
+ def __getattribute__(self, item):
+ if item.startswith('on_'):
+ logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid()),
+ file=self.out)
+ return super(EchoCallback, self).__getattribute__(item)
+
+
+class TesterCallback(Callback):
+ def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None):
+ super(TesterCallback, self).__init__()
+ self.tester = Tester(data, model,
+ metrics=metrics, batch_size=batch_size,
+ num_workers=num_workers, verbose=0)
+ # parse metric_key
+ # increase_better is True. It means the exp result gets better if the indicator increases.
+ # It is true by default.
+ self.increase_better = True
+ if metric_key is not None:
+ self.increase_better = False if metric_key[0] == "-" else True
+ self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key
+ else:
+ self.metric_key = None
+ self.score = None
+
+ def on_validation(self):
+ cur_score = self.tester.test()
+ eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format(
+ self.epoch, self.n_epochs, self.step, self.n_steps,
+ self.tester._format_eval_results(cur_score))
+ self.logger.info(eval_str)
+ is_better = self.compare_better(cur_score)
+ if is_better:
+ self.score = cur_score
+ return cur_score, is_better
+
+ def _get_score(self, metric_dict, key):
+ for metric in metric_dict.items():
+ if key in metric:
+ return metric[key]
+ return None
+
+ def compare_better(self, a):
+ if self.score is None:
+ return True
+ if self.metric_key is None:
+ self.metric_key = list(list(self.score.values())[0].keys())[0]
+ k = self.metric_key
+ score = self._get_score(self.score, k)
+ new_score = self._get_score(a, k)
+ if score is None or new_score is None:
+ return False
+ if self.increase_better:
+ return score <= new_score
+ else:
+ return score >= new_score
+
+ def on_train_end(self):
+ self.logger.info('Evaluate on training ends.')
+ self.on_validation()
diff --git a/fastNLP/core/const.py b/fastNLP/core/const.py
index 89ff51a2..ad5d1f1e 100644
--- a/fastNLP/core/const.py
+++ b/fastNLP/core/const.py
@@ -1,3 +1,13 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "Const"
+]
+
+
class Const:
"""
fastNLP中field命名常量。
@@ -7,12 +17,14 @@ class Const:
具体列表::
- INPUT 模型的序列输入 words(复数words1, words2)
- CHAR_INPUT 模型character输入 chars(复数chars1, chars2)
- INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2)
- OUTPUT 模型输出 pred(复数pred1, pred2)
- TARGET 真实目标 target(复数target1,target2)
- LOSS 损失函数 loss (复数loss1,loss2)
+ INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, )
+ CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2)
+ INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2)
+ OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2)
+ TARGET 真实目标 target(具有多列target时,依次使用target1,target2)
+ LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2)
+ RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2)
+ RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2)
"""
INPUT = 'words'
@@ -21,37 +33,49 @@ class Const:
OUTPUT = 'pred'
TARGET = 'target'
LOSS = 'loss'
-
+ RAW_WORD = 'raw_words'
+ RAW_CHAR = 'raw_chars'
+
@staticmethod
def INPUTS(i):
"""得到第 i 个 ``INPUT`` 的命名"""
i = int(i) + 1
return Const.INPUT + str(i)
-
+
@staticmethod
def CHAR_INPUTS(i):
"""得到第 i 个 ``CHAR_INPUT`` 的命名"""
i = int(i) + 1
return Const.CHAR_INPUT + str(i)
-
+
+ @staticmethod
+ def RAW_WORDS(i):
+ i = int(i) + 1
+ return Const.RAW_WORD + str(i)
+
+ @staticmethod
+ def RAW_CHARS(i):
+ i = int(i) + 1
+ return Const.RAW_CHAR + str(i)
+
@staticmethod
def INPUT_LENS(i):
"""得到第 i 个 ``INPUT_LEN`` 的命名"""
i = int(i) + 1
return Const.INPUT_LEN + str(i)
-
+
@staticmethod
def OUTPUTS(i):
"""得到第 i 个 ``OUTPUT`` 的命名"""
i = int(i) + 1
return Const.OUTPUT + str(i)
-
+
@staticmethod
def TARGETS(i):
"""得到第 i 个 ``TARGET`` 的命名"""
i = int(i) + 1
return Const.TARGET + str(i)
-
+
@staticmethod
def LOSSES(i):
"""得到第 i 个 ``LOSS`` 的命名"""
diff --git a/fastNLP/core/dataset.py b/fastNLP/core/dataset.py
index 8d2c13e7..51bcef43 100644
--- a/fastNLP/core/dataset.py
+++ b/fastNLP/core/dataset.py
@@ -1,7 +1,7 @@
"""
:class:`~fastNLP.core.dataset.DataSet` 是fastNLP中用于承载数据的容器。可以将DataSet看做是一个表格,
-每一行是一个sample (在fastNLP中被称为 :mod:`~.instance` ),
-每一列是一个feature (在fastNLP中称为 :mod:`.field` )。
+每一行是一个sample (在fastNLP中被称为 :mod:`~fastNLP.core.instance` ),
+每一列是一个feature (在fastNLP中称为 :mod:`~fastNLP.core.field` )。
.. csv-table:: Following is a demo layout of DataSet
:header: "sentence", "words", "seq_len"
@@ -13,57 +13,64 @@
在fastNLP内部每一行是一个 :class:`~fastNLP.Instance` 对象; 每一列是一个 :class:`~fastNLP.FieldArray` 对象。
-1 DataSet的创建
- 创建DataSet主要有以下的3种方式
+----------------------------
+1.DataSet的创建
+----------------------------
-1.1 传入dict
+创建DataSet主要有以下的3种方式
- Example::
+1.1 传入dict
+----------------------------
- from fastNLP import DataSet
- data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."],
- 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'],
- 'seq_len': [6, 3, 3]}
- dataset = DataSet(data)
- # 传入的dict的每个key的value应该为具有相同长度的list
+ .. code-block::
-1.2 通过构建Instance
+ from fastNLP import DataSet
+ data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."],
+ 'words': [['this', 'is', 'the', 'first', 'instance', '.'], ['Second', 'instance', '.'], ['Third', 'instance', '.'],
+ 'seq_len': [6, 3, 3]}
+ dataset = DataSet(data)
+ # 传入的dict的每个key的value应该为具有相同长度的list
- Example::
+1.2 通过 Instance 构建
+----------------------------
- from fastNLP import DataSet
- from fastNLP import Instance
- dataset = DataSet()
- instance = Instance(sentence="This is the first instance",
- words=['this', 'is', 'the', 'first', 'instance', '.'],
- seq_len=6)
- dataset.append(instance)
- # 可以继续append更多内容,但是append的instance应该和第一个instance拥有完全相同的field
+ .. code-block::
-1.3 通过list(Instance)
+ from fastNLP import DataSet
+ from fastNLP import Instance
+ dataset = DataSet()
+ instance = Instance(sentence="This is the first instance",
+ words=['this', 'is', 'the', 'first', 'instance', '.'],
+ seq_len=6)
+ dataset.append(instance)
+ # 可以继续append更多内容,但是append的instance应该和第一个instance拥有完全相同的field
- Example::
+1.3 通过 List[Instance] 构建
+--------------------------------------
- from fastNLP import DataSet
- from fastNLP import Instance
- instances = []
- instances.append(Instance(sentence="This is the first instance",
- words=['this', 'is', 'the', 'first', 'instance', '.'],
- seq_len=6))
- instances.append(Instance(sentence="Second instance .",
- words=['Second', 'instance', '.'],
- seq_len=3))
- dataset = DataSet(instances)
+ .. code-block::
-2 DataSet与预处理
- 常见的预处理有如下几种
+ from fastNLP import DataSet
+ from fastNLP import Instance
+ instances = []
+ winstances.append(Instance(sentence="This is the first instance",
+ ords=['this', 'is', 'the', 'first', 'instance', '.'],
+ seq_len=6))
+ instances.append(Instance(sentence="Second instance .",
+ words=['Second', 'instance', '.'],
+ seq_len=3))
+ dataset = DataSet(instances)
+
+--------------------------------------
+2.DataSet与预处理
+--------------------------------------
-2.1 从某个文本文件读取内容 #
+常见的预处理有如下几种
- .. todo::
- 引用DataLoader
+2.1 从某个文本文件读取内容
+--------------------------------------
- Example::
+ .. code-block::
from fastNLP import DataSet
from fastNLP import Instance
@@ -78,9 +85,13 @@
sent, label = line.strip().split('\t')
dataset.append(Instance(sentence=sent, label=label))
+ .. note::
+ 直接读取特定数据集的数据请参考 :doc:`/tutorials/tutorial_2_load_dataset`
+
2.2 对DataSet中的内容处理
+--------------------------------------
- Example::
+ .. code-block::
from fastNLP import DataSet
data = {'sentence':["This is the first instance .", "Second instance .", "Third instance ."]}
@@ -97,8 +108,9 @@
dataset.apply(get_words, new_field_name='words')
2.3 删除DataSet的内容
+--------------------------------------
- Example::
+ .. code-block::
from fastNLP import DataSet
dataset = DataSet({'a': list(range(-5, 5))})
@@ -113,15 +125,17 @@
2.4 遍历DataSet的内容
+--------------------------------------
- Example::
+ .. code-block::
for instance in dataset:
# do something
2.5 一些其它操作
+--------------------------------------
- Example::
+ .. code-block::
# 检查是否存在名为'a'的field
dataset.has_field('a') # 或 ('a' in dataset)
@@ -129,21 +143,25 @@
dataset.rename_field('a', 'b')
# DataSet的长度
len(dataset)
+
+--------------------------------------
+3.DataSet与自然语言处理(NLP)
+--------------------------------------
-3 DataSet与自然语言处理(NLP)
- 在目前深度学习的模型中,大都依赖于随机梯度下降法(SGD)进行模型的优化。随机梯度下降需要将数据切分成一个一个的Batch,
- 一个Batch进行一次前向计算(forward)与梯度后向传播(backward)。在自然语言处理的场景下,往往还需要对数据进行pad。这是
- 由于句子的长度一般是不同的,但是一次Batch中的每个field都必须是一个tensor,所以需要将所有句子都补齐到相同的长度。
+在目前深度学习的模型中,大都依赖于随机梯度下降法(SGD)进行模型的优化。随机梯度下降需要将数据切分成一个个的 batch,
+一个batch进行一次前向计算(forward)与梯度后向传播(backward)。在自然语言处理的场景下,往往还需要对数据进行pad。这是
+由于句子的长度一般是不同的,但是一次batch中的每个field都必须是一个tensor,所以需要将所有句子都补齐到相同的长度。
-3.1 DataSet与Batch
+3.1 DataSet与DataSetIter
+--------------------------------------
- 我们先看fastNLP中如何将数据分成一个一个的Batch的例子, 这里我们使用随机生成的数据来模拟一个二分类文本分类任务,
+ 我们先看fastNLP中如何将数据分成一个一个的batch的例子, 这里我们使用随机生成的数据来模拟一个二分类文本分类任务,
words和characters是输入,labels是文本类别
- Example::
+ .. code-block::
from fastNLP import DataSet
- from fastNLP import Batch
+ from fastNLP import DataSetIter
from fastNLP import SequentialSampler
from fastNLP import EngChar2DPadder
@@ -163,7 +181,7 @@
d.set_target('label')
d.set_input('words', 'chars')
- for batch_x, batch_y in Batch(d, sampler=SequentialSampler(), batch_size=2):
+ for batch_x, batch_y in DataSetIter(d, sampler=SequentialSampler(), batch_size=2):
print("batch_x:", batch_x)
print("batch_y:", batch_y)
break
@@ -182,23 +200,26 @@
# [ 0, 0, 0, 0, 0]]])}
# {'label': tensor([0, 0])}
- 其中 :class:`~fastNLP.Batch` 是用于从DataSet中按照batch_size为大小取出batch的迭代器,
- :class:`~fastNLP.SequentialSampler` 用于指示 Batch 以怎样的
+ 其中 :class:`~fastNLP.DataSetIter` 是用于从DataSet中按照batch_size为大小取出batch的迭代器,
+ :class:`~fastNLP.SequentialSampler` 用于指示 :class:`~fastNLP.DataSetIter` 以怎样的
顺序从DataSet中取出instance以组成一个batch,
- 更详细的说明请参照 :class:`~fastNLP.Batch` 和 :class:`~fastNLP.SequentialSampler` 文档。
+ 更详细的说明请参照 :class:`~fastNLP.DataSetIter` 和 :class:`~fastNLP.SequentialSampler` 文档。
- 通过DataSet.set_input('words', 'chars'), fastNLP将认为'words'和'chars'这两个field都是input,并将它们都放入迭代器
- 生成的第一个dict中; DataSet.set_target('labels'), fastNLP将认为'labels'这个field是target,并将其放入到迭代器的第
+ 通过 ``DataSet.set_input('words', 'chars')`` , fastNLP将认为 `words` 和 `chars` 这两个field都是input,并将它们都放入迭代器
+ 生成的第一个dict中; ``DataSet.set_target('labels')`` , fastNLP将认为 `labels` 这个field是target,并将其放入到迭代器的第
二个dict中。如上例中所打印结果。分为input和target的原因是由于它们在被 :class:`~fastNLP.Trainer` 所使用时会有所差异,
详见 :class:`~fastNLP.Trainer`
- 当把某个field设置为'target'或者'input'的时候(两者不是互斥的,可以同时设为input和target),fastNLP不仅仅只是将其放
- 置到不同的dict中,而还会对被设置为input或target的field进行类型检查。类型检查的目的是为了看能否把该field转为
- pytorch的torch.LongTensor或torch.FloatTensor类型(也可以在Batch中设置输出numpy类型,参考 :class:`~fastNLP.Batch` ),如上例所示,
- fastNLP已将words,chars和label转为了Tensor类型。如果field在每个instance都拥有相同的维度(不能超过两维),且最内层
- 的元素都为相同的type(int, float, np.int*, np.float*),则fastNLP默认将对该field进行pad。也支持全为str的field作为
- target和input,这种情况下,fastNLP默认不进行pad。另外,当某个field已经被设置为了target或者input后,之后append的
- instance对应的field必须要和前面已有的内容一致,否则会报错。
+ 当把某个field设置为 `target` 或者 `input` 的时候(两者不是互斥的,可以同时设为两种),fastNLP不仅仅只是将其放
+ 置到不同的dict中,而还会对被设置为 `input` 或 `target` 的 field 进行类型检查。类型检查的目的是为了看能否把该 field 转为
+ pytorch的 :class:`torch.LongTensor` 或 :class:`torch.FloatTensor` 类型
+ (也可以在 :class:`~fastNLP.DataSetIter` 中设置输出numpy类型,参考 :class:`~fastNLP.DataSetIter` )。
+
+ 如上例所示,fastNLP已将 `words` ,`chars` 和 `label` 转为了 :class:`Tensor` 类型。
+ 如果 field 在每个 `instance` 都拥有相同的维度(不能超过两维),且最内层的元素都为相同的 type(int, float, np.int*, np.float*),
+ 则fastNLP默认将对该 field 进行pad。也支持全为str的field作为target和input,这种情况下,fastNLP默认不进行pad。
+ 另外,当某个 field 已经被设置为了 target 或者 input 后,之后 `append` 的
+ `instance` 对应的 field 必须要和前面已有的内容一致,否则会报错。
可以查看field的dtype::
@@ -217,6 +238,7 @@
错误::
from fastNLP import DataSet
+
d = DataSet({'data': [1, 'a']})
d.set_input('data')
>> RuntimeError: Mixed data types in Field data: [, ]
@@ -231,6 +253,7 @@
当某个field被设置为忽略type之后,fastNLP将不对其进行pad。
3.2 DataSet与pad
+--------------------------------------
在fastNLP里,pad是与一个field绑定的。即不同的field可以使用不同的pad方式,比如在英文任务中word需要的pad和
character的pad方式往往是不同的。fastNLP是通过一个叫做 :class:`~fastNLP.Padder` 的子类来完成的。
@@ -240,7 +263,7 @@
如果 :class:`~fastNLP.AutoPadder` 或 :class:`~fastNLP.EngChar2DPadder` 无法满足需求,
也可以自己写一个 :class:`~fastNLP.Padder` 。
- Example::
+ .. code-block::
from fastNLP import DataSet
from fastNLP import EngChar2DPadder
@@ -268,6 +291,7 @@ import _pickle as pickle
import warnings
import numpy as np
+from copy import deepcopy
from .field import AutoPadder
from .field import FieldArray
@@ -275,6 +299,8 @@ from .instance import Instance
from .utils import _get_func_signature
from .field import AppendToTargetOrInputException
from .field import SetInputOrTargetException
+from .const import Const
+from ._logger import logger
class DataSet(object):
"""
@@ -326,7 +352,11 @@ class DataSet(object):
self.idx])
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx)
return self.dataset.field_arrays[item][self.idx]
-
+
+ def items(self):
+ ins = self.dataset[self.idx]
+ return ins.items()
+
def __repr__(self):
return self.dataset[self.idx].__repr__()
@@ -405,7 +435,7 @@ class DataSet(object):
"""
将一个instance对象append到DataSet后面。
- :param instance: :class:`~fastNLP.Instance` 类型。若DataSet不为空,则instance应该拥有和DataSet完全一样的field。
+ :param ~fastNLP.Instance instance: 若DataSet不为空,则instance应该拥有和DataSet完全一样的field。
"""
if len(self.field_arrays) == 0:
@@ -423,7 +453,7 @@ class DataSet(object):
try:
self.field_arrays[name].append(field)
except AppendToTargetOrInputException as e:
- print(f"Cannot append to field:{name}.")
+ logger.error(f"Cannot append to field:{name}.")
raise e
def add_fieldarray(self, field_name, fieldarray):
@@ -431,7 +461,7 @@ class DataSet(object):
将fieldarray添加到DataSet中.
:param str field_name: 新加入的field的名称
- :param fieldarray: :class:`~fastNLP.FieldArray` 类型。需要加入DataSet的field的内容
+ :param ~fastNLP.core.FieldArray fieldarray: 需要加入DataSet的field的内容
:return:
"""
if not isinstance(fieldarray, FieldArray):
@@ -447,8 +477,7 @@ class DataSet(object):
:param str field_name: 新增的field的名称
:param list fields: 需要新增的field的内容
- :param None, padder: :class:`~fastNLP.Padder` 类型,
- 如果为None,则不进行pad,默认使用 :class:`~fastNLP.AutoPadder` 自动判断是否需要做pad。
+ :param None,~fastNLP.Padder padder: 如果为None,则不进行pad,默认使用 :class:`~fastNLP.AutoPadder` 自动判断是否需要做pad。
:param bool is_input: 新加入的field是否是input
:param bool is_target: 新加入的field是否是target
:param bool ignore_type: 是否忽略对新加入的field的类型检查
@@ -465,7 +494,7 @@ class DataSet(object):
"""
删除第index个instance
- :param int index: 需要删除的instance的index,从0开始
+ :param int index: 需要删除的instance的index,序号从0开始。
"""
assert isinstance(index, int), "Only integer supported."
if len(self) <= index:
@@ -475,6 +504,7 @@ class DataSet(object):
else:
for field in self.field_arrays.values():
field.pop(index)
+ return self
def delete_field(self, field_name):
"""
@@ -483,7 +513,22 @@ class DataSet(object):
:param str field_name: 需要删除的field的名称.
"""
self.field_arrays.pop(field_name)
-
+ return self
+
+ def copy_field(self, field_name, new_field_name):
+ """
+ 深度copy名为field_name的field到new_field_name
+
+ :param str field_name: 需要copy的field。
+ :param str new_field_name: copy生成的field名称
+ :return: self
+ """
+ if not self.has_field(field_name):
+ raise KeyError(f"Field:{field_name} not found in DataSet.")
+ fieldarray = deepcopy(self.get_field(field_name))
+ self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray)
+ return self
+
def has_field(self, field_name):
"""
判断DataSet中是否有名为field_name这个field
@@ -510,7 +555,7 @@ class DataSet(object):
"""
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray`
- :return: dict: 返回如上所述的字典
+ :return dict: 返回如上所述的字典
"""
return self.field_arrays
@@ -518,7 +563,7 @@ class DataSet(object):
"""
返回一个list,包含所有 field 的名字
- :return: list: 返回如上所述的列表
+ :return list: 返回如上所述的列表
"""
return sorted(self.field_arrays.keys())
@@ -544,7 +589,7 @@ class DataSet(object):
raise KeyError("DataSet has no field named {}.".format(old_name))
return self
- def set_target(self, *field_names, flag=True):
+ def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
"""
将field_names的field设置为target
@@ -555,19 +600,23 @@ class DataSet(object):
:param str field_names: field的名称
:param bool flag: 将field_name的target状态设置为flag
+ :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
+ 行的数据进行类型和维度推断本列的数据的类型和维度。
"""
assert isinstance(flag, bool), "Only bool type supported."
for name in field_names:
if name in self.field_arrays:
try:
+ self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self.field_arrays[name].is_target = flag
except SetInputOrTargetException as e:
- print(f"Cannot set field:{name} as target.")
+ logger.error(f"Cannot set field:{name} as target.")
raise e
else:
raise KeyError("{} is not a valid field name.".format(name))
+ return self
- def set_input(self, *field_names, flag=True):
+ def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True):
"""
将field_names的field设置为input::
@@ -576,16 +625,20 @@ class DataSet(object):
:param str field_names: field的名称
:param bool flag: 将field_name的input状态设置为flag
+ :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
+ 行的数据进行类型和维度推断本列的数据的类型和维度。
"""
for name in field_names:
if name in self.field_arrays:
try:
+ self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self.field_arrays[name].is_input = flag
except SetInputOrTargetException as e:
- print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.")
+ logger.error(f"Cannot set field:{name} as input, exception happens at the {e.index} value.")
raise e
else:
raise KeyError("{} is not a valid field name.".format(name))
+ return self
def set_ignore_type(self, *field_names, flag=True):
"""
@@ -602,6 +655,7 @@ class DataSet(object):
self.field_arrays[name].ignore_type = flag
else:
raise KeyError("{} is not a valid field name.".format(name))
+ return self
def set_padder(self, field_name, padder):
"""
@@ -612,11 +666,12 @@ class DataSet(object):
dataset.set_padder('chars', padder) # 则chars这个field会使用EngChar2DPadder进行pad操作
:param str field_name: 设置field的padding方式为padder
- :param None, Padder padder: 设置为None即删除padder, 即对该field不进行pad操作。
+ :param None,~fastNLP.Padder padder: 设置为None即删除padder, 即对该field不进行pad操作。
"""
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_padder(padder)
+ return self
def set_pad_val(self, field_name, pad_val):
"""
@@ -628,6 +683,7 @@ class DataSet(object):
if field_name not in self.field_arrays:
raise KeyError("There is no field named {}.".format(field_name))
self.field_arrays[field_name].set_pad_val(pad_val)
+ return self
def get_input_name(self):
"""
@@ -660,7 +716,7 @@ class DataSet(object):
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
- :return: list(Any), 里面的元素为func的返回值,所以list长度为DataSet的长度
+ :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
@@ -673,7 +729,7 @@ class DataSet(object):
results.append(func(ins[field_name]))
except Exception as e:
if idx != -1:
- print("Exception happens at the `{}`th instance.".format(idx))
+ logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx+1))
raise e
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
@@ -687,7 +743,7 @@ class DataSet(object):
"""
将results作为加入到新的field中,field名称为new_field_name
- :param list(str) results: 一般是apply*()之后的结果
+ :param List[str] results: 一般是apply*()之后的结果
:param str new_field_name: 新加入的field的名称
:param dict kwargs: 用户apply*()时传入的自定义参数
:return:
@@ -730,7 +786,7 @@ class DataSet(object):
3. ignore_type: bool, 如果为True则将 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
- :return: list(Any), 里面的元素为func的返回值,所以list长度为DataSet的长度
+ :return List[Any]: 里面的元素为func的返回值,所以list长度为DataSet的长度
"""
assert len(self) != 0, "Null DataSet cannot use apply()."
idx = -1
@@ -738,10 +794,11 @@ class DataSet(object):
results = []
for idx, ins in enumerate(self._inner_iter()):
results.append(func(ins))
- except Exception as e:
+ except BaseException as e:
if idx != -1:
- print("Exception happens at the `{}`th instance.".format(idx))
+ logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
+
# results = [func(ins) for ins in self._inner_iter()]
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None
raise ValueError("{} always return None.".format(_get_func_signature(func=func)))
@@ -751,7 +808,7 @@ class DataSet(object):
return results
- def add_seq_len(self, field_name:str, new_field_name='seq_len'):
+ def add_seq_len(self, field_name:str, new_field_name=Const.INPUT_LEN):
"""
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。
@@ -795,7 +852,7 @@ class DataSet(object):
:param float ratio: 0=parse_version('1.1'):
+ self.model = DDP(model, device_ids=[self.local_rank],
+ output_device=self.local_rank, find_unused_parameters=True)
+ else:
+ self.model = DDP(model, device_ids=[self.local_rank],
+ output_device=self.local_rank)
+
+ self.optimizer = optimizer
+ self.sampler = DistributedSampler(self.train_data)
+ self.data_iterator = self._get_data_iter(self.train_data)
+ self.n_steps = self._get_n_steps()
+
+ # for evaluation, only run eval on master proc
+ if dev_data and metrics:
+ cb = TesterCallback(
+ dev_data, model, metrics,
+ batch_size=batch_size_per_gpu, num_workers=num_workers)
+ self.callback_manager.add_callback([cb], master=True)
+
+ # Setup logging
+ dist.barrier()
+ self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M')
+ if self.save_path:
+ self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time)
+ else:
+ self.cp_save_path = None
+
+ # use INFO in the master, WARN for others
+ logger.setLevel(logging.INFO if self.is_master else logging.WARNING)
+ self.logger = logger
+ self.logger.info("Setup Distributed Trainer")
+ self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format(
+ os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False))
+ self.logger.info("Num of processes: {}".format(self.world_size))
+ self.logger.info("Use device: {}".format(device))
+ self.logger.info("Training with fp16: {}, optimization level: {}".format(
+ len(self.fp16) > 0, self.fp16 if self.fp16 else None))
+
+ def _get_n_steps(self):
+ batch_size = self.world_size * self.batch_size_per_gpu
+ return (len(self.train_data) // batch_size + int(
+ len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs
+
+ def _get_data_iter(self, dataset):
+ if isinstance(dataset, DataSet):
+ return DataSetIter(
+ dataset=dataset, batch_size=self.batch_size_per_gpu,
+ num_workers=self.num_data_workers, sampler=self.sampler,
+ drop_last=self.drop_last
+ )
+ elif isinstance(dataset, BatchIter):
+ return dataset
+ else:
+ raise TypeError("train_data type {} not support".format(type(dataset)))
+
+ def _get_optimizer(self, optimizer):
+ if isinstance(optimizer, torch.optim.Optimizer):
+ return optimizer
+ elif isinstance(optimizer, Optimizer):
+ return optimizer.construct_from_pytorch(self.model.parameters())
+ elif optimizer is None:
+ return torch.optim.Adam(self.model.parameters(), lr=4e-3)
+ else:
+ raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
+
+ @property
+ def is_master(self):
+ return self.rank == 0
+
+ def train(self, on_exception='auto'):
+ try:
+ self.logger.info("###### Training epochs started ######")
+ self.logger.info('Total epochs: %d'% self.n_epochs)
+ self.logger.info('Total steps: %d'% self.n_steps)
+ self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu)
+ self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size())
+ self.logger.info('Total num of samples: %d'% len(self.train_data))
+ self.logger.info("Num of callbacks for all workers: {}".format(
+ len(self.callback_manager.callbacks_all)))
+ self.logger.info("Num of callbacks for master workers: {}".format(
+ len(self.callback_manager.callbacks_master)))
+ self.logger.info("Callbacks for all workers: {}".format(
+ [repr(cb) for cb in self.callback_manager.callbacks_all]))
+ self.logger.info("Callbacks for master workers: {}".format(
+ [repr(cb) for cb in self.callback_manager.callbacks_master]))
+
+ start_time = time.time()
+ results = {}
+ if self.n_epochs <= 0:
+ self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs))
+ results['seconds'] = 0.
+ return results
+
+ try:
+ self.callback_manager.on_train_begin()
+ self._train()
+ self.callback_manager.on_train_end()
+
+ except BaseException as e:
+ self.callback_manager.on_exception(e)
+ if on_exception == 'auto':
+ if not isinstance(e, (CallbackException, KeyboardInterrupt)):
+ raise e
+ else:
+ self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__))
+ elif on_exception == 'raise':
+ raise e
+
+ results['seconds'] = round(time.time() - start_time, 2)
+ self.logger.info("###### Train finished ######")
+ self.logger.info('Total train time: {} seconds.'. format(results['seconds']))
+ return results
+ finally:
+ self.close()
+
+ def _train(self):
+ if self.fp16:
+ # skip check, done in __init__()
+ from apex import amp
+ self.step = 0
+ self.epoch = 0
+ self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}',
+ leave=False, dynamic_ncols=True, disable=not self.is_master)
+ pbar = self.pbar
+ avg_loss = 0
+ data_iterator = self.data_iterator
+ self.model.zero_grad()
+ for epoch in range(1, self.n_epochs + 1):
+ self.epoch = epoch
+ pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
+ # early stopping
+ self.callback_manager.on_epoch_begin()
+ for batch_x, batch_y in data_iterator:
+ self.model.train()
+ self.step += 1
+ _move_dict_value_to_device(batch_x, batch_y, device=self.device)
+ indices = data_iterator.get_batch_indices()
+ # negative sampling; replace unknown; re-weight batch_y
+ self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
+ prediction = self._data_forward(self.model, batch_x)
+
+ # edit prediction
+ self.callback_manager.on_loss_begin(batch_y, prediction)
+ loss = self._compute_loss(prediction, batch_y)
+ avg_loss += loss.item()
+
+ # Is loss NaN or inf? requires_grad = False
+ self.callback_manager.on_backward_begin(loss)
+
+ if self.fp16:
+ with amp.scale_loss(loss, self.optimizer) as scale_loss:
+ scale_loss.backward()
+ else:
+ loss.backward()
+
+ self.callback_manager.on_backward_end()
+
+ self._update()
+ self.callback_manager.on_step_end()
+
+ if self.step % self.print_every == 0:
+ avg_loss = float(avg_loss) / self.print_every
+ print_output = "loss:{:<6.5f}".format(avg_loss)
+ pbar.update(self.print_every)
+ pbar.set_postfix_str(print_output)
+ avg_loss = 0
+
+ self.callback_manager.on_batch_end()
+
+ if (self.validate_every > 0 and self.step % self.validate_every == 0):
+ self._do_validation()
+
+ if self.cp_save_path and \
+ self.save_every > 0 and \
+ self.step % self.save_every == 0:
+ self.save_check_point()
+
+ # ================= mini-batch end ==================== #
+ if self.validate_every < 0:
+ self._do_validation()
+
+ if self.save_every < 0 and self.cp_save_path:
+ self.save_check_point()
+ # lr decay; early stopping
+ self.callback_manager.on_epoch_end()
+ # =============== epochs end =================== #
+ pbar.close()
+ self.pbar = None
+ # ============ tqdm end ============== #
+
+ def _update(self):
+ """Perform weight update on a model.
+
+ """
+ if self.step % self.update_every == 0:
+ self.optimizer.step()
+ self.model.zero_grad()
+
+ def _data_forward(self, network, x):
+ x = _build_args(self._forward_func, **x)
+ y = network(**x)
+ if not isinstance(y, dict):
+ raise TypeError(
+ f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
+ return y
+
+ def _compute_loss(self, predict, truth):
+ """Compute loss given prediction and ground truth.
+
+ :param predict: prediction dict, produced by model.forward
+ :param truth: ground truth dict, produced by batch_y
+ :return: a scalar
+ """
+ loss = self.losser(predict, truth)
+ if self.update_every > 1:
+ loss = loss / self.update_every
+ return loss.mean()
+
+ def save_check_point(self, only_params=False):
+ # only master save models
+ if self.is_master:
+ os.makedirs(self.cp_save_path, exist_ok=True)
+ path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step))
+ self.logger.info("Save checkpoint to {}".format(path))
+ model_to_save = self.model.module
+ if only_params:
+ model_to_save = model_to_save.state_dict()
+ torch.save(model_to_save, path)
+
+ def _do_validation(self):
+ self.callback_manager.on_valid_begin()
+ eval_res = self.callback_manager.on_validation()
+ eval_res = list(filter(lambda x: x is not None, eval_res))
+ if len(eval_res):
+ eval_res, is_better = list(zip(*eval_res))
+ else:
+ eval_res, is_better = None, None
+ self.callback_manager.on_valid_end(
+ eval_res, self.metric_key, self.optimizer, is_better)
+ dist.barrier()
+
+ def close(self):
+ dist.destroy_process_group()
diff --git a/fastNLP/core/field.py b/fastNLP/core/field.py
index 65eb0194..05f987c2 100644
--- a/fastNLP/core/field.py
+++ b/fastNLP/core/field.py
@@ -1,36 +1,53 @@
+"""
+.. todo::
+ doc
+"""
+__all__ = [
+ "Padder",
+ "AutoPadder",
+ "EngChar2DPadder",
+]
-from numbers import Number
-import torch
-import numpy as np
-from typing import Any
from abc import abstractmethod
-from copy import deepcopy
from collections import Counter
+from copy import deepcopy
+from numbers import Number
+from typing import Any
+
+import numpy as np
+import torch
+
+from ._logger import logger
+from .utils import _is_iterable
+
class SetInputOrTargetException(Exception):
def __init__(self, msg, index=None, field_name=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
- self.field_name = field_name # 标示当前field的名称
+ self.field_name = field_name # 标示当前field的名称
+
class AppendToTargetOrInputException(Exception):
def __init__(self, msg, index=None, field_name=None):
super().__init__(msg)
self.msg = msg
self.index = index # 标示在哪个数据遭遇到问题了
- self.field_name = field_name # 标示当前field的名称
+ self.field_name = field_name # 标示当前field的名称
+
class FieldArray:
- def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False):
- if len(content)==0:
+ def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False,
+ use_1st_ins_infer_dim_type=True):
+ if len(content) == 0:
raise RuntimeError("Empty fieldarray is not allowed.")
_content = content
try:
_content = list(_content)
except BaseException as e:
- print(f"Cannot convert content(of type:{type(content)}) into list.")
+ logger.error(f"Cannot convert content(of type:{type(content)}) into list.")
raise e
self.name = name
self.content = _content
@@ -38,36 +55,37 @@ class FieldArray:
# 根据input的情况设置input,target等
self._cell_ndim = None # 多少维度
self.dtype = None # 最内层的element都是什么类型的
+ self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type)
self._is_input = False
self._is_target = False
-
+
if is_input:
self.is_input = is_input
if is_target:
self.is_target = is_target
-
+
if padder is None:
padder = AutoPadder(pad_val=0)
else:
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder."
padder = deepcopy(padder)
self.set_padder(padder)
-
+
@property
def ignore_type(self):
return self._ignore_type
-
+
@ignore_type.setter
def ignore_type(self, value):
if value:
self._cell_ndim = None
self.dtype = None
self._ignore_type = value
-
+
@property
def is_input(self):
return self._is_input
-
+
@is_input.setter
def is_input(self, value):
"""
@@ -77,16 +95,16 @@ class FieldArray:
if value is True and \
self._is_target is False and \
self._ignore_type is False:
- self._check_dtype_and_ndim()
+ self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
if value is False and self._is_target is False:
self.dtype = None
self._cell_ndim = None
self._is_input = value
-
+
@property
def is_target(self):
return self._is_target
-
+
@is_target.setter
def is_target(self, value):
"""
@@ -95,70 +113,82 @@ class FieldArray:
if value is True and \
self._is_input is False and \
self._ignore_type is False:
- self._check_dtype_and_ndim()
+ self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type)
if value is False and self._is_input is False:
self.dtype = None
self._cell_ndim = None
self._is_target = value
-
- def _check_dtype_and_ndim(self):
+
+ def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True):
"""
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有
通过将直接报错.
+ :param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim
:return:
"""
cell_0 = self.content[0]
index = 0
try:
type_0, dim_0 = _get_ele_type_and_dim(cell_0)
- for cell in self.content[1:]:
- index += 1
- type_i, dim_i = _get_ele_type_and_dim(cell)
- if type_i!=type_0:
- raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}."
- ".".format(type_i, index, type_0))
- if dim_0!=dim_i:
- raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with "
- "dimension:{}.".format(dim_i, index, dim_0))
+ if not only_check_1st_ins_dim_type:
+ for cell in self.content[1:]:
+ index += 1
+ type_i, dim_i = _get_ele_type_and_dim(cell)
+ if type_i != type_0:
+ raise SetInputOrTargetException(
+ "Type:{} in index {} is different from the first element with type:{}."
+ ".".format(type_i, index, type_0))
+ if dim_0 != dim_i:
+ raise SetInputOrTargetException(
+ "Dimension:{} in index {} is different from the first element with "
+ "dimension:{}.".format(dim_i, index, dim_0))
self._cell_ndim = dim_0
self.dtype = type_0
except SetInputOrTargetException as e:
e.index = index
raise e
-
- def append(self, val:Any):
+
+ def append(self, val: Any):
"""
:param val: 把该val append到fieldarray。
:return:
"""
- if (self._is_target or self._is_input) and self._ignore_type is False:
+ if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type:
type_, dim_ = _get_ele_type_and_dim(val)
- if self.dtype!=type_:
+ if self.dtype != type_:
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with "
f"previous values(type:{self.dtype}).")
- if self._cell_ndim!=dim_:
+ if self._cell_ndim != dim_:
raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with "
f"previous values(dim:{self._cell_ndim}).")
self.content.append(val)
else:
self.content.append(val)
-
+
+ def pop(self, index):
+ """
+ 删除该field中index处的元素
+ :param int index: 从0开始的数据下标。
+ :return:
+ """
+ self.content.pop(index)
+
def __getitem__(self, indices):
return self.get(indices, pad=False)
-
+
def __setitem__(self, idx, val):
assert isinstance(idx, int)
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型
type_, dim_ = _get_ele_type_and_dim(val)
- if self.dtype!=type_:
+ if self.dtype != type_:
raise RuntimeError(f"Value(type:{type_}) are of different types with "
- f"other values(type:{self.dtype}).")
- if self._cell_ndim!=dim_:
+ f"other values(type:{self.dtype}).")
+ if self._cell_ndim != dim_:
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with "
- f"previous values(dim:{self._cell_ndim}).")
+ f"previous values(dim:{self._cell_ndim}).")
self.content[idx] = val
-
+
def get(self, indices, pad=True):
"""
根据给定的indices返回内容
@@ -171,16 +201,16 @@ class FieldArray:
return self.content[indices]
if self.is_input is False and self.is_target is False:
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name))
-
+
contents = [self.content[i] for i in indices]
if self.padder is None or pad is False:
return np.array(contents)
else:
return self.pad(contents)
-
+
def pad(self, contents):
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim)
-
+
def set_padder(self, padder):
"""
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。
@@ -192,7 +222,7 @@ class FieldArray:
self.padder = deepcopy(padder)
else:
self.padder = None
-
+
def set_pad_val(self, pad_val):
"""
修改padder的pad_val.
@@ -202,7 +232,7 @@ class FieldArray:
if self.padder is not None:
self.padder.set_pad_val(pad_val)
return self
-
+
def __len__(self):
"""
Returns the size of FieldArray.
@@ -210,7 +240,7 @@ class FieldArray:
:return int length:
"""
return len(self.content)
-
+
def to(self, other):
"""
将other的属性复制给本FieldArray(other必须为FieldArray类型).
@@ -220,15 +250,15 @@ class FieldArray:
:return: :class:`~fastNLP.FieldArray`
"""
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other))
-
+
self.ignore_type = other.ignore_type
self.is_input = other.is_input
self.is_target = other.is_target
self.padder = other.padder
-
+
return self
-
- def split(self, sep:str=None, inplace:bool=True):
+
+ def split(self, sep: str = None, inplace: bool = True):
"""
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值
@@ -241,11 +271,11 @@ class FieldArray:
try:
new_contents.append(cell.split(sep))
except Exception as e:
- print(f"Exception happens when process value in index {index}.")
+ logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
-
- def int(self, inplace:bool=True):
+
+ def int(self, inplace: bool = True):
"""
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。)
@@ -261,10 +291,10 @@ class FieldArray:
else:
new_contents.append(int(cell))
except Exception as e:
- print(f"Exception happens when process value in index {index}.")
- print(e)
+ logger.error(f"Exception happens when process value in index {index}.")
+ raise e
return self._after_process(new_contents, inplace=inplace)
-
+
def float(self, inplace=True):
"""
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -281,10 +311,10 @@ class FieldArray:
else:
new_contents.append(float(cell))
except Exception as e:
- print(f"Exception happens when process value in index {index}.")
+ logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
-
+
def bool(self, inplace=True):
"""
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -301,11 +331,11 @@ class FieldArray:
else:
new_contents.append(bool(cell))
except Exception as e:
- print(f"Exception happens when process value in index {index}.")
+ logger.error(f"Exception happens when process value in index {index}.")
raise e
-
+
return self._after_process(new_contents, inplace=inplace)
-
+
def lower(self, inplace=True):
"""
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -322,10 +352,10 @@ class FieldArray:
else:
new_contents.append(cell.lower())
except Exception as e:
- print(f"Exception happens when process value in index {index}.")
+ logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
-
+
def upper(self, inplace=True):
"""
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的),
@@ -342,10 +372,10 @@ class FieldArray:
else:
new_contents.append(cell.upper())
except Exception as e:
- print(f"Exception happens when process value in index {index}.")
+ logger.error(f"Exception happens when process value in index {index}.")
raise e
return self._after_process(new_contents, inplace=inplace)
-
+
def value_count(self):
"""
返回该field下不同value的数量。多用于统计label数量
@@ -353,17 +383,18 @@ class FieldArray:
:return: Counter, key是label,value是出现次数
"""
count = Counter()
-
+
def cum(cell):
if _is_iterable(cell) and not isinstance(cell, str):
for cell_ in cell:
cum(cell_)
else:
count[cell] += 1
+
for cell in self.content:
cum(cell)
return count
-
+
def _after_process(self, new_contents, inplace):
"""
当调用处理函数之后,决定是否要替换field。
@@ -378,14 +409,14 @@ class FieldArray:
self.is_input = self.is_input
self.is_target = self.is_input
except SetInputOrTargetException as e:
- print("The newly generated field cannot be set as input or target.")
+ logger.error("The newly generated field cannot be set as input or target.")
raise e
return self
else:
return new_contents
-def _get_ele_type_and_dim(cell:Any, dim=0):
+def _get_ele_type_and_dim(cell: Any, dim=0):
"""
识别cell的类别与dimension的数量
@@ -401,13 +432,13 @@ def _get_ele_type_and_dim(cell:Any, dim=0):
elif isinstance(cell, list):
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
- types = set([i for i,j in res])
- dims = set([j for i,j in res])
- if len(types)>1:
+ types = set([i for i, j in res])
+ dims = set([j for i, j in res])
+ if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
- elif len(types)==0:
+ elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.")
- if len(dims)>1:
+ if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()
elif isinstance(cell, torch.Tensor):
@@ -418,28 +449,19 @@ def _get_ele_type_and_dim(cell:Any, dim=0):
# 否则需要继续往下iterate
dim += 1
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell]
- types = set([i for i,j in res])
- dims = set([j for i,j in res])
- if len(types)>1:
+ types = set([i for i, j in res])
+ dims = set([j for i, j in res])
+ if len(types) > 1:
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types)))
- elif len(types)==0:
+ elif len(types) == 0:
raise SetInputOrTargetException("Empty value encountered.")
- if len(dims)>1:
+ if len(dims) > 1:
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims)))
return types.pop(), dims.pop()
- else: # 包含tuple, set, dict以及其它的类型
+ else: # 包含tuple, set, dict以及其它的类型
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.")
-def _is_iterable(value):
- # 检查是否是iterable的, duck typing
- try:
- iter(value)
- return True
- except BaseException as e:
- return False
-
-
class Padder:
"""
别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder`
@@ -448,28 +470,29 @@ class Padder:
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。
.. py:function:: __call__(self, contents, field_name, field_ele_dtype):
+
传入的是List内容。假设有以下的DataSet。
- :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
+ :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
deepcopy一份。
:param str, field_name: field的名称。
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。
:return: np.array([padded_element])
"""
-
+
def __init__(self, pad_val=0, **kwargs):
self.pad_val = pad_val
-
+
def set_pad_val(self, pad_val):
self.pad_val = pad_val
-
+
@abstractmethod
- def __call__(self, contents, field_name, field_ele_dtype, dim:int):
+ def __call__(self, contents, field_name, field_ele_dtype, dim: int):
"""
传入的是List内容。假设有以下的DataSet。
- :param list(Any) contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
+ :param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前
deepcopy一份。
:param str, field_name: field的名称。
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,
@@ -532,23 +555,24 @@ class AutoPadder(Padder):
3 其它情况不进行处理,返回一个np.array类型。
"""
+
def __init__(self, pad_val=0):
super().__init__(pad_val=pad_val)
-
+
def __call__(self, contents, field_name, field_ele_dtype, dim):
if field_ele_dtype:
- if dim>3:
+ if dim > 3:
return np.array(contents)
if isinstance(field_ele_dtype, type) and \
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)):
- if dim==0:
+ if dim == 0:
array = np.array(contents, dtype=field_ele_dtype)
- elif dim==1:
+ elif dim == 1:
max_len = max(map(len, contents))
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
array[i, :len(content_i)] = content_i
- elif dim==2:
+ elif dim == 2:
max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents])
@@ -558,20 +582,21 @@ class AutoPadder(Padder):
array[i, j, :len(content_ii)] = content_ii
else:
shape = np.shape(contents)
- if len(shape)==4: # 说明各dimension是相同的大小
+ if len(shape) == 4: # 说明各dimension是相同的大小
array = np.array(contents, dtype=field_ele_dtype)
else:
- raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
+ raise RuntimeError(
+ f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return array
elif str(field_ele_dtype).startswith('torch'):
- if dim==0:
+ if dim == 0:
tensor = torch.tensor(contents).to(field_ele_dtype)
- elif dim==1:
+ elif dim == 1:
max_len = max(map(len, contents))
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
tensor[i, :len(content_i)] = torch.tensor(content_i)
- elif dim==2:
+ elif dim == 2:
max_len = max(map(len, contents))
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for
content_i in contents])
@@ -582,15 +607,18 @@ class AutoPadder(Padder):
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii)
else:
shapes = set([np.shape(content_i) for content_i in contents])
- if len(shapes)>1:
- raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
+ if len(shapes) > 1:
+ raise RuntimeError(
+ f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
shape = shapes.pop()
- if len(shape)==3:
- tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype)
+ if len(shape) == 3:
+ tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val,
+ dtype=field_ele_dtype)
for i, content_i in enumerate(contents):
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype)
else:
- raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
+ raise RuntimeError(
+ f"Field:{field_name} has 3 dimensions, every sample should have the same shape.")
return tensor
else:
return np.array(contents) # 不进行任何操作
@@ -621,7 +649,7 @@ class EngChar2DPadder(Padder):
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder
"""
-
+
def __init__(self, pad_val=0, pad_length=0):
"""
:param pad_val: int, pad的位置使用该index
@@ -629,9 +657,9 @@ class EngChar2DPadder(Padder):
都pad或截取到该长度.
"""
super().__init__(pad_val=pad_val)
-
+
self.pad_length = pad_length
-
+
def __call__(self, contents, field_name, field_ele_dtype, dim):
"""
期望输入类似于
@@ -650,7 +678,7 @@ class EngChar2DPadder(Padder):
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format(
field_name, field_ele_dtype
))
- assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions."
+ assert dim == 2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions."
if self.pad_length < 1:
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents])
else:
@@ -658,12 +686,12 @@ class EngChar2DPadder(Padder):
max_sent_length = max(len(word_lst) for word_lst in contents)
batch_size = len(contents)
dtype = type(contents[0][0][0])
-
+
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val,
dtype=dtype)
for b_idx, word_lst in enumerate(contents):
for c_idx, char_lst in enumerate(word_lst):
chars = char_lst[:max_char_length]
padded_array[b_idx, c_idx, :len(chars)] = chars
-
+
return padded_array
diff --git a/fastNLP/core/instance.py b/fastNLP/core/instance.py
index 5408522e..9a5d9edf 100644
--- a/fastNLP/core/instance.py
+++ b/fastNLP/core/instance.py
@@ -35,6 +35,13 @@ class Instance(object):
:param Any field: 新增field的内容
"""
self.fields[field_name] = field
+
+ def items(self):
+ """
+ 返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value
+ :return:
+ """
+ return self.fields.items()
def __getitem__(self, name):
if name in self.fields:
diff --git a/fastNLP/core/losses.py b/fastNLP/core/losses.py
index 14aacef0..d5549cec 100644
--- a/fastNLP/core/losses.py
+++ b/fastNLP/core/losses.py
@@ -28,6 +28,7 @@ from .utils import _check_arg_dict_list
from .utils import _check_function_or_method
from .utils import _get_func_signature
from .utils import seq_len_to_mask
+import warnings
class LossBase(object):
@@ -205,10 +206,14 @@ class CrossEntropyLoss(LossBase):
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
- :param seq_len: 句子的长度, 长度之外的token不会计算loss。。
+ :param seq_len: 句子的长度, 长度之外的token不会计算loss。
+ :param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes)
+ 或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第
+ 二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等,
+ 那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替
传入seq_len.
- :param str reduction: 支持'mean','sum'和'none'.
+ :param str reduction: 支持 `mean` ,`sum` 和 `none` .
Example::
@@ -216,17 +221,21 @@ class CrossEntropyLoss(LossBase):
"""
- def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'):
+ def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'):
super(CrossEntropyLoss, self).__init__()
self._init_param_map(pred=pred, target=target, seq_len=seq_len)
self.padding_idx = padding_idx
assert reduction in ('mean', 'sum', 'none')
self.reduction = reduction
+ self.class_in_dim = class_in_dim
def get_loss(self, pred, target, seq_len=None):
if pred.dim() > 2:
- if pred.size(1) != target.size(1):
- pred = pred.transpose(1, 2)
+ if self.class_in_dim == -1:
+ if pred.size(1) != target.size(1): # 有可能顺序替换了
+ pred = pred.transpose(1, 2)
+ else:
+ pred = pred.tranpose(-1, pred)
pred = pred.reshape(-1, pred.size(-1))
target = target.reshape(-1)
if seq_len is not None:
@@ -265,9 +274,9 @@ class BCELoss(LossBase):
二分类交叉熵损失函数
- :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
- :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
- :param str reduction: 支持'mean','sum'和'none'.
+ :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
+ :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
+ :param str reduction: 支持 `mean` ,`sum` 和 `none` .
"""
def __init__(self, pred=None, target=None, reduction='mean'):
@@ -286,11 +295,11 @@ class NLLLoss(LossBase):
负对数似然损失函数
- :param pred: 参数映射表中`pred`的映射关系,None表示映射关系为`pred`->`pred`
- :param target: 参数映射表中`target`的映射关系,None表示映射关系为`target`->`target`
+ :param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred`
+ :param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target`
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替
传入seq_len.
- :param str reduction: 支持'mean','sum'和'none'.
+ :param str reduction: 支持 `mean` ,`sum` 和 `none` .
"""
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'):
diff --git a/fastNLP/core/metrics.py b/fastNLP/core/metrics.py
index f75b6c90..1d1e3819 100644
--- a/fastNLP/core/metrics.py
+++ b/fastNLP/core/metrics.py
@@ -27,14 +27,14 @@ from abc import abstractmethod
class MetricBase(object):
"""
- 所有metrics的基类,,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。
+ 所有metrics的基类,所有的传入到Trainer, Tester的Metric需要继承自该对象,需要覆盖写入evaluate(), get_metric()方法。
evaluate(xxx)中传入的是一个batch的数据。
get_metric(xxx)当所有数据处理完毕,调用该方法得到最终的metric值
以分类问题中,Accuracy计算为例
- 假设model的forward返回dict中包含'pred'这个key, 并且该key需要用于Accuracy::
+ 假设model的forward返回dict中包含 `pred` 这个key, 并且该key需要用于Accuracy::
class Model(nn.Module):
def __init__(xxx):
@@ -43,7 +43,7 @@ class MetricBase(object):
# do something
return {'pred': pred, 'other_keys':xxx} # pred's shape: batch_size x num_classes
- 假设dataset中'label'这个field是需要预测的值,并且该field被设置为了target
+ 假设dataset中 `label` 这个field是需要预测的值,并且该field被设置为了target
对应的AccMetric可以按如下的定义, version1, 只使用这一次::
class AccMetric(MetricBase):
@@ -118,6 +118,7 @@ class MetricBase(object):
def __init__(self):
self._param_map = {} # key is param in function, value is input param.
self._checked = False
+ self._metric_name = self.__class__.__name__
@property
def param_map(self):
@@ -135,6 +136,23 @@ class MetricBase(object):
@abstractmethod
def get_metric(self, reset=True):
raise NotImplemented
+
+ def set_metric_name(self, name:str):
+ """
+ 设置metric的名称,默认是Metric的class name.
+
+ :param str name:
+ :return: self
+ """
+ self._metric_name = name
+ return self
+
+ def get_metric_name(self):
+ """
+ 返回metric的名称
+ :return:
+ """
+ return self._metric_name
def _init_param_map(self, key_map=None, **kwargs):
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map
@@ -358,6 +376,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None):
"""
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间)
+ 也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列
:param tags: List[str],
:param ignore_labels: List[str], 在该list中的label将被忽略
@@ -478,7 +497,7 @@ class SpanFPreRecMetric(MetricBase):
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric`
在序列标注问题中,以span的方式计算F, pre, rec.
- 比如中文Part of speech中,会以character的方式进行标注,句子'中国在亚洲'对应的POS可能为(以BMES为例)
+ 比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例)
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。
最后得到的metric结果为::
@@ -502,15 +521,15 @@ class SpanFPreRecMetric(MetricBase):
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN),
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'.
- :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用'pred'取数据
- :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用'target'取数据
- :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用'seq_len'取数据。
+ :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据
+ :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据
+ :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这
个label
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个
label的f1, pre, rec
- :param str f_type: 'micro'或'macro'. 'micro':通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; 'macro':
+ :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` :
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同)
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` .
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。
@@ -624,7 +643,7 @@ class SpanFPreRecMetric(MetricBase):
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp)
f_sum += f
pre_sum += pre
- rec_sum + rec
+ rec_sum += rec
if not self.only_gross and tag != '': # tag!=''防止无tag的情况
f_key = 'f-{}'.format(tag)
pre_key = 'pre-{}'.format(tag)
@@ -814,8 +833,8 @@ class ExtractiveQAMetric(MetricBase):
if not self.right_open:
e += 1
te += 1
- if ts == 0 and te == int(not self.right_open):
- if s == 0 and e == int(not self.right_open):
+ if ts == 0 and te == 1:
+ if s == 0 and e == 1:
self.no_ans_correct += 1
self.no2no += 1
else:
diff --git a/fastNLP/core/optimizer.py b/fastNLP/core/optimizer.py
index 1fe035bf..e95047b4 100644
--- a/fastNLP/core/optimizer.py
+++ b/fastNLP/core/optimizer.py
@@ -5,7 +5,8 @@ optimizer 模块定义了 fastNLP 中所需的各种优化器,一般做为 :cl
__all__ = [
"Optimizer",
"SGD",
- "Adam"
+ "Adam",
+ "AdamW"
]
import torch
@@ -48,7 +49,7 @@ class NullOptimizer(Optimizer):
super().__init__(None)
def construct_from_pytorch(self, model_params):
- pass
+ return self
def __getattr__(self, item):
def pass_func(*args, **kwargs):
@@ -103,21 +104,28 @@ class Adam(Optimizer):
class AdamW(TorchOptimizer):
- r"""对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入
+ r"""
+ 别名::class:`fastNLP.AdamW` :class:`fastNLP.core.optimizer.AdamW`
+
+ 对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入
+
+ .. todo::
+ 翻译成中文
+
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups
- lr (float, optional): learning rate (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its square (default: (0.9, 0.99))
- eps (float, optional): term added to the denominator to improve
- numerical stability (default: 1e-8)
- weight_decay (float, optional): weight decay coefficient (default: 1e-2)
- amsgrad (boolean, optional): whether to use the AMSGrad variant of this
- algorithm from the paper `On the Convergence of Adam and Beyond`_
- (default: False)
+
+ :param params (iterable): iterable of parameters to optimize or dicts defining
+ parameter groups
+ :param lr (float, optional): learning rate (default: 1e-3)
+ :param betas (Tuple[float, float], optional): coefficients used for computing
+ running averages of gradient and its square (default: (0.9, 0.99))
+ :param eps (float, optional): term added to the denominator to improve
+ numerical stability (default: 1e-8)
+ :param weight_decay (float, optional): weight decay coefficient (default: 1e-2)
+ algorithm from the paper `On the Convergence of Adam and Beyond`_
+ (default: False)
+
.. _Adam\: A Method for Stochastic Optimization:
https://arxiv.org/abs/1412.6980
.. _Decoupled Weight Decay Regularization:
@@ -147,9 +155,9 @@ class AdamW(TorchOptimizer):
def step(self, closure=None):
"""Performs a single optimization step.
- Arguments:
- closure (callable, optional): A closure that reevaluates the model
- and returns the loss.
+
+ :param closure: (callable, optional) A closure that reevaluates the model
+ and returns the loss.
"""
loss = None
if closure is not None:
diff --git a/fastNLP/core/predictor.py b/fastNLP/core/predictor.py
index 2d6a7380..c6b8fc90 100644
--- a/fastNLP/core/predictor.py
+++ b/fastNLP/core/predictor.py
@@ -1,13 +1,15 @@
-"""
- ..todo::
- 检查这个类是否需要
-"""
+"""undocumented"""
+
+__all__ = [
+ "Predictor"
+]
+
from collections import defaultdict
import torch
-from . import DataSetIter
from . import DataSet
+from . import DataSetIter
from . import SequentialSampler
from .utils import _build_args, _move_dict_value_to_device, _get_model_device
@@ -21,7 +23,7 @@ class Predictor(object):
:param torch.nn.Module network: 用来完成预测任务的模型
"""
-
+
def __init__(self, network):
if not isinstance(network, torch.nn.Module):
raise ValueError(
@@ -29,7 +31,7 @@ class Predictor(object):
self.network = network
self.batch_size = 1
self.batch_output = []
-
+
def predict(self, data: DataSet, seq_len_field_name=None):
"""用已经训练好的模型进行inference.
@@ -41,27 +43,27 @@ class Predictor(object):
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data)))
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays:
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data))
-
+
prev_training = self.network.training
self.network.eval()
network_device = _get_model_device(self.network)
batch_output = defaultdict(list)
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False)
-
+
if hasattr(self.network, "predict"):
predict_func = self.network.predict
else:
predict_func = self.network.forward
-
+
with torch.no_grad():
for batch_x, _ in data_iterator:
_move_dict_value_to_device(batch_x, _, device=network_device)
refined_batch_x = _build_args(predict_func, **batch_x)
prediction = predict_func(**refined_batch_x)
-
+
if seq_len_field_name is not None:
seq_lens = batch_x[seq_len_field_name].tolist()
-
+
for key, value in prediction.items():
value = value.cpu().numpy()
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1):
@@ -74,6 +76,6 @@ class Predictor(object):
batch_output[key].extend(tmp_batch)
else:
batch_output[key].append(value)
-
+
self.network.train(prev_training)
return batch_output
diff --git a/fastNLP/core/sampler.py b/fastNLP/core/sampler.py
index c5784f59..9ca04fa0 100644
--- a/fastNLP/core/sampler.py
+++ b/fastNLP/core/sampler.py
@@ -25,9 +25,9 @@ class Sampler(object):
def __call__(self, data_set):
"""
- :param DataSet data_set: `DataSet` 对象, 需要Sample的数据
- :return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出
- """
+ :param DataSet data_set: `DataSet` 对象, 需要Sample的数据
+ :return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出
+ """
raise NotImplementedError
@@ -62,16 +62,27 @@ class BucketSampler(Sampler):
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素
:param int num_buckets: bucket的数量
- :param int batch_size: batch的大小
+ :param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需
+ 要显示传递该值
:param str seq_len_field_name: 对应序列长度的 `field` 的名字
"""
- def __init__(self, num_buckets=10, batch_size=32, seq_len_field_name='seq_len'):
+ def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'):
self.num_buckets = num_buckets
self.batch_size = batch_size
self.seq_len_field_name = seq_len_field_name
-
+
+ def set_batch_size(self, batch_size):
+ """
+
+ :param int batch_size: 每个batch的大小
+ :return:
+ """
+ self.batch_size = batch_size
+
def __call__(self, data_set):
+ if self.batch_size is None:
+ raise RuntimeError("batch_size is None.")
seq_lens = data_set.get_all_fields()[self.seq_len_field_name].content
total_sample_num = len(seq_lens)
diff --git a/fastNLP/core/tester.py b/fastNLP/core/tester.py
index 7048d0ae..e549df81 100644
--- a/fastNLP/core/tester.py
+++ b/fastNLP/core/tester.py
@@ -1,7 +1,7 @@
"""
tester模块实现了 fastNLP 所需的Tester类,能在提供数据、模型以及metric的情况下进行性能测试。
-Example::
+.. code-block::
import numpy as np
import torch
@@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation
"""
+import time
+
import torch
import torch.nn as nn
+try:
+ from tqdm.auto import tqdm
+except:
+ from .utils import _pseudo_tqdm as tqdm
+
from .batch import BatchIter, DataSetIter
from .dataset import DataSet
from .metrics import _prepare_metrics
@@ -47,7 +54,9 @@ from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
from ._parallel_utils import _data_parallel_wrapper
+from ._parallel_utils import _model_contains_inner_module
from functools import partial
+from ._logger import logger
__all__ = [
"Tester"
@@ -60,15 +69,14 @@ class Tester(object):
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。
- :param data: 需要测试的数据集, :class:`~fastNLP.DataSet` 类型
+ :param ~fastNLP.DataSet data: 需要测试的数据集
:param torch.nn.module model: 使用的模型
- :param metrics: :class:`~fastNLP.core.metrics.MetricBase` 或者一个列表的 :class:`~fastNLP.core.metrics.MetricBase`
+ :param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics
:param int batch_size: evaluation时使用的batch_size有多大。
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
的计算位置进行管理。支持以下的输入:
- 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中,
- 可见的第二个GPU中;
+ 1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中;
2. torch.device:将模型装载到torch.device上。
@@ -80,13 +88,12 @@ class Tester(object):
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。
+ :param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。
"""
- def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1):
+ def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True):
super(Tester, self).__init__()
-
- if not isinstance(data, DataSet):
- raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.")
+
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.")
@@ -96,6 +103,8 @@ class Tester(object):
self._model = _move_model_to_device(model, device=device)
self.batch_size = batch_size
self.verbose = verbose
+ self.use_tqdm = use_tqdm
+ self.logger = logger
if isinstance(data, DataSet):
self.data_iterator = DataSetIter(
@@ -107,19 +116,22 @@ class Tester(object):
# check predict
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \
- (isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and
- callable(self._model.module.predict)):
+ (_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and
+ callable(self._model.module.predict)):
if isinstance(self._model, nn.DataParallel):
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict',
self._model.device_ids,
self._model.output_device),
network=self._model.module)
+ self._predict_func = self._model.module.predict # 用于匹配参数
+ elif isinstance(self._model, nn.parallel.DistributedDataParallel):
self._predict_func = self._model.module.predict
+ self._predict_func_wrapper = self._model.module.predict # 用于调用
else:
self._predict_func = self._model.predict
self._predict_func_wrapper = self._model.predict
else:
- if isinstance(self._model, nn.DataParallel):
+ if _model_contains_inner_module(model):
self._predict_func_wrapper = self._model.forward
self._predict_func = self._model.module.forward
else:
@@ -140,21 +152,39 @@ class Tester(object):
eval_results = {}
try:
with torch.no_grad():
- for batch_x, batch_y in data_iterator:
- _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
- pred_dict = self._data_forward(self._predict_func, batch_x)
- if not isinstance(pred_dict, dict):
- raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
- f"must be `dict`, got {type(pred_dict)}.")
+ if not self.use_tqdm:
+ from .utils import _pseudo_tqdm as inner_tqdm
+ else:
+ inner_tqdm = tqdm
+ with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar:
+ pbar.set_description_str(desc="Test")
+
+ start_time = time.time()
+
+ for batch_x, batch_y in data_iterator:
+ _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
+ pred_dict = self._data_forward(self._predict_func, batch_x)
+ if not isinstance(pred_dict, dict):
+ raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "
+ f"must be `dict`, got {type(pred_dict)}.")
+ for metric in self.metrics:
+ metric(pred_dict, batch_y)
+
+ if self.use_tqdm:
+ pbar.update()
+
for metric in self.metrics:
- metric(pred_dict, batch_y)
- for metric in self.metrics:
- eval_result = metric.get_metric()
- if not isinstance(eval_result, dict):
- raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
- f"`dict`, got {type(eval_result)}")
- metric_name = metric.__class__.__name__
- eval_results[metric_name] = eval_result
+ eval_result = metric.get_metric()
+ if not isinstance(eval_result, dict):
+ raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be "
+ f"`dict`, got {type(eval_result)}")
+ metric_name = metric.get_metric_name()
+ eval_results[metric_name] = eval_result
+ pbar.close()
+ end_time = time.time()
+ test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!'
+ # pbar.write(test_str)
+ self.logger.info(test_str)
except _CheckError as e:
prev_func_signature = _get_func_signature(self._predict_func)
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature,
@@ -162,7 +192,7 @@ class Tester(object):
dataset=self.data, check_level=0)
if self.verbose >= 1:
- print("[tester] \n{}".format(self._format_eval_results(eval_results)))
+ logger.info("[tester] \n{}".format(self._format_eval_results(eval_results)))
self._mode(network, is_test=False)
return eval_results
diff --git a/fastNLP/core/trainer.py b/fastNLP/core/trainer.py
index eabda99c..290a89c1 100644
--- a/fastNLP/core/trainer.py
+++ b/fastNLP/core/trainer.py
@@ -11,288 +11,310 @@ Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在
(5) 保存获得更好验证性能的模型。
-1 Trainer的基本使用
- 下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。
-
- Example::
-
- import numpy as np
- from torch import nn
- import torch
- import torch.nn.functional as F
- from torch.optim import SGD
-
- from fastNLP import DataSet
- from fastNLP import Trainer
- from fastNLP import CrossEntropyLoss
- from fastNLP import AccuracyMetric
- from fastNLP.modules.decoder import MLP
-
- # 模型
- class Model(nn.Module):
- def __init__(self, input_num):
- super().__init__()
- self.fcs = MLP([input_num, 40, 40, 2], 'relu')
-
- def forward(self, x):
- x = self.fcs(x)
- return {'pred': x}
- model = Model(10)
-
- # 生成数据
- def generate_psedo_dataset(num_samples):
- dataset = DataSet()
- data = np.random.randint(2, size=(num_samples, 10))
- label = np.sum(data, axis=1)%2
- dataset = DataSet({'x':data.astype(float), 'label': label})
- dataset.set_input('x')
- dataset.set_target('label')
- return dataset
- tr_dataset = generate_psedo_dataset(1000)
- dev_data = generate_psedo_dataset(100)
-
- # 训练
- trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
- optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
- dev_data = dev_data, metrics=AccuracyMetric(target='label'))
- trainer.train()
-
- 由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。
- 使用Trainer需要满足以下几个条件:
+
+----------------------------
+1. Trainer的基本使用
+----------------------------
+
+下面的例子是使用神经网络来进行预测一个序列中是否有偶数个1。
+
+.. code-block:: python
+
+ import numpy as np
+ from torch import nn
+ import torch
+ import torch.nn.functional as F
+ from torch.optim import SGD
+
+ from fastNLP import DataSet
+ from fastNLP import Trainer
+ from fastNLP import CrossEntropyLoss
+ from fastNLP import AccuracyMetric
+ from fastNLP.modules.decoder import MLP
+
+ # 模型
+ class Model(nn.Module):
+ def __init__(self, input_num):
+ super().__init__()
+ self.fcs = MLP([input_num, 40, 40, 2], 'relu')
+
+ def forward(self, x):
+ x = self.fcs(x)
+ return {'pred': x}
+ model = Model(10)
+
+ # 生成数据
+ def generate_psedo_dataset(num_samples):
+ dataset = DataSet()
+ data = np.random.randint(2, size=(num_samples, 10))
+ label = np.sum(data, axis=1)%2
+ dataset = DataSet({'x':data.astype(float), 'label': label})
+ dataset.set_input('x')
+ dataset.set_target('label')
+ return dataset
+ tr_dataset = generate_psedo_dataset(1000)
+ dev_data = generate_psedo_dataset(100)
+
+ # 训练
+ trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
+ optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
+ dev_data = dev_data, metrics=AccuracyMetric(target='label'))
+ trainer.train()
+
+由上面的例子可以看出通过使用Trainer,可以使得训练部分的代码大幅减少。
+使用Trainer需要满足以下几个条件:
1.1 模型
- 1 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是
- 通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该
- 改名为'data'。
+----------------------------
+
+1 模型的forward()的参数名需要与DataSet中的名字对应。实际上fastNLP在将DataSet中的数据传递给模型forward()时,是
+通过匹配名称实现的。所以上例中,如果Model的forward函数修改为forward(self, data), 则DataSet中的'x'这个field就应该
+改名为'data'。
- 2 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递
- 给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。
+2 传递给forward()的参数是DataSet中被设置为input的那些field。但如果forward()中没有对应的参数,则不会将数据传递
+给forward()。例如,DataSet中'x1', 'x2'都是input,但是模型的函数为forward(self, x1), 那么'x2'不会传递给forward()。
- 3 模型的forward()返回值需要为一个dict。
+3 模型的forward()返回值需要为一个dict。
1.2 Loss
- fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing,
- :mod:`Loss` 与 :mod:`Metric` 都使用了通过名称来匹配相应内容的策略。如上面的例子中
+----------------------------
- Example::
+fastNLP中的为了不限制forward函数的返回内容数量(比如一些复杂任务需要返回多个内容,如Dependency Parsing,
+:mod:`Loss` 与 :mod:`Metric` 都使用了通过名称来匹配相应内容的策略。如上面的例子中
- trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
- optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
- dev_data = dev_data, metrics=AccuracyMetric(target='label'))
+.. code-block:: python
- loss被设置为了 :class:`~fastNLP.CrossEntropyLoss` , 但在初始化的时候传入了target='label'这个参数,
- :class:`~fastNLP.CrossEntropyLoss` 的初始化参数为(pred=None, target=None, padding_idx=-100)。
-
- 这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值与真实值。
- 其中 `pred` 一般来自于模型forward()的返回结果,`target` 一般是来自于DataSet中被设置为target的field。
- 由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 :mod:`Loss` 提供一种类似于映射的机制来匹配对应的值,
- 比如这里 :class:`~fastNLP.CrossEntropyLoss` 将尝试找到名为'label'的内容来作为真实值得到loss;
- 而pred=None, 则 :class:`~fastNLP.CrossEntropyLoss` 使用'pred'作为名称匹配预测值,
- 正好forward的返回值也叫pred,所以这里不需要申明pred。
-
- 尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。
- fastNLP中提供了 :class:`~fastNLP.LossInForward` 这个loss。
- 这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,并使用它作为loss。
- 如果Trainer初始化没有提供loss则默认使用 :class:`~fastNLP.LossInForward` 。
-
- .. todo::
- 补充一个例子 详细例子可以参照
+ trainer = Trainer(tr_dataset, model, loss=CrossEntropyLoss(target='label'),
+ optimizer=SGD(model.parameters(), lr=0.1),n_epochs=1000,
+ dev_data = dev_data, metrics=AccuracyMetric(target='label'))
+
+loss被设置为了 :class:`~fastNLP.CrossEntropyLoss` , 但在初始化的时候传入了target='label'这个参数,
+:class:`~fastNLP.CrossEntropyLoss` 的初始化参数为(pred=None, target=None, padding_idx=-100)。
+
+这里的两个参数分别为计算CrossEntropy时需要使用到的模型的预测值与真实值。
+其中 `pred` 一般来自于模型forward()的返回结果,`target` 一般是来自于DataSet中被设置为target的field。
+由于每个人对真实值或者model的返回值取名并不一样,所以fastNLP的 :mod:`Loss` 提供一种类似于映射的机制来匹配对应的值,
+比如这里 :class:`~fastNLP.CrossEntropyLoss` 将尝试找到名为'label'的内容来作为真实值得到loss;
+而pred=None, 则 :class:`~fastNLP.CrossEntropyLoss` 使用'pred'作为名称匹配预测值,
+正好forward的返回值也叫pred,所以这里不需要申明pred。
+
+尽管fastNLP使用了映射机制来使得loss的计算变得比较灵活,但有些情况下loss必须在模型中进行计算,比如使用了CRF的模型。
+fastNLP中提供了 :class:`~fastNLP.LossInForward` 这个loss。
+这个loss的原理是直接在forward()的返回结果中找到loss_key(默认寻找'loss')指定的那个tensor,并使用它作为loss。
+如果Trainer初始化没有提供loss则默认使用 :class:`~fastNLP.LossInForward` 。
+
+.. todo::
+ 补充一个例子 详细例子可以参照
1.3 Metric
- :mod:`Metric` 使用了与上述Loss一样的策略,即使用名称进行匹配。
- AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。
-
- 在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法,
- 如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,
- 传入到predict()的参数也是从DataSet中被设置为input的field中选择出来的;
- 与forward()一样,返回值需要为一个dict。
+----------------------------
+
+:mod:`Metric` 使用了与上述Loss一样的策略,即使用名称进行匹配。
+AccuracyMetric(target='label')的情况与CrossEntropyLoss 是同理的。
+
+在进行验证时,可能用到的计算与forward()中不太一致,没有办法直接从forward()的结果中得到预测值,这时模型可以提供一个predict()方法,
+如果提供的模型具有predict方法,则在模型验证时将调用predict()方法获取预测结果,
+传入到predict()的参数也是从DataSet中被设置为input的field中选择出来的;
+与forward()一样,返回值需要为一个dict。
+
+.. todo::
+ 补充一个例子 具体例子可以参考
- .. todo::
- 补充一个例子 具体例子可以参考
+----------------------------
+2. Trainer的代码检查
+----------------------------
-2 Trainer的代码检查
- 由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制
- 比如下面的例子中,由于各种原因产生的报错
+由于在fastNLP中采取了映射的机制,所以难免可能存在对应出错的情况。Trainer提供一种映射检查机制,可以通过check_code_level来进行控制
+比如下面的例子中,由于各种原因产生的报错
Example2.1
- ::
-
- import numpy as np
- from torch import nn
- import torch
- from torch.optim import SGD
- from fastNLP import Trainer
- from fastNLP import DataSet
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(1, 1)
- def forward(self, x, b):
- loss = torch.mean((self.fc(x)-b)**2)
- return {'loss': loss}
- model = Model()
-
- dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
- dataset.set_input('a', 'b')
-
- trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001))
-
- trainer = Trainer(dataset, model, SGD(model.parameters()))
- # 会报以下的错误
- # input fields after batch(if batch size is 2):
- # a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- # b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
- # There is no target field.
- # ....
- # NameError:
- # Problems occurred when calling Model.forward(self, x, b)
- # missing param: ['x']
- # unused field: ['a']
- # Suggestion: You need to provide ['x'] in DataSet and set it as input.
-
- 这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类
- 信息可以为你提供参考
-
- 1 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里
- 因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。
-
- 2 NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
- 出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
- 就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。
-
- 下面的例子是由于loss计算的时候找不到需要的值
+----------------------------
+
+.. code-block:: python
+
+ import numpy as np
+ from torch import nn
+ import torch
+ from torch.optim import SGD
+ from fastNLP import Trainer
+ from fastNLP import DataSet
+
+ class Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fc = nn.Linear(1, 1)
+ def forward(self, x, b):
+ loss = torch.mean((self.fc(x)-b)**2)
+ return {'loss': loss}
+ model = Model()
+
+ dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
+ dataset.set_input('a', 'b')
+
+ trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001))
+
+ trainer = Trainer(dataset, model, SGD(model.parameters()))
+ # 会报以下的错误
+ # input fields after batch(if batch size is 2):
+ # a: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
+ # b: (1)type:torch.Tensor (2)dtype:torch.int64, (3)shape:torch.Size([2])
+ # There is no target field.
+ # ....
+ # NameError:
+ # Problems occurred when calling Model.forward(self, x, b)
+ # missing param: ['x']
+ # unused field: ['a']
+ # Suggestion: You need to provide ['x'] in DataSet and set it as input.
+
+这里就是由于在Trainer初始化的时候,fastNLP会尝试使用一个batch_size=2的batch去运行一遍forward()以及backward()。这里有两类
+信息可以为你提供参考
+
+1 'input fields after batch...'这部分显示的是train dataset经过Batch操作后,每个field对应的类型以及进行shape。这里
+因为train dataset没有target所以没有显示。根据这里可以看出是否正确将需要的内容设置为了input或target。
+
+2 NameError,NameError发生在映射出错的情况。这里报错的原因是由于尝试进行forward计算时(可以通过Model.forward(self, x, b)判断
+出当前是在调取forward),却没有获取到forward()函数中需要的'x';在报错信息中同时指出了缺'x',而'a'没有被使用,那么可能
+就是由于field的名称不对。这里将dataset中'a'这个field的名称改为'x',或者model的参数从'x'修改为'a'都可以解决问题。
+
+下面的例子是由于loss计算的时候找不到需要的值
Example2.2
- ::
-
- import numpy as np
- from torch import nn
- from torch.optim import SGD
- from fastNLP import Trainer
- from fastNLP import DataSet
- from fastNLP import L1Loss
- import torch
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(1, 1)
- def forward(self, a):
- return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1}
-
- model = Model()
-
- dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2})
-
- dataset.set_input('a')
- dataset.set_target('b')
-
- trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001))
- # 报错信息如下
- # input fields after batch(if batch size is 2):
- # a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
- # target fields after batch(if batch size is 2):
- # b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
- # ....
- # NameError:
- # Problems occurred when calling L1Loss.get_loss(self, pred, target)
- # missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)']
- # unused field: ['b']
- # unused param: ['pred_b', 'No use']
- # target field: ['b']
- # param from Model.forward(self, a): ['pred_b', 'No use']
- # Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a).
- # (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a).
-
- 报错信息也包含两部分:
-
- 1 第一部分与上面是一样的
-
- 2 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的);
- 报错的原因是因为 `pred` 和 `label` (我们在初始化L1Loss时将target指定为了label)都没有找到。
- 这里'unused field'是DataSet中出现了,但却没有被设置为input或者target的field;
- 'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了target的field;
- 'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。
-
- 但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred,
- 将DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。
-
-
- 下面是带有dev dataset时如果出现错误会发生的报错,
+----------------------------
+
+.. code-block:: python
+
+ import numpy as np
+ from torch import nn
+ from torch.optim import SGD
+ from fastNLP import Trainer
+ from fastNLP import DataSet
+ from fastNLP import L1Loss
+ import torch
+
+ class Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fc = nn.Linear(1, 1)
+ def forward(self, a):
+ return {'pred_b': self.fc(a.unsqueeze(1)).squeeze(1), 'No use':1}
+
+ model = Model()
+
+ dataset = DataSet({'a': np.arange(10, dtype=float), 'b':np.arange(10, dtype=float)*2})
+
+ dataset.set_input('a')
+ dataset.set_target('b')
+
+ trainer = Trainer(dataset, model, loss=L1Loss(target='label'), optimizer=SGD(model.parameters(), lr=0.001))
+ # 报错信息如下
+ # input fields after batch(if batch size is 2):
+ # a: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
+ # target fields after batch(if batch size is 2):
+ # b: (1)type:torch.Tensor (2)dtype:torch.float32, (3)shape:torch.Size([2])
+ # ....
+ # NameError:
+ # Problems occurred when calling L1Loss.get_loss(self, pred, target)
+ # missing param: ['pred(assign to `pred` in `L1Loss`)', 'label(assign to `target` in `L1Loss`)']
+ # unused field: ['b']
+ # unused param: ['pred_b', 'No use']
+ # target field: ['b']
+ # param from Model.forward(self, a): ['pred_b', 'No use']
+ # Suggestion: (1). Check key assignment for `target` when initialize L1Loss. Or provide `label` in DataSet or output of Model.forward(self, a).
+ # (2). Check key assignment for `pred` when initialize L1Loss. Or provide `pred` in DataSet or output of Model.forward(self, a).
+
+报错信息也包含两部分:
+
+1 第一部分与上面是一样的
+
+2 这里报错的原因是由于计算loss的时候找不到相应的值(通过L1Loss.get_loss(self, pred, target)判断出来的);
+报错的原因是因为 `pred` 和 `label` (我们在初始化L1Loss时将target指定为了label)都没有找到。
+这里'unused field'是DataSet中出现了,但却没有被设置为input或者target的field;
+'unused param'是forward()中返回且没有被使用到的内容;'target field'是被设置为了target的field;
+'param from Model.forward(self, a)'是forward()返回的所有key。"Suggestion"是关于当前错误处理的建议。
+
+但是在一些情况下,比如forward()返回值只有一个,target也只有一个,fastNLP不会进行匹配,而直接将forward()的结果作为pred,
+将DataSet中的target设置为target。上面的例子在返回值中加入了一个'No use'则只是为了使得Loss去匹配结果。
+
+
+下面是带有dev dataset时如果出现错误会发生的报错,
Example2.3
- ::
+----------------------------
+
+.. code-block:: python
+
+ import numpy as np
+ from torch import nn
+ from torch.optim import SGD
+ from fastNLP import Trainer
+ from fastNLP import DataSet
+ from fastNLP import AccuracyMetric
+ import torch
+
+ class Model(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.fc = nn.Linear(1, 1)
+ def forward(self, a, b):
+ loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2)
+ return {'loss': loss}
+ def predict(self, a): # 使用predict()进行验证
+ return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key
+ model = Model()
+
+ dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
+ dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2})
+
+ dataset.set_input('a', 'b')
+ dev_data.set_input('a') # 这里没有设置target
+
+ trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001),
+ dev_data=dev_data, metrics=AccuracyMetric())
+
+ # 报错信息
+ # ...
+ # NameError:
+ # Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None)
+ # missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)']
+ # unused param: ['output']
+ # target field: []
+ # param from Model.predict(self, a): ['output']
+ # Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a).
+ # (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a).
+
+报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation
+的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
+指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。
+
+可以通过check_code_level调节检查的强度。默认为0,即进行检查。
+
+----------------------------
+3. Trainer与callback
+----------------------------
+
+虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。
+为了解决这个问题fastNLP引入了callback的机制,:class:`~fastNLP.Callback` 是一种在Trainer训练过程中特定阶段会运行的函数集合,
+所有的 :class:`~fastNLP.Callback` 都具有on_*(比如on_train_start, on_backward_begin)等函数。
+如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用,例如::
+
+ from fastNLP import Callback, EarlyStopCallback, Trainer, CrossEntropyLoss, AccuracyMetric
+ from fastNLP.models import CNNText
+
+ start_time = time.time()
- import numpy as np
- from torch import nn
- from torch.optim import SGD
- from fastNLP import Trainer
- from fastNLP import DataSet
- from fastNLP import AccuracyMetric
- import torch
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
- self.fc = nn.Linear(1, 1)
- def forward(self, a, b):
- loss = torch.mean((self.fc(a.float().unsqueeze(1))-b.float())**2)
- return {'loss': loss}
- def predict(self, a): # 使用predict()进行验证
- return {'output':self.fc(a.float().unsqueeze(1))} #这里return的值不包含'pred'这个key
- model = Model()
-
- dataset = DataSet({'a': np.arange(10), 'b':np.arange(10)*2})
- dev_data = DataSet({'a': np.arange(10, 20), 'b':np.arange(10, 20)*2})
-
- dataset.set_input('a', 'b')
- dev_data.set_input('a') # 这里没有设置target
-
- trainer = Trainer(dataset, model, loss=None, optimizer=SGD(model.parameters(), lr=0.001),
- dev_data=dev_data, metrics=AccuracyMetric())
-
- # 报错信息
- # ...
- # NameError:
- # Problems occurred when calling AccuracyMetric.evaluate(self, pred, target, seq_len=None)
- # missing param: ['pred(assign to `pred` in `AccuracyMetric`)', 'target(assign to `target` in `AccuracyMetric`)']
- # unused param: ['output']
- # target field: []
- # param from Model.predict(self, a): ['output']
- # Suggestion: (1). Check key assignment for `pred` when initialize AccuracyMetric. Or provide `pred` in DataSet or output of Model.predict(self, a).
- # (2). Check key assignment for `target` when initialize AccuracyMetric. Or provide `target` in DataSet or output of Model.predict(self, a).
-
- 报错信息和前面都是类似的,但是可以通过'AccuracyMetric.evaluate(self, pred, target, seq_len=None)'看出这里是evaluation
- 的时候发生了错误。这样避免了需要在完成一整个epoch的训练才能发现evaluation弄错的情况。这里的修改是通过在初始化metric的时候
- 指明通过'output'获取`pred`, 即AccuracyMetric(pred='output')。
-
- 可以通过check_code_level调节检查的强度。默认为0,即进行检查。
-
-3 Trainer与callback
- 虽然Trainer本身已经集成了一些功能,但仍然不足以囊括训练过程中可能需要到的功能,比如负采样,learning rate decay, Early Stop等。
- 为了解决这个问题fastNLP引入了callback的机制,:class:`~fastNLP.Callback` 是一种在Trainer训练过程中特定阶段会运行的函数集合,
- 所有的 :class:`~fastNLP.Callback` 都具有on_*(比如on_train_start, on_backward_begin)等函数。
- 如果 Callback 实现了该函数,则Trainer运行至对应阶段,会进行调用,例如::
+ class MyCallback(Callback):
+ def on_epoch_end(self):
+ print('{:d}ms\n\n'.format(round((time.time()-start_time)*1000)))
- from fastNLP import Callback, EarlyStopCallback, Trainer, CrossEntropyLoss, AccuracyMetric
- from fastNLP.models import CNNText
-
- start_time = time.time()
-
- class MyCallback(Callback):
- def on_epoch_end(self):
- print('{:d}ms\n\n'.format(round((time.time()-start_time)*1000)))
-
- model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1)
- trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, loss=CrossEntropyLoss(),
- metrics=AccuracyMetric(), callbacks=[MyCallback(),EarlyStopCallback(10)])
- trainer.train()
-
- 这里,我们通过继承 :class:`~fastNLP.Callback` 类定义了自己的 callback 的,并和内置的 :class:`~fastNLP.EarlyStopCallback`
- 一起传给了 :class:`~fastNLP.Trainer` ,增强了 :class:`~fastNLP.Trainer` 的功能
+ model = CNNText((len(vocab),50), num_classes=5, padding=2, dropout=0.1)
+ trainer = Trainer(model=model, train_data=train_data, dev_data=dev_data, loss=CrossEntropyLoss(),
+ metrics=AccuracyMetric(), callbacks=[MyCallback(),EarlyStopCallback(10)])
+ trainer.train()
- fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。
+这里,我们通过继承 :class:`~fastNLP.Callback` 类定义了自己的 callback 的,并和内置的 :class:`~fastNLP.EarlyStopCallback`
+一起传给了 :class:`~fastNLP.Trainer` ,增强了 :class:`~fastNLP.Trainer` 的功能
+
+fastNLP已经自带了很多callback函数供使用,可以参考 :doc:`fastNLP.core.callback` 。
"""
__all__ = [
@@ -314,7 +336,7 @@ except:
import warnings
from .batch import DataSetIter, BatchIter
-from .callback import CallbackManager, CallbackException
+from .callback import CallbackManager, CallbackException, Callback
from .dataset import DataSet
from .losses import _prepare_losser
from .metrics import _prepare_metrics
@@ -330,7 +352,8 @@ from .utils import _move_dict_value_to_device
from .utils import _get_func_signature
from .utils import _get_model_device
from .utils import _move_model_to_device
-
+from ._parallel_utils import _model_contains_inner_module
+from ._logger import logger
class Trainer(object):
"""
@@ -367,8 +390,8 @@ class Trainer(object):
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。
- :param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。
- 保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
+ :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存
+ 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型
的计算位置进行管理。支持以下的输入:
@@ -398,33 +421,28 @@ class Trainer(object):
batch_size=32, sampler=None, drop_last=False, update_every=1,
num_workers=0, n_epochs=10, print_every=5,
dev_data=None, metrics=None, metric_key=None,
- validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False,
- callbacks=None, check_code_level=0):
- if prefetch and num_workers==0:
- num_workers = 1
- if prefetch:
- warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.")
-
+ validate_every=-1, save_path=None, use_tqdm=True, device=None,
+ callbacks=None, check_code_level=0, **kwargs):
super(Trainer, self).__init__()
if not isinstance(model, nn.Module):
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.")
-
+
# check metrics and dev_data
if (not metrics) and dev_data is not None:
raise ValueError("No metric for dev_data evaluation.")
if metrics and (dev_data is None):
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ")
-
+
# check update every
assert update_every >= 1, "update_every must be no less than 1."
self.update_every = int(update_every)
-
+
# check save_path
if not (save_path is None or isinstance(save_path, str)):
raise ValueError("save_path can only be None or `str`.")
# prepare evaluate
metrics = _prepare_metrics(metrics)
-
+
# parse metric_key
# increase_better is True. It means the exp result gets better if the indicator increases.
# It is true by default.
@@ -436,28 +454,69 @@ class Trainer(object):
self.metric_key = None
# prepare loss
losser = _prepare_losser(loss)
-
- # sampler check
- if sampler is not None and not isinstance(sampler, Sampler):
- raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler)))
- if sampler is None:
- sampler = RandomSampler()
+ if isinstance(train_data, BatchIter):
+ if sampler is not None:
+ warnings.warn("sampler is ignored when train_data is a BatchIter.")
+ if num_workers>0:
+ warnings.warn("num_workers is ignored when train_data is BatchIter.")
+ if drop_last:
+ warnings.warn("drop_last is ignored when train_data is BatchIter.")
+
+ if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的
+ # device为None
+ if device is not None:
+ warnings.warn("device is ignored when model is nn.parallel.DistributedDataParallel.")
+ device = None
+ # Sampler要是分布式的
+ if sampler is None:
+ sampler = torch.utils.data.DistributedSampler(train_data)
+ elif not isinstance(sampler, torch.utils.data.DistributedSampler):
+ raise TypeError("When using nn.parallel.DistributedDataParallel, "
+ "sampler must be None or torch.utils.data.DistributedSampler.")
+ # 不能保存模型
+ if save_path:
+ raise RuntimeError("Saving model in Distributed situation is not allowed right now.")
+ else:
+ # sampler check
+ if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)):
+ raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}")
+ if sampler is None:
+ sampler = RandomSampler()
+ elif hasattr(sampler, 'set_batch_size'):
+ sampler.set_batch_size(batch_size)
if isinstance(train_data, DataSet):
self.data_iterator = DataSetIter(
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last)
elif isinstance(train_data, BatchIter):
self.data_iterator = train_data
+ train_data = train_data.dataset
else:
raise TypeError("train_data type {} not support".format(type(train_data)))
- if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter):
- _check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data,
- metric_key=self.metric_key, check_level=check_code_level,
- batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE))
- # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码
self.model = _move_model_to_device(model, device=device)
+ if _model_contains_inner_module(self.model):
+ self._forward_func = self.model.module.forward
+ else:
+ self._forward_func = self.model.forward
+ if check_code_level > -1:
+ # _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入
+ # 名是否匹配
+ dev_dataset = dev_data
+ if isinstance(dev_data, BatchIter):
+ dev_dataset = None
+ warnings.warn("dev_data is of BatchIter type, ignore validation checking.")
+ check_batch_size = min(batch_size, DEFAULT_CHECK_BATCH_SIZE)
+ if isinstance(self.model, nn.DataParallel):
+ _num_devices = len(self.model.device_ids)
+ if batch_size//_num_devices>1: # 如果多卡是每个卡可以分多个数据的,则用每个卡给两个sample
+ check_batch_size = max(len(self.model.device_ids)*2, check_batch_size)
+ else:
+ check_batch_size = max(len(self.model.device_ids), check_batch_size)
+ _check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics,
+ dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level,
+ batch_size=check_batch_size)
self.train_data = train_data
self.dev_data = dev_data # If None, No validation.
@@ -472,8 +531,7 @@ class Trainer(object):
self.best_dev_epoch = None
self.best_dev_step = None
self.best_dev_perf = None
- self.n_steps = (len(self.train_data) // self.batch_size + int(
- len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs
+ self.n_steps = len(self.data_iterator) * self.n_epochs
if isinstance(optimizer, torch.optim.Optimizer):
self.optimizer = optimizer
@@ -483,22 +541,32 @@ class Trainer(object):
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3)
else:
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer)))
-
+
+ self.logger = logger
+
self.use_tqdm = use_tqdm
+ if 'test_use_tqdm' in kwargs:
+ self.test_use_tqdm = kwargs.get('test_use_tqdm')
+ else:
+ self.test_use_tqdm = self.use_tqdm
self.pbar = None
self.print_every = abs(self.print_every)
-
+ self.kwargs = kwargs
if self.dev_data is not None:
self.tester = Tester(model=self.model,
data=self.dev_data,
metrics=self.metrics,
- batch_size=self.batch_size,
+ batch_size=kwargs.get("dev_batch_size", self.batch_size),
device=None, # 由上面的部分处理device
- verbose=0)
-
+ verbose=0,
+ use_tqdm=self.test_use_tqdm)
+
self.step = 0
self.start_time = None # start timestamp
-
+
+ if isinstance(callbacks, Callback):
+ callbacks = [callbacks]
+
self.callback_manager = CallbackManager(env={"trainer": self},
callbacks=callbacks)
@@ -524,7 +592,7 @@ class Trainer(object):
"""
results = {}
if self.n_epochs <= 0:
- print(f"training epoch is {self.n_epochs}, nothing was done.")
+ self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.")
results['seconds'] = 0.
return results
try:
@@ -533,8 +601,8 @@ class Trainer(object):
self._load_best_model = load_best_model
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S'))
start_time = time.time()
- print("training epochs started " + self.start_time, flush=True)
-
+ self.logger.info("training epochs started " + self.start_time)
+
try:
self.callback_manager.on_train_begin()
self._train()
@@ -547,11 +615,11 @@ class Trainer(object):
raise e
elif on_exception == 'raise':
raise e
-
+
if self.dev_data is not None and self.best_dev_perf is not None:
- print(
- "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) +
- self.tester._format_eval_results(self.best_dev_perf), )
+ self.logger.info(
+ "\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step))
+ self.logger.info(self.tester._format_eval_results(self.best_dev_perf))
results['best_eval'] = self.best_dev_perf
results['best_epoch'] = self.best_dev_epoch
results['best_step'] = self.best_dev_step
@@ -559,27 +627,23 @@ class Trainer(object):
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])
load_succeed = self._load_model(self.model, model_name)
if load_succeed:
- print("Reloaded the best model.")
+ self.logger.info("Reloaded the best model.")
else:
- print("Fail to reload best model.")
+ self.logger.info("Fail to reload best model.")
finally:
pass
results['seconds'] = round(time.time() - start_time, 2)
-
+
return results
-
+
def _train(self):
if not self.use_tqdm:
- from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm
+ from .utils import _pseudo_tqdm as inner_tqdm
else:
inner_tqdm = tqdm
self.step = 0
self.epoch = 0
start = time.time()
- if isinstance(self.model, nn.DataParallel):
- self._forward_func = self.model.module.forward
- else:
- self._forward_func = self.model.forward
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
self.pbar = pbar
avg_loss = 0
@@ -597,21 +661,21 @@ class Trainer(object):
# negative sampling; replace unknown; re-weight batch_y
self.callback_manager.on_batch_begin(batch_x, batch_y, indices)
prediction = self._data_forward(self.model, batch_x)
-
+
# edit prediction
self.callback_manager.on_loss_begin(batch_y, prediction)
loss = self._compute_loss(prediction, batch_y).mean()
avg_loss += loss.item()
loss = loss / self.update_every
-
+
# Is loss NaN or inf? requires_grad = False
self.callback_manager.on_backward_begin(loss)
self._grad_backward(loss)
self.callback_manager.on_backward_end()
-
+
self._update()
self.callback_manager.on_step_end()
-
+
if self.step % self.print_every == 0:
avg_loss = float(avg_loss) / self.print_every
if self.use_tqdm:
@@ -625,29 +689,29 @@ class Trainer(object):
pbar.set_postfix_str(print_output)
avg_loss = 0
self.callback_manager.on_batch_end()
-
+
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \
and self.dev_data is not None:
eval_res = self._do_validation(epoch=epoch, step=self.step)
- eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step,
- self.n_steps) + \
- self.tester._format_eval_results(eval_res)
- pbar.write(eval_str + '\n')
-
+ eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,
+ self.n_steps)
+ # pbar.write(eval_str + '\n')
+ self.logger.info(eval_str)
+ self.logger.info(self.tester._format_eval_results(eval_res)+'\n')
# ================= mini-batch end ==================== #
-
+
# lr decay; early stopping
self.callback_manager.on_epoch_end()
# =============== epochs end =================== #
pbar.close()
self.pbar = None
# ============ tqdm end ============== #
-
+
def _do_validation(self, epoch, step):
self.callback_manager.on_valid_begin()
res = self.tester.test()
-
+
is_better_eval = False
if self._better_eval_result(res):
if self.save_path is not None:
@@ -662,7 +726,7 @@ class Trainer(object):
# get validation results; adjust optimizer
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval)
return res
-
+
def _mode(self, model, is_test=False):
"""Train mode or Test mode. This is for PyTorch currently.
@@ -674,14 +738,14 @@ class Trainer(object):
model.eval()
else:
model.train()
-
+
def _update(self):
"""Perform weight update on a model.
"""
if self.step % self.update_every == 0:
self.optimizer.step()
-
+
def _data_forward(self, network, x):
x = _build_args(self._forward_func, **x)
y = network(**x)
@@ -689,7 +753,7 @@ class Trainer(object):
raise TypeError(
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.")
return y
-
+
def _grad_backward(self, loss):
"""Compute gradient with link rules.
@@ -700,7 +764,7 @@ class Trainer(object):
if (self.step-1) % self.update_every == 0:
self.model.zero_grad()
loss.backward()
-
+
def _compute_loss(self, predict, truth):
"""Compute loss given prediction and ground truth.
@@ -709,7 +773,7 @@ class Trainer(object):
:return: a scalar
"""
return self.losser(predict, truth)
-
+
def _save_model(self, model, model_name, only_param=False):
""" 存储不含有显卡信息的state_dict或model
:param model:
@@ -721,7 +785,7 @@ class Trainer(object):
model_path = os.path.join(self.save_path, model_name)
if not os.path.exists(self.save_path):
os.makedirs(self.save_path, exist_ok=True)
- if isinstance(model, nn.DataParallel):
+ if _model_contains_inner_module(model):
model = model.module
if only_param:
state_dict = model.state_dict()
@@ -732,7 +796,7 @@ class Trainer(object):
model.cpu()
torch.save(model, model_path)
model.to(self._model_device)
-
+
def _load_model(self, model, model_name, only_param=False):
# 返回bool值指示是否成功reload模型
if self.save_path is not None:
@@ -741,7 +805,7 @@ class Trainer(object):
states = torch.load(model_path)
else:
states = torch.load(model_path).state_dict()
- if isinstance(model, nn.DataParallel):
+ if _model_contains_inner_module(model):
model.module.load_state_dict(states)
else:
model.load_state_dict(states)
@@ -750,7 +814,7 @@ class Trainer(object):
else:
return False
return True
-
+
def _better_eval_result(self, metrics):
"""Check if the current epoch yields better validation results.
@@ -765,17 +829,20 @@ class Trainer(object):
self.best_metric_indicator = indicator_val
else:
if self.increase_better is True:
- if indicator_val > self.best_metric_indicator:
+ if indicator_val >= self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
else:
- if indicator_val < self.best_metric_indicator:
+ if indicator_val <= self.best_metric_indicator:
self.best_metric_indicator = indicator_val
else:
is_better = False
return is_better
+ @property
+ def is_master(self):
+ return True
DEFAULT_CHECK_BATCH_SIZE = 2
DEFAULT_CHECK_NUM_BATCH = 2
@@ -797,14 +864,15 @@ def _get_value_info(_dict):
strs.append(_str)
return strs
+
from numbers import Number
from .batch import _to_tensor
-def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE,
- dev_data=None, metric_key=None,
- check_level=0):
+
+
+def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE,
+ dev_data=None, metric_key=None, check_level=0):
# check get_loss 方法
- model_devcie = _get_model_device(model=model)
-
+ model_device = _get_model_device(model=model)
def _iter():
start_idx = 0
while start_idx>> seq_len = torch.arange(2, 16)
>>> mask = seq_len_to_mask(seq_len)
@@ -691,7 +636,7 @@ def seq_len_to_mask(seq_len, max_len=None):
:param np.ndarray,torch.LongTensor seq_len: shape将是(B,)
:param int max_len: 将长度pad到这个长度。默认(None)使用的是seq_len中最长的长度。但在nn.DataParallel的场景下可能不同卡的seq_len会有
区别,所以需要传入一个max_len使得mask的长度是pad到该长度。
- :return: np.ndarray or torch.Tensor, shape将是(B, max_length)。 元素类似为bool或torch.uint8
+ :return: np.ndarray, torch.Tensor 。shape将是(B, max_length), 元素类似为bool或torch.uint8
"""
if isinstance(seq_len, np.ndarray):
assert len(np.shape(seq_len)) == 1, f"seq_len can only have one dimension, got {len(np.shape(seq_len))}."
@@ -715,15 +660,14 @@ class _pseudo_tqdm:
"""
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据
"""
-
def __init__(self, **kwargs):
- pass
+ self.logger = logger
def write(self, info):
- print(info)
+ self.logger.info(info)
def set_postfix_str(self, info):
- print(info)
+ self.logger.info(info)
def __getattr__(self, item):
def pass_func(*args, **kwargs):
@@ -737,7 +681,8 @@ class _pseudo_tqdm:
def __exit__(self, exc_type, exc_val, exc_tb):
del self
-def iob2(tags:List[str])->List[str]:
+
+def iob2(tags: List[str]) -> List[str]:
"""
检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两者的差异见
https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format
@@ -760,7 +705,8 @@ def iob2(tags:List[str])->List[str]:
tags[i] = "B" + tag[1:]
return tags
-def iob2bioes(tags:List[str])->List[str]:
+
+def iob2bioes(tags: List[str]) -> List[str]:
"""
将iob的tag转换为bioes编码
:param tags: List[str]. 编码需要是大写的。
@@ -773,15 +719,35 @@ def iob2bioes(tags:List[str])->List[str]:
else:
split = tag.split('-')[0]
if split == 'B':
- if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
+ if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I':
new_tags.append(tag)
else:
new_tags.append(tag.replace('B-', 'S-'))
elif split == 'I':
- if i + 1= self.max_size:
- print("[Warning] Vocabulary has reached the max size {} when calling {} method. "
- "Adding more words may cause unexpected behaviour of Vocabulary. ".format(
+ logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. "
+ "Adding more words may cause unexpected behaviour of Vocabulary. ".format(
self.max_size, func.__name__))
return func(self, *args, **kwargs)
@@ -92,7 +100,7 @@ class Vocabulary(object):
self.rebuild = True
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法
self._no_create_word = Counter()
-
+
@_check_build_status
def update(self, word_lst, no_create_entry=False):
"""依次增加序列中词在词典中的出现频率
@@ -107,6 +115,7 @@ class Vocabulary(object):
"""
self._add_no_create_entry(word_lst, no_create_entry)
self.word_count.update(word_lst)
+ return self
@_check_build_status
def add(self, word, no_create_entry=False):
@@ -123,23 +132,24 @@ class Vocabulary(object):
"""
self._add_no_create_entry(word, no_create_entry)
self.word_count[word] += 1
-
+ return self
+
def _add_no_create_entry(self, word, no_create_entry):
"""
在新加入word时,检查_no_create_word的设置。
- :param str, List[str] word:
+ :param str List[str] word:
:param bool no_create_entry:
:return:
"""
- if isinstance(word, str):
+ if isinstance(word, str) or not _is_iterable(word):
word = [word]
for w in word:
if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0):
self._no_create_word[w] += 1
elif not no_create_entry and w in self._no_create_word:
self._no_create_word.pop(w)
-
+
@_check_build_status
def add_word(self, word, no_create_entry=False):
"""
@@ -169,6 +179,7 @@ class Vocabulary(object):
则这个词将认为是需要创建单独的vector的。
"""
self.update(word_lst, no_create_entry=no_create_entry)
+ return self
def build_vocab(self):
"""
@@ -193,13 +204,15 @@ class Vocabulary(object):
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)})
self.build_reverse_vocab()
self.rebuild = False
-
+ return self
+
def build_reverse_vocab(self):
"""
- 基于 "word to index" dict, 构建 "index to word" dict.
+ 基于 `word to index` dict, 构建 `index to word` dict.
"""
self.idx2word = {i: w for w, i in self.word2idx.items()}
+ return self
@_check_build_vocab
def __len__(self):
@@ -250,46 +263,57 @@ class Vocabulary(object):
# remember to use `field_name`
vocab.index_dataset(train_data, dev_data, test_data, field_name='words')
- :param datasets: 需要转index的 class:`~fastNLP.DataSet` , 支持一个或多个(list)
- :param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field.
- 目前仅支持 ``str`` , ``list(str)`` , ``list(list(str))``
- :param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field.
- Default: ``None``
+ :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集
+ :param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field.
+ 目前支持 ``str`` , ``List[str]``
+ :param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field.
+ Default: ``None``.
"""
- def index_instance(ins):
+ def index_instance(field):
"""
有几种情况, str, 1d-list, 2d-list
:param ins:
:return:
"""
- field = ins[field_name]
- if isinstance(field, str):
+ if isinstance(field, str) or not _is_iterable(field):
return self.to_index(field)
- elif isinstance(field, list):
- if not isinstance(field[0], list):
+ else:
+ if isinstance(field[0], str) or not _is_iterable(field[0]):
return [self.to_index(w) for w in field]
else:
- if isinstance(field[0][0], list):
+ if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
raise RuntimeError("Only support field with 2 dimensions.")
return [[self.to_index(c) for c in w] for w in field]
- if new_field_name is None:
- new_field_name = field_name
+ new_field_name = new_field_name or field_name
+
+ if type(new_field_name) == type(field_name):
+ if isinstance(new_field_name, list):
+ assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \
+ "field_name."
+ elif isinstance(new_field_name, str):
+ field_name = [field_name]
+ new_field_name = [new_field_name]
+ else:
+ raise TypeError("field_name and new_field_name can only be str or List[str].")
+
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
try:
- dataset.apply(index_instance, new_field_name=new_field_name)
+ for f_n, n_f_n in zip(field_name, new_field_name):
+ dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n)
except Exception as e:
- print("When processing the `{}` dataset, the following error occurred.".format(idx))
+ logger.info("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
else:
raise RuntimeError("Only DataSet type is allowed.")
-
+ return self
+
@property
def _no_create_word_length(self):
return len(self._no_create_word)
-
+
def from_dataset(self, *datasets, field_name, no_create_entry_dataset=None):
"""
使用dataset的对应field中词构建词典::
@@ -297,11 +321,10 @@ class Vocabulary(object):
# remember to use `field_name`
vocab.from_dataset(train_data1, train_data2, field_name='words')
- :param datasets: 需要转index的 class:`~fastNLP.DataSet` , 支持一个或多个(list)
- :param field_name: 可为 ``str`` 或 ``list(str)`` .
- 构建词典所使用的 field(s), 支持一个或多个field
- 若有多个 DataSet, 每个DataSet都必须有这些field.
- 目前仅支持的field结构: ``str`` , ``list(str)`` , ``list(list(str))``
+ :param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集
+ :param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` .
+ 构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构
+ : ``str`` , ``List[str]``
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。
@@ -319,29 +342,29 @@ class Vocabulary(object):
def construct_vocab(ins, no_create_entry=False):
for fn in field_name:
field = ins[fn]
- if isinstance(field, str):
+ if isinstance(field, str) or not _is_iterable(field):
self.add_word(field, no_create_entry=no_create_entry)
- elif isinstance(field, (list, np.ndarray)):
- if not isinstance(field[0], (list, np.ndarray)):
+ else:
+ if isinstance(field[0], str) or not _is_iterable(field[0]):
for word in field:
self.add_word(word, no_create_entry=no_create_entry)
else:
- if isinstance(field[0][0], (list, np.ndarray)):
+ if not isinstance(field[0][0], str) and _is_iterable(field[0][0]):
raise RuntimeError("Only support field with 2 dimensions.")
for words in field:
for word in words:
self.add_word(word, no_create_entry=no_create_entry)
-
+
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
try:
dataset.apply(construct_vocab)
- except Exception as e:
- print("When processing the `{}` dataset, the following error occurred.".format(idx))
+ except BaseException as e:
+ log("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e
else:
raise TypeError("Only DataSet type is allowed.")
-
+
if no_create_entry_dataset is not None:
partial_construct_vocab = partial(construct_vocab, no_create_entry=True)
if isinstance(no_create_entry_dataset, DataSet):
@@ -352,7 +375,7 @@ class Vocabulary(object):
raise TypeError("Only DataSet type is allowed.")
dataset.apply(partial_construct_vocab)
return self
-
+
def _is_word_no_create_entry(self, word):
"""
判断当前的word是否是不需要创建entry的,具体参见from_dataset的说明
@@ -360,11 +383,10 @@ class Vocabulary(object):
:return: bool
"""
return word in self._no_create_word
-
+
def to_index(self, w):
"""
- 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出
- ``ValueError``::
+ 将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``::
index = vocab.to_index('abc')
# equals to
@@ -416,6 +438,7 @@ class Vocabulary(object):
self.idx2word = None
self.rebuild = True
self._no_create_word.clear()
+ return self
def __getstate__(self):
"""Use to prepare data for pickle.
diff --git a/fastNLP/embeddings/__init__.py b/fastNLP/embeddings/__init__.py
new file mode 100644
index 00000000..8a970e25
--- /dev/null
+++ b/fastNLP/embeddings/__init__.py
@@ -0,0 +1,27 @@
+"""
+embeddings 模块主要用于从各种预训练的模型中获取词语的分布式表示,目前支持的预训练模型包括word2vec, glove, ELMO, BERT等。这里所有
+embedding的forward输入都是形状为 ``(batch_size, max_len)`` 的torch.LongTensor,输出都是 ``(batch_size, max_len, embedding_dim)`` 的
+torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获取最大的输入index范围, 用 `self.embeddig_dim` 或 `self.embed_size` 获取embedding的
+输出维度。
+"""
+
+__all__ = [
+ "Embedding",
+ "TokenEmbedding",
+ "StaticEmbedding",
+ "ElmoEmbedding",
+ "BertEmbedding",
+ "BertWordPieceEncoder",
+ "StackEmbedding",
+ "LSTMCharEmbedding",
+ "CNNCharEmbedding",
+ "get_embeddings",
+]
+
+from .embedding import Embedding, TokenEmbedding
+from .static_embedding import StaticEmbedding
+from .elmo_embedding import ElmoEmbedding
+from .bert_embedding import BertEmbedding, BertWordPieceEncoder
+from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding
+from .stack_embedding import StackEmbedding
+from .utils import get_embeddings
diff --git a/fastNLP/embeddings/bert_embedding.py b/fastNLP/embeddings/bert_embedding.py
new file mode 100644
index 00000000..047048d8
--- /dev/null
+++ b/fastNLP/embeddings/bert_embedding.py
@@ -0,0 +1,471 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "BertEmbedding",
+ "BertWordPieceEncoder"
+]
+
+import os
+import collections
+
+from torch import nn
+import torch
+import numpy as np
+from itertools import chain
+
+from ..core.vocabulary import Vocabulary
+from ..io.file_utils import _get_embedding_url, cached_path, PRETRAINED_BERT_MODEL_DIR
+from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer
+from .contextual_embedding import ContextualEmbedding
+import warnings
+from ..core import logger
+
+
+class BertEmbedding(ContextualEmbedding):
+ """
+ 别名::class:`fastNLP.embeddings.BertEmbedding` :class:`fastNLP.embeddings.bert_embedding.BertEmbedding`
+
+ 使用BERT对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于
+ 预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word
+ 时切分),在分割之后长度可能会超过最大长度限制。
+
+ BertEmbedding可以支持自动下载权重,当前支持的模型有以下的几种(待补充):
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import BertEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1')
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ >>> # torch.Size([1, 5, 2304])
+
+ :param ~fastNLP.Vocabulary vocab: 词表
+ :param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名),
+ 权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。
+ :param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是
+ 从0开始,可以以负数去索引倒数几层。
+ :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
+ 中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
+ 会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的
+ embedding长度不匹配。
+ :param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测,
+ 一般该值为True。
+ :param bool requires_grad: 是否需要gradient以更新Bert的权重。
+ :param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个
+ word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS]
+ 来进行分类的任务将auto_truncate置为True。
+ """
+
+ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1',
+ pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False,
+ pooled_cls=True, requires_grad: bool = False, auto_truncate: bool = False):
+ super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ # 根据model_dir_or_name检查是否存在并下载
+ if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
+ if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'):
+ warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve"
+ " faster speed.")
+ model_url = _get_embedding_url('bert', model_dir_or_name.lower())
+ model_dir = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_dir = os.path.abspath(os.path.expanduser(model_dir_or_name))
+ else:
+ raise ValueError(f"Cannot recognize {model_dir_or_name}.")
+
+ self._word_sep_index = None
+ if '[SEP]' in vocab:
+ self._word_sep_index = vocab['[SEP]']
+
+ self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
+ pool_method=pool_method, include_cls_sep=include_cls_sep,
+ pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2)
+
+ self.requires_grad = requires_grad
+ self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
+
+ def _delete_model_weights(self):
+ del self.model
+
+ def forward(self, words):
+ """
+ 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
+ 删除这两个token的表示。
+
+ :param torch.LongTensor words: [batch_size, max_len]
+ :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
+ """
+ words = self.drop_word(words)
+ outputs = self._get_sent_reprs(words)
+ if outputs is not None:
+ return self.dropout(outputs)
+ outputs = self.model(words)
+ outputs = torch.cat([*outputs], dim=-1)
+
+ return self.dropout(outputs)
+
+ def drop_word(self, words):
+ """
+ 按照设定随机将words设置为unknown_index。
+
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ with torch.no_grad():
+ if self._word_sep_index: # 不能drop sep
+ sep_mask = words.eq(self._word_sep_index)
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ pad_mask = words.ne(0)
+ mask = pad_mask.__and__(mask) # pad的位置不为unk
+ words = words.masked_fill(mask, self._word_unk_index)
+ if self._word_sep_index:
+ words.masked_fill_(sep_mask, self._word_sep_index)
+ return words
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+
+ :return:
+ """
+ requires_grads = set([param.requires_grad for name, param in self.named_parameters()
+ if 'word_pieces_lengths' not in name])
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for name, param in self.named_parameters():
+ if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中
+ continue
+ param.requires_grad = value
+
+
+class BertWordPieceEncoder(nn.Module):
+ """
+ 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。
+
+ :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``
+ :param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
+ :param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取
+ [CLS]做预测,一般该值为True。
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param bool requires_grad: 是否需要gradient。
+ """
+
+ def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False,
+ word_dropout=0, dropout=0, requires_grad: bool = False):
+ super().__init__()
+
+ if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
+ model_url = _get_embedding_url('bert', model_dir_or_name.lower())
+ model_dir = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
+ model_dir = model_dir_or_name
+ else:
+ raise ValueError(f"Cannot recognize {model_dir_or_name}.")
+
+ self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers, pooled_cls=pooled_cls)
+ self._sep_index = self.model._sep_index
+ self._wordpiece_pad_index = self.model._wordpiece_pad_index
+ self._wordpiece_unk_index = self.model._wordpiece_unknown_index
+ self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
+ self.requires_grad = requires_grad
+ self.word_dropout = word_dropout
+ self.dropout_layer = nn.Dropout(dropout)
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+ :return:
+ """
+ requires_grads = set([param.requires_grad for name, param in self.named_parameters()])
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for name, param in self.named_parameters():
+ param.requires_grad = value
+
+ @property
+ def embed_size(self):
+ return self._embed_size
+
+ @property
+ def embedding_dim(self):
+ return self._embed_size
+
+ @property
+ def num_embedding(self):
+ return self.model.encoder.config.vocab_size
+
+ def index_datasets(self, *datasets, field_name, add_cls_sep=True):
+ """
+ 使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了
+ bert的pad value。
+
+ :param ~fastNLP.DataSet datasets: DataSet对象
+ :param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
+ :param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。
+ :return:
+ """
+ self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep)
+
+ def forward(self, word_pieces, token_type_ids=None):
+ """
+ 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
+
+ :param words: batch_size x max_len
+ :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入),
+ 第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。
+ :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
+ """
+ with torch.no_grad():
+ sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
+ if token_type_ids is None:
+ sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
+ token_type_ids = sep_mask_cumsum.fmod(2)
+ if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
+ token_type_ids = token_type_ids.eq(0).long()
+
+ word_pieces = self.drop_word(word_pieces)
+ outputs = self.model(word_pieces, token_type_ids)
+ outputs = torch.cat([*outputs], dim=-1)
+
+ return self.dropout_layer(outputs)
+
+ def drop_word(self, words):
+ """
+ 按照设定随机将words设置为unknown_index。
+
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ with torch.no_grad():
+ if self._word_sep_index: # 不能drop sep
+ sep_mask = words.eq(self._wordpiece_unk_index)
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ pad_mask = words.ne(self._wordpiece_pad_index)
+ mask = pad_mask.__and__(mask) # pad的位置不为unk
+ words = words.masked_fill(mask, self._word_unk_index)
+ if self._word_sep_index:
+ words.masked_fill_(sep_mask, self._wordpiece_unk_index)
+ return words
+
+
+class _WordBertModel(nn.Module):
+ def __init__(self, model_dir: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first',
+ include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2):
+ super().__init__()
+
+ self.tokenzier = BertTokenizer.from_pretrained(model_dir)
+ self.encoder = BertModel.from_pretrained(model_dir)
+ self._max_position_embeddings = self.encoder.config.max_position_embeddings
+ # 检查encoder_layer_number是否合理
+ encoder_layer_number = len(self.encoder.encoder.layer)
+ self.layers = list(map(int, layers.split(',')))
+ for layer in self.layers:
+ if layer < 0:
+ assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a bert model with {encoder_layer_number} layers."
+ else:
+ assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a bert model with {encoder_layer_number} layers."
+
+ assert pool_method in ('avg', 'max', 'first', 'last')
+ self.pool_method = pool_method
+ self.include_cls_sep = include_cls_sep
+ self.pooled_cls = pooled_cls
+ self.auto_truncate = auto_truncate
+
+ # 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP]
+ logger.info("Start to generating word pieces for word.")
+ # 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值
+ word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的
+ found_count = 0
+ self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids
+ if '[sep]' in vocab:
+ warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.")
+ if "[CLS]" in vocab:
+ warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin "
+ "and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin"
+ " and end.")
+ for word, index in vocab:
+ if index == vocab.padding_idx: # pad是个特殊的符号
+ word = '[PAD]'
+ elif index == vocab.unknown_idx:
+ word = '[UNK]'
+ word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
+ if len(word_pieces) == 1:
+ if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到
+ if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面
+ if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry(
+ word): # 出现次数大于这个次数才新增
+ word_piece_dict[word] = 1 # 新增一个值
+ continue
+ for word_piece in word_pieces:
+ word_piece_dict[word_piece] = 1
+ found_count += 1
+ original_embed = self.encoder.embeddings.word_embeddings.weight.data
+ # 特殊词汇要特殊处理
+ embed = nn.Embedding(len(word_piece_dict), original_embed.size(1)) # 新的embed
+ new_word_piece_vocab = collections.OrderedDict()
+ for index, token in enumerate(['[PAD]', '[UNK]']):
+ word_piece_dict.pop(token, None)
+ embed.weight.data[index] = original_embed[self.tokenzier.vocab[token]]
+ new_word_piece_vocab[token] = index
+ for token in word_piece_dict.keys():
+ if token in self.tokenzier.vocab:
+ embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab[token]]
+ else:
+ embed.weight.data[len(new_word_piece_vocab)] = original_embed[self.tokenzier.vocab['[UNK]']]
+ new_word_piece_vocab[token] = len(new_word_piece_vocab)
+ self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab)
+ self.encoder.embeddings.word_embeddings = embed
+
+ word_to_wordpieces = []
+ word_pieces_lengths = []
+ for word, index in vocab:
+ if index == vocab.padding_idx: # pad是个特殊的符号
+ word = '[PAD]'
+ elif index == vocab.unknown_idx:
+ word = '[UNK]'
+ word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word)
+ word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces)
+ word_to_wordpieces.append(word_pieces)
+ word_pieces_lengths.append(len(word_pieces))
+ self._cls_index = self.tokenzier.vocab['[CLS]']
+ self._sep_index = self.tokenzier.vocab['[SEP]']
+ self._word_pad_index = vocab.padding_idx
+ self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
+ logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab)))
+ self.word_to_wordpieces = np.array(word_to_wordpieces)
+ self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False)
+ logger.debug("Successfully generate word pieces.")
+
+ def forward(self, words):
+ """
+
+ :param words: torch.LongTensor, batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
+ """
+ with torch.no_grad():
+ batch_size, max_word_len = words.size()
+ word_mask = words.ne(self._word_pad_index) # 为1的地方有word
+ seq_len = word_mask.sum(dim=-1)
+ batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0),
+ 0) # batch_size x max_len
+ word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size
+ word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding)
+ if word_piece_length + 2 > self._max_position_embeddings:
+ if self.auto_truncate:
+ word_pieces_lengths = word_pieces_lengths.masked_fill(
+ word_pieces_lengths + 2 > self._max_position_embeddings,
+ self._max_position_embeddings - 2)
+ else:
+ raise RuntimeError(
+ "After split words into word pieces, the lengths of word pieces are longer than the "
+ f"maximum allowed sequence length:{self._max_position_embeddings} of bert.")
+
+ # +2是由于需要加入[CLS]与[SEP]
+ word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)),
+ fill_value=self._wordpiece_pad_index)
+ attn_masks = torch.zeros_like(word_pieces)
+ # 1. 获取words的word_pieces的id,以及对应的span范围
+ word_indexes = words.cpu().numpy()
+ for i in range(batch_size):
+ word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]]))
+ if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2:
+ word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2]
+ word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i)
+ attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1)
+ # 添加[cls]和[sep]
+ word_pieces[:, 0].fill_(self._cls_index)
+ batch_indexes = torch.arange(batch_size).to(words)
+ word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index
+ if self._has_sep_in_vocab: # 但[SEP]在vocab中出现应该才会需要token_ids
+ sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len
+ sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1])
+ token_type_ids = sep_mask_cumsum.fmod(2)
+ if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0
+ token_type_ids = token_type_ids.eq(0).long()
+ else:
+ token_type_ids = torch.zeros_like(word_pieces)
+ # 2. 获取hidden的结果,根据word_pieces进行对应的pool计算
+ # all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...]
+ bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
+ output_all_encoded_layers=True)
+ # output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size
+
+ if self.include_cls_sep:
+ outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2,
+ bert_outputs[-1].size(-1))
+ s_shift = 1
+ else:
+ outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len,
+ bert_outputs[-1].size(-1))
+ s_shift = 0
+ batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1)
+ batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len
+ for l_index, l in enumerate(self.layers):
+ output_layer = bert_outputs[l]
+ real_word_piece_length = output_layer.size(1) - 2
+ if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的
+ paddings = output_layer.new_zeros(batch_size,
+ word_piece_length - real_word_piece_length,
+ output_layer.size(2))
+ output_layer = torch.cat((output_layer, paddings), dim=1).contiguous()
+ # 从word_piece collapse到word的表示
+ truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size
+ outputs_seq_len = seq_len + s_shift
+ if self.pool_method == 'first':
+ for i in range(batch_size):
+ i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置
+ outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[
+ i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size
+ elif self.pool_method == 'last':
+ for i in range(batch_size):
+ i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i] + 1] - 1 # 每个word的end
+ outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length]
+ elif self.pool_method == 'max':
+ for i in range(batch_size):
+ for j in range(seq_len[i]):
+ start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
+ outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2)
+ else:
+ for i in range(batch_size):
+ for j in range(seq_len[i]):
+ start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1]
+ outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2)
+ if self.include_cls_sep:
+ if l in (len(bert_outputs) - 1, -1) and self.pooled_cls:
+ outputs[l_index, :, 0] = pooled_cls
+ else:
+ outputs[l_index, :, 0] = output_layer[:, 0]
+ outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift]
+ # 3. 最终的embedding结果
+ return outputs
diff --git a/fastNLP/embeddings/char_embedding.py b/fastNLP/embeddings/char_embedding.py
new file mode 100644
index 00000000..acffa054
--- /dev/null
+++ b/fastNLP/embeddings/char_embedding.py
@@ -0,0 +1,325 @@
+"""
+该文件中主要包含的是character的Embedding,包括基于CNN与LSTM的character Embedding。与其它Embedding一样,这里的Embedding输入也是
+词的index而不需要使用词语中的char的index来获取表达。
+"""
+
+__all__ = [
+ "CNNCharEmbedding",
+ "LSTMCharEmbedding"
+]
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+from .static_embedding import StaticEmbedding
+from ..modules.encoder.lstm import LSTM
+from ..core.vocabulary import Vocabulary
+from .embedding import TokenEmbedding
+from .utils import _construct_char_vocab_from_vocab
+from .utils import get_embeddings
+from ..core import logger
+
+
+class CNNCharEmbedding(TokenEmbedding):
+ """
+ 别名::class:`fastNLP.embeddings.CNNCharEmbedding` :class:`fastNLP.embeddings.char_embedding.CNNCharEmbedding`
+
+ 使用CNN生成character embedding。CNN的结构为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout.
+ 不同的kernel大小的fitler结果是concat起来然后通过一层fully connected layer, 然后输出word的表示。
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import CNNCharEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = CNNCharEmbedding(vocab, embed_size=50)
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ >>> # torch.Size([1, 5,50])
+
+ :param vocab: 词表
+ :param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50.
+ :param char_emb_size: character的embed的维度。character是从vocab中生成的。默认值为50.
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param float dropout: 以多大的概率drop分布式表示与char embedding的输出。
+ :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
+ :param kernel_sizes: kernel的大小. 默认值为[5, 3, 1].
+ :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
+ :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
+ :param min_char_freq: character的最少出现次数。默认值为2.
+ :param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹
+ (文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,
+ 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
+ """
+
+ def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
+ dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1),
+ pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None):
+ super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ for kernel in kernel_sizes:
+ assert kernel % 2 == 1, "Only odd kernel is allowed."
+
+ assert pool_method in ('max', 'avg')
+ self.pool_method = pool_method
+ # activation function
+ if isinstance(activation, str):
+ if activation.lower() == 'relu':
+ self.activation = F.relu
+ elif activation.lower() == 'sigmoid':
+ self.activation = F.sigmoid
+ elif activation.lower() == 'tanh':
+ self.activation = F.tanh
+ elif activation is None:
+ self.activation = lambda x: x
+ elif callable(activation):
+ self.activation = activation
+ else:
+ raise Exception(
+ "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
+
+ logger.info("Start constructing character vocabulary.")
+ # 建立char的词表
+ self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
+ self.char_pad_index = self.char_vocab.padding_idx
+ logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
+ # 对vocab进行index
+ max_word_len = max(map(lambda x: len(x[0]), vocab))
+ self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len),
+ fill_value=self.char_pad_index, dtype=torch.long),
+ requires_grad=False)
+ self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
+ for word, index in vocab:
+ # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的也是同一个embed
+ self.words_to_chars_embedding[index, :len(word)] = \
+ torch.LongTensor([self.char_vocab.to_index(c) for c in word])
+ self.word_lengths[index] = len(word)
+ # self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
+ if pre_train_char_embed:
+ self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed)
+ else:
+ self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size))
+
+ self.convs = nn.ModuleList([nn.Conv1d(
+ char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
+ for i in range(len(kernel_sizes))])
+ self._embed_size = embed_size
+ self.fc = nn.Linear(sum(filter_nums), embed_size)
+ self.reset_parameters()
+
+ def forward(self, words):
+ """
+ 输入words的index后,生成对应的words的表示。
+
+ :param words: [batch_size, max_len]
+ :return: [batch_size, max_len, embed_size]
+ """
+ words = self.drop_word(words)
+ batch_size, max_len = words.size()
+ chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
+ word_lengths = self.word_lengths[words] # batch_size x max_len
+ max_word_len = word_lengths.max()
+ chars = chars[:, :, :max_word_len]
+ # 为1的地方为mask
+ chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
+ chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
+ chars = self.dropout(chars)
+ reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
+ reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
+ conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)
+ for conv in self.convs]
+ conv_chars = torch.cat(conv_chars, dim=-1).contiguous() # B x max_len x max_word_len x sum(filters)
+ conv_chars = self.activation(conv_chars)
+ if self.pool_method == 'max':
+ conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
+ chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters)
+ else:
+ conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
+ chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
+ chars = self.fc(chars)
+ return self.dropout(chars)
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+ :return:
+ """
+ params = []
+ for name, param in self.named_parameters():
+ if 'words_to_chars_embedding' not in name and 'word_lengths' not in name:
+ params.append(param.requires_grad)
+ requires_grads = set(params)
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for name, param in self.named_parameters():
+ if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
+ continue
+ param.requires_grad = value
+
+ def reset_parameters(self):
+ for name, param in self.named_parameters():
+ if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset
+ continue
+ if 'char_embedding' in name:
+ continue
+ if param.data.dim() > 1:
+ nn.init.xavier_uniform_(param, 1)
+ else:
+ nn.init.uniform_(param, -1, 1)
+
+
+class LSTMCharEmbedding(TokenEmbedding):
+ """
+ 别名::class:`fastNLP.embeddings.LSTMCharEmbedding` :class:`fastNLP.embeddings.char_embedding.LSTMCharEmbedding`
+
+ 使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import LSTMCharEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = LSTMCharEmbedding(vocab, embed_size=50)
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ >>> # torch.Size([1, 5,50])
+
+ :param vocab: 词表
+ :param embed_size: LSTMCharEmbedding的输出维度。默认值为50.
+ :param char_emb_size: character的embedding的维度。默认值为50.
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。
+ :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50.
+ :param pool_method: 支持'max', 'avg'。
+ :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
+ :param min_char_freq: character的最小出现次数。默认值为2.
+ :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
+ :param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹
+ (文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,
+ 没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding.
+ """
+
+ def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0,
+ dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu',
+ min_char_freq: int = 2,
+ bidirectional=True, pre_train_char_embed: str = None):
+ super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ assert hidden_size % 2 == 0, "Only even kernel is allowed."
+
+ assert pool_method in ('max', 'avg')
+ self.pool_method = pool_method
+ # activation function
+ if isinstance(activation, str):
+ if activation.lower() == 'relu':
+ self.activation = F.relu
+ elif activation.lower() == 'sigmoid':
+ self.activation = F.sigmoid
+ elif activation.lower() == 'tanh':
+ self.activation = F.tanh
+ elif activation is None:
+ self.activation = lambda x: x
+ elif callable(activation):
+ self.activation = activation
+ else:
+ raise Exception(
+ "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
+
+ logger.info("Start constructing character vocabulary.")
+ # 建立char的词表
+ self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
+ self.char_pad_index = self.char_vocab.padding_idx
+ logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.")
+ # 对vocab进行index
+ self.max_word_len = max(map(lambda x: len(x[0]), vocab))
+ self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len),
+ fill_value=self.char_pad_index, dtype=torch.long),
+ requires_grad=False)
+ self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
+ for word, index in vocab:
+ # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
+ self.words_to_chars_embedding[index, :len(word)] = \
+ torch.LongTensor([self.char_vocab.to_index(c) for c in word])
+ self.word_lengths[index] = len(word)
+ # self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
+ if pre_train_char_embed:
+ self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed)
+ else:
+ self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
+
+ self.fc = nn.Linear(hidden_size, embed_size)
+ hidden_size = hidden_size // 2 if bidirectional else hidden_size
+
+ self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True)
+ self._embed_size = embed_size
+ self.bidirectional = bidirectional
+
+ def forward(self, words):
+ """
+ 输入words的index后,生成对应的words的表示。
+
+ :param words: [batch_size, max_len]
+ :return: [batch_size, max_len, embed_size]
+ """
+ words = self.drop_word(words)
+ batch_size, max_len = words.size()
+ chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
+ word_lengths = self.word_lengths[words] # batch_size x max_len
+ max_word_len = word_lengths.max()
+ chars = chars[:, :, :max_word_len]
+ # 为mask的地方为1
+ chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
+ chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
+ chars = self.dropout(chars)
+ reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
+ char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len)
+ lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
+ # B x M x M x H
+
+ lstm_chars = self.activation(lstm_chars)
+ if self.pool_method == 'max':
+ lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
+ chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H
+ else:
+ lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
+ chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
+
+ chars = self.fc(chars)
+
+ return self.dropout(chars)
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+
+ :return:
+ """
+ params = []
+ for name, param in self.named_parameters():
+ if 'words_to_chars_embedding' not in name and 'word_lengths' not in name:
+ params.append(param)
+ requires_grads = set(params)
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for name, param in self.named_parameters():
+ if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
+ continue
+ param.requires_grad = value
diff --git a/fastNLP/embeddings/contextual_embedding.py b/fastNLP/embeddings/contextual_embedding.py
new file mode 100644
index 00000000..9910a44b
--- /dev/null
+++ b/fastNLP/embeddings/contextual_embedding.py
@@ -0,0 +1,110 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "ContextualEmbedding"
+]
+
+from abc import abstractmethod
+
+import torch
+
+from .embedding import TokenEmbedding
+from ..core import logger
+from ..core.batch import DataSetIter
+from ..core.dataset import DataSet
+from ..core.sampler import SequentialSampler
+from ..core.utils import _move_model_to_device, _get_model_device
+from ..core.vocabulary import Vocabulary
+
+
+class ContextualEmbedding(TokenEmbedding):
+ def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0):
+ super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True):
+ """
+ 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。
+
+ :param datasets: DataSet对象
+ :param batch_size: int, 生成cache的sentence表示时使用的batch的大小
+ :param device: 参考 :class::fastNLP.Trainer 的device
+ :param delete_weights: 似乎在生成了cache之后删除权重,在不需要finetune动态模型的情况下,删除权重会大量减少内存占用。
+ :return:
+ """
+ for index, dataset in enumerate(datasets):
+ try:
+ assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed."
+ assert 'words' in dataset.get_input_name(), "`words` field has to be set as input."
+ except Exception as e:
+ logger.error(f"Exception happens at {index} dataset.")
+ raise e
+
+ sent_embeds = {}
+ _move_model_to_device(self, device=device)
+ device = _get_model_device(self)
+ pad_index = self._word_vocab.padding_idx
+ logger.info("Start to calculate sentence representations.")
+ with torch.no_grad():
+ for index, dataset in enumerate(datasets):
+ try:
+ batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
+ for batch_x, batch_y in batch:
+ words = batch_x['words'].to(device)
+ words_list = words.tolist()
+ seq_len = words.ne(pad_index).sum(dim=-1)
+ max_len = words.size(1)
+ # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。
+ seq_len_from_behind = (max_len - seq_len).tolist()
+ word_embeds = self(words).detach().cpu().numpy()
+ for b in range(words.size(0)):
+ length = seq_len_from_behind[b]
+ if length == 0:
+ sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b]
+ else:
+ sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length]
+ except Exception as e:
+ logger.error(f"Exception happens at {index} dataset.")
+ raise e
+ logger.info("Finish calculating sentence representations.")
+ self.sent_embeds = sent_embeds
+ if delete_weights:
+ self._delete_model_weights()
+
+ def _get_sent_reprs(self, words):
+ """
+ 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None
+
+ :param words: torch.LongTensor
+ :return:
+ """
+ if hasattr(self, 'sent_embeds'):
+ words_list = words.tolist()
+ seq_len = words.ne(self._word_pad_index).sum(dim=-1)
+ _embeds = []
+ for b in range(len(words)):
+ words_i = tuple(words_list[b][:seq_len[b]])
+ embed = self.sent_embeds[words_i]
+ _embeds.append(embed)
+ max_sent_len = max(map(len, _embeds))
+ embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float,
+ device=words.device)
+ for i, embed in enumerate(_embeds):
+ embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device)
+ return embeds
+ return None
+
+ @abstractmethod
+ def _delete_model_weights(self):
+ """删除计算表示的模型以节省资源"""
+ raise NotImplementedError
+
+ def remove_sentence_cache(self):
+ """
+ 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。
+
+ :return:
+ """
+ del self.sent_embeds
diff --git a/fastNLP/embeddings/elmo_embedding.py b/fastNLP/embeddings/elmo_embedding.py
new file mode 100644
index 00000000..3df424a2
--- /dev/null
+++ b/fastNLP/embeddings/elmo_embedding.py
@@ -0,0 +1,345 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "ElmoEmbedding"
+]
+
+import os
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import json
+import codecs
+
+from ..core.vocabulary import Vocabulary
+from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR
+from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder
+from .contextual_embedding import ContextualEmbedding
+from ..core import logger
+
+class ElmoEmbedding(ContextualEmbedding):
+ """
+ 别名::class:`fastNLP.embeddings.ElmoEmbedding` :class:`fastNLP.embeddings.elmo_embedding.ElmoEmbedding`
+
+ 使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。当前支持的使用名称初始化的模型有以下的这些(待补充)
+
+ Example::
+
+ >>> import torch
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import ElmoEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> # 使用不同层的concat的结果
+ >>> embed = ElmoEmbedding(vocab, model_dir_or_name='en', layers='1,2', requires_grad=False)
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]])
+ >>> outputs = embed(words)
+ >>> outputs.size()
+ >>> # torch.Size([1, 5, 2048])
+
+ >>> # 使用不同层的weighted sum。
+ >>> embed = ElmoEmbedding(vocab, model_dir_or_name='en', layers='mix', requires_grad=False)
+ >>> embed.set_mix_weights_requires_grad() # 使得weighted的权重是可以学习的,但ELMO的LSTM部分是不更新
+
+ :param vocab: 词表
+ :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件,
+ 其中一个是以json为后缀的配置文件,另一个是以pkl为后缀的权重文件;第二种是传入ELMo版本的名称,将自动查看缓存中是否存在该模型,
+ 没有的话将自动下载并缓存。
+ :param layers: str, 指定返回的层数(从0开始), 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果
+ 按照这个顺序concat起来,默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致,
+ 初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。)
+ :param requires_grad: bool, 该层是否需要gradient, 默认为False.
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
+ 并删除character encoder,之后将直接使用cache的embedding。默认为False。
+ """
+
+ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = False,
+ word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False):
+ super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+
+ # 根据model_dir_or_name检查是否存在并下载
+ if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
+ model_url = _get_embedding_url('elmo', model_dir_or_name.lower())
+ model_dir = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_dir = model_dir_or_name
+ else:
+ raise ValueError(f"Cannot recognize {model_dir_or_name}.")
+ self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs)
+
+ if layers == 'mix':
+ self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1),
+ requires_grad=requires_grad)
+ self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad)
+ self._get_outputs = self._get_mixed_outputs
+ self._embed_size = self.model.config['lstm']['projection_dim'] * 2
+ else:
+ layers = list(map(int, layers.split(',')))
+ assert len(layers) > 0, "Must choose one output"
+ for layer in layers:
+ assert 0 <= layer <= 2, "Layer index should be in range [0, 2]."
+ self.layers = layers
+ self._get_outputs = self._get_layer_outputs
+ self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2
+
+ self.requires_grad = requires_grad
+
+ def _get_mixed_outputs(self, outputs):
+ # outputs: num_layers x batch_size x max_len x hidden_size
+ # return: batch_size x max_len x hidden_size
+ weights = F.softmax(self.layer_weights + 1 / len(outputs), dim=0).to(outputs)
+ outputs = torch.einsum('l,lbij->bij', weights, outputs)
+ return self.gamma.to(outputs) * outputs
+
+ def set_mix_weights_requires_grad(self, flag=True):
+ """
+ 当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用
+ 该方法没有用。
+
+ :param bool flag: 混合不同层表示的结果是否可以训练。
+ :return:
+ """
+ if hasattr(self, 'layer_weights'):
+ self.layer_weights.requires_grad = flag
+ self.gamma.requires_grad = flag
+
+ def _get_layer_outputs(self, outputs):
+ if len(self.layers) == 1:
+ outputs = outputs[self.layers[0]]
+ else:
+ outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1)
+
+ return outputs
+
+ def forward(self, words: torch.LongTensor):
+ """
+ 计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的
+ 被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens;
+ backward_hiddens].
+
+ :param words: batch_size x max_len
+ :return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers))
+ """
+ words = self.drop_word(words)
+ outputs = self._get_sent_reprs(words)
+ if outputs is not None:
+ return self.dropout(outputs)
+ outputs = self.model(words)
+ outputs = self._get_outputs(outputs)
+ return self.dropout(outputs)
+
+ def _delete_model_weights(self):
+ for name in ['layers', 'model', 'layer_weights', 'gamma']:
+ if hasattr(self, name):
+ delattr(self, name)
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+
+ :return:
+ """
+ requires_grads = set([param.requires_grad for name, param in self.named_parameters()
+ if 'words_to_chars_embedding' not in name and 'words_to_words' not in name])
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for name, param in self.named_parameters():
+ if 'words_to_chars_embedding' in name or 'words_to_words' in name: # 这个不能加入到requires_grad中
+ continue
+ param.requires_grad = value
+
+
+class _ElmoModel(nn.Module):
+ """
+ 该Module是ElmoEmbedding中进行所有的heavy lifting的地方。做的工作,包括
+ (1) 根据配置,加载模型;
+ (2) 根据vocab,对模型中的embedding进行调整. 并将其正确初始化
+ (3) 保存一个words与chars的对应转换,获取时自动进行相应的转换
+ (4) 设计一个保存token的embedding,允许缓存word的表示。
+
+ """
+
+ def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False):
+ super(_ElmoModel, self).__init__()
+ self.model_dir = model_dir
+ dir = os.walk(self.model_dir)
+ config_file = None
+ weight_file = None
+ config_count = 0
+ weight_count = 0
+ for path, dir_list, file_list in dir:
+ for file_name in file_list:
+ if file_name.__contains__(".json"):
+ config_file = file_name
+ config_count += 1
+ elif file_name.__contains__(".pkl"):
+ weight_file = file_name
+ weight_count += 1
+ if config_count > 1 or weight_count > 1:
+ raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.")
+ elif config_count == 0 or weight_count == 0:
+ raise Exception(f"No config file or weight file found in {model_dir}")
+ with open(os.path.join(model_dir, config_file), 'r') as config_f:
+ config = json.load(config_f)
+ self.weight_file = os.path.join(model_dir, weight_file)
+ self.config = config
+
+ OOV_TAG = ''
+ PAD_TAG = ''
+ BOS_TAG = ''
+ EOS_TAG = ''
+ BOW_TAG = ''
+ EOW_TAG = ''
+
+ # For the model trained with character-based word encoder.
+ char_lexicon = {}
+ with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
+ for line in fpi:
+ tokens = line.strip().split('\t')
+ if len(tokens) == 1:
+ tokens.insert(0, '\u3000')
+ token, i = tokens
+ char_lexicon[token] = int(i)
+
+ # 做一些sanity check
+ for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]:
+ assert special_word in char_lexicon, f"{special_word} not found in char.dic."
+
+ # 从vocab中构建char_vocab
+ char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG)
+ # 需要保证与在里面
+ char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG])
+
+ for word, index in vocab:
+ char_vocab.add_word_lst(list(word))
+
+ self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx
+ # 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示)
+ char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']),
+ padding_idx=len(char_vocab))
+
+ # 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict
+ elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu')
+
+ char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight']
+
+ found_char_count = 0
+ for char, index in char_vocab: # 调整character embedding
+ if char in char_lexicon:
+ index_in_pre = char_lexicon.get(char)
+ found_char_count += 1
+ else:
+ index_in_pre = char_lexicon[OOV_TAG]
+ char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]
+
+ logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
+ # 生成words到chars的映射
+ max_chars = config['char_cnn']['max_characters_per_token']
+
+ self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars),
+ fill_value=len(char_vocab),
+ dtype=torch.long),
+ requires_grad=False)
+ for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]:
+ if len(word) + 2 > max_chars:
+ word = word[:max_chars - 2]
+ if index == self._pad_index:
+ continue
+ elif word == BOS_TAG or word == EOS_TAG:
+ char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [
+ char_vocab.to_index(EOW_TAG)]
+ char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
+ else:
+ char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [
+ char_vocab.to_index(EOW_TAG)]
+ char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
+ self.words_to_chars_embedding[index] = torch.LongTensor(char_ids)
+
+ self.char_vocab = char_vocab
+
+ self.token_embedder = ConvTokenEmbedder(
+ config, self.weight_file, None, char_emb_layer)
+ elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight
+ self.token_embedder.load_state_dict(elmo_model["char_cnn"])
+
+ self.output_dim = config['lstm']['projection_dim']
+
+ # lstm encoder
+ self.encoder = ElmobiLm(config)
+ self.encoder.load_state_dict(elmo_model["lstm"])
+
+ if cache_word_reprs:
+ if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
+ logger.info("Start to generate cache word representations.")
+ batch_size = 320
+ # bos eos
+ word_size = self.words_to_chars_embedding.size(0)
+ num_batches = word_size // batch_size + \
+ int(word_size % batch_size != 0)
+
+ self.cached_word_embedding = nn.Embedding(word_size,
+ config['lstm']['projection_dim'])
+ with torch.no_grad():
+ for i in range(num_batches):
+ words = torch.arange(i * batch_size,
+ min((i + 1) * batch_size, word_size)).long()
+ chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars
+ word_reprs = self.token_embedder(words.unsqueeze(1),
+ chars).detach() # batch_size x 1 x config['encoder']['projection_dim']
+ self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1)
+
+ logger.info("Finish generating cached word representations. Going to delete the character encoder.")
+ del self.token_embedder, self.words_to_chars_embedding
+ else:
+ logger.info("There is no need to cache word representations, since no character information is used.")
+
+ def forward(self, words):
+ """
+
+ :param words: batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size
+ """
+ # 扩展,
+ batch_size, max_len = words.size()
+ expanded_words = words.new_zeros(batch_size, max_len + 2) # 因为pad一定为0,
+ seq_len = words.ne(self._pad_index).sum(dim=-1)
+ expanded_words[:, 1:-1] = words
+ expanded_words[:, 0].fill_(self.bos_index)
+ expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index
+ seq_len = seq_len + 2
+ zero_tensor = expanded_words.new_zeros(expanded_words.shape)
+ mask = (expanded_words == zero_tensor).unsqueeze(-1)
+ if hasattr(self, 'cached_word_embedding'):
+ token_embedding = self.cached_word_embedding(expanded_words)
+ else:
+ if hasattr(self, 'words_to_chars_embedding'):
+ chars = self.words_to_chars_embedding[expanded_words]
+ else:
+ chars = None
+ token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim
+
+ encoder_output = self.encoder(token_embedding, seq_len)
+ if encoder_output.size(2) < max_len + 2:
+ num_layers, _, output_len, hidden_size = encoder_output.size()
+ dummy_tensor = encoder_output.new_zeros(num_layers, batch_size,
+ max_len + 2 - output_len, hidden_size)
+ encoder_output = torch.cat((encoder_output, dummy_tensor), 2)
+ sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
+ token_embedding = token_embedding.masked_fill(mask, 0)
+ token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3])
+ encoder_output = torch.cat((token_embedding, encoder_output), dim=0)
+
+ # 删除, . 这里没有精确地删除,但应该也不会影响最后的结果了。
+ encoder_output = encoder_output[:, :, 1:-1]
+ return encoder_output
diff --git a/fastNLP/embeddings/embedding.py b/fastNLP/embeddings/embedding.py
new file mode 100644
index 00000000..5e7b9803
--- /dev/null
+++ b/fastNLP/embeddings/embedding.py
@@ -0,0 +1,207 @@
+"""
+该模块中的Embedding主要用于随机初始化的embedding(更推荐使用 :class:`fastNLP.embeddings.StaticEmbedding` ),或按照预训练权重初始化Embedding。
+
+"""
+
+__all__ = [
+ "Embedding",
+ "TokenEmbedding"
+]
+
+import torch.nn as nn
+from abc import abstractmethod
+import torch
+
+from .utils import get_embeddings
+
+
+class Embedding(nn.Module):
+ """
+ 别名::class:`fastNLP.embeddings.Embedding` :class:`fastNLP.embeddings.embedding.Embedding`
+
+ 词向量嵌入,支持输入多种方式初始化. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度.
+
+ Example::
+
+ >>> import numpy as np
+ >>> from fastNLP.embeddings import Embedding
+ >>> init_embed = (2000, 100)
+ >>> embed = Embedding(init_embed) # 随机初始化一个具有2000个词,每个词表示为100维的词向量
+ >>> init_embed = np.zeros((2000, 100))
+ >>> embed = Embedding(init_embed) # 使用numpy.ndarray的值作为初始化值初始化一个Embedding
+
+ :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: 支持传入Embedding的大小(传入tuple(int, int),
+ 第一个int为vocab_zie, 第二个int为embed_dim); 或传入Tensor, Embedding, numpy.ndarray等则直接使用该值初始化Embedding;
+ :param float word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有
+ 一定的regularize的作用。设置该值时,必须同时设置unk_index
+ :param float dropout: 对Embedding的输出的dropout。
+ :param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。
+ """
+
+ def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None):
+
+ super(Embedding, self).__init__()
+
+ self.embed = get_embeddings(init_embed)
+
+ self.dropout = nn.Dropout(dropout)
+ if not isinstance(self.embed, TokenEmbedding):
+ if hasattr(self.embed, 'embed_size'):
+ self._embed_size = self.embed.embed_size
+ elif hasattr(self.embed, 'embedding_dim'):
+ self._embed_size = self.embed.embedding_dim
+ else:
+ self._embed_size = self.embed.weight.size(1)
+ if word_dropout > 0 and not isinstance(unk_index, int):
+ raise ValueError("When drop word is set, you need to pass in the unk_index.")
+ else:
+ self._embed_size = self.embed.embed_size
+ unk_index = self.embed.get_word_vocab().unknown_idx
+ self.unk_index = unk_index
+ self.word_dropout = word_dropout
+
+ def forward(self, words):
+ """
+ :param torch.LongTensor words: [batch, seq_len]
+ :return: torch.Tensor : [batch, seq_len, embed_dim]
+ """
+ if self.word_dropout > 0 and self.training:
+ mask = torch.ones_like(words).float() * self.word_dropout
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ words = words.masked_fill(mask, self.unk_index)
+ words = self.embed(words)
+ return self.dropout(words)
+
+ @property
+ def num_embedding(self) -> int:
+ if isinstance(self.embed, nn.Embedding):
+ return self.embed.weight.size(0)
+ else:
+ return self.embed.num_embedding
+
+ def __len__(self):
+ return len(self.embed)
+
+ @property
+ def embed_size(self) -> int:
+ return self._embed_size
+
+ @property
+ def embedding_dim(self) -> int:
+ return self._embed_size
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+ :return:
+ """
+ if not isinstance(self.embed, TokenEmbedding):
+ return self.embed.weight.requires_grad
+ else:
+ return self.embed.requires_grad
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ if not isinstance(self.embed, TokenEmbedding):
+ self.embed.weight.requires_grad = value
+ else:
+ self.embed.requires_grad = value
+
+ @property
+ def size(self):
+ if isinstance(self.embed, TokenEmbedding):
+ return self.embed.size
+ else:
+ return self.embed.weight.size()
+
+
+class TokenEmbedding(nn.Module):
+ def __init__(self, vocab, word_dropout=0.0, dropout=0.0):
+ super(TokenEmbedding, self).__init__()
+ if vocab.rebuild:
+ vocab.build_vocab()
+ assert vocab.padding is not None, "Vocabulary must have a padding entry."
+ self._word_vocab = vocab
+ self._word_pad_index = vocab.padding_idx
+ if word_dropout > 0:
+ assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word."
+ self.word_dropout = word_dropout
+ self._word_unk_index = vocab.unknown_idx
+ self.dropout_layer = nn.Dropout(dropout)
+
+ def drop_word(self, words):
+ """
+ 按照设定随机将words设置为unknown_index。
+
+ :param torch.LongTensor words: batch_size x max_len
+ :return:
+ """
+ if self.word_dropout > 0 and self.training:
+ mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device)
+ mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1
+ pad_mask = words.ne(self._word_pad_index)
+ mask = mask.__and__(pad_mask)
+ words = words.masked_fill(mask, self._word_unk_index)
+ return words
+
+ def dropout(self, words):
+ """
+ 对embedding后的word表示进行drop。
+
+ :param torch.FloatTensor words: batch_size x max_len x embed_size
+ :return:
+ """
+ return self.dropout_layer(words)
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+ :return:
+ """
+ requires_grads = set([param.requires_grad for param in self.parameters()])
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for param in self.parameters():
+ param.requires_grad = value
+
+ def __len__(self):
+ return len(self._word_vocab)
+
+ @property
+ def embed_size(self) -> int:
+ return self._embed_size
+
+ @property
+ def embedding_dim(self) -> int:
+ return self._embed_size
+
+ @property
+ def num_embedding(self) -> int:
+ """
+ 这个值可能会大于实际的embedding矩阵的大小。
+ :return:
+ """
+ return len(self._word_vocab)
+
+ def get_word_vocab(self):
+ """
+ 返回embedding的词典。
+
+ :return: Vocabulary
+ """
+ return self._word_vocab
+
+ @property
+ def size(self):
+ return torch.Size(self.num_embedding, self._embed_size)
+
+ @abstractmethod
+ def forward(self, words):
+ raise NotImplementedError
diff --git a/fastNLP/embeddings/stack_embedding.py b/fastNLP/embeddings/stack_embedding.py
new file mode 100644
index 00000000..14781945
--- /dev/null
+++ b/fastNLP/embeddings/stack_embedding.py
@@ -0,0 +1,104 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "StackEmbedding",
+]
+
+from typing import List
+
+import torch
+from torch import nn as nn
+
+from .embedding import TokenEmbedding
+
+
+class StackEmbedding(TokenEmbedding):
+ """
+ 别名::class:`fastNLP.embeddings.StackEmbedding` :class:`fastNLP.embeddings.stack_embedding.StackEmbedding`
+
+ 支持将多个embedding集合成一个embedding。
+
+ Example::
+
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import StaticEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True)
+ >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
+
+ :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置
+ 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+
+ """
+
+ def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
+ vocabs = []
+ for embed in embeds:
+ if hasattr(embed, 'get_word_vocab'):
+ vocabs.append(embed.get_word_vocab())
+ _vocab = vocabs[0]
+ for vocab in vocabs[1:]:
+ assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."
+
+ super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
+ assert isinstance(embeds, list)
+ for embed in embeds:
+ assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
+ self.embeds = nn.ModuleList(embeds)
+ self._embed_size = sum([embed.embed_size for embed in self.embeds])
+
+ def append(self, embed: TokenEmbedding):
+ """
+ 添加一个embedding到结尾。
+ :param embed:
+ :return:
+ """
+ assert isinstance(embed, TokenEmbedding)
+ self.embeds.append(embed)
+
+ def pop(self):
+ """
+ 弹出最后一个embed
+ :return:
+ """
+ return self.embeds.pop()
+
+ @property
+ def embed_size(self):
+ return self._embed_size
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+ :return:
+ """
+ requires_grads = set([embed.requires_grad for embed in self.embeds()])
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for embed in self.embeds():
+ embed.requires_grad = value
+
+ def forward(self, words):
+ """
+ 得到多个embedding的结果,并把结果按照顺序concat起来。
+
+ :param words: batch_size x max_len
+ :return: 返回的shape和当前这个stack embedding中embedding的组成有关
+ """
+ outputs = []
+ words = self.drop_word(words)
+ for embed in self.embeds:
+ outputs.append(embed(words))
+ outputs = self.dropout(torch.cat(outputs, dim=-1))
+ return outputs
diff --git a/fastNLP/embeddings/static_embedding.py b/fastNLP/embeddings/static_embedding.py
new file mode 100644
index 00000000..98986565
--- /dev/null
+++ b/fastNLP/embeddings/static_embedding.py
@@ -0,0 +1,308 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "StaticEmbedding"
+]
+import os
+
+import torch
+import torch.nn as nn
+import numpy as np
+import warnings
+
+from ..core.vocabulary import Vocabulary
+from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path
+from .embedding import TokenEmbedding
+from ..modules.utils import _get_file_name_base_on_postfix
+from copy import deepcopy
+from collections import defaultdict
+from ..core import logger
+
+
+class StaticEmbedding(TokenEmbedding):
+ """
+ 别名::class:`fastNLP.embeddings.StaticEmbedding` :class:`fastNLP.embeddings.static_embedding.StaticEmbedding`
+
+ StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来,
+ 如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。
+ 当前支持自动下载的预训练vector有以下的几种(待补充);
+
+ Example::
+
+ >>> from fastNLP import Vocabulary
+ >>> from fastNLP.embeddings import StaticEmbedding
+ >>> vocab = Vocabulary().add_word_lst("The whether is good .".split())
+ >>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50d')
+
+ >>> vocab = Vocabulary().add_word_lst(["The", 'the', "THE"])
+ >>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50d", lower=True)
+ >>> # "the", "The", "THE"它们共用一个vector,且将使用"the"在预训练词表中寻找它们的初始化表示。
+
+ >>> vocab = Vocabulary().add_word_lst(["The", "the", "THE"])
+ >>> embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
+ >>> words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE"]]])
+ >>> embed(words)
+ >>> tensor([[[ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849],
+ [ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849],
+ [ 0.5773, 0.7251, -0.3104, 0.0777, 0.4849]]],
+ grad_fn=) # 每种word的输出是一致的。
+
+ :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
+ :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个
+ 以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
+ 如果输入为None则使用embedding_dim的维度随机初始化一个embedding。
+ :param int embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。
+ :param bool requires_grad: 是否需要gradient. 默认为True
+ :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对
+ :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
+ 为大写的词语开辟一个vector表示,则将lower设置为False。
+ :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
+ :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
+ :param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。
+ :param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。
+ """
+
+ def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True,
+ init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs):
+ super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
+ if embedding_dim > 0:
+ model_dir_or_name = None
+
+ # 得到cache_path
+ if model_dir_or_name is None:
+ assert embedding_dim >= 1, "The dimension of embedding should be larger than 1."
+ embedding_dim = int(embedding_dim)
+ model_path = None
+ elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
+ model_url = _get_embedding_url('static', model_dir_or_name.lower())
+ model_path = cached_path(model_url, name='embedding')
+ # 检查是否存在
+ elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_path = os.path.abspath(os.path.expanduser(model_dir_or_name))
+ elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))):
+ model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt')
+ else:
+ raise ValueError(f"Cannot recognize {model_dir_or_name}.")
+
+ # 根据min_freq缩小vocab
+ truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq)
+ if truncate_vocab:
+ truncated_vocab = deepcopy(vocab)
+ truncated_vocab.min_freq = min_freq
+ truncated_vocab.word2idx = None
+ if lower: # 如果有lower,将大小写的的freq需要同时考虑到
+ lowered_word_count = defaultdict(int)
+ for word, count in truncated_vocab.word_count.items():
+ lowered_word_count[word.lower()] += count
+ for word in truncated_vocab.word_count.keys():
+ word_count = truncated_vocab.word_count[word]
+ if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq:
+ truncated_vocab.add_word_lst([word] * (min_freq - word_count),
+ no_create_entry=truncated_vocab._is_word_no_create_entry(word))
+
+ # 只限制在train里面的词语使用min_freq筛选
+ if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None:
+ for word in truncated_vocab.word_count.keys():
+ if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq:
+ truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]),
+ no_create_entry=True)
+ truncated_vocab.build_vocab()
+ truncated_words_to_words = torch.arange(len(vocab)).long()
+ for word, index in vocab:
+ truncated_words_to_words[index] = truncated_vocab.to_index(word)
+ logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.")
+ vocab = truncated_vocab
+
+ self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False)
+ # 读取embedding
+ if lower:
+ lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
+ for word, index in vocab:
+ if vocab._is_word_no_create_entry(word):
+ lowered_vocab.add_word(word.lower(), no_create_entry=True)
+ else:
+ lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
+ logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} "
+ f"unique lowered words.")
+ if model_path:
+ embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method)
+ else:
+ embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
+ self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
+ if lowered_vocab.unknown:
+ unknown_idx = lowered_vocab.unknown_idx
+ else:
+ unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow
+ self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
+ words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
+ requires_grad=False)
+ for word, index in vocab:
+ if word not in lowered_vocab:
+ word = word.lower()
+ if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word):
+ continue # 如果不需要创建entry,已经默认unknown了
+ words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
+ self.words_to_words = words_to_words
+ self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index
+ else:
+ if model_path:
+ embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method)
+ else:
+ embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method)
+ self.words_to_words = nn.Parameter(torch.arange(len(vocab)).long(), requires_grad=False)
+ if not self.only_norm_found_vector and normalize:
+ embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12)
+
+ if truncate_vocab:
+ for i in range(len(truncated_words_to_words)):
+ index_in_truncated_vocab = truncated_words_to_words[i]
+ truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab]
+ del self.words_to_words
+ self.words_to_words = nn.Parameter(truncated_words_to_words, requires_grad=False)
+
+ self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
+ padding_idx=vocab.padding_idx,
+ max_norm=None, norm_type=2, scale_grad_by_freq=False,
+ sparse=False, _weight=embedding)
+ self._embed_size = self.embedding.weight.size(1)
+ self.requires_grad = requires_grad
+
+ def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None):
+ """
+
+ :param int num_embedding: embedding的entry的数量
+ :param int embedding_dim: embedding的维度大小
+ :param callable init_embed: 初始化方法
+ :return: torch.FloatTensor
+ """
+ embed = torch.zeros(num_embedding, embedding_dim)
+
+ if init_embed is None:
+ nn.init.uniform_(embed, -np.sqrt(3 / embedding_dim), np.sqrt(3 / embedding_dim))
+ else:
+ init_embed(embed)
+
+ return embed
+
+ @property
+ def requires_grad(self):
+ """
+ Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+
+ :return:
+ """
+ requires_grads = set([param.requires_grad for name, param in self.named_parameters()
+ if 'words_to_words' not in name])
+ if len(requires_grads) == 1:
+ return requires_grads.pop()
+ else:
+ return None
+
+ @requires_grad.setter
+ def requires_grad(self, value):
+ for name, param in self.named_parameters():
+ if 'words_to_words' in name:
+ continue
+ param.requires_grad = value
+
+ def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='', unknown='',
+ error='ignore', init_method=None):
+ """
+ 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是
+ word2vec(第一行只有两个元素)还是glove格式的数据。
+
+ :param str embed_filepath: 预训练的embedding的路径。
+ :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。
+ 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。
+ :param dtype: 读出的embedding的类型
+ :param str padding: 词表中padding的token
+ :param str unknown: 词表中unknown的token
+ :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。
+ 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。
+ :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.zeros_
+ :return torch.tensor: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
+ """
+ assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
+ if not os.path.exists(embed_filepath):
+ raise FileNotFoundError("`{}` does not exist.".format(embed_filepath))
+ with open(embed_filepath, 'r', encoding='utf-8') as f:
+ line = f.readline().strip()
+ parts = line.split()
+ start_idx = 0
+ if len(parts) == 2:
+ dim = int(parts[1])
+ start_idx += 1
+ else:
+ dim = len(parts) - 1
+ f.seek(0)
+ matrix = {}
+ if vocab.padding:
+ matrix[vocab.padding_idx] = torch.zeros(dim)
+ if vocab.unknown:
+ matrix[vocab.unknown_idx] = torch.zeros(dim)
+ found_count = 0
+ found_unknown = False
+ for idx, line in enumerate(f, start_idx):
+ try:
+ parts = line.strip().split()
+ word = ''.join(parts[:-dim])
+ nums = parts[-dim:]
+ # 对齐unk与pad
+ if word == padding and vocab.padding is not None:
+ word = vocab.padding
+ elif word == unknown and vocab.unknown is not None:
+ word = vocab.unknown
+ found_unknown = True
+ if word in vocab:
+ index = vocab.to_index(word)
+ matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
+ if self.only_norm_found_vector:
+ matrix[index] = matrix[index] / np.linalg.norm(matrix[index])
+ found_count += 1
+ except Exception as e:
+ if error == 'ignore':
+ warnings.warn("Error occurred at the {} line.".format(idx))
+ else:
+ logger.error("Error occurred at the {} line.".format(idx))
+ raise e
+ logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
+ for word, index in vocab:
+ if index not in matrix and not vocab._is_word_no_create_entry(word):
+ if found_unknown: # 如果有unkonwn,用unknown初始化
+ matrix[index] = matrix[vocab.unknown_idx]
+ else:
+ matrix[index] = None
+ # matrix中代表是需要建立entry的词
+ vectors = self._randomly_init_embed(len(matrix), dim, init_method)
+
+ if vocab.unknown is None: # 创建一个专门的unknown
+ unknown_idx = len(matrix)
+ vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
+ else:
+ unknown_idx = vocab.unknown_idx
+ self.words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
+ requires_grad=False)
+
+ for index, (index_in_vocab, vec) in enumerate(matrix.items()):
+ if vec is not None:
+ vectors[index] = vec
+ self.words_to_words[index_in_vocab] = index
+
+ return vectors
+
+ def forward(self, words):
+ """
+ 传入words的index
+
+ :param words: torch.LongTensor, [batch_size, max_len]
+ :return: torch.FloatTensor, [batch_size, max_len, embed_size]
+ """
+ if hasattr(self, 'words_to_words'):
+ words = self.words_to_words[words]
+ words = self.drop_word(words)
+ words = self.embedding(words)
+ words = self.dropout(words)
+ return words
diff --git a/fastNLP/embeddings/utils.py b/fastNLP/embeddings/utils.py
new file mode 100644
index 00000000..844a0c93
--- /dev/null
+++ b/fastNLP/embeddings/utils.py
@@ -0,0 +1,57 @@
+"""
+.. todo::
+ doc
+"""
+import numpy as np
+import torch
+from torch import nn as nn
+
+from ..core.vocabulary import Vocabulary
+
+__all__ = [
+ 'get_embeddings'
+]
+
+
+def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1):
+ """
+ 给定一个word的vocabulary生成character的vocabulary.
+
+ :param vocab: 从vocab
+ :param min_freq:
+ :return:
+ """
+ char_vocab = Vocabulary(min_freq=min_freq)
+ for word, index in vocab:
+ if not vocab._is_word_no_create_entry(word):
+ char_vocab.add_word_lst(list(word))
+ return char_vocab
+
+
+def get_embeddings(init_embed):
+ """
+ 根据输入的init_embed返回Embedding对象。如果输入是tuple, 则随机初始化一个nn.Embedding; 如果输入是numpy.ndarray, 则按照ndarray
+ 的值将nn.Embedding初始化; 如果输入是torch.Tensor, 则按该值初始化nn.Embedding; 如果输入是fastNLP中的embedding将不做处理
+ 返回原对象。
+
+ :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入
+ nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化;
+ 传入torch.Tensor, 将使用传入的值作为Embedding初始化。
+ :return nn.Embedding: embeddings
+ """
+ if isinstance(init_embed, tuple):
+ res = nn.Embedding(
+ num_embeddings=init_embed[0], embedding_dim=init_embed[1])
+ nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)),
+ b=np.sqrt(3 / res.weight.data.size(1)))
+ elif isinstance(init_embed, nn.Module):
+ res = init_embed
+ elif isinstance(init_embed, torch.Tensor):
+ res = nn.Embedding.from_pretrained(init_embed, freeze=False)
+ elif isinstance(init_embed, np.ndarray):
+ init_embed = torch.tensor(init_embed, dtype=torch.float32)
+ res = nn.Embedding.from_pretrained(init_embed, freeze=False)
+ else:
+ raise TypeError(
+ 'invalid init_embed type: {}'.format((type(init_embed))))
+ return res
diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py
index d1d1dc5d..8ed1956a 100644
--- a/fastNLP/io/__init__.py
+++ b/fastNLP/io/__init__.py
@@ -3,40 +3,87 @@
1. 用于读入 embedding 的 :doc:`EmbedLoader ` 类,
-2. 用于读入数据的 :doc:`DataSetLoader ` 类
+2. 用于读入不同格式数据的 :doc:`Loader ` 类
-3. 用于保存和载入模型的类, 参考 :doc:`/fastNLP.io.model_io`
+3. 用于处理读入数据的 :doc:`Pipe ` 类
+
+4. 用于保存和载入模型的类, 参考 :doc:`model_io文档`
这些类的使用方法如下:
"""
__all__ = [
+ 'DataBundle',
+
'EmbedLoader',
+
+ 'Loader',
+
+ 'YelpLoader',
+ 'YelpFullLoader',
+ 'YelpPolarityLoader',
+ 'IMDBLoader',
+ 'SSTLoader',
+ 'SST2Loader',
- 'DataBundle',
- 'DataSetLoader',
+ 'ConllLoader',
+ 'Conll2003Loader',
+ 'Conll2003NERLoader',
+ 'OntoNotesNERLoader',
+ 'CTBLoader',
+ "MsraNERLoader",
+ "WeiboNERLoader",
+ "PeopleDailyNERLoader",
'CSVLoader',
'JsonLoader',
-
+
+ 'CWSLoader',
+
+ 'MNLILoader',
+ "QuoraLoader",
+ "SNLILoader",
+ "QNLILoader",
+ "RTELoader",
+
+ "Pipe",
+
+ "YelpFullPipe",
+ "YelpPolarityPipe",
+ "SSTPipe",
+ "SST2Pipe",
+ "IMDBPipe",
+ "Conll2003Pipe",
+
+ "Conll2003NERPipe",
+ "OntoNotesNERPipe",
+ "MsraNERPipe",
+ "PeopleDailyPipe",
+ "WeiboNERPipe",
+
+ "CWSPipe",
+
+ "MatchingBertPipe",
+ "RTEBertPipe",
+ "SNLIBertPipe",
+ "QuoraBertPipe",
+ "QNLIBertPipe",
+ "MNLIBertPipe",
+ "MatchingPipe",
+ "RTEPipe",
+ "SNLIPipe",
+ "QuoraPipe",
+ "QNLIPipe",
+ "MNLIPipe",
+
'ModelLoader',
'ModelSaver',
- 'ConllLoader',
- 'Conll2003Loader',
- 'MatchingLoader',
- 'PeopleDailyCorpusLoader',
- 'SNLILoader',
- 'SSTLoader',
- 'SST2Loader',
- 'MNLILoader',
- 'QNLILoader',
- 'QuoraLoader',
- 'RTELoader',
]
from .embed_loader import EmbedLoader
-from .base_loader import DataBundle, DataSetLoader
+from .data_bundle import DataBundle
from .dataset_loader import CSVLoader, JsonLoader
from .model_io import ModelLoader, ModelSaver
-from .data_loader import *
+from .loader import *
+from .pipe import *
diff --git a/fastNLP/io/base_loader.py b/fastNLP/io/base_loader.py
deleted file mode 100644
index 62793836..00000000
--- a/fastNLP/io/base_loader.py
+++ /dev/null
@@ -1,220 +0,0 @@
-__all__ = [
- "BaseLoader",
- 'DataBundle',
- 'DataSetLoader',
-]
-
-import _pickle as pickle
-import os
-from typing import Union, Dict
-import os
-from ..core.dataset import DataSet
-
-
-class BaseLoader(object):
- """
- 各个 Loader 的基类,提供了 API 的参考。
-
- """
-
- def __init__(self):
- super(BaseLoader, self).__init__()
-
- @staticmethod
- def load_lines(data_path):
- """
- 按行读取,舍弃每行两侧空白字符,返回list of str
-
- :param data_path: 读取数据的路径
- """
- with open(data_path, "r", encoding="utf=8") as f:
- text = f.readlines()
- return [line.strip() for line in text]
-
- @classmethod
- def load(cls, data_path):
- """
- 先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str
-
- :param data_path:
- """
- with open(data_path, "r", encoding="utf-8") as f:
- text = f.readlines()
- return [[word for word in sent.strip()] for sent in text]
-
- @classmethod
- def load_with_cache(cls, data_path, cache_path):
- """缓存版的load
- """
- if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path):
- with open(cache_path, 'rb') as f:
- return pickle.load(f)
- else:
- obj = cls.load(data_path)
- with open(cache_path, 'wb') as f:
- pickle.dump(obj, f)
- return obj
-
-
-def _download_from_url(url, path):
- try:
- from tqdm.auto import tqdm
- except:
- from ..core.utils import _pseudo_tqdm as tqdm
- import requests
-
- """Download file"""
- r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
- chunk_size = 16 * 1024
- total_size = int(r.headers.get('Content-length', 0))
- with open(path, "wb") as file, \
- tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
- for chunk in r.iter_content(chunk_size):
- if chunk:
- file.write(chunk)
- t.update(len(chunk))
-
-
-def _uncompress(src, dst):
- import zipfile
- import gzip
- import tarfile
- import os
-
- def unzip(src, dst):
- with zipfile.ZipFile(src, 'r') as f:
- f.extractall(dst)
-
- def ungz(src, dst):
- with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
- length = 16 * 1024 # 16KB
- buf = f.read(length)
- while buf:
- uf.write(buf)
- buf = f.read(length)
-
- def untar(src, dst):
- with tarfile.open(src, 'r:gz') as f:
- f.extractall(dst)
-
- fn, ext = os.path.splitext(src)
- _, ext_2 = os.path.splitext(fn)
- if ext == '.zip':
- unzip(src, dst)
- elif ext == '.gz' and ext_2 != '.tar':
- ungz(src, dst)
- elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
- untar(src, dst)
- else:
- raise ValueError('unsupported file {}'.format(src))
-
-
-class DataBundle:
- """
- 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)及它们所用的词表和词嵌入。
-
- :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
- :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
- """
-
- def __init__(self, vocabs: dict = None, datasets: dict = None):
- self.vocabs = vocabs or {}
- self.datasets = datasets or {}
-
- def __repr__(self):
- _str = 'In total {} datasets:\n'.format(len(self.datasets))
- for name, dataset in self.datasets.items():
- _str += '\t{} has {} instances.\n'.format(name, len(dataset))
- _str += 'In total {} vocabs:\n'.format(len(self.vocabs))
- for name, vocab in self.vocabs.items():
- _str += '\t{} has {} entries.\n'.format(name, len(vocab))
- return _str
-
-
-class DataSetLoader:
- """
- 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
-
- 定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
-
- 开发者至少应该编写如下内容:
-
- - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
- - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
- - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
-
- **process 函数中可以 调用load 函数或 _load 函数**
-
- """
- URL = ''
- DATA_DIR = ''
-
- ROOT_DIR = '.fastnlp/datasets/'
- UNCOMPRESS = True
-
- def _download(self, url: str, pdir: str, uncompress=True) -> str:
- """
-
- 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。
-
- :param url: 下载的网站
- :param pdir: 下载到的目录
- :param uncompress: 是否自动解压缩
- :return: 数据的存放路径
- """
- fn = os.path.basename(url)
- path = os.path.join(pdir, fn)
- """check data exists"""
- if not os.path.exists(path):
- os.makedirs(pdir, exist_ok=True)
- _download_from_url(url, path)
- if uncompress:
- dst = os.path.join(pdir, 'data')
- if not os.path.exists(dst):
- _uncompress(path, dst)
- return dst
- return path
-
- def download(self):
- return self._download(
- self.URL,
- os.path.join(self.ROOT_DIR, self.DATA_DIR),
- uncompress=self.UNCOMPRESS)
-
- def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
- """
- 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
- 如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。
-
- :param Union[str, Dict[str, str]] paths: 文件路径
- :return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
- """
- if isinstance(paths, str):
- return self._load(paths)
- return {name: self._load(path) for name, path in paths.items()}
-
- def _load(self, path: str) -> DataSet:
- """从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象
-
- :param str path: 文件路径
- :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
- """
- raise NotImplementedError
-
- def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle:
- """
- 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。
-
- 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
- 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。
-
- 返回的 :class:`DataBundle` 对象有如下属性:
-
- - vocabs: 由从数据集中获取的词表组成的字典,每个词表
- - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
-
- :param paths: 原始数据读取的路径
- :param options: 根据不同的任务和数据集,设计自己的参数
- :return: 返回一个 DataBundle
- """
- raise NotImplementedError
diff --git a/fastNLP/io/config_io.py b/fastNLP/io/config_io.py
deleted file mode 100644
index 4acdbb96..00000000
--- a/fastNLP/io/config_io.py
+++ /dev/null
@@ -1,311 +0,0 @@
-"""
-用于读入和处理和保存 config 文件
- .. todo::
- 这个模块中的类可能被抛弃?
-"""
-__all__ = [
- "ConfigLoader",
- "ConfigSection",
- "ConfigSaver"
-]
-
-import configparser
-import json
-import os
-
-from .base_loader import BaseLoader
-
-
-class ConfigLoader(BaseLoader):
- """
- 别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader`
-
- 读取配置文件的Loader
-
- :param str data_path: 配置文件的路径
-
- """
-
- def __init__(self, data_path=None):
- super(ConfigLoader, self).__init__()
- if data_path is not None:
- self.config = self.parse(super(ConfigLoader, self).load(data_path))
-
- @staticmethod
- def parse(string):
- raise NotImplementedError
-
- @staticmethod
- def load_config(file_path, sections):
- """
- 把配置文件的section 存入提供的 ``sections`` 中
-
- :param str file_path: 配置文件的路径
- :param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection`
-
- Example::
-
- test_args = ConfigSection()
- ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args})
-
- """
- assert isinstance(sections, dict)
- cfg = configparser.ConfigParser()
- if not os.path.exists(file_path):
- raise FileNotFoundError("config file {} not found. ".format(file_path))
- cfg.read(file_path)
- for s in sections:
- attr_list = [i for i in sections[s].__dict__.keys() if
- not callable(getattr(sections[s], i)) and not i.startswith("__")]
- if s not in cfg:
- print('section %s not found in config file' % (s))
- continue
- gen_sec = cfg[s]
- for attr in gen_sec.keys():
- try:
- val = json.loads(gen_sec[attr])
- # print(s, attr, val, type(val))
- if attr in attr_list:
- assert type(val) == type(getattr(sections[s], attr)), \
- 'type not match, except %s but got %s' % \
- (type(getattr(sections[s], attr)), type(val))
- """
- if attr in attr_list then check its type and
- update its value.
- else add a new attr in sections[s]
- """
- setattr(sections[s], attr, val)
- except Exception as e:
- print("cannot load attribute %s in section %s"
- % (attr, s))
- pass
-
-
-class ConfigSection(object):
- """
- 别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection`
-
- ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用
-
- """
-
- def __init__(self):
- super(ConfigSection, self).__init__()
-
- def __getitem__(self, key):
- """
- :param key: str, the name of the attribute
- :return attr: the value of this attribute
- if key not in self.__dict__.keys():
- return self[key]
- else:
- raise AttributeError
- """
- if key in self.__dict__.keys():
- return getattr(self, key)
- raise AttributeError("do NOT have attribute %s" % key)
-
- def __setitem__(self, key, value):
- """
- :param key: str, the name of the attribute
- :param value: the value of this attribute
- if key not in self.__dict__.keys():
- self[key] will be added
- else:
- self[key] will be updated
- """
- if key in self.__dict__.keys():
- if not isinstance(value, type(getattr(self, key))):
- raise AttributeError("attr %s except %s but got %s" %
- (key, str(type(getattr(self, key))), str(type(value))))
- setattr(self, key, value)
-
- def __contains__(self, item):
- """
- :param item: The key of item.
- :return: True if the key in self.__dict__.keys() else False.
- """
- return item in self.__dict__.keys()
-
- def __eq__(self, other):
- """Overwrite the == operator
-
- :param other: Another ConfigSection() object which to be compared.
- :return: True if value of each key in each ConfigSection() object are equal to the other, else False.
- """
- for k in self.__dict__.keys():
- if k not in other.__dict__.keys():
- return False
- if getattr(self, k) != getattr(self, k):
- return False
-
- for k in other.__dict__.keys():
- if k not in self.__dict__.keys():
- return False
- if getattr(self, k) != getattr(self, k):
- return False
-
- return True
-
- def __ne__(self, other):
- """Overwrite the != operator
-
- :param other:
- :return:
- """
- return not self.__eq__(other)
-
- @property
- def data(self):
- return self.__dict__
-
-
-class ConfigSaver(object):
- """
- 别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver`
-
- ConfigSaver 是用来存储配置文件并解决相关冲突的类
-
- :param str file_path: 配置文件的路径
-
- """
-
- def __init__(self, file_path):
- self.file_path = file_path
- if not os.path.exists(self.file_path):
- raise FileNotFoundError("file {} NOT found!".__format__(self.file_path))
-
- def _get_section(self, sect_name):
- """
- This is the function to get the section with the section name.
-
- :param sect_name: The name of section what wants to load.
- :return: The section.
- """
- sect = ConfigSection()
- ConfigLoader().load_config(self.file_path, {sect_name: sect})
- return sect
-
- def _read_section(self):
- """
- This is the function to read sections from the config file.
-
- :return: sect_list, sect_key_list
- sect_list: A list of ConfigSection().
- sect_key_list: A list of names in sect_list.
- """
- sect_name = None
-
- sect_list = {}
- sect_key_list = []
-
- single_section = {}
- single_section_key = []
-
- with open(self.file_path, 'r') as f:
- lines = f.readlines()
-
- for line in lines:
- if line.startswith('[') and line.endswith(']\n'):
- if sect_name is None:
- pass
- else:
- sect_list[sect_name] = single_section, single_section_key
- single_section = {}
- single_section_key = []
- sect_key_list.append(sect_name)
- sect_name = line[1: -2]
- continue
-
- if line.startswith('#'):
- single_section[line] = '#'
- single_section_key.append(line)
- continue
-
- if line.startswith('\n'):
- single_section_key.append('\n')
- continue
-
- if '=' not in line:
- raise RuntimeError("can NOT load config file {}".__format__(self.file_path))
-
- key = line.split('=', maxsplit=1)[0].strip()
- value = line.split('=', maxsplit=1)[1].strip() + '\n'
- single_section[key] = value
- single_section_key.append(key)
-
- if sect_name is not None:
- sect_list[sect_name] = single_section, single_section_key
- sect_key_list.append(sect_name)
- return sect_list, sect_key_list
-
- def _write_section(self, sect_list, sect_key_list):
- """
- This is the function to write config file with section list and name list.
-
- :param sect_list: A list of ConfigSection() need to be writen into file.
- :param sect_key_list: A list of name of sect_list.
- :return:
- """
- with open(self.file_path, 'w') as f:
- for sect_key in sect_key_list:
- single_section, single_section_key = sect_list[sect_key]
- f.write('[' + sect_key + ']\n')
- for key in single_section_key:
- if key == '\n':
- f.write('\n')
- continue
- if single_section[key] == '#':
- f.write(key)
- continue
- f.write(key + ' = ' + single_section[key])
- f.write('\n')
-
- def save_config_file(self, section_name, section):
- """
- 这个方法可以用来修改并保存配置文件中单独的一个 section
-
- :param str section_name: 需要保存的 section 的名字.
- :param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型
- """
- section_file = self._get_section(section_name)
- if len(section_file.__dict__.keys()) == 0: # the section not in the file before
- # append this section to config file
- with open(self.file_path, 'a') as f:
- f.write('[' + section_name + ']\n')
- for k in section.__dict__.keys():
- f.write(k + ' = ')
- if isinstance(section[k], str):
- f.write('\"' + str(section[k]) + '\"\n\n')
- else:
- f.write(str(section[k]) + '\n\n')
- else:
- # the section exists
- change_file = False
- for k in section.__dict__.keys():
- if k not in section_file:
- # find a new key in this section
- change_file = True
- break
- if section_file[k] != section[k]:
- change_file = True
- break
- if not change_file:
- return
-
- sect_list, sect_key_list = self._read_section()
- if section_name not in sect_key_list:
- raise AttributeError()
-
- sect, sect_key = sect_list[section_name]
- for k in section.__dict__.keys():
- if k not in sect_key:
- if sect_key[-1] != '\n':
- sect_key.append('\n')
- sect_key.append(k)
- sect[k] = str(section[k])
- if isinstance(section[k], str):
- sect[k] = "\"" + sect[k] + "\""
- sect[k] = sect[k] + "\n"
- sect_list[section_name] = sect, sect_key
- self._write_section(sect_list, sect_key_list)
diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py
new file mode 100644
index 00000000..db60a86f
--- /dev/null
+++ b/fastNLP/io/data_bundle.py
@@ -0,0 +1,401 @@
+"""
+.. todo::
+ doc
+"""
+__all__ = [
+ 'DataBundle',
+]
+
+import _pickle as pickle
+import os
+from typing import Union, Dict
+
+from ..core.dataset import DataSet
+from ..core.vocabulary import Vocabulary
+
+
+class BaseLoader(object):
+ """
+ 各个 Loader 的基类,提供了 API 的参考。
+
+ """
+
+ def __init__(self):
+ super(BaseLoader, self).__init__()
+
+ @staticmethod
+ def load_lines(data_path):
+ """
+ 按行读取,舍弃每行两侧空白字符,返回list of str
+
+ :param data_path: 读取数据的路径
+ """
+ with open(data_path, "r", encoding="utf=8") as f:
+ text = f.readlines()
+ return [line.strip() for line in text]
+
+ @classmethod
+ def load(cls, data_path):
+ """
+ 先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str
+
+ :param data_path:
+ """
+ with open(data_path, "r", encoding="utf-8") as f:
+ text = f.readlines()
+ return [[word for word in sent.strip()] for sent in text]
+
+ @classmethod
+ def load_with_cache(cls, data_path, cache_path):
+ """缓存版的load
+ """
+ if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path):
+ with open(cache_path, 'rb') as f:
+ return pickle.load(f)
+ else:
+ obj = cls.load(data_path)
+ with open(cache_path, 'wb') as f:
+ pickle.dump(obj, f)
+ return obj
+
+
+def _download_from_url(url, path):
+ try:
+ from tqdm.auto import tqdm
+ except:
+ from ..core.utils import _pseudo_tqdm as tqdm
+ import requests
+
+ """Download file"""
+ r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True)
+ chunk_size = 16 * 1024
+ total_size = int(r.headers.get('Content-length', 0))
+ with open(path, "wb") as file, \
+ tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t:
+ for chunk in r.iter_content(chunk_size):
+ if chunk:
+ file.write(chunk)
+ t.update(len(chunk))
+
+
+def _uncompress(src, dst):
+ import zipfile
+ import gzip
+ import tarfile
+ import os
+
+ def unzip(src, dst):
+ with zipfile.ZipFile(src, 'r') as f:
+ f.extractall(dst)
+
+ def ungz(src, dst):
+ with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf:
+ length = 16 * 1024 # 16KB
+ buf = f.read(length)
+ while buf:
+ uf.write(buf)
+ buf = f.read(length)
+
+ def untar(src, dst):
+ with tarfile.open(src, 'r:gz') as f:
+ f.extractall(dst)
+
+ fn, ext = os.path.splitext(src)
+ _, ext_2 = os.path.splitext(fn)
+ if ext == '.zip':
+ unzip(src, dst)
+ elif ext == '.gz' and ext_2 != '.tar':
+ ungz(src, dst)
+ elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz':
+ untar(src, dst)
+ else:
+ raise ValueError('unsupported file {}'.format(src))
+
+
+class DataBundle:
+ """
+ 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种
+ Loader的load函数生成,可以通过以下的方法获取里面的内容
+
+ Example::
+
+ data_bundle = YelpLoader().load({'train':'/path/to/train', 'dev': '/path/to/dev'})
+ train_vocabs = data_bundle.vocabs['train']
+ train_data = data_bundle.datasets['train']
+ dev_data = data_bundle.datasets['train']
+
+ :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict
+ :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict
+ """
+
+ def __init__(self, vocabs: dict = None, datasets: dict = None):
+ self.vocabs = vocabs or {}
+ self.datasets = datasets or {}
+
+ def set_vocab(self, vocab, field_name):
+ """
+ 向DataBunlde中增加vocab
+
+ :param ~fastNLP.Vocabulary vocab: 词表
+ :param str field_name: 这个vocab对应的field名称
+ :return: self
+ """
+ assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports."
+ self.vocabs[field_name] = vocab
+ return self
+
+ def set_dataset(self, dataset, name):
+ """
+
+ :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet
+ :param str name: dataset的名称
+ :return: self
+ """
+ self.datasets[name] = dataset
+ return self
+
+ def get_dataset(self, name:str)->DataSet:
+ """
+ 获取名为name的dataset
+
+ :param str name: dataset的名称,一般为'train', 'dev', 'test'
+ :return: DataSet
+ """
+ return self.datasets[name]
+
+ def delete_dataset(self, name:str):
+ """
+ 删除名为name的DataSet
+
+ :param str name:
+ :return: self
+ """
+ self.datasets.pop(name, None)
+ return self
+
+ def get_vocab(self, field_name:str)->Vocabulary:
+ """
+ 获取field名为field_name对应的vocab
+
+ :param str field_name: 名称
+ :return: Vocabulary
+ """
+ return self.vocabs[field_name]
+
+ def delete_vocab(self, field_name:str):
+ """
+ 删除vocab
+ :param str field_name:
+ :return: self
+ """
+ self.vocabs.pop(field_name, None)
+ return self
+
+ def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True):
+ """
+ 将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作::
+
+ data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True
+ data_bundle.set_input('words', flag=False) # 将words这个field的input属性设置为False
+
+ :param str field_names: field的名称
+ :param bool flag: 将field_name的input状态设置为flag
+ :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
+ 行的数据进行类型和维度推断本列的数据的类型和维度。
+ :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
+ 如果为False,则报错
+ :return self
+ """
+ for field_name in field_names:
+ for name, dataset in self.datasets.items():
+ if not ignore_miss_dataset and not dataset.has_field(field_name):
+ raise KeyError(f"Field:{field_name} was not found in DataSet:{name}")
+ if not dataset.has_field(field_name):
+ continue
+ else:
+ dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
+ return self
+
+ def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True):
+ """
+ 将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作::
+
+ data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True
+ data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False
+
+ :param str field_names: field的名称
+ :param bool flag: 将field_name的target状态设置为flag
+ :param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一
+ 行的数据进行类型和维度推断本列的数据的类型和维度。
+ :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
+ 如果为False,则报错
+ :return self
+ """
+ for field_name in field_names:
+ for name, dataset in self.datasets.items():
+ if not ignore_miss_dataset and not dataset.has_field(field_name):
+ raise KeyError(f"Field:{field_name} was not found in DataSet:{name}")
+ if not dataset.has_field(field_name):
+ continue
+ else:
+ dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type)
+ return self
+
+ def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True):
+ """
+ 将DataBundle中所有的field_name复制一份叫new_field_name.
+
+ :param str field_name:
+ :param str new_field_name:
+ :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
+ 如果为False,则报错
+ :return: self
+ """
+ for name, dataset in self.datasets.items():
+ if dataset.has_field(field_name=field_name):
+ dataset.copy_field(field_name=field_name, new_field_name=new_field_name)
+ elif not ignore_miss_dataset:
+ raise KeyError(f"{field_name} not found DataSet:{name}.")
+ return self
+
+ def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs):
+ """
+ 对DataBundle中所有的dataset使用apply方法
+
+ :param callable func: input是instance中名为 `field_name` 的field的内容。
+ :param str field_name: 传入func的是哪个field。
+ :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
+ 盖之前的field。如果为None则不创建新的field。
+ :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
+ 如果为False,则报错
+ :param optional kwargs: 支持输入is_input,is_target,ignore_type
+
+ 1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input
+
+ 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
+
+ 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
+ """
+ for name, dataset in self.datasets.items():
+ if dataset.has_field(field_name=field_name):
+ dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs)
+ elif not ignore_miss_dataset:
+ raise KeyError(f"{field_name} not found DataSet:{name}.")
+ return self
+
+ def apply(self, func, new_field_name:str, **kwargs):
+ """
+ 对DataBundle中所有的dataset使用apply方法
+
+ :param callable func: input是instance中名为 `field_name` 的field的内容。
+ :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
+ 盖之前的field。如果为None则不创建新的field。
+ :param optional kwargs: 支持输入is_input,is_target,ignore_type
+
+ 1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input
+
+ 2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target
+
+ 3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型
+ """
+ for name, dataset in self.datasets.items():
+ dataset.apply(func, new_field_name=new_field_name, **kwargs)
+ return self
+
+ def __repr__(self):
+ _str = 'In total {} datasets:\n'.format(len(self.datasets))
+ for name, dataset in self.datasets.items():
+ _str += '\t{} has {} instances.\n'.format(name, len(dataset))
+ _str += 'In total {} vocabs:\n'.format(len(self.vocabs))
+ for name, vocab in self.vocabs.items():
+ _str += '\t{} has {} entries.\n'.format(name, len(vocab))
+ return _str
+
+
+class DataSetLoader:
+ """
+ 别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader`
+
+ 定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。
+
+ 开发者至少应该编写如下内容:
+
+ - _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet`
+ - load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet`
+ - process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet`
+
+ **process 函数中可以 调用load 函数或 _load 函数**
+
+ """
+ URL = ''
+ DATA_DIR = ''
+
+ ROOT_DIR = '.fastnlp/datasets/'
+ UNCOMPRESS = True
+
+ def _download(self, url: str, pdir: str, uncompress=True) -> str:
+ """
+
+ 从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。
+
+ :param url: 下载的网站
+ :param pdir: 下载到的目录
+ :param uncompress: 是否自动解压缩
+ :return: 数据的存放路径
+ """
+ fn = os.path.basename(url)
+ path = os.path.join(pdir, fn)
+ """check data exists"""
+ if not os.path.exists(path):
+ os.makedirs(pdir, exist_ok=True)
+ _download_from_url(url, path)
+ if uncompress:
+ dst = os.path.join(pdir, 'data')
+ if not os.path.exists(dst):
+ _uncompress(path, dst)
+ return dst
+ return path
+
+ def download(self):
+ return self._download(
+ self.URL,
+ os.path.join(self.ROOT_DIR, self.DATA_DIR),
+ uncompress=self.UNCOMPRESS)
+
+ def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]:
+ """
+ 从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。
+ 如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。
+
+ :param Union[str, Dict[str, str]] paths: 文件路径
+ :return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典
+ """
+ if isinstance(paths, str):
+ return self._load(paths)
+ return {name: self._load(path) for name, path in paths.items()}
+
+ def _load(self, path: str) -> DataSet:
+ """从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象
+
+ :param str path: 文件路径
+ :return: 一个 :class:`~fastNLP.DataSet` 类型的对象
+ """
+ raise NotImplementedError
+
+ def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle:
+ """
+ 对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。
+
+ 从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。
+ 如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。
+
+ 返回的 :class:`DataBundle` 对象有如下属性:
+
+ - vocabs: 由从数据集中获取的词表组成的字典,每个词表
+ - datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const`
+
+ :param paths: 原始数据读取的路径
+ :param options: 根据不同的任务和数据集,设计自己的参数
+ :return: 返回一个 DataBundle
+ """
+ raise NotImplementedError
diff --git a/fastNLP/io/data_loader/__init__.py b/fastNLP/io/data_loader/__init__.py
index d4777ff8..8a9dd60b 100644
--- a/fastNLP/io/data_loader/__init__.py
+++ b/fastNLP/io/data_loader/__init__.py
@@ -1,13 +1,18 @@
-"""
-用于读数据集的模块, 具体包括:
+"""undocumented
+.. warning::
+
+ 本模块在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。
+
+用于读数据集的模块, 可以读取文本分类、序列标注、Matching任务的数据集
-这些模块的使用方法如下:
+这些模块的具体介绍如下,您可以通过阅读 :doc:`教程` 来进行了解。
"""
__all__ = [
'ConllLoader',
'Conll2003Loader',
'IMDBLoader',
'MatchingLoader',
+ 'SNLILoader',
'MNLILoader',
'MTL16Loader',
'PeopleDailyCorpusLoader',
@@ -16,7 +21,6 @@ __all__ = [
'RTELoader',
'SSTLoader',
'SST2Loader',
- 'SNLILoader',
'YelpLoader',
]
diff --git a/fastNLP/io/data_loader/conll.py b/fastNLP/io/data_loader/conll.py
index 61f4f61b..31a90881 100644
--- a/fastNLP/io/data_loader/conll.py
+++ b/fastNLP/io/data_loader/conll.py
@@ -1,40 +1,49 @@
from ...core.dataset import DataSet
from ...core.instance import Instance
-from ..base_loader import DataSetLoader
+from ..data_bundle import DataSetLoader
from ..file_reader import _read_conll
-
+from typing import Union, Dict
+from ..utils import check_loader_paths
+from ..data_bundle import DataBundle
class ConllLoader(DataSetLoader):
"""
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader`
- 读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为
- 该符号在conll 2003中被用为文档分割符。
-
- 列号从0开始, 每列对应内容为::
-
- Column Type
- 0 Document ID
- 1 Part number
- 2 Word number
- 3 Word itself
- 4 Part-of-Speech
- 5 Parse bit
- 6 Predicate lemma
- 7 Predicate Frameset ID
- 8 Word sense
- 9 Speaker/Author
- 10 Named Entities
- 11:N Predicate Arguments
- N Coreference
-
- :param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
- :param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
- :param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False``
+ 该ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示:
+
+ Example::
+
+ # 文件中的内容
+ Nadim NNP B-NP B-PER
+ Ladki NNP I-NP I-PER
+
+ AL-AIN NNP B-NP B-LOC
+ United NNP B-NP B-LOC
+ Arab NNP I-NP I-LOC
+ Emirates NNPS I-NP I-LOC
+ 1996-12-06 CD I-NP O
+ ...
+
+ # 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列
+ dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')
+ # 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列
+ dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll')
+ # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field
+ dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll')
+
+ dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')中DataSet的raw_words
+ 列与pos列的内容都是List[str]
+
+ 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。
+
+ :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
+ :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
+ :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
"""
- def __init__(self, headers, indexes=None, dropna=False):
+ def __init__(self, headers, indexes=None, dropna=True):
super(ConllLoader, self).__init__()
if not isinstance(headers, (list, tuple)):
raise TypeError(
@@ -49,25 +58,52 @@ class ConllLoader(DataSetLoader):
self.indexes = indexes
def _load(self, path):
+ """
+ 传入的一个文件路径,将该文件读入DataSet中,field由Loader初始化时指定的headers决定。
+
+ :param str path: 文件的路径
+ :return: DataSet
+ """
ds = DataSet()
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
ins = {h: data[i] for i, h in enumerate(self.headers)}
ds.append(Instance(**ins))
return ds
+ def load(self, paths: Union[str, Dict[str, str]]) -> DataBundle:
+ """
+ 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。
+
+ 读取的field根据ConllLoader初始化时传入的headers决定。
+
+ :param Union[str, Dict[str, str]] paths:
+ :return: :class:`~fastNLP.DataSet` 类的对象或 :class:`~fastNLP.io.DataBundle` 的字典
+ """
+ paths = check_loader_paths(paths)
+ datasets = {name: self._load(path) for name, path in paths.items()}
+ data_bundle = DataBundle(datasets=datasets)
+ return data_bundle
+
class Conll2003Loader(ConllLoader):
"""
- 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.dataset_loader.Conll2003Loader`
+ 别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader`
+
+ 该Loader用以读取Conll2003数据,conll2003的数据可以在https://github.com/davidsbatista/NER-datasets/tree/master/CONLL2003
+ 找到。数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。
+
+ 返回的DataSet将具有以下['raw_words', 'pos', 'chunks', 'ner']四个field, 每个field中的内容都是List[str]。
+
+ .. csv-table:: Conll2003Loader处理之 :header: "raw_words", "words", "target", "seq_len"
- 读取Conll2003数据
+ "[Nadim, Ladki]", "[1, 2]", "[1, 2]", 2
+ "[AL-AIN, United, Arab, ...]", "[3, 4, 5,...]", "[3, 4]", 5
+ "[...]", "[...]", "[...]", .
- 关于数据集的更多信息,参考:
- https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
"""
def __init__(self):
headers = [
- 'tokens', 'pos', 'chunks', 'ner',
+ 'raw_words', 'pos', 'chunks', 'ner',
]
super(Conll2003Loader, self).__init__(headers=headers)
diff --git a/fastNLP/io/data_loader/imdb.py b/fastNLP/io/data_loader/imdb.py
index bf53c5be..c9dda76e 100644
--- a/fastNLP/io/data_loader/imdb.py
+++ b/fastNLP/io/data_loader/imdb.py
@@ -2,7 +2,7 @@
from typing import Union, Dict
from ..embed_loader import EmbeddingOption, EmbedLoader
-from ..base_loader import DataSetLoader, DataBundle
+from ..data_bundle import DataSetLoader, DataBundle
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.instance import Instance
@@ -13,9 +13,12 @@ from ..utils import get_tokenizer
class IMDBLoader(DataSetLoader):
"""
+ 别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.data_loader.IMDBLoader`
+
读取IMDB数据集,DataSet包含以下fields:
words: list(str), 需要分类的文本
+
target: str, 文本的标签
"""
diff --git a/fastNLP/io/data_loader/matching.py b/fastNLP/io/data_loader/matching.py
index cecaee96..41c9a98d 100644
--- a/fastNLP/io/data_loader/matching.py
+++ b/fastNLP/io/data_loader/matching.py
@@ -4,9 +4,9 @@ from typing import Union, Dict, List
from ...core.const import Const
from ...core.vocabulary import Vocabulary
-from ..base_loader import DataBundle, DataSetLoader
+from ..data_bundle import DataBundle, DataSetLoader
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
-from ...modules.encoder._bert import BertTokenizer
+from ...modules.encoder.bert import BertTokenizer
class MatchingLoader(DataSetLoader):
@@ -121,7 +121,7 @@ class MatchingLoader(DataSetLoader):
PRETRAIN_URL = _get_base_url('bert')
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
model_url = PRETRAIN_URL + model_name
- model_dir = cached_path(model_url)
+ model_dir = cached_path(model_url, name='embedding')
# 检查是否存在
elif os.path.isdir(bert_tokenizer):
model_dir = bert_tokenizer
diff --git a/fastNLP/io/data_loader/mnli.py b/fastNLP/io/data_loader/mnli.py
index 5d857533..65863f3d 100644
--- a/fastNLP/io/data_loader/mnli.py
+++ b/fastNLP/io/data_loader/mnli.py
@@ -12,7 +12,9 @@ class MNLILoader(MatchingLoader, CSVLoader):
读取MNLI数据集,读取的DataSet包含fields::
words1: list(str),第一句文本, premise
+
words2: list(str), 第二句文本, hypothesis
+
target: str, 真实标签
数据来源:
diff --git a/fastNLP/io/data_loader/mtl.py b/fastNLP/io/data_loader/mtl.py
index 940ece51..923aadfb 100644
--- a/fastNLP/io/data_loader/mtl.py
+++ b/fastNLP/io/data_loader/mtl.py
@@ -1,18 +1,21 @@
from typing import Union, Dict
-from ..base_loader import DataBundle
+from ..data_bundle import DataBundle
from ..dataset_loader import CSVLoader
from ...core.vocabulary import Vocabulary, VocabularyOption
from ...core.const import Const
-from ..utils import check_dataloader_paths
+from ..utils import check_loader_paths
class MTL16Loader(CSVLoader):
"""
+ 别名::class:`fastNLP.io.MTL16Loader` :class:`fastNLP.io.data_loader.MTL16Loader`
+
读取MTL16数据集,DataSet包含以下fields:
words: list(str), 需要分类的文本
+
target: str, 文本的标签
数据来源:https://pan.baidu.com/s/1c2L6vdA
@@ -35,7 +38,7 @@ class MTL16Loader(CSVLoader):
src_vocab_opt: VocabularyOption = None,
tgt_vocab_opt: VocabularyOption = None,):
- paths = check_dataloader_paths(paths)
+ paths = check_loader_paths(paths)
datasets = {}
info = DataBundle()
for name, path in paths.items():
diff --git a/fastNLP/io/data_loader/people_daily.py b/fastNLP/io/data_loader/people_daily.py
index d8c55aef..afd66744 100644
--- a/fastNLP/io/data_loader/people_daily.py
+++ b/fastNLP/io/data_loader/people_daily.py
@@ -1,5 +1,5 @@
-from ..base_loader import DataSetLoader
+from ..data_bundle import DataSetLoader
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.const import Const
@@ -7,7 +7,7 @@ from ...core.const import Const
class PeopleDailyCorpusLoader(DataSetLoader):
"""
- 别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.dataset_loader.PeopleDailyCorpusLoader`
+ 别名::class:`fastNLP.io.PeopleDailyCorpusLoader` :class:`fastNLP.io.data_loader.PeopleDailyCorpusLoader`
读取人民日报数据集
"""
diff --git a/fastNLP/io/data_loader/qnli.py b/fastNLP/io/data_loader/qnli.py
index ff6302b2..84b0f3d6 100644
--- a/fastNLP/io/data_loader/qnli.py
+++ b/fastNLP/io/data_loader/qnli.py
@@ -12,7 +12,9 @@ class QNLILoader(MatchingLoader, CSVLoader):
读取QNLI数据集,读取的DataSet包含fields::
words1: list(str),第一句文本, premise
+
words2: list(str), 第二句文本, hypothesis
+
target: str, 真实标签
数据来源:
diff --git a/fastNLP/io/data_loader/quora.py b/fastNLP/io/data_loader/quora.py
index 12cc42ce..d0ee41ec 100644
--- a/fastNLP/io/data_loader/quora.py
+++ b/fastNLP/io/data_loader/quora.py
@@ -12,7 +12,9 @@ class QuoraLoader(MatchingLoader, CSVLoader):
读取MNLI数据集,读取的DataSet包含fields::
words1: list(str),第一句文本, premise
+
words2: list(str), 第二句文本, hypothesis
+
target: str, 真实标签
数据来源:
diff --git a/fastNLP/io/data_loader/rte.py b/fastNLP/io/data_loader/rte.py
index c6c64ef8..f8c5e2fc 100644
--- a/fastNLP/io/data_loader/rte.py
+++ b/fastNLP/io/data_loader/rte.py
@@ -12,7 +12,9 @@ class RTELoader(MatchingLoader, CSVLoader):
读取RTE数据集,读取的DataSet包含fields::
words1: list(str),第一句文本, premise
+
words2: list(str), 第二句文本, hypothesis
+
target: str, 真实标签
数据来源:
diff --git a/fastNLP/io/data_loader/snli.py b/fastNLP/io/data_loader/snli.py
index 8334fcfd..1db0ac5b 100644
--- a/fastNLP/io/data_loader/snli.py
+++ b/fastNLP/io/data_loader/snli.py
@@ -12,7 +12,9 @@ class SNLILoader(MatchingLoader, JsonLoader):
读取SNLI数据集,读取的DataSet包含fields::
words1: list(str),第一句文本, premise
+
words2: list(str), 第二句文本, hypothesis
+
target: str, 真实标签
数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
diff --git a/fastNLP/io/data_loader/sst.py b/fastNLP/io/data_loader/sst.py
index df46b47f..2034fc2b 100644
--- a/fastNLP/io/data_loader/sst.py
+++ b/fastNLP/io/data_loader/sst.py
@@ -2,13 +2,13 @@
from typing import Union, Dict
from nltk import Tree
-from ..base_loader import DataBundle, DataSetLoader
+from ..data_bundle import DataBundle, DataSetLoader
from ..dataset_loader import CSVLoader
from ...core.vocabulary import VocabularyOption, Vocabulary
from ...core.dataset import DataSet
from ...core.const import Const
from ...core.instance import Instance
-from ..utils import check_dataloader_paths, get_tokenizer
+from ..utils import check_loader_paths, get_tokenizer
class SSTLoader(DataSetLoader):
@@ -67,7 +67,7 @@ class SSTLoader(DataSetLoader):
paths, train_subtree=True,
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,):
- paths = check_dataloader_paths(paths)
+ paths = check_loader_paths(paths)
input_name, target_name = 'words', 'target'
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
@@ -104,7 +104,9 @@ class SSTLoader(DataSetLoader):
class SST2Loader(CSVLoader):
"""
- 数据来源"SST":'https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
+ 别名::class:`fastNLP.io.SST2Loader` :class:`fastNLP.io.data_loader.SST2Loader`
+
+ 数据来源 SST: https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8
"""
def __init__(self):
@@ -127,11 +129,12 @@ class SST2Loader(CSVLoader):
tgt_vocab_opt: VocabularyOption = None,
char_level_op=False):
- paths = check_dataloader_paths(paths)
+ paths = check_loader_paths(paths)
datasets = {}
info = DataBundle()
for name, path in paths.items():
dataset = self.load(path)
+ dataset.apply_field(lambda words:words.copy(), field_name='words', new_field_name='raw_words')
datasets[name] = dataset
def wordtochar(words):
@@ -152,7 +155,9 @@ class SST2Loader(CSVLoader):
for dataset in datasets.values():
dataset.apply_field(wordtochar, field_name=Const.INPUT, new_field_name=Const.CHAR_INPUT)
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt)
- src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT)
+ src_vocab.from_dataset(datasets['train'], field_name=Const.INPUT, no_create_entry_dataset=[
+ dataset for name, dataset in datasets.items() if name!='train'
+ ])
src_vocab.index_dataset(*datasets.values(), field_name=Const.INPUT)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
diff --git a/fastNLP/io/data_loader/yelp.py b/fastNLP/io/data_loader/yelp.py
index c287a90c..f2bc60c8 100644
--- a/fastNLP/io/data_loader/yelp.py
+++ b/fastNLP/io/data_loader/yelp.py
@@ -6,19 +6,24 @@ from ...core.const import Const
from ...core.dataset import DataSet
from ...core.instance import Instance
from ...core.vocabulary import VocabularyOption, Vocabulary
-from ..base_loader import DataBundle, DataSetLoader
+from ..data_bundle import DataBundle, DataSetLoader
from typing import Union, Dict
-from ..utils import check_dataloader_paths, get_tokenizer
+from ..utils import check_loader_paths, get_tokenizer
class YelpLoader(DataSetLoader):
"""
+ 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.data_loader.YelpLoader`
读取Yelp_full/Yelp_polarity数据集, DataSet包含fields:
+
words: list(str), 需要分类的文本
+
target: str, 文本的标签
+
chars:list(str),未index的字符列表
数据集:yelp_full/yelp_polarity
+
:param fine_grained: 是否使用SST-5标准,若 ``False`` , 使用SST-2。Default: ``False``
:param lower: 是否需要自动转小写,默认为False。
"""
@@ -57,7 +62,7 @@ class YelpLoader(DataSetLoader):
src_vocab_op: VocabularyOption = None,
tgt_vocab_op: VocabularyOption = None,
char_level_op=False):
- paths = check_dataloader_paths(paths)
+ paths = check_loader_paths(paths)
info = DataBundle(datasets=self.load(paths))
src_vocab = Vocabulary() if src_vocab_op is None else Vocabulary(**src_vocab_op)
tgt_vocab = Vocabulary(unknown=None, padding=None) \
diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py
index ad6bbdc1..fca0de69 100644
--- a/fastNLP/io/dataset_loader.py
+++ b/fastNLP/io/dataset_loader.py
@@ -1,4 +1,8 @@
-"""
+"""undocumented
+.. warning::
+
+ 本模块将在 `0.5.0版本` 中被废弃,由 :mod:`~fastNLP.io.loader` 和 :mod:`~fastNLP.io.pipe` 模块替代。
+
dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的数据, 并返回 `DataSet` ,
得到的 :class:`~fastNLP.DataSet` 对象可以直接传入 :class:`~fastNLP.Trainer` 和 :class:`~fastNLP.Tester`, 用于模型的训练和测试。
以SNLI数据集为例::
@@ -11,6 +15,7 @@ dataset_loader模块实现了许多 DataSetLoader, 用于读取不同格式的
# ... do stuff
为 fastNLP 提供 DataSetLoader 的开发者请参考 :class:`~fastNLP.io.DataSetLoader` 的介绍。
+
"""
__all__ = [
'CSVLoader',
@@ -18,10 +23,10 @@ __all__ = [
]
+from .data_bundle import DataSetLoader
+from .file_reader import _read_csv, _read_json
from ..core.dataset import DataSet
from ..core.instance import Instance
-from .file_reader import _read_csv, _read_json
-from .base_loader import DataSetLoader
class JsonLoader(DataSetLoader):
@@ -114,25 +119,3 @@ def _cut_long_sentence(sent, max_sample_length=200):
else:
cutted_sentence.append(sent)
return cutted_sentence
-
-
-def _add_seg_tag(data):
- """
-
- :param data: list of ([word], [pos], [heads], [head_tags])
- :return: list of ([word], [pos])
- """
-
- _processed = []
- for word_list, pos_list, _, _ in data:
- new_sample = []
- for word, pos in zip(word_list, pos_list):
- if len(word) == 1:
- new_sample.append((word, 'S-' + pos))
- else:
- new_sample.append((word[0], 'B-' + pos))
- for c in word[1:-1]:
- new_sample.append((c, 'M-' + pos))
- new_sample.append((word[-1], 'E-' + pos))
- _processed.append(list(map(list, zip(*new_sample))))
- return _processed
diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py
index 91a0919c..780d91e4 100644
--- a/fastNLP/io/embed_loader.py
+++ b/fastNLP/io/embed_loader.py
@@ -1,16 +1,21 @@
+"""
+.. todo::
+ doc
+"""
__all__ = [
"EmbedLoader",
"EmbeddingOption",
]
+import logging
import os
import warnings
import numpy as np
-from ..core.vocabulary import Vocabulary
-from .base_loader import BaseLoader
+from .data_bundle import BaseLoader
from ..core.utils import Option
+from ..core.vocabulary import Vocabulary
class EmbeddingOption(Option):
@@ -91,10 +96,10 @@ class EmbedLoader(BaseLoader):
if error == 'ignore':
warnings.warn("Error occurred at the {} line.".format(idx))
else:
- print("Error occurred at the {} line.".format(idx))
+ logging.error("Error occurred at the {} line.".format(idx))
raise e
total_hits = sum(hit_flags)
- print("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
+ logging.info("Found {} out of {} words in the pre-training embedding.".format(total_hits, len(vocab)))
if init_method is None:
found_vectors = matrix[hit_flags]
if len(found_vectors) != 0:
@@ -157,7 +162,7 @@ class EmbedLoader(BaseLoader):
warnings.warn("Error occurred at the {} line.".format(idx))
pass
else:
- print("Error occurred at the {} line.".format(idx))
+ logging.error("Error occurred at the {} line.".format(idx))
raise e
if dim == -1:
raise RuntimeError("{} is an empty file.".format(embed_filepath))
diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py
index 0ae0a319..7a953098 100644
--- a/fastNLP/io/file_reader.py
+++ b/fastNLP/io/file_reader.py
@@ -1,8 +1,13 @@
-"""
+"""undocumented
此模块用于给其它模块提供读取文件的函数,没有为用户提供 API
"""
+
+__all__ = []
+
import json
+from ..core import logger
+
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
"""
@@ -23,8 +28,8 @@ def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True):
headers = headers.split(sep)
start_idx += 1
elif not isinstance(headers, (list, tuple)):
- raise TypeError("headers should be list or tuple, not {}." \
- .format(type(headers)))
+ raise TypeError("headers should be list or tuple, not {}." \
+ .format(type(headers)))
for line_idx, line in enumerate(f, start_idx):
contents = line.rstrip('\r\n').split(sep)
if len(contents) != len(headers):
@@ -81,6 +86,7 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
:if False, raise ValueError when reading invalid data. default: True
:return: generator, every time yield (line number, conll item)
"""
+
def parse_conll(sample):
sample = list(map(list, zip(*sample)))
sample = [sample[i] for i in indexes]
@@ -88,14 +94,15 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
if len(f) <= 0:
raise ValueError('empty field')
return sample
+
with open(path, 'r', encoding=encoding) as f:
sample = []
start = next(f).strip()
- if '-DOCSTART-' not in start and start!='':
+ if start != '':
sample.append(start.split())
for line_idx, line in enumerate(f, 1):
line = line.strip()
- if line=='':
+ if line == '':
if len(sample):
try:
res = parse_conll(sample)
@@ -103,13 +110,13 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
yield line_idx, res
except Exception as e:
if dropna:
+ logger.warn('Invalid instance which ends at line: {} has been dropped.'.format(line_idx))
continue
- raise ValueError('invalid instance ends at line: {}'.format(line_idx))
+ raise ValueError('Invalid instance which ends at line: {}'.format(line_idx))
elif line.startswith('#'):
continue
else:
- if not line.startswith('-DOCSTART-'):
- sample.append(line.split())
+ sample.append(line.split())
if len(sample) > 0:
try:
res = parse_conll(sample)
@@ -117,5 +124,5 @@ def _read_conll(path, encoding='utf-8', indexes=None, dropna=True):
except Exception as e:
if dropna:
return
- print('invalid instance ends at line: {}'.format(line_idx))
+ logger.error('invalid instance ends at line: {}'.format(line_idx))
raise e
diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py
index cb762eb7..8ecdff25 100644
--- a/fastNLP/io/file_utils.py
+++ b/fastNLP/io/file_utils.py
@@ -1,71 +1,154 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "cached_path",
+ "get_filepath",
+ "get_cache_path",
+ "split_filename_suffix",
+ "get_from_cache",
+]
import os
+import re
+import shutil
+import tempfile
from pathlib import Path
from urllib.parse import urlparse
-import re
+
import requests
-import tempfile
+from requests import HTTPError
from tqdm import tqdm
-import shutil
-import hashlib
+from ..core import logger
PRETRAINED_BERT_MODEL_DIR = {
- 'en': 'bert-base-cased-f89bfe08.zip',
- 'en-base-uncased': 'bert-base-uncased-3413b23c.zip',
- 'en-base-cased': 'bert-base-cased-f89bfe08.zip',
- 'en-large-uncased': 'bert-large-uncased-20939f45.zip',
- 'en-large-cased': 'bert-large-cased-e0cf90fc.zip',
-
- 'en-large-cased-wwm': 'bert-large-cased-wwm-a457f118.zip',
- 'en-large-uncased-wwm': 'bert-large-uncased-wwm-92a50aeb.zip',
- 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc-c7099855.zip',
-
- 'cn': 'bert-base-chinese-29d0a84a.zip',
- 'cn-base': 'bert-base-chinese-29d0a84a.zip',
-
- 'multilingual': 'bert-base-multilingual-cased-1bd364ee.zip',
- 'multilingual-base-uncased': 'bert-base-multilingual-uncased-f8730fe4.zip',
- 'multilingual-base-cased': 'bert-base-multilingual-cased-1bd364ee.zip',
+ 'en': 'bert-base-cased.zip',
+ 'en-large-cased-wwm': 'bert-large-cased-wwm.zip',
+ 'en-large-uncased-wwm': 'bert-large-uncased-wwm.zip',
+
+ 'en-large-uncased': 'bert-large-uncased.zip',
+ 'en-large-cased': 'bert-large-cased.zip',
+
+ 'en-base-uncased': 'bert-base-uncased.zip',
+ 'en-base-cased': 'bert-base-cased.zip',
+
+ 'en-base-cased-mrpc': 'bert-base-cased-finetuned-mrpc.zip',
+
+ 'multi-base-cased': 'bert-base-multilingual-cased.zip',
+ 'multi-base-uncased': 'bert-base-multilingual-uncased.zip',
+
+ 'cn': 'bert-chinese-wwm.zip',
+ 'cn-base': 'bert-base-chinese.zip',
+ 'cn-wwm': 'bert-chinese-wwm.zip',
+ 'cn-wwm-ext': "bert-chinese-wwm-ext.zip"
}
PRETRAINED_ELMO_MODEL_DIR = {
- 'en': 'elmo_en-d39843fe.tar.gz',
- 'cn': 'elmo_cn-5e9b34e2.tar.gz'
+ 'en': 'elmo_en_Medium.zip',
+ 'en-small': "elmo_en_Small.zip",
+ 'en-original-5.5b': 'elmo_en_Original_5.5B.zip',
+ 'en-original': 'elmo_en_Original.zip',
+ 'en-medium': 'elmo_en_Medium.zip'
}
PRETRAIN_STATIC_FILES = {
- 'en': 'glove.840B.300d-cc1ad5e1.tar.gz',
- 'en-glove-840b-300': 'glove.840B.300d-cc1ad5e1.tar.gz',
- 'en-glove-6b-50': "glove.6B.50d-a6028c70.tar.gz",
- 'en-word2vec-300': "GoogleNews-vectors-negative300-be166d9d.tar.gz",
- 'en-fasttext': "cc.en.300.vec-d53187b2.gz",
- 'cn': "tencent_cn-dab24577.tar.gz",
- 'cn-fasttext': "cc.zh.300.vec-d68a9bcf.gz",
+ 'en': 'glove.840B.300d.zip',
+
+ 'en-glove-6b-50d': 'glove.6B.50d.zip',
+ 'en-glove-6b-100d': 'glove.6B.100d.zip',
+ 'en-glove-6b-200d': 'glove.6B.200d.zip',
+ 'en-glove-6b-300d': 'glove.6B.300d.zip',
+ 'en-glove-42b-300d': 'glove.42B.300d.zip',
+ 'en-glove-840b-300d': 'glove.840B.300d.zip',
+ 'en-glove-twitter-27b-25d': 'glove.twitter.27B.25d.zip',
+ 'en-glove-twitter-27b-50d': 'glove.twitter.27B.50d.zip',
+ 'en-glove-twitter-27b-100d': 'glove.twitter.27B.100d.zip',
+ 'en-glove-twitter-27b-200d': 'glove.twitter.27B.200d.zip',
+
+ 'en-word2vec-300': "GoogleNews-vectors-negative300.txt.gz",
+
+ 'en-fasttext-wiki': "wiki-news-300d-1M.vec.zip",
+ 'en-fasttext-crawl': "crawl-300d-2M.vec.zip",
+
+ 'cn': "tencent_cn.zip",
+ 'cn-tencent': "tencent_cn.zip",
+ 'cn-fasttext': "cc.zh.300.vec.gz",
+ 'cn-sgns-literature-word': 'sgns.literature.word.txt.zip',
+}
+
+DATASET_DIR = {
+ 'aclImdb': "imdb.zip",
+ "yelp-review-full": "yelp_review_full.tar.gz",
+ "yelp-review-polarity": "yelp_review_polarity.tar.gz",
+ "mnli": "MNLI.zip",
+ "snli": "SNLI.zip",
+ "qnli": "QNLI.zip",
+ "sst-2": "SST-2.zip",
+ "sst": "SST.zip",
+ "rte": "RTE.zip",
+ "msra-ner": "MSRA_NER.zip",
+ "peopledaily": "peopledaily.zip",
+ "weibo-ner": "weibo_NER.zip",
+
+ "cws-pku": 'cws_pku.zip',
+ "cws-cityu": "cws_cityu.zip",
+ "cws-as": 'cws_as.zip',
+ "cws-msra": 'cws_msra.zip'
+}
+
+PRETRAIN_MAP = {'elmo': PRETRAINED_ELMO_MODEL_DIR,
+ "bert": PRETRAINED_BERT_MODEL_DIR,
+ "static": PRETRAIN_STATIC_FILES}
+
+# 用于扩展fastNLP的下载
+FASTNLP_EXTEND_DATASET_URL = 'fastnlp_dataset_url.txt'
+FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt',
+ 'bert':'fastnlp_bert_url.txt',
+ 'static': 'fastnlp_static_url.txt'
}
-def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:
+def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path:
"""
- 给定一个url或者文件名(可以是具体的文件名,也可以是文件),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并
- 将文件放入到cache_dir中
+ 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件,
+
+ 1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir
+ 2. 如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name}
+
+ 如果有该文件,就直接返回路径
+
+ 如果没有该文件,则尝试用传入的url下载
+
+ 或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并
+ 将文件放入到cache_dir中.
+
+ :param str url_or_filename: 文件的下载url或者文件名称。
+ :param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径
+ :param str name: 中间一层的名称。如embedding, dataset
+ :return:
"""
if cache_dir is None:
- dataset_cache = Path(get_defalt_path())
+ data_cache = Path(get_cache_path())
else:
- dataset_cache = cache_dir
+ data_cache = cache_dir
+
+ if name:
+ data_cache = os.path.join(data_cache, name)
parsed = urlparse(url_or_filename)
if parsed.scheme in ("http", "https"):
# URL, so get it from the cache (downloading if necessary)
- return get_from_cache(url_or_filename, dataset_cache)
- elif parsed.scheme == "" and Path(os.path.join(dataset_cache, url_or_filename)).exists():
+ return get_from_cache(url_or_filename, Path(data_cache))
+ elif parsed.scheme == "" and Path(os.path.join(data_cache, url_or_filename)).exists():
# File, and it exists.
- return Path(url_or_filename)
+ return Path(os.path.join(data_cache, url_or_filename))
elif parsed.scheme == "":
# File, but it doesn't exist.
- raise FileNotFoundError("file {} not found".format(url_or_filename))
+ raise FileNotFoundError("file {} not found in {}.".format(url_or_filename, data_cache))
else:
# Something unknown
raise ValueError(
@@ -75,48 +158,143 @@ def cached_path(url_or_filename: str, cache_dir: Path=None) -> Path:
def get_filepath(filepath):
"""
- 如果filepath中只有一个文件,则直接返回对应的全路径
- :param filepath:
+ 如果filepath为文件夹,
+
+ 如果内含多个文件, 返回filepath
+
+ 如果只有一个文件, 返回filepath + filename
+
+ 如果filepath为文件
+
+ 返回filepath
+
+ :param str filepath: 路径
:return:
"""
if os.path.isdir(filepath):
files = os.listdir(filepath)
- if len(files)==1:
+ if len(files) == 1:
return os.path.join(filepath, files[0])
else:
return filepath
- return filepath
+ elif os.path.isfile(filepath):
+ return filepath
+ else:
+ raise FileNotFoundError(f"{filepath} is not a valid file or directory.")
-def get_defalt_path():
+def get_cache_path():
"""
- 获取默认的fastNLP存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。
+ 获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。
- :return:
+ :return str: 存放路径
"""
if 'FASTNLP_CACHE_DIR' in os.environ:
fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR')
- if os.path.exists(fastnlp_cache_dir):
+ if os.path.isdir(fastnlp_cache_dir):
return fastnlp_cache_dir
- raise RuntimeError("Some errors happens on cache directory.")
- else:
- raise RuntimeError("There function is not available right now.")
+ else:
+ raise NotADirectoryError(f"{os.environ['FASTNLP_CACHE_DIR']} is not a directory.")
fastnlp_cache_dir = os.path.expanduser(os.path.join("~", ".fastNLP"))
return fastnlp_cache_dir
def _get_base_url(name):
+ """
+ 根据name返回下载的url地址。
+
+ :param str name: 支持dataset和embedding两种
+ :return:
+ """
# 返回的URL结尾必须是/
- if 'FASTNLP_BASE_URL' in os.environ:
- fastnlp_base_url = os.environ['FASTNLP_BASE_URL']
- return fastnlp_base_url
- raise RuntimeError("There function is not available right now.")
+ environ_name = "FASTNLP_{}_URL".format(name.upper())
+
+ if environ_name in os.environ:
+ url = os.environ[environ_name]
+ if url.endswith('/'):
+ return url
+ else:
+ return url + '/'
+ else:
+ URLS = {
+ 'embedding': "http://dbcloud.irocn.cn:8989/api/public/dl/",
+ "dataset": "http://dbcloud.irocn.cn:8989/api/public/dl/dataset/"
+ }
+ if name.lower() not in URLS:
+ raise KeyError(f"{name} is not recognized.")
+ return URLS[name.lower()]
+
+
+def _get_embedding_url(embed_type, name):
+ """
+ 给定embedding类似和名称,返回下载url
+
+ :param str embed_type: 支持static, bert, elmo。即embedding的类型
+ :param str name: embedding的名称, 例如en, cn, based等
+ :return: str, 下载的url地址
+ """
+ # 从扩展中寻找下载的url
+ _filename = FASTNLP_EXTEND_EMBEDDING_URL.get(embed_type, None)
+ if _filename:
+ url = _read_extend_url_file(_filename, name)
+ if url:
+ return url
+ embed_map = PRETRAIN_MAP.get(embed_type, None)
+ if embed_map:
+ filename = embed_map.get(name, None)
+ if filename:
+ url = _get_base_url('embedding') + filename
+ return url
+ raise KeyError("There is no {}. Only supports {}.".format(name, list(embed_map.keys())))
+ else:
+ raise KeyError(f"There is no {embed_type}. Only supports bert, elmo, static")
+
+def _read_extend_url_file(filename, name)->str:
+ """
+ filename中的内容使用制表符隔开,第一列是名称,第二列是下载的url地址
+
+ :param str filename: 在默认的路径下寻找file这个文件
+ :param str name: 需要寻找的资源的名称
+ :return: str or None
+ """
+ cache_dir = get_cache_path()
+ filepath = os.path.join(cache_dir, filename)
+ if os.path.exists(filepath):
+ with open(filepath, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ if len(parts) == 2:
+ if name == parts[0]:
+ return parts[1]
+ return None
+
+def _get_dataset_url(name):
+ """
+ 给定dataset的名称,返回下载url
+
+ :param str name: 给定dataset的名称,比如imdb, sst-2等
+ :return: str
+ """
+ # 从扩展中寻找下载的url
+ url = _read_extend_url_file(FASTNLP_EXTEND_DATASET_URL, name)
+ if url:
+ return url
+
+ filename = DATASET_DIR.get(name, None)
+ if filename:
+ url = _get_base_url('dataset') + filename
+ return url
+ else:
+ raise KeyError(f"There is no {name}.")
def split_filename_suffix(filepath):
"""
- 给定filepath返回对应的name和suffix
- :param filepath:
+ 给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型
+
+ :param filepath: 文件路径
:return: filename, suffix
"""
filename = os.path.basename(filepath)
@@ -127,21 +305,19 @@ def split_filename_suffix(filepath):
def get_from_cache(url: str, cache_dir: Path = None) -> Path:
"""
- 尝试在cache_dir中寻找url定义的资源; 如果没有找到。则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。
- 如果从url中下载的资源解压后有多个文件,则返回directory的路径; 如果只有一个资源,则返回具体的路径。
-
+ 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的
+ 文件解压,将解压后的文件全部放在cache_dir文件夹中。
+
+ 如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。
+
+ :param url: 资源的 url
+ :param cache_dir: cache 目录
+ :return: 路径
"""
cache_dir.mkdir(parents=True, exist_ok=True)
filename = re.sub(r".+/", "", url)
dir_name, suffix = split_filename_suffix(filename)
- sep_index = dir_name[::-1].index('-')
- if sep_index<0:
- check_sum = None
- else:
- check_sum = dir_name[-sep_index+1:]
- sep_index = len(dir_name) if sep_index==-1 else -sep_index-1
- dir_name = dir_name[:sep_index]
# 寻找与它名字匹配的内容, 而不关心后缀
match_dir_name = match_file(dir_name, cache_dir)
@@ -154,11 +330,11 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
return get_filepath(cache_path)
# make HEAD request to check ETag TODO ETag可以用来判断资源是否已经更新了,之后需要加上
- response = requests.head(url, headers={"User-Agent": "fastNLP"})
- if response.status_code != 200:
- raise IOError(
- f"HEAD request failed for url {url} with status code {response.status_code}."
- )
+ # response = requests.head(url, headers={"User-Agent": "fastNLP"})
+ # if response.status_code != 200:
+ # raise IOError(
+ # f"HEAD request failed for url {url} with status code {response.status_code}."
+ # )
# add ETag to filename if it exists
# etag = response.headers.get("ETag")
@@ -166,74 +342,77 @@ def get_from_cache(url: str, cache_dir: Path = None) -> Path:
if not cache_path.exists():
# Download to temporary file, then copy to cache dir once finished.
# Otherwise you get corrupt cache entries if the download gets interrupted.
- fd, temp_filename = tempfile.mkstemp()
- print("%s not found in cache, downloading to %s"%(url, temp_filename))
-
# GET file object
req = requests.get(url, stream=True, headers={"User-Agent": "fastNLP"})
- content_length = req.headers.get("Content-Length")
- total = int(content_length) if content_length is not None else None
- progress = tqdm(unit="B", total=total)
- sha256 = hashlib.sha256()
- with open(temp_filename, "wb") as temp_file:
- for chunk in req.iter_content(chunk_size=1024):
- if chunk: # filter out keep-alive new chunks
- progress.update(len(chunk))
- temp_file.write(chunk)
- sha256.update(chunk)
- # check sum
- digit = sha256.hexdigest()[:8]
- if not check_sum:
- assert digit == check_sum, "File corrupted when download."
- progress.close()
- print(f"Finish download from {url}.")
-
- # 开始解压
- delete_temp_dir = None
- if suffix in ('.zip', '.tar.gz'):
- uncompress_temp_dir = tempfile.mkdtemp()
- delete_temp_dir = uncompress_temp_dir
- print(f"Start to uncompress file to {uncompress_temp_dir}.")
- if suffix == '.zip':
- unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
- else:
- untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir))
- filenames = os.listdir(uncompress_temp_dir)
- if len(filenames)==1:
- if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])):
- uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])
-
- cache_path.mkdir(parents=True, exist_ok=True)
- print("Finish un-compressing file.")
- else:
- uncompress_temp_dir = temp_filename
- cache_path = str(cache_path) + suffix
- success = False
- try:
- # 复制到指定的位置
- print(f"Copy file to {cache_path}.")
- if os.path.isdir(uncompress_temp_dir):
- for filename in os.listdir(uncompress_temp_dir):
- shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path/filename)
- else:
- shutil.copyfile(uncompress_temp_dir, cache_path)
- success = True
- except Exception as e:
- print(e)
- raise e
- finally:
- if not success:
- if cache_path.exists():
- if cache_path.is_file():
- os.remove(cache_path)
+ if req.status_code == 200:
+ success = False
+ fd, temp_filename = tempfile.mkstemp()
+ uncompress_temp_dir = None
+ try:
+ content_length = req.headers.get("Content-Length")
+ total = int(content_length) if content_length is not None else None
+ progress = tqdm(unit="B", total=total, unit_scale=1)
+ logger.info("%s not found in cache, downloading to %s" % (url, temp_filename))
+
+ with open(temp_filename, "wb") as temp_file:
+ for chunk in req.iter_content(chunk_size=1024 * 16):
+ if chunk: # filter out keep-alive new chunks
+ progress.update(len(chunk))
+ temp_file.write(chunk)
+ progress.close()
+ logger.info(f"Finish download from {url}")
+
+ # 开始解压
+ if suffix in ('.zip', '.tar.gz', '.gz'):
+ uncompress_temp_dir = tempfile.mkdtemp()
+ logger.debug(f"Start to uncompress file to {uncompress_temp_dir}")
+ if suffix == '.zip':
+ unzip_file(Path(temp_filename), Path(uncompress_temp_dir))
+ elif suffix == '.gz':
+ ungzip_file(temp_filename, uncompress_temp_dir, dir_name)
else:
- shutil.rmtree(cache_path)
- if delete_temp_dir:
- shutil.rmtree(delete_temp_dir)
- os.close(fd)
- os.remove(temp_filename)
-
- return get_filepath(cache_path)
+ untar_gz_file(Path(temp_filename), Path(uncompress_temp_dir))
+ filenames = os.listdir(uncompress_temp_dir)
+ if len(filenames) == 1:
+ if os.path.isdir(os.path.join(uncompress_temp_dir, filenames[0])):
+ uncompress_temp_dir = os.path.join(uncompress_temp_dir, filenames[0])
+
+ cache_path.mkdir(parents=True, exist_ok=True)
+ logger.debug("Finish un-compressing file.")
+ else:
+ uncompress_temp_dir = temp_filename
+ cache_path = str(cache_path) + suffix
+
+ # 复制到指定的位置
+ logger.info(f"Copy file to {cache_path}")
+ if os.path.isdir(uncompress_temp_dir):
+ for filename in os.listdir(uncompress_temp_dir):
+ if os.path.isdir(os.path.join(uncompress_temp_dir, filename)):
+ shutil.copytree(os.path.join(uncompress_temp_dir, filename), cache_path / filename)
+ else:
+ shutil.copyfile(os.path.join(uncompress_temp_dir, filename), cache_path / filename)
+ else:
+ shutil.copyfile(uncompress_temp_dir, cache_path)
+ success = True
+ except Exception as e:
+ logger.error(e)
+ raise e
+ finally:
+ if not success:
+ if cache_path.exists():
+ if cache_path.is_file():
+ os.remove(cache_path)
+ else:
+ shutil.rmtree(cache_path)
+ os.close(fd)
+ os.remove(temp_filename)
+ if os.path.isdir(uncompress_temp_dir):
+ shutil.rmtree(uncompress_temp_dir)
+ elif os.path.isfile(uncompress_temp_dir):
+ os.remove(uncompress_temp_dir)
+ return get_filepath(cache_path)
+ else:
+ raise HTTPError(f"Status code:{req.status_code}. Fail to download from {url}.")
def unzip_file(file: Path, to: Path):
@@ -245,55 +424,39 @@ def unzip_file(file: Path, to: Path):
zipObj.extractall(to)
-def untar_gz_file(file:Path, to:Path):
+def untar_gz_file(file: Path, to: Path):
import tarfile
with tarfile.open(file, 'r:gz') as tar:
tar.extractall(to)
-def match_file(dir_name: str, cache_dir: str) -> str:
+def ungzip_file(file: str, to: str, filename:str):
+ import gzip
+
+ g_file = gzip.GzipFile(file)
+ with open(os.path.join(to, filename), 'wb+') as f:
+ f.write(g_file.read())
+ g_file.close()
+
+
+def match_file(dir_name: str, cache_dir: Path) -> str:
"""
- 匹配的原则是,在cache_dir下的文件: (1) 与dir_name完全一致; (2) 除了后缀以外和dir_name完全一致。
+ 匹配的原则是: 在cache_dir下的文件与dir_name完全一致, 或除了后缀以外和dir_name完全一致。
如果找到了两个匹配的结果将报错. 如果找到了则返回匹配的文件的名称; 没有找到返回空字符串
:param dir_name: 需要匹配的名称
:param cache_dir: 在该目录下找匹配dir_name是否存在
- :return: str
+ :return str: 做为匹配结果的字符串
"""
files = os.listdir(cache_dir)
matched_filenames = []
for file_name in files:
- if re.match(dir_name+'$', file_name) or re.match(dir_name+'\\..*', file_name):
+ if re.match(dir_name + '$', file_name) or re.match(dir_name + '\\..*', file_name):
matched_filenames.append(file_name)
- if len(matched_filenames)==0:
+ if len(matched_filenames) == 0:
return ''
- elif len(matched_filenames)==1:
+ elif len(matched_filenames) == 1:
return matched_filenames[-1]
else:
raise RuntimeError(f"Duplicate matched files:{matched_filenames}, this should be caused by a bug.")
-
-
-if __name__ == '__main__':
- cache_dir = Path('caches')
- cache_dir = None
- # 需要对cache_dir进行测试
- base_url = 'http://0.0.0.0:8888/file/download'
- # if True:
- # for filename in os.listdir(cache_dir):
- # if os.path.isdir(os.path.join(cache_dir, filename)):
- # shutil.rmtree(os.path.join(cache_dir, filename))
- # else:
- # os.remove(os.path.join(cache_dir, filename))
- # 1. 测试.txt文件
- print(cached_path(base_url + '/{}'.format('txt_test-bcb4fe65.txt'), cache_dir))
- # 2. 测试.zip文件(只有一个文件)
- print(cached_path(base_url + '/{}'.format('zip_test-40966d39.zip'), cache_dir))
- # 3. 测试.zip文件(有多个文件)
- print(cached_path(base_url + '/{}'.format('zip_pack_test-70c0b20d.zip'), cache_dir))
- # 4. 测试.tar.gz文件
- print(cached_path(base_url + '/{}'.format('tar_gz_test-3e2679cf.tar.gz'), cache_dir))
- # 5. 测试.tar.gz多个文件
- print(cached_path(base_url + '/{}'.format('tar_gz_pack_test-08dfdccd.tar.gz'), cache_dir))
-
- # 6. 测试.pkl文件
diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py
new file mode 100644
index 00000000..6c23f213
--- /dev/null
+++ b/fastNLP/io/loader/__init__.py
@@ -0,0 +1,83 @@
+"""
+Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的
+三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式,
+读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ;
+``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数:
+
+0.传入None
+ 将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。
+
+1.传入一个文件的 path
+ 返回的 `data_bundle` 包含一个名为 `train` 的 dataset ,可以通过 ``data_bundle.datasets['train']`` 获取
+
+2.传入一个文件夹目录
+ 将读取的是这个文件夹下文件名中包含 `train` , `test` , `dev` 的文件,其它文件会被忽略。假设某个目录下的文件为::
+
+ |
+ +-train.txt
+ +-dev.txt
+ +-test.txt
+ +-other.txt
+
+ 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , ``data_bundle.datasets['dev']`` ,
+ ``data_bundle.datasets['test']`` 获取对应的 `dataset` ,其中 `other.txt` 的内容会被忽略。假设某个目录下的文件为::
+
+ |
+ +-train.txt
+ +-dev.txt
+
+ 在 Loader().load('/path/to/dir') 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` ,
+ ``data_bundle.datasets['dev']`` 获取对应的 dataset。
+
+3.传入一个字典
+ 字典的的 key 为 `dataset` 的名称,value 是该 `dataset` 的文件路径::
+
+ paths = {'train':'/path/to/train', 'dev': '/path/to/dev', 'test':'/path/to/test'}
+
+ 在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.datasets['train']`` , ``data_bundle.datasets['dev']`` ,
+ ``data_bundle.datasets['test']`` 来获取对应的 `dataset`
+
+fastNLP 目前提供了如下的 Loader
+
+
+
+"""
+
+__all__ = [
+ 'Loader',
+
+ 'YelpLoader',
+ 'YelpFullLoader',
+ 'YelpPolarityLoader',
+ 'IMDBLoader',
+ 'SSTLoader',
+ 'SST2Loader',
+
+ 'ConllLoader',
+ 'Conll2003Loader',
+ 'Conll2003NERLoader',
+ 'OntoNotesNERLoader',
+ 'CTBLoader',
+ "MsraNERLoader",
+ "PeopleDailyNERLoader",
+ "WeiboNERLoader",
+
+ 'CSVLoader',
+ 'JsonLoader',
+
+ 'CWSLoader',
+
+ 'MNLILoader',
+ "QuoraLoader",
+ "SNLILoader",
+ "QNLILoader",
+ "RTELoader"
+]
+from .classification import YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader
+from .conll import ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader
+from .csv import CSVLoader
+from .cws import CWSLoader
+from .json import JsonLoader
+from .loader import Loader
+from .matching import MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader
+from .conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py
new file mode 100644
index 00000000..ec00d2b4
--- /dev/null
+++ b/fastNLP/io/loader/classification.py
@@ -0,0 +1,348 @@
+"""undocumented"""
+
+__all__ = [
+ "YelpLoader",
+ "YelpFullLoader",
+ "YelpPolarityLoader",
+ "IMDBLoader",
+ "SSTLoader",
+ "SST2Loader",
+]
+
+import glob
+import os
+import random
+import shutil
+import time
+import warnings
+
+from .loader import Loader
+from ...core.dataset import DataSet
+from ...core.instance import Instance
+
+
+class YelpLoader(Loader):
+ """
+ 别名::class:`fastNLP.io.YelpLoader` :class:`fastNLP.io.loader.YelpLoader`
+
+ 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。
+
+ Example::
+
+ "1","I got 'new' tires from the..."
+ "1","Don't waste your time..."
+
+ 读取YelpFull, YelpPolarity的数据。可以通过xxx下载并预处理数据。
+ 读取的DataSet将具备以下的数据结构
+
+ .. csv-table::
+ :header: "raw_words", "target"
+
+ "I got 'new' tires from them and... ", "1"
+ "Don't waste your time. We had two...", "1"
+ "...", "..."
+
+ """
+
+ def __init__(self):
+ super(YelpLoader, self).__init__()
+
+ def _load(self, path: str = None):
+ ds = DataSet()
+ with open(path, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ sep_index = line.index(',')
+ target = line[:sep_index]
+ raw_words = line[sep_index + 1:]
+ if target.startswith("\""):
+ target = target[1:]
+ if target.endswith("\""):
+ target = target[:-1]
+ if raw_words.endswith("\""):
+ raw_words = raw_words[:-1]
+ if raw_words.startswith('"'):
+ raw_words = raw_words[1:]
+ raw_words = raw_words.replace('""', '"') # 替换双引号
+ if raw_words:
+ ds.append(Instance(raw_words=raw_words, target=target))
+ return ds
+
+
+class YelpFullLoader(YelpLoader):
+ def download(self, dev_ratio: float = 0.1, re_download: bool = False):
+ """
+ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章
+
+ Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
+ in Neural Information Processing Systems 28 (NIPS 2015)
+
+ 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.csv, test.csv,
+ dev.csv三个文件。
+
+ :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
+ :param bool re_download: 是否重新下载数据,以重新切分数据。
+ :return: str, 数据集的目录地址
+ """
+
+ dataset_name = 'yelp-review-full'
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+ modify_time = 0
+ for filepath in glob.glob(os.path.join(data_dir, '*')):
+ modify_time = os.stat(filepath).st_mtime
+ break
+ if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
+ shutil.rmtree(data_dir)
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+
+ if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
+ if dev_ratio > 0:
+ assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
+ try:
+ with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
+ open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
+ open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2:
+ for line in f:
+ if random.random() < dev_ratio:
+ f2.write(line)
+ else:
+ f1.write(line)
+ os.remove(os.path.join(data_dir, 'train.csv'))
+ os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv'))
+ finally:
+ if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
+ os.remove(os.path.join(data_dir, 'middle_file.csv'))
+
+ return data_dir
+
+
+class YelpPolarityLoader(YelpLoader):
+ def download(self, dev_ratio: float = 0.1, re_download=False):
+ """
+ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章
+
+ Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances
+ in Neural Information Processing Systems 28 (NIPS 2015)
+
+ 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分dev_ratio这么多作为dev
+
+ :param float dev_ratio: 如果路径中不存在dev.csv, 从train划分多少作为dev的数据。 如果为0,则不划分dev。
+ :param bool re_download: 是否重新下载数据,以重新切分数据。
+ :return: str, 数据集的目录地址
+ """
+ dataset_name = 'yelp-review-polarity'
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+ modify_time = 0
+ for filepath in glob.glob(os.path.join(data_dir, '*')):
+ modify_time = os.stat(filepath).st_mtime
+ break
+ if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
+ shutil.rmtree(data_dir)
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+
+ if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
+ if dev_ratio > 0:
+ assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
+ try:
+ with open(os.path.join(data_dir, 'train.csv'), 'r', encoding='utf-8') as f, \
+ open(os.path.join(data_dir, 'middle_file.csv'), 'w', encoding='utf-8') as f1, \
+ open(os.path.join(data_dir, 'dev.csv'), 'w', encoding='utf-8') as f2:
+ for line in f:
+ if random.random() < dev_ratio:
+ f2.write(line)
+ else:
+ f1.write(line)
+ os.remove(os.path.join(data_dir, 'train.csv'))
+ os.renames(os.path.join(data_dir, 'middle_file.csv'), os.path.join(data_dir, 'train.csv'))
+ finally:
+ if os.path.exists(os.path.join(data_dir, 'middle_file.csv')):
+ os.remove(os.path.join(data_dir, 'middle_file.csv'))
+
+ return data_dir
+
+
+class IMDBLoader(Loader):
+ """
+ 别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.loader.IMDBLoader`
+
+ IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签
+ DataSet具备以下的结构:
+
+ .. csv-table::
+ :header: "raw_words", "target"
+
+ "Bromwell High is a cartoon ... ", "pos"
+ "Story of a man who has ...", "neg"
+ "...", "..."
+
+ """
+
+ def __init__(self):
+ super(IMDBLoader, self).__init__()
+
+ def _load(self, path: str):
+ dataset = DataSet()
+ with open(path, 'r', encoding="utf-8") as f:
+ for line in f:
+ line = line.strip()
+ if not line:
+ continue
+ parts = line.split('\t')
+ target = parts[0]
+ words = parts[1]
+ if words:
+ dataset.append(Instance(raw_words=words, target=target))
+
+ if len(dataset) == 0:
+ raise RuntimeError(f"{path} has no valid data.")
+
+ return dataset
+
+ def download(self, dev_ratio: float = 0.1, re_download=False):
+ """
+ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章
+
+ http://www.aclweb.org/anthology/P11-1015
+
+ 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后从train中切分0.1作为dev
+
+ :param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev
+ :param bool re_download: 是否重新下载数据,以重新切分数据。
+ :return: str, 数据集的目录地址
+ """
+ dataset_name = 'aclImdb'
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+ modify_time = 0
+ for filepath in glob.glob(os.path.join(data_dir, '*')):
+ modify_time = os.stat(filepath).st_mtime
+ break
+ if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
+ shutil.rmtree(data_dir)
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+
+ if not os.path.exists(os.path.join(data_dir, 'dev.csv')):
+ if dev_ratio > 0:
+ assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
+ try:
+ with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \
+ open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \
+ open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2:
+ for line in f:
+ if random.random() < dev_ratio:
+ f2.write(line)
+ else:
+ f1.write(line)
+ os.remove(os.path.join(data_dir, 'train.txt'))
+ os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt'))
+ finally:
+ if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
+ os.remove(os.path.join(data_dir, 'middle_file.txt'))
+
+ return data_dir
+
+
+class SSTLoader(Loader):
+ """
+ 别名::class:`fastNLP.io.SSTLoader` :class:`fastNLP.io.loader.SSTLoader`
+
+ 读取之后的DataSet具有以下的结构
+
+ .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field
+ :header: "raw_words"
+
+ "(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..."
+ "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..."
+ "..."
+
+ raw_words列是str。
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def _load(self, path: str):
+ """
+ 从path读取SST文件
+
+ :param str path: 文件路径
+ :return: DataSet
+ """
+ ds = DataSet()
+ with open(path, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ ds.append(Instance(raw_words=line))
+ return ds
+
+ def download(self):
+ """
+ 自动下载数据集,如果你使用了这个数据集,请引用以下的文章
+
+ https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf
+
+ :return: str, 数据集的目录地址
+ """
+ output_dir = self._get_dataset_path(dataset_name='sst')
+ return output_dir
+
+
+class SST2Loader(Loader):
+ """
+ 数据SST2的Loader
+ 读取之后DataSet将如下所示
+
+ .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field
+ :header: "raw_words", "target"
+
+ "it 's a charming and often affecting...", "1"
+ "unflinchingly bleak and...", "0"
+ "..."
+
+ test的DataSet没有target列。
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def _load(self, path: str):
+ """
+ 从path读取SST2文件
+
+ :param str path: 数据路径
+ :return: DataSet
+ """
+ ds = DataSet()
+
+ with open(path, 'r', encoding='utf-8') as f:
+ f.readline() # 跳过header
+ if 'test' in os.path.split(path)[1]:
+ warnings.warn("SST2's test file has no target.")
+ for line in f:
+ line = line.strip()
+ if line:
+ sep_index = line.index('\t')
+ raw_words = line[sep_index + 1:]
+ if raw_words:
+ ds.append(Instance(raw_words=raw_words))
+ else:
+ for line in f:
+ line = line.strip()
+ if line:
+ raw_words = line[:-2]
+ target = line[-1]
+ if raw_words:
+ ds.append(Instance(raw_words=raw_words, target=target))
+ return ds
+
+ def download(self):
+ """
+ 自动下载数据集,如果你使用了该数据集,请引用以下的文章
+
+ https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf
+
+ :return:
+ """
+ output_dir = self._get_dataset_path(dataset_name='sst-2')
+ return output_dir
diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py
new file mode 100644
index 00000000..1bd1b448
--- /dev/null
+++ b/fastNLP/io/loader/conll.py
@@ -0,0 +1,455 @@
+"""undocumented"""
+
+__all__ = [
+ "ConllLoader",
+ "Conll2003Loader",
+ "Conll2003NERLoader",
+ "OntoNotesNERLoader",
+ "CTBLoader",
+ "CNNERLoader",
+ "MsraNERLoader",
+ "WeiboNERLoader",
+ "PeopleDailyNERLoader"
+]
+
+import glob
+import os
+import random
+import shutil
+import time
+
+from .loader import Loader
+from ..file_reader import _read_conll
+from ...core.const import Const
+from ...core.dataset import DataSet
+from ...core.instance import Instance
+
+
+class ConllLoader(Loader):
+ """
+ 别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.loader.ConllLoader`
+
+ ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示:
+
+ Example::
+
+ # 文件中的内容
+ Nadim NNP B-NP B-PER
+ Ladki NNP I-NP I-PER
+
+ AL-AIN NNP B-NP B-LOC
+ United NNP B-NP B-LOC
+ Arab NNP I-NP I-LOC
+ Emirates NNPS I-NP I-LOC
+ 1996-12-06 CD I-NP O
+ ...
+
+ # 如果用以下的参数读取,返回的DataSet将包含raw_words和pos两个field, 这两个field的值分别取自于第0列与第1列
+ dataset = ConllLoader(headers=['raw_words', 'pos'], indexes=[0, 1])._load('/path/to/train.conll')
+ # 如果用以下的参数读取,返回的DataSet将包含raw_words和ner两个field, 这两个field的值分别取自于第0列与第2列
+ dataset = ConllLoader(headers=['raw_words', 'ner'], indexes=[0, 3])._load('/path/to/train.conll')
+ # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field
+ dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll')
+
+ ConllLoader返回的DataSet的field由传入的headers确定。
+
+ 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。
+
+ :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应
+ :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None``
+ :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True``
+
+ """
+
+ def __init__(self, headers, indexes=None, dropna=True):
+ super(ConllLoader, self).__init__()
+ if not isinstance(headers, (list, tuple)):
+ raise TypeError(
+ 'invalid headers: {}, should be list of strings'.format(headers))
+ self.headers = headers
+ self.dropna = dropna
+ if indexes is None:
+ self.indexes = list(range(len(self.headers)))
+ else:
+ if len(indexes) != len(headers):
+ raise ValueError
+ self.indexes = indexes
+
+ def _load(self, path):
+ """
+ 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
+
+ :param str path: 文件的路径
+ :return: DataSet
+ """
+ ds = DataSet()
+ for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
+ ins = {h: data[i] for i, h in enumerate(self.headers)}
+ ds.append(Instance(**ins))
+ return ds
+
+
+class Conll2003Loader(ConllLoader):
+ """
+ 用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。
+
+ Example::
+
+ Nadim NNP B-NP B-PER
+ Ladki NNP I-NP I-PER
+
+ AL-AIN NNP B-NP B-LOC
+ United NNP B-NP B-LOC
+ Arab NNP I-NP I-LOC
+ Emirates NNPS I-NP I-LOC
+ 1996-12-06 CD I-NP O
+ ...
+
+ 返回的DataSet的内容为
+
+ .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。
+ :header: "raw_words", "pos", "chunk", "ner"
+
+ "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]"
+ "[AL-AIN, United, Arab, ...]", "[NNP, NNP, NNP, ...]", "[B-NP, B-NP, I-NP, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
+ "[...]", "[...]", "[...]", "[...]"
+
+ """
+
+ def __init__(self):
+ headers = [
+ 'raw_words', 'pos', 'chunk', 'ner',
+ ]
+ super(Conll2003Loader, self).__init__(headers=headers)
+
+ def _load(self, path):
+ """
+ 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
+
+ :param str path: 文件的路径
+ :return: DataSet
+ """
+ ds = DataSet()
+ for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
+ doc_start = False
+ for i, h in enumerate(self.headers):
+ field = data[i]
+ if str(field[0]).startswith('-DOCSTART-'):
+ doc_start = True
+ break
+ if doc_start:
+ continue
+ ins = {h: data[i] for i, h in enumerate(self.headers)}
+ ds.append(Instance(**ins))
+ return ds
+
+ def download(self, output_dir=None):
+ raise RuntimeError("conll2003 cannot be downloaded automatically.")
+
+
+class Conll2003NERLoader(ConllLoader):
+ """
+ 用于读取conll2003任务的NER数据。
+
+ Example::
+
+ Nadim NNP B-NP B-PER
+ Ladki NNP I-NP I-PER
+
+ AL-AIN NNP B-NP B-LOC
+ United NNP B-NP B-LOC
+ Arab NNP I-NP I-LOC
+ Emirates NNPS I-NP I-LOC
+ 1996-12-06 CD I-NP O
+ ...
+
+ 返回的DataSet的内容为
+
+ .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码
+ :header: "raw_words", "target"
+
+ "[Nadim, Ladki]", "[B-PER, I-PER]"
+ "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
+ "[...]", "[...]"
+
+ """
+
+ def __init__(self):
+ headers = [
+ 'raw_words', 'target',
+ ]
+ super().__init__(headers=headers, indexes=[0, 3])
+
+ def _load(self, path):
+ """
+ 传入的一个文件路径,将该文件读入DataSet中,field由ConllLoader初始化时指定的headers决定。
+
+ :param str path: 文件的路径
+ :return: DataSet
+ """
+ ds = DataSet()
+ for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna):
+ doc_start = False
+ for i, h in enumerate(self.headers):
+ field = data[i]
+ if str(field[0]).startswith('-DOCSTART-'):
+ doc_start = True
+ break
+ if doc_start:
+ continue
+ ins = {h: data[i] for i, h in enumerate(self.headers)}
+ ds.append(Instance(**ins))
+ return ds
+
+ def download(self):
+ raise RuntimeError("conll2003 cannot be downloaded automatically.")
+
+
+class OntoNotesNERLoader(ConllLoader):
+ """
+ 用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考
+ https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。
+
+ 返回的DataSet的内容为
+
+ .. csv-table:: 下面是使用OntoNoteNERLoader读取的DataSet所具备的结构, target列是BIO编码
+ :header: "raw_words", "target"
+
+ "[Nadim, Ladki]", "[B-PER, I-PER]"
+ "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
+ "[...]", "[...]"
+
+ """
+
+ def __init__(self):
+ super().__init__(headers=[Const.RAW_WORD, Const.TARGET], indexes=[3, 10])
+
+ def _load(self, path: str):
+ dataset = super()._load(path)
+
+ def convert_to_bio(tags):
+ bio_tags = []
+ flag = None
+ for tag in tags:
+ label = tag.strip("()*")
+ if '(' in tag:
+ bio_label = 'B-' + label
+ flag = label
+ elif flag:
+ bio_label = 'I-' + flag
+ else:
+ bio_label = 'O'
+ if ')' in tag:
+ flag = None
+ bio_tags.append(bio_label)
+ return bio_tags
+
+ def convert_word(words):
+ converted_words = []
+ for word in words:
+ word = word.replace('/.', '.') # 有些结尾的.是/.形式的
+ if not word.startswith('-'):
+ converted_words.append(word)
+ continue
+ # 以下是由于这些符号被转义了,再转回来
+ tfrs = {'-LRB-': '(',
+ '-RRB-': ')',
+ '-LSB-': '[',
+ '-RSB-': ']',
+ '-LCB-': '{',
+ '-RCB-': '}'
+ }
+ if word in tfrs:
+ converted_words.append(tfrs[word])
+ else:
+ converted_words.append(word)
+ return converted_words
+
+ dataset.apply_field(convert_word, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
+ dataset.apply_field(convert_to_bio, field_name=Const.TARGET, new_field_name=Const.TARGET)
+
+ return dataset
+
+ def download(self):
+ raise RuntimeError("Ontonotes cannot be downloaded automatically, you can refer "
+ "https://github.com/yhcc/OntoNotes-5.0-NER to download and preprocess.")
+
+
+class CTBLoader(Loader):
+ def __init__(self):
+ super().__init__()
+
+ def _load(self, path: str):
+ pass
+
+
+class CNNERLoader(Loader):
+ def _load(self, path: str):
+ """
+ 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample
+
+ Example::
+
+ 我 O
+ 们 O
+ 变 O
+ 而 O
+ 以 O
+ 书 O
+ 会 O
+ ...
+
+ :param str path: 文件路径
+ :return: DataSet,包含raw_words列和target列
+ """
+ ds = DataSet()
+ with open(path, 'r', encoding='utf-8') as f:
+ raw_chars = []
+ target = []
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split()
+ if len(parts) == 1: # 网上下载的数据有一些列少tag,默认补充O
+ parts.append('O')
+ raw_chars.append(parts[0])
+ target.append(parts[1])
+ else:
+ if raw_chars:
+ ds.append(Instance(raw_chars=raw_chars, target=target))
+ raw_chars = []
+ target = []
+ return ds
+
+
+class MsraNERLoader(CNNERLoader):
+ """
+ 读取MSRA-NER数据,数据中的格式应该类似与下列的内容
+
+ Example::
+
+ 我 O
+ 们 O
+ 变 O
+ 而 O
+ 以 O
+ 书 O
+ 会 O
+ ...
+
+ 读取后的DataSet包含以下的field
+
+ .. csv-table:: target列是基于BIO的编码方式
+ :header: "raw_chars", "target"
+
+ "[我, 们, 变...]", "[O, O, ...]"
+ "[中, 共, 中, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
+ "[...]", "[...]"
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str:
+ """
+ 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language
+ Processing Bakeoff: Word Segmentation and Named Entity Recognition.
+
+ 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.conll, test.conll,
+ dev.conll三个文件。
+
+ :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
+ :param bool re_download: 是否重新下载数据,以重新切分数据。
+ :return: str, 数据集的目录地址
+ :return:
+ """
+ dataset_name = 'msra-ner'
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+ modify_time = 0
+ for filepath in glob.glob(os.path.join(data_dir, '*')):
+ modify_time = os.stat(filepath).st_mtime
+ break
+ if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
+ shutil.rmtree(data_dir)
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+
+ if not os.path.exists(os.path.join(data_dir, 'dev.conll')):
+ if dev_ratio > 0:
+ assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
+ try:
+ with open(os.path.join(data_dir, 'train.conll'), 'r', encoding='utf-8') as f, \
+ open(os.path.join(data_dir, 'middle_file.conll'), 'w', encoding='utf-8') as f1, \
+ open(os.path.join(data_dir, 'dev.conll'), 'w', encoding='utf-8') as f2:
+ lines = [] # 一个sample包含很多行
+ for line in f:
+ line = line.strip()
+ if line:
+ lines.append(line)
+ else:
+ if random.random() < dev_ratio:
+ f2.write('\n'.join(lines) + '\n\n')
+ else:
+ f1.write('\n'.join(lines) + '\n\n')
+ lines.clear()
+ os.remove(os.path.join(data_dir, 'train.conll'))
+ os.renames(os.path.join(data_dir, 'middle_file.conll'), os.path.join(data_dir, 'train.conll'))
+ finally:
+ if os.path.exists(os.path.join(data_dir, 'middle_file.conll')):
+ os.remove(os.path.join(data_dir, 'middle_file.conll'))
+
+ return data_dir
+
+
+class WeiboNERLoader(CNNERLoader):
+ def __init__(self):
+ super().__init__()
+
+ def download(self) -> str:
+ """
+ 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for
+ Chinese Social Media with Jointly Trained Embeddings.
+
+ :return: str
+ """
+ dataset_name = 'weibo-ner'
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+
+ return data_dir
+
+
+class PeopleDailyNERLoader(CNNERLoader):
+ """
+ 支持加载的数据格式如下
+
+ Example::
+
+ 当 O
+ 希 O
+ 望 O
+ 工 O
+ 程 O
+ 救 O
+ 助 O
+ 的 O
+ 百 O
+
+ 读取后的DataSet包含以下的field
+
+ .. csv-table:: target列是基于BIO的编码方式
+ :header: "raw_chars", "target"
+
+ "[我, 们, 变...]", "[O, O, ...]"
+ "[中, 共, 中, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
+ "[...]", "[...]"
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def download(self) -> str:
+ dataset_name = 'peopledaily'
+ data_dir = self._get_dataset_path(dataset_name=dataset_name)
+
+ return data_dir
diff --git a/fastNLP/io/loader/csv.py b/fastNLP/io/loader/csv.py
new file mode 100644
index 00000000..0d6e35fa
--- /dev/null
+++ b/fastNLP/io/loader/csv.py
@@ -0,0 +1,38 @@
+"""undocumented"""
+
+__all__ = [
+ "CSVLoader",
+]
+
+from .loader import Loader
+from ..file_reader import _read_csv
+from ...core.dataset import DataSet
+from ...core.instance import Instance
+
+
+class CSVLoader(Loader):
+ """
+ 别名::class:`fastNLP.io.CSVLoader` :class:`fastNLP.io.loader.CSVLoader`
+
+ 读取CSV格式的数据集, 返回 ``DataSet`` 。
+
+ :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称
+ 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None``
+ :param str sep: CSV文件中列与列之间的分隔符. Default: ","
+ :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
+ Default: ``False``
+ """
+
+ def __init__(self, headers=None, sep=",", dropna=False):
+ super().__init__()
+ self.headers = headers
+ self.sep = sep
+ self.dropna = dropna
+
+ def _load(self, path):
+ ds = DataSet()
+ for idx, data in _read_csv(path, headers=self.headers,
+ sep=self.sep, dropna=self.dropna):
+ ds.append(Instance(**data))
+ return ds
+
diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py
new file mode 100644
index 00000000..2fbb1091
--- /dev/null
+++ b/fastNLP/io/loader/cws.py
@@ -0,0 +1,94 @@
+"""undocumented"""
+
+__all__ = [
+ "CWSLoader"
+]
+
+import glob
+import os
+import random
+import shutil
+import time
+
+from .loader import Loader
+from ...core.dataset import DataSet
+from ...core.instance import Instance
+
+
+class CWSLoader(Loader):
+ """
+ CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如:
+
+ Example::
+
+ 上海 浦东 开发 与 法制 建设 同步
+ 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )
+ ...
+
+ 该Loader读取后的DataSet具有如下的结构
+
+ .. csv-table::
+ :header: "raw_words"
+
+ "上海 浦东 开发 与 法制 建设 同步"
+ "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )"
+ "..."
+
+ :param: str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None
+ """
+ def __init__(self, dataset_name:str=None):
+ super().__init__()
+ datanames = {'pku': 'cws-pku', 'msra':'cws-msra', 'as':'cws-as', 'cityu':'cws-cityu'}
+ if dataset_name in datanames:
+ self.dataset_name = datanames[dataset_name]
+ else:
+ self.dataset_name = None
+
+ def _load(self, path:str):
+ ds = DataSet()
+ with open(path, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ ds.append(Instance(raw_words=line))
+ return ds
+
+ def download(self, dev_ratio=0.1, re_download=False)->str:
+ """
+ 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff,
+ 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看
+
+ :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。
+ :param bool re_download: 是否重新下载数据,以重新切分数据。
+ :return: str
+ """
+ if self.dataset_name is None:
+ return None
+ data_dir = self._get_dataset_path(dataset_name=self.dataset_name)
+ modify_time = 0
+ for filepath in glob.glob(os.path.join(data_dir, '*')):
+ modify_time = os.stat(filepath).st_mtime
+ break
+ if time.time() - modify_time > 1 and re_download: # 通过这种比较丑陋的方式判断一下文件是否是才下载的
+ shutil.rmtree(data_dir)
+ data_dir = self._get_dataset_path(dataset_name=self.dataset_name)
+
+ if not os.path.exists(os.path.join(data_dir, 'dev.txt')):
+ if dev_ratio > 0:
+ assert 0 < dev_ratio < 1, "dev_ratio should be in range (0,1)."
+ try:
+ with open(os.path.join(data_dir, 'train.txt'), 'r', encoding='utf-8') as f, \
+ open(os.path.join(data_dir, 'middle_file.txt'), 'w', encoding='utf-8') as f1, \
+ open(os.path.join(data_dir, 'dev.txt'), 'w', encoding='utf-8') as f2:
+ for line in f:
+ if random.random() < dev_ratio:
+ f2.write(line)
+ else:
+ f1.write(line)
+ os.remove(os.path.join(data_dir, 'train.txt'))
+ os.renames(os.path.join(data_dir, 'middle_file.txt'), os.path.join(data_dir, 'train.txt'))
+ finally:
+ if os.path.exists(os.path.join(data_dir, 'middle_file.txt')):
+ os.remove(os.path.join(data_dir, 'middle_file.txt'))
+
+ return data_dir
diff --git a/fastNLP/io/loader/json.py b/fastNLP/io/loader/json.py
new file mode 100644
index 00000000..012dee5a
--- /dev/null
+++ b/fastNLP/io/loader/json.py
@@ -0,0 +1,46 @@
+"""undocumented"""
+
+__all__ = [
+ "JsonLoader"
+]
+
+from .loader import Loader
+from ..file_reader import _read_json
+from ...core.dataset import DataSet
+from ...core.instance import Instance
+
+
+class JsonLoader(Loader):
+ """
+ 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader`
+
+ 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象
+
+ :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name
+ ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` ,
+ `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名
+ ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None``
+ :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` .
+ Default: ``False``
+ """
+
+ def __init__(self, fields=None, dropna=False):
+ super(JsonLoader, self).__init__()
+ self.dropna = dropna
+ self.fields = None
+ self.fields_list = None
+ if fields:
+ self.fields = {}
+ for k, v in fields.items():
+ self.fields[k] = k if v is None else v
+ self.fields_list = list(self.fields.keys())
+
+ def _load(self, path):
+ ds = DataSet()
+ for idx, d in _read_json(path, fields=self.fields_list, dropna=self.dropna):
+ if self.fields:
+ ins = {self.fields[k]: v for k, v in d.items()}
+ else:
+ ins = d
+ ds.append(Instance(**ins))
+ return ds
diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py
new file mode 100644
index 00000000..22636a27
--- /dev/null
+++ b/fastNLP/io/loader/loader.py
@@ -0,0 +1,91 @@
+"""undocumented"""
+
+__all__ = [
+ "Loader"
+]
+
+from typing import Union, Dict
+
+from .. import DataBundle
+from ..file_utils import _get_dataset_url, get_cache_path, cached_path
+from ..utils import check_loader_paths
+from ...core.dataset import DataSet
+
+
+class Loader:
+ """
+ 各种数据 Loader 的基类,提供了 API 的参考.
+
+ """
+
+ def __init__(self):
+ pass
+
+ def _load(self, path: str) -> DataSet:
+ """
+ 给定一个路径,返回读取的DataSet。
+
+ :param str path: 路径
+ :return: DataSet
+ """
+ raise NotImplementedError
+
+ def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
+ """
+ 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。
+
+ 读取的field根据ConllLoader初始化时传入的headers决定。
+
+ :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式
+ (0) 如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。
+
+ (1) 传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件
+ 名包含'train'、 'dev'、 'test'则会报错::
+
+ data_bundle = ConllLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train、
+ # dev、 test等有所变化,可以通过以下的方式取出DataSet
+ tr_data = data_bundle.datasets['train']
+ te_data = data_bundle.datasets['test'] # 如果目录下有文件包含test这个字段
+
+ (2) 传入文件路径::
+
+ data_bundle = ConllLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train'
+ tr_data = data_bundle.datasets['train'] # 可以通过以下的方式取出DataSet
+
+ (3) 传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test::
+
+ paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"}
+ data_bundle = ConllLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test"
+ dev_data = data_bundle.datasets['dev']
+
+ :return: 返回的 :class:`~fastNLP.io.DataBundle`
+ """
+ if paths is None:
+ paths = self.download()
+ paths = check_loader_paths(paths)
+ datasets = {name: self._load(path) for name, path in paths.items()}
+ data_bundle = DataBundle(datasets=datasets)
+ return data_bundle
+
+ def download(self) -> str:
+ """
+ 自动下载该数据集
+
+ :return: 下载后解压目录
+ """
+ raise NotImplementedError(f"{self.__class__} cannot download data automatically.")
+
+ @staticmethod
+ def _get_dataset_path(dataset_name):
+ """
+ 传入dataset的名称,获取读取数据的目录。如果数据不存在,会尝试自动下载并缓存
+
+ :param str dataset_name: 数据集的名称
+ :return: str, 数据集的目录地址。直接到该目录下读取相应的数据即可。
+ """
+
+ default_cache_path = get_cache_path()
+ url = _get_dataset_url(dataset_name)
+ output_dir = cached_path(url_or_filename=url, cache_dir=default_cache_path, name='dataset')
+
+ return output_dir
diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py
new file mode 100644
index 00000000..7f03ca3e
--- /dev/null
+++ b/fastNLP/io/loader/matching.py
@@ -0,0 +1,319 @@
+"""undocumented"""
+
+__all__ = [
+ "MNLILoader",
+ "SNLILoader",
+ "QNLILoader",
+ "RTELoader",
+ "QuoraLoader",
+]
+
+import os
+import warnings
+from typing import Union, Dict
+
+from .json import JsonLoader
+from .loader import Loader
+from .. import DataBundle
+from ...core.const import Const
+from ...core.dataset import DataSet
+from ...core.instance import Instance
+
+
+class MNLILoader(Loader):
+ """
+ 读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没
+ 有target列。
+
+ .. csv-table::
+ :header: "raw_words1", "raw_words2", "target"
+
+ "The new rights are...", "Everyone really likes..", "neutral"
+ "This site includes a...", "The Government Executive...", "contradiction"
+ "...", "...","."
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def _load(self, path: str):
+ ds = DataSet()
+ with open(path, 'r', encoding='utf-8') as f:
+ f.readline() # 跳过header
+ if path.endswith("test.tsv"):
+ warnings.warn("RTE's test file has no target.")
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ raw_words1 = parts[8]
+ raw_words2 = parts[9]
+ if raw_words1 and raw_words2:
+ ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2))
+ else:
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ raw_words1 = parts[8]
+ raw_words2 = parts[9]
+ target = parts[-1]
+ if raw_words1 and raw_words2 and target:
+ ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
+ return ds
+
+ def load(self, paths: str = None):
+ """
+
+ :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv,
+ test_mismatched.tsv, train.tsv文件夹
+ :return: DataBundle
+ """
+ if paths:
+ paths = os.path.abspath(os.path.expanduser(paths))
+ else:
+ paths = self.download()
+ if not os.path.isdir(paths):
+ raise NotADirectoryError(f"{paths} is not a valid directory.")
+
+ files = {'dev_matched': "dev_matched.tsv",
+ "dev_mismatched": "dev_mismatched.tsv",
+ "test_matched": "test_matched.tsv",
+ "test_mismatched": "test_mismatched.tsv",
+ "train": 'train.tsv'}
+
+ datasets = {}
+ for name, filename in files.items():
+ filepath = os.path.join(paths, filename)
+ if not os.path.isfile(filepath):
+ if 'test' not in name:
+ raise FileNotFoundError(f"{name} not found in directory {filepath}.")
+ datasets[name] = self._load(filepath)
+
+ data_bundle = DataBundle(datasets=datasets)
+
+ return data_bundle
+
+ def download(self):
+ """
+ 如果你使用了这个数据,请引用
+
+ https://www.nyu.edu/projects/bowman/multinli/paper.pdf
+ :return:
+ """
+ output_dir = self._get_dataset_path('mnli')
+ return output_dir
+
+
+class SNLILoader(JsonLoader):
+ """
+ 读取之后的DataSet中的field情况为
+
+ .. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field
+ :header: "raw_words1", "raw_words2", "target"
+
+ "The new rights are...", "Everyone really likes..", "neutral"
+ "This site includes a...", "The Government Executive...", "entailment"
+ "...", "...", "."
+
+ """
+
+ def __init__(self):
+ super().__init__(fields={
+ 'sentence1': Const.RAW_WORDS(0),
+ 'sentence2': Const.RAW_WORDS(1),
+ 'gold_label': Const.TARGET,
+ })
+
+ def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle:
+ """
+ 从指定一个或多个路径中的文件中读取数据,返回:class:`~fastNLP.io.DataBundle` 。
+
+ 读取的field根据ConllLoader初始化时传入的headers决定。
+
+ :param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl
+ 和snli_1.0_test.jsonl三个文件。
+
+ :return: 返回的:class:`~fastNLP.io.DataBundle`
+ """
+ _paths = {}
+ if paths is None:
+ paths = self.download()
+ if paths:
+ if os.path.isdir(paths):
+ if not os.path.isfile(os.path.join(paths, 'snli_1.0_train.jsonl')):
+ raise FileNotFoundError(f"snli_1.0_train.jsonl is not found in {paths}")
+ _paths['train'] = os.path.join(paths, 'snli_1.0_train.jsonl')
+ for filename in ['snli_1.0_dev.jsonl', 'snli_1.0_test.jsonl']:
+ filepath = os.path.join(paths, filename)
+ _paths[filename.split('_')[-1].split('.')[0]] = filepath
+ paths = _paths
+ else:
+ raise NotADirectoryError(f"{paths} is not a valid directory.")
+
+ datasets = {name: self._load(path) for name, path in paths.items()}
+ data_bundle = DataBundle(datasets=datasets)
+ return data_bundle
+
+ def download(self):
+ """
+ 如果您的文章使用了这份数据,请引用
+
+ http://nlp.stanford.edu/pubs/snli_paper.pdf
+
+ :return: str
+ """
+ return self._get_dataset_path('snli')
+
+
+class QNLILoader(JsonLoader):
+ """
+ QNLI数据集的Loader,
+ 加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label
+
+ .. csv-table::
+ :header: "raw_words1", "raw_words2", "target"
+
+ "What came into force after the new...", "As of that day...", "entailment"
+ "What is the first major...", "The most important tributaries", "not_entailment"
+ "...","."
+
+ test数据集没有target列
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def _load(self, path):
+ ds = DataSet()
+
+ with open(path, 'r', encoding='utf-8') as f:
+ f.readline() # 跳过header
+ if path.endswith("test.tsv"):
+ warnings.warn("QNLI's test file has no target.")
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ raw_words1 = parts[1]
+ raw_words2 = parts[2]
+ if raw_words1 and raw_words2:
+ ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2))
+ else:
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ raw_words1 = parts[1]
+ raw_words2 = parts[2]
+ target = parts[-1]
+ if raw_words1 and raw_words2 and target:
+ ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
+ return ds
+
+ def download(self):
+ """
+ 如果您的实验使用到了该数据,请引用
+
+ .. todo::
+ 补充
+
+ :return:
+ """
+ return self._get_dataset_path('qnli')
+
+
+class RTELoader(Loader):
+ """
+ RTE数据的loader
+ 加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label
+
+ .. csv-table::
+ :header: "raw_words1", "raw_words2", "target"
+
+ "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment"
+ "Yet, we now are discovering that...", "Bacteria is winning...", "entailment"
+ "...","."
+
+ test数据集没有target列
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def _load(self, path: str):
+ ds = DataSet()
+
+ with open(path, 'r', encoding='utf-8') as f:
+ f.readline() # 跳过header
+ if path.endswith("test.tsv"):
+ warnings.warn("RTE's test file has no target.")
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ raw_words1 = parts[1]
+ raw_words2 = parts[2]
+ if raw_words1 and raw_words2:
+ ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2))
+ else:
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ raw_words1 = parts[1]
+ raw_words2 = parts[2]
+ target = parts[-1]
+ if raw_words1 and raw_words2 and target:
+ ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
+ return ds
+
+ def download(self):
+ return self._get_dataset_path('rte')
+
+
+class QuoraLoader(Loader):
+ """
+ Quora matching任务的数据集Loader
+
+ 支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子
+
+ Example::
+
+ 1 How do I get funding for my web based startup idea ? How do I get seed funding pre product ? 327970
+ 1 How can I stop my depression ? What can I do to stop being depressed ? 339556
+ ...
+
+ 加载的DataSet将具备以下的field
+
+ .. csv-table::
+ :header: "raw_words1", "raw_words2", "target"
+
+ "What should I do to avoid...", "1"
+ "How do I not sleep in a boring class...", "0"
+ "...","."
+
+ """
+
+ def __init__(self):
+ super().__init__()
+
+ def _load(self, path: str):
+ ds = DataSet()
+
+ with open(path, 'r', encoding='utf-8') as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ parts = line.split('\t')
+ raw_words1 = parts[1]
+ raw_words2 = parts[2]
+ target = parts[0]
+ if raw_words1 and raw_words2 and target:
+ ds.append(Instance(raw_words1=raw_words1, raw_words2=raw_words2, target=target))
+ return ds
+
+ def download(self):
+ raise RuntimeError("Quora cannot be downloaded automatically.")
diff --git a/fastNLP/io/model_io.py b/fastNLP/io/model_io.py
index ffaa4ef5..22ced1ce 100644
--- a/fastNLP/io/model_io.py
+++ b/fastNLP/io/model_io.py
@@ -8,7 +8,7 @@ __all__ = [
import torch
-from .base_loader import BaseLoader
+from .data_bundle import BaseLoader
class ModelLoader(BaseLoader):
diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py
new file mode 100644
index 00000000..048e4cfe
--- /dev/null
+++ b/fastNLP/io/pipe/__init__.py
@@ -0,0 +1,48 @@
+"""
+Pipe用于处理通过 Loader 读取的数据,所有的 Pipe 都包含 ``process`` 和 ``process_from_file`` 两种方法。
+``process(data_bundle)`` 传入一个 :class:`~fastNLP.io.DataBundle` 类型的对象, 在传入的 `data_bundle` 上进行原位修改,并将其返回;
+``process_from_file(paths)`` 传入的文件路径,返回一个 :class:`~fastNLP.io.DataBundle` 类型的对象。
+``process(data_bundle)`` 或者 ``process_from_file(paths)`` 的返回 `data_bundle` 中的 :class:`~fastNLP.DataSet`
+一般都包含原文与转换为index的输入以及转换为index的target;除了 :class:`~fastNLP.DataSet` 之外,
+`data_bundle` 还会包含将field转为index时所建立的词表。
+
+"""
+__all__ = [
+ "Pipe",
+
+ "CWSPipe",
+
+ "YelpFullPipe",
+ "YelpPolarityPipe",
+ "SSTPipe",
+ "SST2Pipe",
+ "IMDBPipe",
+
+ "Conll2003NERPipe",
+ "OntoNotesNERPipe",
+ "MsraNERPipe",
+ "WeiboNERPipe",
+ "PeopleDailyPipe",
+ "Conll2003Pipe",
+
+ "MatchingBertPipe",
+ "RTEBertPipe",
+ "SNLIBertPipe",
+ "QuoraBertPipe",
+ "QNLIBertPipe",
+ "MNLIBertPipe",
+ "MatchingPipe",
+ "RTEPipe",
+ "SNLIPipe",
+ "QuoraPipe",
+ "QNLIPipe",
+ "MNLIPipe",
+]
+
+from .classification import YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe
+from .conll import Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe
+from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \
+ MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe
+from .pipe import Pipe
+from .conll import Conll2003Pipe
+from .cws import CWSPipe
diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py
new file mode 100644
index 00000000..30c591a4
--- /dev/null
+++ b/fastNLP/io/pipe/classification.py
@@ -0,0 +1,459 @@
+"""undocumented"""
+
+__all__ = [
+ "YelpFullPipe",
+ "YelpPolarityPipe",
+ "SSTPipe",
+ "SST2Pipe",
+ 'IMDBPipe'
+]
+
+import re
+
+from nltk import Tree
+
+from .pipe import Pipe
+from .utils import get_tokenizer, _indexize, _add_words_field, _drop_empty_instance
+from ..data_bundle import DataBundle
+from ..loader.classification import IMDBLoader, YelpFullLoader, SSTLoader, SST2Loader, YelpPolarityLoader
+from ...core.const import Const
+from ...core.dataset import DataSet
+from ...core.instance import Instance
+from ...core.vocabulary import Vocabulary
+
+nonalpnum = re.compile('[^0-9a-zA-Z?!\']+')
+
+
+
+class _CLSPipe(Pipe):
+ """
+ 分类问题的基类,负责对classification的数据进行tokenize操作。默认是对raw_words列操作,然后生成words列
+
+ """
+
+ def __init__(self, tokenizer: str = 'spacy', lang='en'):
+ self.tokenizer = get_tokenizer(tokenizer, lang=lang)
+
+ def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
+ """
+ 将DataBundle中的数据进行tokenize
+
+ :param DataBundle data_bundle:
+ :param str field_name:
+ :param str new_field_name:
+ :return: 传入的DataBundle对象
+ """
+ new_field_name = new_field_name or field_name
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
+
+ return data_bundle
+
+ def _granularize(self, data_bundle, tag_map):
+ """
+ 该函数对data_bundle中'target'列中的内容进行转换。
+
+ :param data_bundle:
+ :param dict tag_map: 将target列中的tag做以下的映射,比如{"0":0, "1":0, "3":1, "4":1}, 则会删除target为"2"的instance,
+ 且将"1"认为是第0类。
+ :return: 传入的data_bundle
+ """
+ for name in list(data_bundle.datasets.keys()):
+ dataset = data_bundle.get_dataset(name)
+ dataset.apply_field(lambda target: tag_map.get(target, -100), field_name=Const.TARGET,
+ new_field_name=Const.TARGET)
+ dataset.drop(lambda ins: ins[Const.TARGET] == -100)
+ data_bundle.set_dataset(dataset, name)
+ return data_bundle
+
+
+def _clean_str(words):
+ """
+ heavily borrowed from github
+ https://github.com/LukeZhuang/Hierarchical-Attention-Network/blob/master/yelp-preprocess.ipynb
+ :param sentence: is a str
+ :return:
+ """
+ words_collection = []
+ for word in words:
+ if word in ['-lrb-', '-rrb-', '', '-r', '-l', 'b-']:
+ continue
+ tt = nonalpnum.split(word)
+ t = ''.join(tt)
+ if t != '':
+ words_collection.append(t)
+
+ return words_collection
+
+
+class YelpFullPipe(_CLSPipe):
+ """
+ 处理YelpFull的数据, 处理之后DataSet中的内容如下
+
+ .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field
+ :header: "raw_words", "words", "target", "seq_len"
+
+ "It 's a ...", "[4, 2, 10, ...]", 0, 10
+ "Offers that ...", "[20, 40, ...]", 1, 21
+ "...", "[...]", ., .
+
+ :param bool lower: 是否对输入进行小写化。
+ :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将
+ 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。
+ :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
+ """
+
+ def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy'):
+ super().__init__(tokenizer=tokenizer, lang='en')
+ self.lower = lower
+ assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
+ self.granularity = granularity
+
+ if granularity == 2:
+ self.tag_map = {"1": 0, "2": 0, "4": 1, "5": 1}
+ elif granularity == 3:
+ self.tag_map = {"1": 0, "2": 0, "3": 1, "4": 2, "5": 2}
+ else:
+ self.tag_map = {"1": 0, "2": 1, "3": 2, "4": 3, "5": 4}
+
+ def _tokenize(self, data_bundle, field_name=Const.INPUT, new_field_name=None):
+ """
+ 将DataBundle中的数据进行tokenize
+
+ :param DataBundle data_bundle:
+ :param str field_name:
+ :param str new_field_name:
+ :return: 传入的DataBundle对象
+ """
+ new_field_name = new_field_name or field_name
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name)
+ dataset.apply_field(_clean_str, field_name=field_name, new_field_name=new_field_name)
+ return data_bundle
+
+ def process(self, data_bundle):
+ """
+ 传入的DataSet应该具备如下的结构
+
+ .. csv-table::
+ :header: "raw_words", "target"
+
+ "I got 'new' tires from them and... ", "1"
+ "Don't waste your time. We had two...", "1"
+ "...", "..."
+
+ :param data_bundle:
+ :return:
+ """
+
+ # 复制一列words
+ data_bundle = _add_words_field(data_bundle, lower=self.lower)
+
+ # 进行tokenize
+ data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
+
+ # 根据granularity设置tag
+ data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
+
+ # 删除空行
+ data_bundle = _drop_empty_instance(data_bundle, field_name=Const.INPUT)
+
+ # index
+ data_bundle = _indexize(data_bundle=data_bundle)
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+
+ data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
+ data_bundle.set_target(Const.TARGET)
+
+ return data_bundle
+
+ def process_from_file(self, paths=None):
+ """
+
+ :param paths:
+ :return: DataBundle
+ """
+ data_bundle = YelpFullLoader().load(paths)
+ return self.process(data_bundle=data_bundle)
+
+
+class YelpPolarityPipe(_CLSPipe):
+ """
+ 处理YelpPolarity的数据, 处理之后DataSet中的内容如下
+
+ .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field
+ :header: "raw_words", "words", "target", "seq_len"
+
+ "It 's a ...", "[4, 2, 10, ...]", 0, 10
+ "Offers that ...", "[20, 40, ...]", 1, 21
+ "...", "[...]", ., .
+
+ :param bool lower: 是否对输入进行小写化。
+ :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
+ """
+
+ def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
+ super().__init__(tokenizer=tokenizer, lang='en')
+ self.lower = lower
+
+ def process(self, data_bundle):
+ # 复制一列words
+ data_bundle = _add_words_field(data_bundle, lower=self.lower)
+
+ # 进行tokenize
+ data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
+ # index
+ data_bundle = _indexize(data_bundle=data_bundle)
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+
+ data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
+ data_bundle.set_target(Const.TARGET)
+
+ return data_bundle
+
+ def process_from_file(self, paths=None):
+ """
+
+ :param str paths:
+ :return: DataBundle
+ """
+ data_bundle = YelpPolarityLoader().load(paths)
+ return self.process(data_bundle=data_bundle)
+
+
+class SSTPipe(_CLSPipe):
+ """
+ 别名::class:`fastNLP.io.SSTPipe` :class:`fastNLP.io.pipe.SSTPipe`
+
+ 经过该Pipe之后,DataSet中具备的field如下所示
+
+ .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field
+ :header: "raw_words", "words", "target", "seq_len"
+
+ "It 's a ...", "[4, 2, 10, ...]", 0, 16
+ "Offers that ...", "[20, 40, ...]", 1, 18
+ "...", "[...]", ., .
+
+ :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False``
+ :param bool train_subtree: 是否将train集通过子树扩展数据。
+ :param bool lower: 是否对输入进行小写化。
+ :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将
+ 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。
+ :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
+ """
+
+ def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy'):
+ super().__init__(tokenizer=tokenizer, lang='en')
+ self.subtree = subtree
+ self.train_tree = train_subtree
+ self.lower = lower
+ assert granularity in (2, 3, 5), "granularity can only be 2,3,5."
+ self.granularity = granularity
+
+ if granularity == 2:
+ self.tag_map = {"0": 0, "1": 0, "3": 1, "4": 1}
+ elif granularity == 3:
+ self.tag_map = {"0": 0, "1": 0, "2": 1, "3": 2, "4": 2}
+ else:
+ self.tag_map = {"0": 0, "1": 1, "2": 2, "3": 3, "4": 4}
+
+ def process(self, data_bundle: DataBundle):
+ """
+ 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与
+
+ .. csv-table::
+ :header: "raw_words"
+
+ "(3 (2 It) (4 (4 (2 's) (4 (3 (2 a)..."
+ "(4 (4 (2 Offers) (3 (3 (2 that) (3 (3 rare)..."
+ "..."
+
+ :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象
+ :return:
+ """
+ # 先取出subtree
+ for name in list(data_bundle.datasets.keys()):
+ dataset = data_bundle.get_dataset(name)
+ ds = DataSet()
+ use_subtree = self.subtree or (name == 'train' and self.train_tree)
+ for ins in dataset:
+ raw_words = ins['raw_words']
+ tree = Tree.fromstring(raw_words)
+ if use_subtree:
+ for t in tree.subtrees():
+ raw_words = " ".join(t.leaves())
+ instance = Instance(raw_words=raw_words, target=t.label())
+ ds.append(instance)
+ else:
+ instance = Instance(raw_words=' '.join(tree.leaves()), target=tree.label())
+ ds.append(instance)
+ data_bundle.set_dataset(ds, name)
+
+ _add_words_field(data_bundle, lower=self.lower)
+
+ # 进行tokenize
+ data_bundle = self._tokenize(data_bundle=data_bundle, field_name=Const.INPUT)
+
+ # 根据granularity设置tag
+ data_bundle = self._granularize(data_bundle, tag_map=self.tag_map)
+
+ # index
+ data_bundle = _indexize(data_bundle=data_bundle)
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+
+ data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
+ data_bundle.set_target(Const.TARGET)
+
+ return data_bundle
+
+ def process_from_file(self, paths=None):
+ data_bundle = SSTLoader().load(paths)
+ return self.process(data_bundle=data_bundle)
+
+
+class SST2Pipe(_CLSPipe):
+ """
+ 加载SST2的数据, 处理完成之后DataSet将拥有以下的field
+
+ .. csv-table::
+ :header: "raw_words", "words", "target", "seq_len"
+
+ "it 's a charming and... ", "[3, 4, 5, 6, 7,...]", 1, 43
+ "unflinchingly bleak and...", "[10, 11, 7,...]", 1, 21
+ "...", "...", ., .
+
+ :param bool lower: 是否对输入进行小写化。
+ :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。
+ """
+
+ def __init__(self, lower=False, tokenizer='spacy'):
+ super().__init__(tokenizer=tokenizer, lang='en')
+ self.lower = lower
+
+ def process(self, data_bundle: DataBundle):
+ """
+ 可以处理的DataSet应该具备如下的结构
+
+ .. csv-table::
+ :header: "raw_words", "target"
+
+ "it 's a charming and... ", 1
+ "unflinchingly bleak and...", 1
+ "...", "..."
+
+ :param data_bundle:
+ :return:
+ """
+ _add_words_field(data_bundle, self.lower)
+
+ data_bundle = self._tokenize(data_bundle=data_bundle)
+
+ src_vocab = Vocabulary()
+ src_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.INPUT,
+ no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
+ name != 'train'])
+ src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
+
+ tgt_vocab = Vocabulary(unknown=None, padding=None)
+ tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
+ datasets = []
+ for name, dataset in data_bundle.datasets.items():
+ if dataset.has_field(Const.TARGET):
+ datasets.append(dataset)
+ tgt_vocab.index_dataset(*datasets, field_name=Const.TARGET)
+
+ data_bundle.set_vocab(src_vocab, Const.INPUT)
+ data_bundle.set_vocab(tgt_vocab, Const.TARGET)
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+
+ data_bundle.set_input(Const.INPUT, Const.INPUT_LEN)
+ data_bundle.set_target(Const.TARGET)
+
+ return data_bundle
+
+ def process_from_file(self, paths=None):
+ """
+
+ :param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。
+ :return: DataBundle
+ """
+ data_bundle = SST2Loader().load(paths)
+ return self.process(data_bundle)
+
+
+class IMDBPipe(_CLSPipe):
+ """
+ 经过本Pipe处理后DataSet将如下
+
+ .. csv-table:: 输出DataSet的field
+ :header: "raw_words", "words", "target", "seq_len"
+
+ "Bromwell High is a cartoon ... ", "[3, 5, 6, 9, ...]", 0, 20
+ "Story of a man who has ...", "[20, 43, 9, 10, ...]", 1, 31
+ "...", "[...]", ., .
+
+ 其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值;
+ words列被设置为input; target列被设置为target。
+
+ :param bool lower: 是否将words列的数据小写。
+ :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
+ """
+
+ def __init__(self, lower: bool = False, tokenizer: str = 'spacy'):
+ super().__init__(tokenizer=tokenizer, lang='en')
+ self.lower = lower
+
+ def process(self, data_bundle: DataBundle):
+ """
+ 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型
+
+ .. csv-table:: 输入DataSet的field
+ :header: "raw_words", "target"
+
+ "Bromwell High is a cartoon ... ", "pos"
+ "Story of a man who has ...", "neg"
+ "...", "..."
+
+ :param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str,
+ target列应该为str。
+ :return: DataBundle
+ """
+
+ # 替换
+ def replace_br(raw_words):
+ raw_words = raw_words.replace("
", ' ')
+ return raw_words
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(replace_br, field_name=Const.RAW_WORD, new_field_name=Const.RAW_WORD)
+
+ _add_words_field(data_bundle, lower=self.lower)
+ self._tokenize(data_bundle, field_name=Const.INPUT, new_field_name=Const.INPUT)
+ _indexize(data_bundle)
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+ dataset.set_input(Const.INPUT, Const.INPUT_LEN)
+ dataset.set_target(Const.TARGET)
+
+ return data_bundle
+
+ def process_from_file(self, paths=None):
+ """
+
+ :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。
+ :return: DataBundle
+ """
+ # 读取数据
+ data_bundle = IMDBLoader().load(paths)
+ data_bundle = self.process(data_bundle)
+
+ return data_bundle
diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py
new file mode 100644
index 00000000..2efec8e0
--- /dev/null
+++ b/fastNLP/io/pipe/conll.py
@@ -0,0 +1,332 @@
+"""undocumented"""
+
+__all__ = [
+ "Conll2003NERPipe",
+ "Conll2003Pipe",
+ "OntoNotesNERPipe",
+ "MsraNERPipe",
+ "PeopleDailyPipe",
+ "WeiboNERPipe"
+]
+
+from .pipe import Pipe
+from .utils import _add_chars_field
+from .utils import _indexize, _add_words_field
+from .utils import iob2, iob2bioes
+from .. import DataBundle
+from ..loader.conll import Conll2003NERLoader, OntoNotesNERLoader
+from ..loader.conll import PeopleDailyNERLoader, WeiboNERLoader, MsraNERLoader, ConllLoader
+from ...core.const import Const
+from ...core.vocabulary import Vocabulary
+
+
+class _NERPipe(Pipe):
+ """
+ NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表
+ (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的
+ Vocabulary转换为index。
+
+ raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target, seq_len。
+
+ :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
+ :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
+ """
+
+ def __init__(self, encoding_type: str = 'bio', lower: bool = False):
+ if encoding_type == 'bio':
+ self.convert_tag = iob2
+ else:
+ self.convert_tag = lambda words: iob2bioes(iob2(words))
+ self.lower = lower
+
+ def process(self, data_bundle: DataBundle) -> DataBundle:
+ """
+ 支持的DataSet的field为
+
+ .. csv-table::
+ :header: "raw_words", "target"
+
+ "[Nadim, Ladki]", "[B-PER, I-PER]"
+ "[AL-AIN, United, Arab, ...]", "[B-LOC, B-LOC, I-LOC, ...]"
+ "[...]", "[...]"
+
+ :param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。
+ 在传入DataBundle基础上原位修改。
+ :return: DataBundle
+ """
+ # 转换tag
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
+
+ _add_words_field(data_bundle, lower=self.lower)
+
+ # index
+ _indexize(data_bundle)
+
+ input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
+ target_fields = [Const.TARGET, Const.INPUT_LEN]
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+
+ data_bundle.set_input(*input_fields)
+ data_bundle.set_target(*target_fields)
+
+ return data_bundle
+
+
+class Conll2003NERPipe(_NERPipe):
+ """
+ Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表
+ (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的
+ Vocabulary转换为index。
+ 经过该Pipe过后,DataSet中的内容如下所示
+
+ .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
+ :header: "raw_words", "words", "target", "seq_len"
+
+ "[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2
+ "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4,...]", 6
+ "[...]", "[...]", "[...]", .
+
+ raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。
+
+ :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
+ :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
+ """
+
+ def process_from_file(self, paths) -> DataBundle:
+ """
+
+ :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。
+ :return: DataBundle
+ """
+ # 读取数据
+ data_bundle = Conll2003NERLoader().load(paths)
+ data_bundle = self.process(data_bundle)
+
+ return data_bundle
+
+
+class Conll2003Pipe(Pipe):
+ def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False):
+ """
+ 经过该Pipe后,DataSet中的内容如下
+
+ .. csv-table::
+ :header: "raw_words", "words", "pos", "chunk", "ner", "seq_len"
+
+ "[Nadim, Ladki]", "[2, 3]", "[0, 0]", "[1, 2]", "[1, 2]", 2
+ "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", 6
+ "[...]", "[...]", "[...]", "[...]", "[...]".
+
+ 其中words, seq_len是input; pos, chunk, ner, seq_len是target
+
+ :param str chunk_encoding_type: 支持bioes, bio。
+ :param str ner_encoding_type: 支持bioes, bio。
+ :param bool lower: 是否将words列小写化后再建立词表
+ """
+ if chunk_encoding_type == 'bio':
+ self.chunk_convert_tag = iob2
+ else:
+ self.chunk_convert_tag = lambda tags: iob2bioes(iob2(tags))
+ if ner_encoding_type == 'bio':
+ self.ner_convert_tag = iob2
+ else:
+ self.ner_convert_tag = lambda tags: iob2bioes(iob2(tags))
+ self.lower = lower
+
+ def process(self, data_bundle) -> DataBundle:
+ """
+ 输入的DataSet应该类似于如下的形式
+
+ .. csv-table::
+ :header: "raw_words", "pos", "chunk", "ner"
+
+ "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]"
+ "[AL-AIN, United, Arab, ...]", "[NNP, NNP...]", "[B-NP, B-NP, ...]", "[B-LOC, B-LOC,...]"
+ "[...]", "[...]", "[...]", "[...]".
+
+ :param data_bundle:
+ :return: 传入的DataBundle
+ """
+ # 转换tag
+ for name, dataset in data_bundle.datasets.items():
+ dataset.drop(lambda x: "-DOCSTART-" in x[Const.RAW_WORD])
+ dataset.apply_field(self.chunk_convert_tag, field_name='chunk', new_field_name='chunk')
+ dataset.apply_field(self.ner_convert_tag, field_name='ner', new_field_name='ner')
+
+ _add_words_field(data_bundle, lower=self.lower)
+
+ # index
+ _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=['pos', 'ner'])
+ # chunk中存在一些tag只在dev中出现,没在train中
+ tgt_vocab = Vocabulary(unknown=None, padding=None)
+ tgt_vocab.from_dataset(*data_bundle.datasets.values(), field_name='chunk')
+ tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name='chunk')
+ data_bundle.set_vocab(tgt_vocab, 'chunk')
+
+ input_fields = [Const.INPUT, Const.INPUT_LEN]
+ target_fields = ['pos', 'ner', 'chunk', Const.INPUT_LEN]
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+
+ data_bundle.set_input(*input_fields)
+ data_bundle.set_target(*target_fields)
+
+ return data_bundle
+
+ def process_from_file(self, paths):
+ """
+
+ :param paths:
+ :return:
+ """
+ data_bundle = ConllLoader(headers=['raw_words', 'pos', 'chunk', 'ner']).load(paths)
+ return self.process(data_bundle)
+
+
+class OntoNotesNERPipe(_NERPipe):
+ """
+ 处理OntoNotes的NER数据,处理之后DataSet中的field情况为
+
+ .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader
+ :header: "raw_words", "words", "target", "seq_len"
+
+ "[Nadim, Ladki]", "[2, 3]", "[1, 2]", 2
+ "[AL-AIN, United, Arab, ...]", "[4, 5, 6,...]", "[3, 4]", 6
+ "[...]", "[...]", "[...]", .
+
+ raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。
+
+ :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
+ :param bool lower: 是否将words小写化后再建立词表,绝大多数情况都不需要设置为True。
+ """
+
+ def process_from_file(self, paths):
+ data_bundle = OntoNotesNERLoader().load(paths)
+ return self.process(data_bundle)
+
+
+class _CNNERPipe(Pipe):
+ """
+ 中文NER任务的处理Pipe, 该Pipe会(1)复制raw_chars列,并命名为chars; (2)在chars, target列建立词表
+ (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将chars,target列根据相应的
+ Vocabulary转换为index。
+
+ raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target, seq_len。
+
+ :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
+ """
+
+ def __init__(self, encoding_type: str = 'bio'):
+ if encoding_type == 'bio':
+ self.convert_tag = iob2
+ else:
+ self.convert_tag = lambda words: iob2bioes(iob2(words))
+
+ def process(self, data_bundle: DataBundle) -> DataBundle:
+ """
+ 支持的DataSet的field为
+
+ .. csv-table::
+ :header: "raw_chars", "target"
+
+ "[相, 比, 之, 下,...]", "[O, O, O, O, ...]"
+ "[青, 岛, 海, 牛, 队, 和, ...]", "[B-ORG, I-ORG, I-ORG, ...]"
+ "[...]", "[...]"
+
+ raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
+
+ :param DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。
+ 在传入DataBundle基础上原位修改。
+ :return: DataBundle
+ """
+ # 转换tag
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(self.convert_tag, field_name=Const.TARGET, new_field_name=Const.TARGET)
+
+ _add_chars_field(data_bundle, lower=False)
+
+ # index
+ _indexize(data_bundle, input_field_names=Const.CHAR_INPUT, target_field_names=Const.TARGET)
+
+ input_fields = [Const.TARGET, Const.CHAR_INPUT, Const.INPUT_LEN]
+ target_fields = [Const.TARGET, Const.INPUT_LEN]
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.CHAR_INPUT)
+
+ data_bundle.set_input(*input_fields)
+ data_bundle.set_target(*target_fields)
+
+ return data_bundle
+
+
+class MsraNERPipe(_CNNERPipe):
+ """
+ 处理MSRA-NER的数据,处理之后的DataSet的field情况为
+
+ .. csv-table::
+ :header: "raw_chars", "chars", "target", "seq_len"
+
+ "[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11
+ "[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21
+ "[...]", "[...]", "[...]", .
+
+ raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
+
+ """
+
+ def process_from_file(self, paths=None) -> DataBundle:
+ data_bundle = MsraNERLoader().load(paths)
+ return self.process(data_bundle)
+
+
+class PeopleDailyPipe(_CNNERPipe):
+ """
+ 处理people daily的ner的数据,处理之后的DataSet的field情况为
+
+ .. csv-table::
+ :header: "raw_chars", "chars", "target", "seq_len"
+
+ "[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11
+ "[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21
+ "[...]", "[...]", "[...]", .
+
+ raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
+ """
+
+ def process_from_file(self, paths=None) -> DataBundle:
+ data_bundle = PeopleDailyNERLoader().load(paths)
+ return self.process(data_bundle)
+
+
+class WeiboNERPipe(_CNNERPipe):
+ """
+ 处理weibo的ner的数据,处理之后的DataSet的field情况为
+
+ .. csv-table::
+ :header: "raw_chars", "chars", "target", "seq_len"
+
+ "[相, 比, 之, 下,...]", "[2, 3, 4, 5, ...]", "[0, 0, 0, 0, ...]", 11
+ "[青, 岛, 海, 牛, 队, 和, ...]", "[10, 21, ....]", "[1, 2, 3, ...]", 21
+ "[...]", "[...]", "[...]", .
+
+ raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的
+ target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。
+
+ :param: str encoding_type: target列使用什么类型的encoding方式,支持bioes, bio两种。
+ """
+
+ def process_from_file(self, paths=None) -> DataBundle:
+ data_bundle = WeiboNERLoader().load(paths)
+ return self.process(data_bundle)
diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py
new file mode 100644
index 00000000..748cf10a
--- /dev/null
+++ b/fastNLP/io/pipe/cws.py
@@ -0,0 +1,266 @@
+"""undocumented"""
+
+__all__ = [
+ "CWSPipe"
+]
+
+import re
+from itertools import chain
+
+from .pipe import Pipe
+from .utils import _indexize
+from .. import DataBundle
+from ..loader import CWSLoader
+from ...core.const import Const
+
+
+def _word_lens_to_bmes(word_lens):
+ """
+
+ :param list word_lens: List[int], 每个词语的长度
+ :return: List[str], BMES的序列
+ """
+ tags = []
+ for word_len in word_lens:
+ if word_len == 1:
+ tags.append('S')
+ else:
+ tags.append('B')
+ tags.extend(['M'] * (word_len - 2))
+ tags.append('E')
+ return tags
+
+
+def _word_lens_to_segapp(word_lens):
+ """
+
+ :param list word_lens: List[int], 每个词语的长度
+ :return: List[str], BMES的序列
+ """
+ tags = []
+ for word_len in word_lens:
+ if word_len == 1:
+ tags.append('SEG')
+ else:
+ tags.extend(['APP'] * (word_len - 1))
+ tags.append('SEG')
+ return tags
+
+
+def _alpha_span_to_special_tag(span):
+ """
+ 将span替换成特殊的字符
+
+ :param str span:
+ :return:
+ """
+ if 'oo' == span.lower(): # speical case when represent 2OO8
+ return span
+ if len(span) == 1:
+ return span
+ else:
+ return ''
+
+
+def _find_and_replace_alpha_spans(line):
+ """
+ 传入原始句子,替换其中的字母为特殊标记
+
+ :param str line:原始数据
+ :return: str
+ """
+ new_line = ''
+ pattern = '[a-zA-Z]+(?=[\u4e00-\u9fff ,%,.。!<-“])'
+ prev_end = 0
+ for match in re.finditer(pattern, line):
+ start, end = match.span()
+ span = line[start:end]
+ new_line += line[prev_end:start] + _alpha_span_to_special_tag(span)
+ prev_end = end
+ new_line += line[prev_end:]
+ return new_line
+
+
+def _digit_span_to_special_tag(span):
+ """
+
+ :param str span: 需要替换的str
+ :return:
+ """
+ if span[0] == '0' and len(span) > 2:
+ return ''
+ decimal_point_count = 0 # one might have more than one decimal pointers
+ for idx, char in enumerate(span):
+ if char == '.' or char == '﹒' or char == '·':
+ decimal_point_count += 1
+ if span[-1] == '.' or span[-1] == '﹒' or span[
+ -1] == '·': # last digit being decimal point means this is not a number
+ if decimal_point_count == 1:
+ return span
+ else:
+ return ''
+ if decimal_point_count == 1:
+ return ''
+ elif decimal_point_count > 1:
+ return ''
+ else:
+ return ''
+
+
+def _find_and_replace_digit_spans(line):
+ """
+ only consider words start with number, contains '.', characters.
+
+ If ends with space, will be processed
+
+ If ends with Chinese character, will be processed
+
+ If ends with or contains english char, not handled.
+
+ floats are replaced by
+
+ otherwise unkdgt
+ """
+ new_line = ''
+ pattern = '\d[\d\\.﹒·]*(?=[\u4e00-\u9fff ,%,。!<-“])'
+ prev_end = 0
+ for match in re.finditer(pattern, line):
+ start, end = match.span()
+ span = line[start:end]
+ new_line += line[prev_end:start] + _digit_span_to_special_tag(span)
+ prev_end = end
+ new_line += line[prev_end:]
+ return new_line
+
+
+class CWSPipe(Pipe):
+ """
+ 对CWS数据进行预处理, 处理之后的数据,具备以下的结构
+
+ .. csv-table::
+ :header: "raw_words", "chars", "target", "bigrams", "trigrams", "seq_len"
+
+ "共同 创造 美好...", "[2, 3, 4...]", "[0, 2, 0, 2,...]", "[10, 4, 1,...]","[6, 4, 1,...]", 13
+ "2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", "[11, 12, ...]","[3, 9, ...]", 20
+ "...", "[...]","[...]", "[...]","[...]", .
+
+ 其中bigrams仅当bigrams列为True的时候为真
+
+ :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None
+ :param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp
+ 的tag为[seg, app, seg, app, app, app, seg, ...]
+ :param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。
+ :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]
+ :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
+ """
+
+ def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, bigrams=False, trigrams=False):
+ if encoding_type == 'bmes':
+ self.word_lens_to_tags = _word_lens_to_bmes
+ else:
+ self.word_lens_to_tags = _word_lens_to_segapp
+
+ self.dataset_name = dataset_name
+ self.bigrams = bigrams
+ self.trigrams = trigrams
+ self.replace_num_alpha = replace_num_alpha
+
+ def _tokenize(self, data_bundle):
+ """
+ 将data_bundle中的'chars'列切分成一个一个的word.
+ 例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ]
+
+ :param data_bundle:
+ :return:
+ """
+ def split_word_into_chars(raw_chars):
+ words = raw_chars.split()
+ chars = []
+ for word in words:
+ char = []
+ subchar = []
+ for c in word:
+ if c == '<':
+ subchar.append(c)
+ continue
+ if c == '>' and subchar[0] == '<':
+ char.append(''.join(subchar))
+ subchar = []
+ if subchar:
+ subchar.append(c)
+ else:
+ char.append(c)
+ char.extend(subchar)
+ chars.append(char)
+ return chars
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT,
+ new_field_name=Const.CHAR_INPUT)
+ return data_bundle
+
+ def process(self, data_bundle: DataBundle) -> DataBundle:
+ """
+ 可以处理的DataSet需要包含raw_words列
+
+ .. csv-table::
+ :header: "raw_words"
+
+ "上海 浦东 开发 与 法制 建设 同步"
+ "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )"
+ "..."
+
+ :param data_bundle:
+ :return:
+ """
+ data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT)
+
+ if self.replace_num_alpha:
+ data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
+ data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
+
+ self._tokenize(data_bundle)
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(lambda chars: self.word_lens_to_tags(map(len, chars)), field_name=Const.CHAR_INPUT,
+ new_field_name=Const.TARGET)
+ dataset.apply_field(lambda chars: list(chain(*chars)), field_name=Const.CHAR_INPUT,
+ new_field_name=Const.CHAR_INPUT)
+ input_field_names = [Const.CHAR_INPUT]
+ if self.bigrams:
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(lambda chars: [c1 + c2 for c1, c2 in zip(chars, chars[1:] + [''])],
+ field_name=Const.CHAR_INPUT, new_field_name='bigrams')
+ input_field_names.append('bigrams')
+ if self.trigrams:
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(lambda chars: [c1 + c2 + c3 for c1, c2, c3 in
+ zip(chars, chars[1:] + [''], chars[2:] + [''] * 2)],
+ field_name=Const.CHAR_INPUT, new_field_name='trigrams')
+ input_field_names.append('trigrams')
+
+ _indexize(data_bundle, input_field_names, Const.TARGET)
+
+ input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
+ target_fields = [Const.TARGET, Const.INPUT_LEN]
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.CHAR_INPUT)
+
+ data_bundle.set_input(*input_fields)
+ data_bundle.set_target(*target_fields)
+
+ return data_bundle
+
+ def process_from_file(self, paths=None) -> DataBundle:
+ """
+
+ :param str paths:
+ :return:
+ """
+ if self.dataset_name is None and paths is None:
+ raise RuntimeError(
+ "You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.")
+ if self.dataset_name is not None and paths is not None:
+ raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously")
+ data_bundle = CWSLoader(self.dataset_name).load(paths)
+ return self.process(data_bundle)
diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py
new file mode 100644
index 00000000..699438c8
--- /dev/null
+++ b/fastNLP/io/pipe/matching.py
@@ -0,0 +1,274 @@
+"""undocumented"""
+
+__all__ = [
+ "MatchingBertPipe",
+ "RTEBertPipe",
+ "SNLIBertPipe",
+ "QuoraBertPipe",
+ "QNLIBertPipe",
+ "MNLIBertPipe",
+ "MatchingPipe",
+ "RTEPipe",
+ "SNLIPipe",
+ "QuoraPipe",
+ "QNLIPipe",
+ "MNLIPipe",
+]
+
+from .pipe import Pipe
+from .utils import get_tokenizer
+from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader
+from ...core.const import Const
+from ...core.vocabulary import Vocabulary
+
+
+class MatchingBertPipe(Pipe):
+ """
+ Matching任务的Bert pipe,输出的DataSet将包含以下的field
+
+ .. csv-table::
+ :header: "raw_words1", "raw_words2", "words", "target", "seq_len"
+
+ "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", 1, 10
+ "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", 0, 5
+ "...", "...", "[...]", ., .
+
+ words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。
+ words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss,
+ 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参).
+
+ :param bool lower: 是否将word小写化。
+ :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。
+ """
+
+ def __init__(self, lower=False, tokenizer: str = 'raw'):
+ super().__init__()
+
+ self.lower = bool(lower)
+ self.tokenizer = get_tokenizer(tokenizer=tokenizer)
+
+ def _tokenize(self, data_bundle, field_names, new_field_names):
+ """
+
+ :param DataBundle data_bundle: DataBundle.
+ :param list field_names: List[str], 需要tokenize的field名称
+ :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。
+ :return: 输入的DataBundle对象
+ """
+ for name, dataset in data_bundle.datasets.items():
+ for field_name, new_field_name in zip(field_names, new_field_names):
+ dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
+ new_field_name=new_field_name)
+ return data_bundle
+
+ def process(self, data_bundle):
+ for dataset in data_bundle.datasets.values():
+ if dataset.has_field(Const.TARGET):
+ dataset.drop(lambda x: x[Const.TARGET] == '-')
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.copy_field(Const.RAW_WORDS(0), Const.INPUTS(0), )
+ dataset.copy_field(Const.RAW_WORDS(1), Const.INPUTS(1), )
+
+ if self.lower:
+ for name, dataset in data_bundle.datasets.items():
+ dataset[Const.INPUTS(0)].lower()
+ dataset[Const.INPUTS(1)].lower()
+
+ data_bundle = self._tokenize(data_bundle, [Const.INPUTS(0), Const.INPUTS(1)],
+ [Const.INPUTS(0), Const.INPUTS(1)])
+
+ # concat两个words
+ def concat(ins):
+ words0 = ins[Const.INPUTS(0)]
+ words1 = ins[Const.INPUTS(1)]
+ words = words0 + ['[SEP]'] + words1
+ return words
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply(concat, new_field_name=Const.INPUT)
+ dataset.delete_field(Const.INPUTS(0))
+ dataset.delete_field(Const.INPUTS(1))
+
+ word_vocab = Vocabulary()
+ word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
+ field_name=Const.INPUT,
+ no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
+ 'train' not in name])
+ word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=Const.INPUT)
+
+ target_vocab = Vocabulary(padding=None, unknown=None)
+ target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
+ has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
+ dataset.has_field(Const.TARGET)]
+ target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
+
+ data_bundle.set_vocab(word_vocab, Const.INPUT)
+ data_bundle.set_vocab(target_vocab, Const.TARGET)
+
+ input_fields = [Const.INPUT, Const.INPUT_LEN]
+ target_fields = [Const.TARGET]
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUT)
+ dataset.set_input(*input_fields, flag=True)
+ for fields in target_fields:
+ if dataset.has_field(fields):
+ dataset.set_target(fields, flag=True)
+
+ return data_bundle
+
+
+class RTEBertPipe(MatchingBertPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = RTELoader().load(paths)
+ return self.process(data_bundle)
+
+
+class SNLIBertPipe(MatchingBertPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = SNLILoader().load(paths)
+ return self.process(data_bundle)
+
+
+class QuoraBertPipe(MatchingBertPipe):
+ def process_from_file(self, paths):
+ data_bundle = QuoraLoader().load(paths)
+ return self.process(data_bundle)
+
+
+class QNLIBertPipe(MatchingBertPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = QNLILoader().load(paths)
+ return self.process(data_bundle)
+
+
+class MNLIBertPipe(MatchingBertPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = MNLILoader().load(paths)
+ return self.process(data_bundle)
+
+
+class MatchingPipe(Pipe):
+ """
+ Matching任务的Pipe。输出的DataSet将包含以下的field
+
+ .. csv-table::
+ :header: "raw_words1", "raw_words2", "words1", "words2", "target", "seq_len1", "seq_len2"
+
+ "The new rights are...", "Everyone really likes..", "[2, 3, 4, 5, ...]", "[10, 20, 6]", 1, 10, 13
+ "This site includes a...", "The Government Executive...", "[11, 12, 13,...]", "[2, 7, ...]", 0, 6, 7
+ "...", "...", "[...]", "[...]", ., ., .
+
+ words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target
+ 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数
+ 的形参名进行传参)。
+
+ :param bool lower: 是否将所有raw_words转为小写。
+ :param str tokenizer: 将原始数据tokenize的方式。支持spacy, raw. spacy是使用spacy切分,raw就是用空格切分。
+ """
+
+ def __init__(self, lower=False, tokenizer: str = 'raw'):
+ super().__init__()
+
+ self.lower = bool(lower)
+ self.tokenizer = get_tokenizer(tokenizer=tokenizer)
+
+ def _tokenize(self, data_bundle, field_names, new_field_names):
+ """
+
+ :param DataBundle data_bundle: DataBundle.
+ :param list field_names: List[str], 需要tokenize的field名称
+ :param list new_field_names: List[str], tokenize之后field的名称,与field_names一一对应。
+ :return: 输入的DataBundle对象
+ """
+ for name, dataset in data_bundle.datasets.items():
+ for field_name, new_field_name in zip(field_names, new_field_names):
+ dataset.apply_field(lambda words: self.tokenizer(words), field_name=field_name,
+ new_field_name=new_field_name)
+ return data_bundle
+
+ def process(self, data_bundle):
+ """
+ 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有
+
+ .. csv-table::
+ :header: "raw_words1", "raw_words2", "target"
+
+ "The new rights are...", "Everyone really likes..", "entailment"
+ "This site includes a...", "The Government Executive...", "not_entailment"
+ "...", "..."
+
+ :param data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容
+ :return: data_bundle
+ """
+ data_bundle = self._tokenize(data_bundle, [Const.RAW_WORDS(0), Const.RAW_WORDS(1)],
+ [Const.INPUTS(0), Const.INPUTS(1)])
+
+ for dataset in data_bundle.datasets.values():
+ if dataset.has_field(Const.TARGET):
+ dataset.drop(lambda x: x[Const.TARGET] == '-')
+
+ if self.lower:
+ for name, dataset in data_bundle.datasets.items():
+ dataset[Const.INPUTS(0)].lower()
+ dataset[Const.INPUTS(1)].lower()
+
+ word_vocab = Vocabulary()
+ word_vocab.from_dataset(*[dataset for name, dataset in data_bundle.datasets.items() if 'train' in name],
+ field_name=[Const.INPUTS(0), Const.INPUTS(1)],
+ no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
+ 'train' not in name])
+ word_vocab.index_dataset(*data_bundle.datasets.values(), field_name=[Const.INPUTS(0), Const.INPUTS(1)])
+
+ target_vocab = Vocabulary(padding=None, unknown=None)
+ target_vocab.from_dataset(data_bundle.datasets['train'], field_name=Const.TARGET)
+ has_target_datasets = [dataset for name, dataset in data_bundle.datasets.items() if
+ dataset.has_field(Const.TARGET)]
+ target_vocab.index_dataset(*has_target_datasets, field_name=Const.TARGET)
+
+ data_bundle.set_vocab(word_vocab, Const.INPUTS(0))
+ data_bundle.set_vocab(target_vocab, Const.TARGET)
+
+ input_fields = [Const.INPUTS(0), Const.INPUTS(1), Const.INPUT_LENS(0), Const.INPUT_LENS(1)]
+ target_fields = [Const.TARGET]
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.INPUTS(0), Const.INPUT_LENS(0))
+ dataset.add_seq_len(Const.INPUTS(1), Const.INPUT_LENS(1))
+ dataset.set_input(*input_fields, flag=True)
+ for fields in target_fields:
+ if dataset.has_field(fields):
+ dataset.set_target(fields, flag=True)
+
+ return data_bundle
+
+
+class RTEPipe(MatchingPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = RTELoader().load(paths)
+ return self.process(data_bundle)
+
+
+class SNLIPipe(MatchingPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = SNLILoader().load(paths)
+ return self.process(data_bundle)
+
+
+class QuoraPipe(MatchingPipe):
+ def process_from_file(self, paths):
+ data_bundle = QuoraLoader().load(paths)
+ return self.process(data_bundle)
+
+
+class QNLIPipe(MatchingPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = QNLILoader().load(paths)
+ return self.process(data_bundle)
+
+
+class MNLIPipe(MatchingPipe):
+ def process_from_file(self, paths=None):
+ data_bundle = MNLILoader().load(paths)
+ return self.process(data_bundle)
diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py
new file mode 100644
index 00000000..a1435fd3
--- /dev/null
+++ b/fastNLP/io/pipe/pipe.py
@@ -0,0 +1,30 @@
+"""undocumented"""
+
+__all__ = [
+ "Pipe",
+]
+
+from .. import DataBundle
+
+
+class Pipe:
+ """
+ 别名::class:`fastNLP.io.Pipe` :class:`fastNLP.io.pipe.Pipe`
+ """
+ def process(self, data_bundle: DataBundle) -> DataBundle:
+ """
+ 对输入的DataBundle进行处理,然后返回该DataBundle。
+
+ :param data_bundle: 需要处理的DataBundle对象
+ :return:
+ """
+ raise NotImplementedError
+
+ def process_from_file(self, paths) -> DataBundle:
+ """
+ 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 `fastNLP.io.loader.Loader.load()`
+
+ :param paths:
+ :return: DataBundle
+ """
+ raise NotImplementedError
diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py
new file mode 100644
index 00000000..f32f58b7
--- /dev/null
+++ b/fastNLP/io/pipe/utils.py
@@ -0,0 +1,176 @@
+"""undocumented"""
+
+__all__ = [
+ "iob2",
+ "iob2bioes",
+ "get_tokenizer",
+]
+
+from typing import List
+
+from ...core.const import Const
+from ...core.vocabulary import Vocabulary
+
+
+def iob2(tags: List[str]) -> List[str]:
+ """
+ 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见
+ https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format
+
+ :param tags: 需要转换的tags
+ """
+ for i, tag in enumerate(tags):
+ if tag == "O":
+ continue
+ split = tag.split("-")
+ if len(split) != 2 or split[0] not in ["I", "B"]:
+ raise TypeError("The encoding schema is not a valid IOB type.")
+ if split[0] == "B":
+ continue
+ elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
+ tags[i] = "B" + tag[1:]
+ elif tags[i - 1][1:] == tag[1:]:
+ continue
+ else: # conversion IOB1 to IOB2
+ tags[i] = "B" + tag[1:]
+ return tags
+
+
+def iob2bioes(tags: List[str]) -> List[str]:
+ """
+ 将iob的tag转换为bioes编码
+ :param tags:
+ :return:
+ """
+ new_tags = []
+ for i, tag in enumerate(tags):
+ if tag == 'O':
+ new_tags.append(tag)
+ else:
+ split = tag.split('-')[0]
+ if split == 'B':
+ if i + 1 != len(tags) and tags[i + 1].split('-')[0] == 'I':
+ new_tags.append(tag)
+ else:
+ new_tags.append(tag.replace('B-', 'S-'))
+ elif split == 'I':
+ if i + 1 < len(tags) and tags[i + 1].split('-')[0] == 'I':
+ new_tags.append(tag)
+ else:
+ new_tags.append(tag.replace('I-', 'E-'))
+ else:
+ raise TypeError("Invalid IOB format.")
+ return new_tags
+
+
+def get_tokenizer(tokenizer: str, lang='en'):
+ """
+
+ :param str tokenizer: 获取tokenzier方法
+ :param str lang: 语言,当前仅支持en
+ :return: 返回tokenize函数
+ """
+ if tokenizer == 'spacy':
+ import spacy
+ spacy.prefer_gpu()
+ if lang != 'en':
+ raise RuntimeError("Spacy only supports en right right.")
+ en = spacy.load(lang)
+ tokenizer = lambda x: [w.text for w in en.tokenizer(x)]
+ elif tokenizer == 'raw':
+ tokenizer = _raw_split
+ else:
+ raise RuntimeError("Only support `spacy`, `raw` tokenizer.")
+ return tokenizer
+
+
+def _raw_split(sent):
+ return sent.split()
+
+
+def _indexize(data_bundle, input_field_names=Const.INPUT, target_field_names=Const.TARGET):
+ """
+ 在dataset中的field_name列建立词表,Const.TARGET列建立词表,并把词表加入到data_bundle中。
+
+ :param data_bundle:
+ :param: str,list input_field_names:
+ :param: str,list target_field_names: 这一列的vocabulary没有unknown和padding
+ :return:
+ """
+ if isinstance(input_field_names, str):
+ input_field_names = [input_field_names]
+ if isinstance(target_field_names, str):
+ target_field_names = [target_field_names]
+ for input_field_name in input_field_names:
+ src_vocab = Vocabulary()
+ src_vocab.from_dataset(data_bundle.datasets['train'], field_name=input_field_name,
+ no_create_entry_dataset=[dataset for name, dataset in data_bundle.datasets.items() if
+ name != 'train'])
+ src_vocab.index_dataset(*data_bundle.datasets.values(), field_name=input_field_name)
+ data_bundle.set_vocab(src_vocab, input_field_name)
+
+ for target_field_name in target_field_names:
+ tgt_vocab = Vocabulary(unknown=None, padding=None)
+ tgt_vocab.from_dataset(data_bundle.datasets['train'], field_name=target_field_name)
+ tgt_vocab.index_dataset(*data_bundle.datasets.values(), field_name=target_field_name)
+ data_bundle.set_vocab(tgt_vocab, target_field_name)
+
+ return data_bundle
+
+
+def _add_words_field(data_bundle, lower=False):
+ """
+ 给data_bundle中的dataset中复制一列words. 并根据lower参数判断是否需要小写化
+
+ :param data_bundle:
+ :param bool lower:是否要小写化
+ :return: 传入的DataBundle
+ """
+ data_bundle.copy_field(field_name=Const.RAW_WORD, new_field_name=Const.INPUT, ignore_miss_dataset=True)
+
+ if lower:
+ for name, dataset in data_bundle.datasets.items():
+ dataset[Const.INPUT].lower()
+ return data_bundle
+
+
+def _add_chars_field(data_bundle, lower=False):
+ """
+ 给data_bundle中的dataset中复制一列chars. 并根据lower参数判断是否需要小写化
+
+ :param data_bundle:
+ :param bool lower:是否要小写化
+ :return: 传入的DataBundle
+ """
+ data_bundle.copy_field(field_name=Const.RAW_CHAR, new_field_name=Const.CHAR_INPUT, ignore_miss_dataset=True)
+
+ if lower:
+ for name, dataset in data_bundle.datasets.items():
+ dataset[Const.CHAR_INPUT].lower()
+ return data_bundle
+
+
+def _drop_empty_instance(data_bundle, field_name):
+ """
+ 删除data_bundle的DataSet中存在的某个field为空的情况
+
+ :param data_bundle: DataBundle
+ :param str field_name: 对哪个field进行检查,如果为None,则任意field为空都会删掉
+ :return: 传入的DataBundle
+ """
+
+ def empty_instance(ins):
+ if field_name:
+ field_value = ins[field_name]
+ if field_value in ((), {}, [], ''):
+ return True
+ return False
+ for _, field_value in ins.items():
+ if field_value in ((), {}, [], ''):
+ return True
+ return False
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.drop(empty_instance)
+
+ return data_bundle
diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py
index a7d2de85..e1de2ae7 100644
--- a/fastNLP/io/utils.py
+++ b/fastNLP/io/utils.py
@@ -1,23 +1,37 @@
-import os
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "check_loader_paths"
+]
+import os
+from pathlib import Path
from typing import Union, Dict
+from ..core import logger
+
-def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
+def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
"""
- 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果
- {
- 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
- 'test': 'xxx' # 可能有,也可能没有
- ...
- }
- 如果paths为不合法的,将直接进行raise相应的错误
+ 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果::
- :param paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
+ {
+ 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。
+ 'test': 'xxx' # 可能有,也可能没有
+ ...
+ }
+
+ 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。
+
+ :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找train(文件名
中包含train这个字段), test.txt, dev.txt; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。
:return:
"""
- if isinstance(paths, str):
+ if isinstance(paths, (str, Path)):
+ paths = os.path.abspath(os.path.expanduser(paths))
if os.path.isfile(paths):
return {'train': paths}
elif os.path.isdir(paths):
@@ -29,26 +43,32 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
path_pair = ('train', filename)
if 'dev' in filename:
if path_pair:
- raise Exception("File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
+ raise Exception(
+ "File:{} in {} contains bot `{}` and `dev`.".format(filename, paths, path_pair[0]))
path_pair = ('dev', filename)
if 'test' in filename:
if path_pair:
- raise Exception("File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
+ raise Exception(
+ "File:{} in {} contains bot `{}` and `test`.".format(filename, paths, path_pair[0]))
path_pair = ('test', filename)
if path_pair:
files[path_pair[0]] = os.path.join(paths, path_pair[1])
+ if 'train' not in files:
+ raise KeyError(f"There is no train file in {paths}.")
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
-
+
elif isinstance(paths, dict):
if paths:
if 'train' not in paths:
raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
+ value = os.path.abspath(os.path.expanduser(value))
if not os.path.isfile(value):
raise TypeError(f"{value} is not a valid file.")
+ paths[key] = value
else:
raise TypeError("All keys and values in paths should be str.")
return paths
@@ -57,13 +77,14 @@ def check_dataloader_paths(paths:Union[str, Dict[str, str]])->Dict[str, str]:
else:
raise TypeError(f"paths only supports str and dict. not {type(paths)}.")
+
def get_tokenizer():
try:
import spacy
spacy.prefer_gpu()
en = spacy.load('en')
- print('use spacy tokenizer')
+ logger.info('use spacy tokenizer')
return lambda x: [w.text for w in en.tokenizer(x)]
except Exception as e:
- print('use raw tokenizer')
+ logger.error('use raw tokenizer')
return lambda x: x.split()
diff --git a/fastNLP/models/base_model.py b/fastNLP/models/base_model.py
index 2646d580..61edb91f 100644
--- a/fastNLP/models/base_model.py
+++ b/fastNLP/models/base_model.py
@@ -1,3 +1,7 @@
+"""undocumented"""
+
+__all__ = []
+
import torch
from ..modules.decoder.mlp import MLP
diff --git a/fastNLP/models/bert.py b/fastNLP/models/bert.py
index fb186ce4..0a89b765 100644
--- a/fastNLP/models/bert.py
+++ b/fastNLP/models/bert.py
@@ -1,14 +1,20 @@
-"""
+"""undocumented
bert.py is modified from huggingface/pytorch-pretrained-BERT, which is licensed under the Apache License 2.0.
"""
+
+__all__ = []
+
+import os
+
import torch
from torch import nn
from .base_model import BaseModel
from ..core.const import Const
+from ..core.utils import seq_len_to_mask
from ..modules.encoder import BertModel
-from ..modules.encoder._bert import BertConfig
+from ..modules.encoder.bert import BertConfig, CONFIG_FILE
class BertForSequenceClassification(BaseModel):
@@ -54,6 +60,7 @@ class BertForSequenceClassification(BaseModel):
self.num_labels = num_labels
if bert_dir is not None:
self.bert = BertModel.from_pretrained(bert_dir)
+ config = BertConfig(os.path.join(bert_dir, CONFIG_FILE))
else:
if config is None:
config = BertConfig(30522)
@@ -67,20 +74,24 @@ class BertForSequenceClassification(BaseModel):
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
return model
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
+ def forward(self, words, seq_len=None, target=None):
+ if seq_len is None:
+ seq_len = torch.ones_like(words, dtype=words.dtype, device=words.device)
+ if len(seq_len.size()) + 1 == len(words.size()):
+ seq_len = seq_len_to_mask(seq_len, max_len=words.size(-1))
+ _, pooled_output = self.bert(words, attention_mask=seq_len, output_all_encoded_layers=False)
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
- if labels is not None:
+ if target is not None:
loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ loss = loss_fct(logits, target)
return {Const.OUTPUT: logits, Const.LOSS: loss}
else:
return {Const.OUTPUT: logits}
- def predict(self, input_ids, token_type_ids=None, attention_mask=None):
- logits = self.forward(input_ids, token_type_ids, attention_mask)
+ def predict(self, words, seq_len=None):
+ logits = self.forward(words, seq_len=seq_len)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}
@@ -140,7 +151,8 @@ class BertForMultipleChoice(BaseModel):
model = cls(num_choices=num_choices, config=config, bert_dir=pretrained_model_dir)
return model
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
+ def forward(self, words, seq_len1=None, seq_len2=None, target=None):
+ input_ids, token_type_ids, attention_mask = words, seq_len1, seq_len2
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1))
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1))
@@ -149,15 +161,15 @@ class BertForMultipleChoice(BaseModel):
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
- if labels is not None:
+ if target is not None:
loss_fct = nn.CrossEntropyLoss()
- loss = loss_fct(reshaped_logits, labels)
+ loss = loss_fct(reshaped_logits, target)
return {Const.OUTPUT: reshaped_logits, Const.LOSS: loss}
else:
return {Const.OUTPUT: reshaped_logits}
- def predict(self, input_ids, token_type_ids=None, attention_mask=None):
- logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT]
+ def predict(self, words, seq_len1=None, seq_len2=None,):
+ logits = self.forward(words, seq_len1=seq_len1, seq_len2=seq_len2)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}
@@ -219,27 +231,27 @@ class BertForTokenClassification(BaseModel):
model = cls(num_labels=num_labels, config=config, bert_dir=pretrained_model_dir)
return model
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
- sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
+ def forward(self, words, seq_len1=None, seq_len2=None, target=None):
+ sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False)
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
- if labels is not None:
+ if target is not None:
loss_fct = nn.CrossEntropyLoss()
# Only keep active parts of the loss
- if attention_mask is not None:
- active_loss = attention_mask.view(-1) == 1
+ if seq_len2 is not None:
+ active_loss = seq_len2.view(-1) == 1
active_logits = logits.view(-1, self.num_labels)[active_loss]
- active_labels = labels.view(-1)[active_loss]
+ active_labels = target.view(-1)[active_loss]
loss = loss_fct(active_logits, active_labels)
else:
- loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+ loss = loss_fct(logits.view(-1, self.num_labels), target.view(-1))
return {Const.OUTPUT: logits, Const.LOSS: loss}
else:
return {Const.OUTPUT: logits}
- def predict(self, input_ids, token_type_ids=None, attention_mask=None):
- logits = self.forward(input_ids, token_type_ids, attention_mask)[Const.OUTPUT]
+ def predict(self, words, seq_len1=None, seq_len2=None):
+ logits = self.forward(words, seq_len1, seq_len2)[Const.OUTPUT]
return {Const.OUTPUT: torch.argmax(logits, dim=-1)}
@@ -304,34 +316,34 @@ class BertForQuestionAnswering(BaseModel):
model = cls(config=config, bert_dir=pretrained_model_dir)
return model
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
- sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
+ def forward(self, words, seq_len1=None, seq_len2=None, target1=None, target2=None):
+ sequence_output, _ = self.bert(words, seq_len1, seq_len2, output_all_encoded_layers=False)
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
end_logits = end_logits.squeeze(-1)
- if start_positions is not None and end_positions is not None:
+ if target1 is not None and target2 is not None:
# If we are on multi-GPU, split add a dimension
- if len(start_positions.size()) > 1:
- start_positions = start_positions.squeeze(-1)
- if len(end_positions.size()) > 1:
- end_positions = end_positions.squeeze(-1)
+ if len(target1.size()) > 1:
+ target1 = target1.squeeze(-1)
+ if len(target2.size()) > 1:
+ target2 = target2.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1)
- start_positions.clamp_(0, ignored_index)
- end_positions.clamp_(0, ignored_index)
+ target1.clamp_(0, ignored_index)
+ target2.clamp_(0, ignored_index)
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
- start_loss = loss_fct(start_logits, start_positions)
- end_loss = loss_fct(end_logits, end_positions)
+ start_loss = loss_fct(start_logits, target1)
+ end_loss = loss_fct(end_logits, target2)
total_loss = (start_loss + end_loss) / 2
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits, Const.LOSS: total_loss}
else:
return {Const.OUTPUTS(0): start_logits, Const.OUTPUTS(1): end_logits}
- def predict(self, input_ids, token_type_ids=None, attention_mask=None, **kwargs):
- logits = self.forward(input_ids, token_type_ids, attention_mask)
+ def predict(self, words, seq_len1=None, seq_len2=None):
+ logits = self.forward(words, seq_len1, seq_len2)
start_logits = logits[Const.OUTPUTS(0)]
end_logits = logits[Const.OUTPUTS(1)]
return {Const.OUTPUTS(0): torch.argmax(start_logits, dim=-1),
diff --git a/fastNLP/models/biaffine_parser.py b/fastNLP/models/biaffine_parser.py
index 8533e7af..bead09fc 100644
--- a/fastNLP/models/biaffine_parser.py
+++ b/fastNLP/models/biaffine_parser.py
@@ -20,7 +20,7 @@ from ..modules.dropout import TimestepDropout
from ..modules.encoder.transformer import TransformerEncoder
from ..modules.encoder.variational_rnn import VarLSTM
from ..modules.utils import initial_parameter
-from ..modules.utils import get_embeddings
+from ..embeddings.utils import get_embeddings
from .base_model import BaseModel
from ..core.utils import seq_len_to_mask
@@ -130,6 +130,8 @@ def _find_cycle(vertices, edges):
class GraphParser(BaseModel):
"""
+ 别名::class:`fastNLP.models.GraphParser` :class:`fastNLP.models.baffine_parser.GraphParser`
+
基于图的parser base class, 支持贪婪解码和最大生成树解码
"""
@@ -148,7 +150,7 @@ class GraphParser(BaseModel):
"""
_, seq_len, _ = arc_matrix.shape
matrix = arc_matrix + torch.diag(arc_matrix.new(seq_len).fill_(-np.inf))
- flip_mask = (mask == 0).byte()
+ flip_mask = mask.eq(0)
matrix.masked_fill_(flip_mask.unsqueeze(1), -np.inf)
_, heads = torch.max(matrix, dim=2)
if mask is not None:
diff --git a/fastNLP/models/cnn_text_classification.py b/fastNLP/models/cnn_text_classification.py
index 081dd510..37a60c35 100644
--- a/fastNLP/models/cnn_text_classification.py
+++ b/fastNLP/models/cnn_text_classification.py
@@ -1,3 +1,8 @@
+"""
+.. todo::
+ doc
+"""
+
__all__ = [
"CNNText"
]
@@ -6,8 +11,9 @@ import torch
import torch.nn as nn
from ..core.const import Const as C
+from ..core.utils import seq_len_to_mask
+from ..embeddings import embedding
from ..modules import encoder
-from fastNLP import seq_len_to_mask
class CNNText(torch.nn.Module):
@@ -24,23 +30,23 @@ class CNNText(torch.nn.Module):
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param float dropout: Dropout的大小
"""
-
+
def __init__(self, init_embed,
num_classes,
kernel_nums=(30, 40, 50),
kernel_sizes=(1, 3, 5),
dropout=0.5):
super(CNNText, self).__init__()
-
+
# no support for pre-trained embedding currently
- self.embed = encoder.Embedding(init_embed)
+ self.embed = embedding.Embedding(init_embed)
self.conv_pool = encoder.ConvMaxpool(
in_channels=self.embed.embedding_dim,
out_channels=kernel_nums,
kernel_sizes=kernel_sizes)
self.dropout = nn.Dropout(dropout)
self.fc = nn.Linear(sum(kernel_nums), num_classes)
-
+
def forward(self, words, seq_len=None):
"""
@@ -57,7 +63,7 @@ class CNNText(torch.nn.Module):
x = self.dropout(x)
x = self.fc(x) # [N,C] -> [N, N_class]
return {C.OUTPUT: x}
-
+
def predict(self, words, seq_len=None):
"""
:param torch.LongTensor words: [batch_size, seq_len],句子中word的index
diff --git a/fastNLP/models/enas_controller.py b/fastNLP/models/enas_controller.py
index e83c6b51..eec820e4 100644
--- a/fastNLP/models/enas_controller.py
+++ b/fastNLP/models/enas_controller.py
@@ -1,5 +1,10 @@
-# Code Modified from https://github.com/carpedm20/ENAS-pytorch
-"""A module with NAS controller-related code."""
+"""undocumented
+Code Modified from https://github.com/carpedm20/ENAS-pytorch
+A module with NAS controller-related code.
+"""
+
+__all__ = []
+
import collections
import os
diff --git a/fastNLP/models/enas_model.py b/fastNLP/models/enas_model.py
index b6b683c0..2e8ca713 100644
--- a/fastNLP/models/enas_model.py
+++ b/fastNLP/models/enas_model.py
@@ -1,7 +1,10 @@
-"""
+"""undocumented
Module containing the shared RNN model.
Code Modified from https://github.com/carpedm20/ENAS-pytorch
"""
+
+__all__ = []
+
import collections
import numpy as np
diff --git a/fastNLP/models/enas_trainer.py b/fastNLP/models/enas_trainer.py
index ef596b03..98d778cd 100644
--- a/fastNLP/models/enas_trainer.py
+++ b/fastNLP/models/enas_trainer.py
@@ -1,11 +1,15 @@
-# Code Modified from https://github.com/carpedm20/ENAS-pytorch
+"""undocumented
+Code Modified from https://github.com/carpedm20/ENAS-pytorch
+"""
+
+__all__ = []
+
import math
-import numpy as np
import time
-import torch
-
from datetime import datetime, timedelta
+import numpy as np
+import torch
from torch.optim import Adam
try:
@@ -14,8 +18,8 @@ except:
from ..core.utils import _pseudo_tqdm as tqdm
from ..core.trainer import Trainer
-from ..core.batch import Batch
-from ..core.callback import CallbackManager, CallbackException
+from ..core.batch import DataSetIter
+from ..core.callback import CallbackException
from ..core.dataset import DataSet
from ..core.utils import _move_dict_value_to_device
from . import enas_utils as utils
@@ -124,8 +128,8 @@ class ENASTrainer(Trainer):
len(self.train_data) % self.batch_size != 0)) * self.n_epochs
with inner_tqdm(total=total_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar:
avg_loss = 0
- data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
- prefetch=self.prefetch)
+ data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
+ prefetch=self.prefetch)
for epoch in range(1, self.n_epochs + 1):
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs))
last_stage = (epoch > self.n_epochs + 1 - self.final_epochs)
@@ -209,8 +213,8 @@ class ENASTrainer(Trainer):
total_loss = 0
train_idx = 0
avg_loss = 0
- data_iterator = Batch(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
- prefetch=self.prefetch)
+ data_iterator = DataSetIter(self.train_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
+ prefetch=self.prefetch)
for batch_x, batch_y in data_iterator:
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
@@ -262,8 +266,8 @@ class ENASTrainer(Trainer):
if not isinstance(entropies, np.ndarray):
entropies = entropies.data.cpu().numpy()
- data_iterator = Batch(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
- prefetch=self.prefetch)
+ data_iterator = DataSetIter(self.dev_data, batch_size=self.batch_size, sampler=self.sampler, as_numpy=False,
+ prefetch=self.prefetch)
for inputs, targets in data_iterator:
valid_loss, hidden, _ = self.get_loss(inputs, targets, hidden, dag)
diff --git a/fastNLP/models/enas_utils.py b/fastNLP/models/enas_utils.py
index 4e402a9a..cd6c2503 100644
--- a/fastNLP/models/enas_utils.py
+++ b/fastNLP/models/enas_utils.py
@@ -1,7 +1,11 @@
-# Code Modified from https://github.com/carpedm20/ENAS-pytorch
+"""undocumented
+Code Modified from https://github.com/carpedm20/ENAS-pytorch
+"""
+
+__all__ = []
-from collections import defaultdict
import collections
+from collections import defaultdict
import numpy as np
import torch
diff --git a/fastNLP/models/sequence_labeling.py b/fastNLP/models/sequence_labeling.py
index 8e6a5db1..0dff21f0 100644
--- a/fastNLP/models/sequence_labeling.py
+++ b/fastNLP/models/sequence_labeling.py
@@ -1,19 +1,82 @@
"""
- 本模块实现了两种序列标注模型
+本模块实现了几种序列标注模型
"""
__all__ = [
"SeqLabeling",
- "AdvSeqLabel"
+ "AdvSeqLabel",
+ # "BiLSTMCRF"
]
import torch
import torch.nn as nn
+import torch.nn.functional as F
from .base_model import BaseModel
+from ..core.const import Const as C
+from ..core.utils import seq_len_to_mask
+from ..embeddings import embedding
+from ..embeddings import get_embeddings
+from ..modules import ConditionalRandomField
+from ..modules import LSTM
from ..modules import decoder, encoder
from ..modules.decoder.crf import allowed_transitions
-from ..core.utils import seq_len_to_mask
-from ..core.const import Const as C
+
+
+class BiLSTMCRF(BaseModel):
+ """
+ 结构为BiLSTM + FC + Dropout + CRF.
+
+ .. todo::
+ 继续补充文档
+
+ :param embed: tuple:
+ :param num_classes:
+ :param num_layers:
+ :param hidden_size:
+ :param dropout:
+ :param target_vocab:
+ :param encoding_type:
+ """
+ def __init__(self, embed, num_classes, num_layers=1, hidden_size=100, dropout=0.5,
+ target_vocab=None, encoding_type=None):
+ super().__init__()
+ self.embed = get_embeddings(embed)
+
+ if num_layers>1:
+ self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
+ batch_first=True, dropout=dropout)
+ else:
+ self.lstm = LSTM(embed.embedding_dim, num_layers=num_layers, hidden_size=hidden_size, bidirectional=True,
+ batch_first=True)
+
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(hidden_size, num_classes)
+
+ trans = None
+ if target_vocab is not None and encoding_type is not None:
+ trans = allowed_transitions(target_vocab.idx2word, encoding_type=encoding_type, include_start_end=True)
+
+ self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans)
+
+ def _forward(self, words, seq_len=None, target=None):
+ words = self.embed(words)
+ feats = self.lstm(words, seq_len=seq_len)
+ feats = self.fc(feats)
+ feats = self.dropout(feats)
+ logits = F.log_softmax(feats, dim=-1)
+ mask = seq_len_to_mask(seq_len)
+ if target is None:
+ pred, _ = self.crf.viterbi_decode(logits, mask)
+ return {C.OUTPUT:pred}
+ else:
+ loss = self.crf(logits, target, mask).mean()
+ return {C.LOSS:loss}
+
+ def forward(self, words, seq_len, target):
+ return self._forward(words, seq_len, target)
+
+ def predict(self, words, seq_len):
+ return self._forward(words, seq_len)
class SeqLabeling(BaseModel):
@@ -32,10 +95,10 @@ class SeqLabeling(BaseModel):
def __init__(self, init_embed, hidden_size, num_classes):
super(SeqLabeling, self).__init__()
- self.Embedding = encoder.embedding.Embedding(init_embed)
- self.Rnn = encoder.lstm.LSTM(self.Embedding.embedding_dim, hidden_size)
+ self.Embedding = embedding.Embedding(init_embed)
+ self.Rnn = encoder.LSTM(self.Embedding.embedding_dim, hidden_size)
self.Linear = nn.Linear(hidden_size, num_classes)
- self.Crf = decoder.crf.ConditionalRandomField(num_classes)
+ self.Crf = decoder.ConditionalRandomField(num_classes)
self.mask = None
def forward(self, words, seq_len, target):
@@ -129,7 +192,7 @@ class AdvSeqLabel(nn.Module):
super().__init__()
- self.Embedding = encoder.embedding.Embedding(init_embed)
+ self.Embedding = embedding.Embedding(init_embed)
self.norm1 = torch.nn.LayerNorm(self.Embedding.embedding_dim)
self.Rnn = encoder.LSTM(input_size=self.Embedding.embedding_dim, hidden_size=hidden_size, num_layers=2,
dropout=dropout,
diff --git a/fastNLP/models/snli.py b/fastNLP/models/snli.py
index d12524cc..5ca4052d 100644
--- a/fastNLP/models/snli.py
+++ b/fastNLP/models/snli.py
@@ -1,3 +1,7 @@
+"""
+.. todo::
+ doc
+"""
__all__ = [
"ESIM"
]
@@ -5,32 +9,36 @@ __all__ = [
import torch
import torch.nn as nn
import torch.nn.functional as F
-
from torch.nn import CrossEntropyLoss
-from fastNLP.models import BaseModel
-from fastNLP.modules.encoder.embedding import TokenEmbedding
-from fastNLP.modules.encoder.lstm import LSTM
-from fastNLP.core.const import Const
-from fastNLP.core.utils import seq_len_to_mask
+from .base_model import BaseModel
+from ..core.const import Const
+from ..core.utils import seq_len_to_mask
+from ..embeddings.embedding import TokenEmbedding, Embedding
class ESIM(BaseModel):
- """ESIM model的一个PyTorch实现
+ """
+ 别名::class:`fastNLP.models.ESIM` :class:`fastNLP.models.snli.ESIM`
+
+ ESIM model的一个PyTorch实现
论文参见: https://arxiv.org/pdf/1609.06038.pdf
- :param fastNLP.TokenEmbedding init_embedding: 初始化的TokenEmbedding
+ :param init_embedding: 初始化的Embedding
:param int hidden_size: 隐藏层大小,默认值为Embedding的维度
:param int num_labels: 目标标签种类数量,默认值为3
:param float dropout_rate: dropout的比率,默认值为0.3
:param float dropout_embed: 对Embedding的dropout比率,默认值为0.1
"""
- def __init__(self, init_embedding: TokenEmbedding, hidden_size=None, num_labels=3, dropout_rate=0.3,
+ def __init__(self, init_embedding, hidden_size=None, num_labels=3, dropout_rate=0.3,
dropout_embed=0.1):
super(ESIM, self).__init__()
- self.embedding = init_embedding
+ if isinstance(init_embedding, TokenEmbedding) or isinstance(init_embedding, Embedding):
+ self.embedding = init_embedding
+ else:
+ self.embedding = Embedding(init_embedding)
self.dropout_embed = EmbedDropout(p=dropout_embed)
if hidden_size is None:
hidden_size = self.embedding.embed_size
diff --git a/fastNLP/models/star_transformer.py b/fastNLP/models/star_transformer.py
index bb91a5b6..b95d1c25 100644
--- a/fastNLP/models/star_transformer.py
+++ b/fastNLP/models/star_transformer.py
@@ -13,7 +13,7 @@ from torch import nn
from ..modules.encoder.star_transformer import StarTransformer
from ..core.utils import seq_len_to_mask
-from ..modules.utils import get_embeddings
+from ..embeddings.utils import get_embeddings
from ..core.const import Const
@@ -34,7 +34,7 @@ class StarTransEnc(nn.Module):
:param emb_dropout: 词嵌入的dropout概率.
:param dropout: 模型除词嵌入外的dropout概率.
"""
-
+
def __init__(self, init_embed,
hidden_size,
num_layers,
@@ -54,7 +54,7 @@ class StarTransEnc(nn.Module):
head_dim=head_dim,
dropout=dropout,
max_len=max_len)
-
+
def forward(self, x, mask):
"""
:param FloatTensor x: [batch, length, hidden] 输入的序列
@@ -79,7 +79,7 @@ class _Cls(nn.Module):
nn.Dropout(dropout),
nn.Linear(hid_dim, num_cls),
)
-
+
def forward(self, x):
h = self.fc(x)
return h
@@ -95,7 +95,7 @@ class _NLICls(nn.Module):
nn.Dropout(dropout),
nn.Linear(hid_dim, num_cls),
)
-
+
def forward(self, x1, x2):
x = torch.cat([x1, x2, torch.abs(x1 - x2), x1 * x2], 1)
h = self.fc(x)
@@ -121,7 +121,7 @@ class STSeqLabel(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""
-
+
def __init__(self, init_embed, num_cls,
hidden_size=300,
num_layers=4,
@@ -141,7 +141,7 @@ class STSeqLabel(nn.Module):
emb_dropout=emb_dropout,
dropout=dropout)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size)
-
+
def forward(self, words, seq_len):
"""
@@ -154,7 +154,7 @@ class STSeqLabel(nn.Module):
output = self.cls(nodes)
output = output.transpose(1, 2) # make hidden to be dim 1
return {Const.OUTPUT: output} # [bsz, n_cls, seq_len]
-
+
def predict(self, words, seq_len):
"""
@@ -186,7 +186,7 @@ class STSeqCls(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""
-
+
def __init__(self, init_embed, num_cls,
hidden_size=300,
num_layers=4,
@@ -206,7 +206,7 @@ class STSeqCls(nn.Module):
emb_dropout=emb_dropout,
dropout=dropout)
self.cls = _Cls(hidden_size, num_cls, cls_hidden_size, dropout=dropout)
-
+
def forward(self, words, seq_len):
"""
@@ -219,7 +219,7 @@ class STSeqCls(nn.Module):
y = 0.5 * (relay + nodes.max(1)[0])
output = self.cls(y) # [bsz, n_cls]
return {Const.OUTPUT: output}
-
+
def predict(self, words, seq_len):
"""
@@ -251,7 +251,7 @@ class STNLICls(nn.Module):
:param emb_dropout: 词嵌入的dropout概率. Default: 0.1
:param dropout: 模型除词嵌入外的dropout概率. Default: 0.1
"""
-
+
def __init__(self, init_embed, num_cls,
hidden_size=300,
num_layers=4,
@@ -271,7 +271,7 @@ class STNLICls(nn.Module):
emb_dropout=emb_dropout,
dropout=dropout)
self.cls = _NLICls(hidden_size, num_cls, cls_hidden_size)
-
+
def forward(self, words1, words2, seq_len1, seq_len2):
"""
@@ -283,16 +283,16 @@ class STNLICls(nn.Module):
"""
mask1 = seq_len_to_mask(seq_len1)
mask2 = seq_len_to_mask(seq_len2)
-
+
def enc(seq, mask):
nodes, relay = self.enc(seq, mask)
return 0.5 * (relay + nodes.max(1)[0])
-
+
y1 = enc(words1, mask1)
y2 = enc(words2, mask2)
output = self.cls(y1, y2) # [bsz, n_cls]
return {Const.OUTPUT: output}
-
+
def predict(self, words1, words2, seq_len1, seq_len2):
"""
diff --git a/fastNLP/modules/__init__.py b/fastNLP/modules/__init__.py
index 2cd2216c..7959e454 100644
--- a/fastNLP/modules/__init__.py
+++ b/fastNLP/modules/__init__.py
@@ -1,46 +1,52 @@
"""
-大部分用于的 NLP 任务神经网络都可以看做由编码 :mod:`~fastNLP.modules.encoder` 、
-解码 :mod:`~fastNLP.modules.decoder` 两种模块组成。
.. image:: figures/text_classification.png
-:mod:`~fastNLP.modules` 中实现了 fastNLP 提供的诸多模块组件,可以帮助用户快速搭建自己所需的网络。
-两种模块的功能和常见组件如下:
+大部分用于的 NLP 任务神经网络都可以看做由 :mod:`embedding` 、 :mod:`~fastNLP.modules.encoder` 、
+:mod:`~fastNLP.modules.decoder` 三种模块组成。 本模块中实现了 fastNLP 提供的诸多模块组件,
+可以帮助用户快速搭建自己所需的网络。几种模块的功能和常见组件如下:
+
+.. csv-table::
+ :header: "类型", "功能", "常见组件"
+
+ "embedding", 参见 :doc:`/fastNLP.embeddings` , "Elmo, Bert"
+ "encoder", "将输入编码为具有表示能力的向量", "CNN, LSTM, Transformer"
+ "decoder", "将具有某种表示意义的向量解码为需要的输出形式 ", "MLP, CRF"
+ "其它", "配合其它组件使用的组件", "Dropout"
-+-----------------------+-----------------------+-----------------------+
-| module type | functionality | example |
-+=======================+=======================+=======================+
-| encoder | 将输入编码为具有具 | embedding, RNN, CNN, |
-| | 有表示能力的向量 | transformer |
-+-----------------------+-----------------------+-----------------------+
-| decoder | 将具有某种表示意义的 | MLP, CRF |
-| | 向量解码为需要的输出 | |
-| | 形式 | |
-+-----------------------+-----------------------+-----------------------+
"""
__all__ = [
# "BertModel",
+
"ConvolutionCharEncoder",
"LSTMCharEncoder",
+
"ConvMaxpool",
- "Embedding",
+
"LSTM",
+
"StarTransformer",
+
"TransformerEncoder",
+
"VarRNN",
"VarLSTM",
"VarGRU",
-
+
"MaxPool",
"MaxPoolWithMask",
"AvgPool",
+ "AvgPoolWithMask",
+
"MultiHeadAttention",
-
+
"MLP",
"ConditionalRandomField",
"viterbi_decode",
"allowed_transitions",
+
+ "TimestepDropout",
]
from . import decoder
@@ -48,4 +54,3 @@ from . import encoder
from .decoder import *
from .dropout import TimestepDropout
from .encoder import *
-from .utils import get_embeddings
diff --git a/fastNLP/modules/decoder/__init__.py b/fastNLP/modules/decoder/__init__.py
index 664618b2..57acb172 100644
--- a/fastNLP/modules/decoder/__init__.py
+++ b/fastNLP/modules/decoder/__init__.py
@@ -1,3 +1,7 @@
+"""
+.. todo::
+ doc
+"""
__all__ = [
"MLP",
"ConditionalRandomField",
@@ -6,6 +10,6 @@ __all__ = [
]
from .crf import ConditionalRandomField
+from .crf import allowed_transitions
from .mlp import MLP
from .utils import viterbi_decode
-from .crf import allowed_transitions
diff --git a/fastNLP/modules/decoder/crf.py b/fastNLP/modules/decoder/crf.py
index c0717d6f..b47d0162 100644
--- a/fastNLP/modules/decoder/crf.py
+++ b/fastNLP/modules/decoder/crf.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"ConditionalRandomField",
"allowed_transitions"
@@ -7,15 +9,16 @@ import torch
from torch import nn
from ..utils import initial_parameter
+from ...core import Vocabulary
def allowed_transitions(id2target, encoding_type='bio', include_start_end=False):
"""
- 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.crf.allowed_transitions`
+ 别名::class:`fastNLP.modules.allowed_transitions` :class:`fastNLP.modules.decoder.allowed_transitions`
给定一个id到label的映射表,返回所有可以跳转的(from_tag_id, to_tag_id)列表。
- :param dict id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
+ :param dict, ~fastNLP.Vocabulary id2target: key是label的indices,value是str类型的tag或tag-label。value可以是只有tag的, 比如"B", "M"; 也可以是
"B-NN", "M-NN", tag和label之间一定要用"-"隔开。一般可以通过Vocabulary.idx2word得到id2label。
:param str encoding_type: 支持"bio", "bmes", "bmeso", "bioes"。
:param bool include_start_end: 是否包含开始与结尾的转换。比如在bio中,b/o可以在开头,但是i不能在开头;
@@ -23,6 +26,8 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False)
start_idx=len(id2label), end_idx=len(id2label)+1。为False, 返回的结果中不含与开始结尾相关的内容
:return: List[Tuple(int, int)]], 内部的Tuple是可以进行跳转的(from_tag_id, to_tag_id)。
"""
+ if isinstance(id2target, Vocabulary):
+ id2target = id2target.idx2word
num_tags = len(id2target)
start_idx = num_tags
end_idx = num_tags + 1
@@ -31,7 +36,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False)
id_label_lst = list(id2target.items())
if include_start_end:
id_label_lst += [(start_idx, 'start'), (end_idx, 'end')]
-
+
def split_tag_label(from_label):
from_label = from_label.lower()
if from_label in ['start', 'end']:
@@ -41,7 +46,7 @@ def allowed_transitions(id2target, encoding_type='bio', include_start_end=False)
from_tag = from_label[:1]
from_label = from_label[2:]
return from_tag, from_label
-
+
for from_id, from_label in id_label_lst:
if from_label in ['', '']:
continue
@@ -93,7 +98,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
return to_tag in ['end', 'b', 'o']
else:
raise ValueError("Unexpect tag {}. Expect only 'B', 'I', 'O'.".format(from_tag))
-
+
elif encoding_type == 'bmes':
"""
第一行是to_tag, 第一列是from_tag,y任意条件下可转,-只有在label相同时可转,n不可转
@@ -151,7 +156,7 @@ def _is_transition_allowed(encoding_type, from_tag, from_label, to_tag, to_label
class ConditionalRandomField(nn.Module):
"""
- 别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.crf.ConditionalRandomField`
+ 别名::class:`fastNLP.modules.ConditionalRandomField` :class:`fastNLP.modules.decoder.ConditionalRandomField`
条件随机场。
提供forward()以及viterbi_decode()两个方法,分别用于训练与inference。
@@ -163,21 +168,21 @@ class ConditionalRandomField(nn.Module):
allowed_transitions()函数得到;如果为None,则所有跃迁均为合法
:param str initial_method: 初始化方法。见initial_parameter
"""
-
+
def __init__(self, num_tags, include_start_end_trans=False, allowed_transitions=None,
initial_method=None):
-
+
super(ConditionalRandomField, self).__init__()
-
+
self.include_start_end_trans = include_start_end_trans
self.num_tags = num_tags
-
+
# the meaning of entry in this matrix is (from_tag_id, to_tag_id) score
self.trans_m = nn.Parameter(torch.randn(num_tags, num_tags))
if self.include_start_end_trans:
self.start_scores = nn.Parameter(torch.randn(num_tags))
self.end_scores = nn.Parameter(torch.randn(num_tags))
-
+
if allowed_transitions is None:
constrain = torch.zeros(num_tags + 2, num_tags + 2)
else:
@@ -185,9 +190,9 @@ class ConditionalRandomField(nn.Module):
for from_tag_id, to_tag_id in allowed_transitions:
constrain[from_tag_id, to_tag_id] = 0
self._constrain = nn.Parameter(constrain, requires_grad=False)
-
+
initial_parameter(self, initial_method)
-
+
def _normalizer_likelihood(self, logits, mask):
"""Computes the (batch_size,) denominator term for the log-likelihood, which is the
sum of the likelihoods across all possible state sequences.
@@ -200,21 +205,21 @@ class ConditionalRandomField(nn.Module):
alpha = logits[0]
if self.include_start_end_trans:
alpha = alpha + self.start_scores.view(1, -1)
-
+
flip_mask = mask.eq(0)
-
+
for i in range(1, seq_len):
emit_score = logits[i].view(batch_size, 1, n_tags)
trans_score = self.trans_m.view(1, n_tags, n_tags)
tmp = alpha.view(batch_size, n_tags, 1) + emit_score + trans_score
alpha = torch.logsumexp(tmp, 1).masked_fill(flip_mask[i].view(batch_size, 1), 0) + \
- alpha.masked_fill(mask[i].byte().view(batch_size, 1), 0)
-
+ alpha.masked_fill(mask[i].eq(1).view(batch_size, 1), 0)
+
if self.include_start_end_trans:
alpha = alpha + self.end_scores.view(1, -1)
-
+
return torch.logsumexp(alpha, 1)
-
+
def _gold_score(self, logits, tags, mask):
"""
Compute the score for the gold path.
@@ -226,9 +231,9 @@ class ConditionalRandomField(nn.Module):
seq_len, batch_size, _ = logits.size()
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
-
+
# trans_socre [L-1, B]
- mask = mask.byte()
+ mask = mask.eq(1)
flip_mask = mask.eq(0)
trans_score = self.trans_m[tags[:seq_len - 1], tags[1:]].masked_fill(flip_mask[1:, :], 0)
# emit_score [L, B]
@@ -243,7 +248,7 @@ class ConditionalRandomField(nn.Module):
score = score + st_scores + ed_scores
# return [B,]
return score
-
+
def forward(self, feats, tags, mask):
"""
用于计算CRF的前向loss,返回值为一个batch_size的FloatTensor,可能需要mean()求得loss。
@@ -258,9 +263,9 @@ class ConditionalRandomField(nn.Module):
mask = mask.transpose(0, 1).float()
all_path_score = self._normalizer_likelihood(feats, mask)
gold_path_score = self._gold_score(feats, tags, mask)
-
+
return all_path_score - gold_path_score
-
+
def viterbi_decode(self, logits, mask, unpad=False):
"""给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数
@@ -276,8 +281,8 @@ class ConditionalRandomField(nn.Module):
"""
batch_size, seq_len, n_tags = logits.size()
logits = logits.transpose(0, 1).data # L, B, H
- mask = mask.transpose(0, 1).data.byte() # L, B
-
+ mask = mask.transpose(0, 1).data.eq(1) # L, B
+
# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
@@ -286,7 +291,7 @@ class ConditionalRandomField(nn.Module):
if self.include_start_end_trans:
transitions[n_tags, :n_tags] += self.start_scores.data
transitions[:n_tags, n_tags + 1] += self.end_scores.data
-
+
vscore += transitions[n_tags, :n_tags]
trans_score = transitions[:n_tags, :n_tags].view(1, n_tags, n_tags).data
for i in range(1, seq_len):
@@ -297,17 +302,17 @@ class ConditionalRandomField(nn.Module):
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
-
+
if self.include_start_end_trans:
vscore += transitions[:n_tags, n_tags + 1].view(1, -1)
-
+
# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
-
+
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
diff --git a/fastNLP/modules/decoder/mlp.py b/fastNLP/modules/decoder/mlp.py
index 418b3a77..f6e687a7 100644
--- a/fastNLP/modules/decoder/mlp.py
+++ b/fastNLP/modules/decoder/mlp.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"MLP"
]
@@ -10,7 +12,7 @@ from ..utils import initial_parameter
class MLP(nn.Module):
"""
- 别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.mlp.MLP`
+ 别名::class:`fastNLP.modules.MLP` :class:`fastNLP.modules.decoder.MLP`
多层感知器
@@ -40,7 +42,7 @@ class MLP(nn.Module):
>>> print(x)
>>> print(y)
"""
-
+
def __init__(self, size_layer, activation='relu', output_activation=None, initial_method=None, dropout=0.0):
super(MLP, self).__init__()
self.hiddens = nn.ModuleList()
@@ -51,9 +53,9 @@ class MLP(nn.Module):
self.output = nn.Linear(size_layer[i - 1], size_layer[i])
else:
self.hiddens.append(nn.Linear(size_layer[i - 1], size_layer[i]))
-
+
self.dropout = nn.Dropout(p=dropout)
-
+
actives = {
'relu': nn.ReLU(),
'tanh': nn.Tanh(),
@@ -82,7 +84,7 @@ class MLP(nn.Module):
else:
raise ValueError("should set activation correctly: {}".format(activation))
initial_parameter(self, initial_method)
-
+
def forward(self, x):
"""
:param torch.Tensor x: MLP接受的输入
diff --git a/fastNLP/modules/decoder/utils.py b/fastNLP/modules/decoder/utils.py
index 249f3ff6..118b1414 100644
--- a/fastNLP/modules/decoder/utils.py
+++ b/fastNLP/modules/decoder/utils.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"viterbi_decode"
]
@@ -6,7 +8,7 @@ import torch
def viterbi_decode(logits, transitions, mask=None, unpad=False):
r"""
- 别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.utils.viterbi_decode`
+ 别名::class:`fastNLP.modules.viterbi_decode` :class:`fastNLP.modules.decoder.viterbi_decode`
给定一个特征矩阵以及转移分数矩阵,计算出最佳的路径以及对应的分数
@@ -27,14 +29,14 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
"compatible."
logits = logits.transpose(0, 1).data # L, B, H
if mask is not None:
- mask = mask.transpose(0, 1).data.byte() # L, B
+ mask = mask.transpose(0, 1).data.eq(1) # L, B
else:
mask = logits.new_ones((seq_len, batch_size), dtype=torch.uint8)
-
+
# dp
vpath = logits.new_zeros((seq_len, batch_size, n_tags), dtype=torch.long)
vscore = logits[0]
-
+
trans_score = transitions.view(1, n_tags, n_tags).data
for i in range(1, seq_len):
prev_score = vscore.view(batch_size, n_tags, 1)
@@ -44,14 +46,14 @@ def viterbi_decode(logits, transitions, mask=None, unpad=False):
vpath[i] = best_dst
vscore = best_score.masked_fill(mask[i].eq(0).view(batch_size, 1), 0) + \
vscore.masked_fill(mask[i].view(batch_size, 1), 0)
-
+
# backtrace
batch_idx = torch.arange(batch_size, dtype=torch.long, device=logits.device)
seq_idx = torch.arange(seq_len, dtype=torch.long, device=logits.device)
lens = (mask.long().sum(0) - 1)
# idxes [L, B], batched idx from seq_len-1 to 0
idxes = (lens.view(1, -1) - seq_idx.view(-1, 1)) % seq_len
-
+
ans = logits.new_empty((seq_len, batch_size), dtype=torch.long)
ans_score, last_tags = vscore.max(1)
ans[idxes[0], batch_idx] = last_tags
diff --git a/fastNLP/modules/dropout.py b/fastNLP/modules/dropout.py
index 1363165c..24c20cc6 100644
--- a/fastNLP/modules/dropout.py
+++ b/fastNLP/modules/dropout.py
@@ -1,14 +1,16 @@
-__all__ = []
+"""undocumented"""
+
+__all__ = [
+ "TimestepDropout"
+]
import torch
class TimestepDropout(torch.nn.Dropout):
"""
- 别名::class:`fastNLP.modules.TimestepDropout`
-
- 接受的参数shape为``[batch_size, num_timesteps, embedding_dim)]`` 使用同一个mask(shape为``(batch_size, embedding_dim)``)
- 在每个timestamp上做dropout。
+ 传入参数的shape为 ``(batch_size, num_timesteps, embedding_dim)``
+ 使用同一个shape为 ``(batch_size, embedding_dim)`` 的mask在每个timestamp上做dropout。
"""
def forward(self, x):
diff --git a/fastNLP/modules/encoder/__init__.py b/fastNLP/modules/encoder/__init__.py
index 7b5bc070..0dfc18de 100644
--- a/fastNLP/modules/encoder/__init__.py
+++ b/fastNLP/modules/encoder/__init__.py
@@ -1,25 +1,22 @@
+"""
+.. todo::
+ doc
+"""
+
__all__ = [
# "BertModel",
-
+
"ConvolutionCharEncoder",
"LSTMCharEncoder",
-
+
"ConvMaxpool",
-
- "Embedding",
- "StaticEmbedding",
- "ElmoEmbedding",
- "BertEmbedding",
- "StackEmbedding",
- "LSTMCharEmbedding",
- "CNNCharEmbedding",
-
+
"LSTM",
-
+
"StarTransformer",
-
+
"TransformerEncoder",
-
+
"VarRNN",
"VarLSTM",
"VarGRU",
@@ -31,16 +28,13 @@ __all__ = [
"MultiHeadAttention",
]
-from ._bert import BertModel
-from .bert import BertWordPieceEncoder
+
+from .attention import MultiHeadAttention
+from .bert import BertModel
from .char_encoder import ConvolutionCharEncoder, LSTMCharEncoder
from .conv_maxpool import ConvMaxpool
-from .embedding import Embedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, \
- StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding
from .lstm import LSTM
+from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask
from .star_transformer import StarTransformer
from .transformer import TransformerEncoder
from .variational_rnn import VarRNN, VarLSTM, VarGRU
-
-from .pooling import MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask
-from .attention import MultiHeadAttention
diff --git a/fastNLP/modules/encoder/_bert.py b/fastNLP/modules/encoder/_bert.py
deleted file mode 100644
index 61a5d7d1..00000000
--- a/fastNLP/modules/encoder/_bert.py
+++ /dev/null
@@ -1,1069 +0,0 @@
-
-
-
-"""
-这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
- 有用,也请引用一下他们。
-"""
-
-
-
-from ...core.vocabulary import Vocabulary
-import collections
-
-import unicodedata
-import numpy as np
-from itertools import chain
-import copy
-import json
-import math
-import os
-
-import torch
-from torch import nn
-import glob
-import sys
-
-CONFIG_FILE = 'bert_config.json'
-
-
-class BertConfig(object):
- """Configuration class to store the configuration of a `BertModel`.
- """
- def __init__(self,
- vocab_size_or_config_json_file,
- hidden_size=768,
- num_hidden_layers=12,
- num_attention_heads=12,
- intermediate_size=3072,
- hidden_act="gelu",
- hidden_dropout_prob=0.1,
- attention_probs_dropout_prob=0.1,
- max_position_embeddings=512,
- type_vocab_size=2,
- initializer_range=0.02,
- layer_norm_eps=1e-12):
- """Constructs BertConfig.
-
- Args:
- vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
- hidden_size: Size of the encoder layers and the pooler layer.
- num_hidden_layers: Number of hidden layers in the Transformer encoder.
- num_attention_heads: Number of attention heads for each attention layer in
- the Transformer encoder.
- intermediate_size: The size of the "intermediate" (i.e., feed-forward)
- layer in the Transformer encoder.
- hidden_act: The non-linear activation function (function or string) in the
- encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
- hidden_dropout_prob: The dropout probabilitiy for all fully connected
- layers in the embeddings, encoder, and pooler.
- attention_probs_dropout_prob: The dropout ratio for the attention
- probabilities.
- max_position_embeddings: The maximum sequence length that this model might
- ever be used with. Typically set this to something large just in case
- (e.g., 512 or 1024 or 2048).
- type_vocab_size: The vocabulary size of the `token_type_ids` passed into
- `BertModel`.
- initializer_range: The sttdev of the truncated_normal_initializer for
- initializing all weight matrices.
- layer_norm_eps: The epsilon used by LayerNorm.
- """
- if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
- and isinstance(vocab_size_or_config_json_file, unicode)):
- with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
- json_config = json.loads(reader.read())
- for key, value in json_config.items():
- self.__dict__[key] = value
- elif isinstance(vocab_size_or_config_json_file, int):
- self.vocab_size = vocab_size_or_config_json_file
- self.hidden_size = hidden_size
- self.num_hidden_layers = num_hidden_layers
- self.num_attention_heads = num_attention_heads
- self.hidden_act = hidden_act
- self.intermediate_size = intermediate_size
- self.hidden_dropout_prob = hidden_dropout_prob
- self.attention_probs_dropout_prob = attention_probs_dropout_prob
- self.max_position_embeddings = max_position_embeddings
- self.type_vocab_size = type_vocab_size
- self.initializer_range = initializer_range
- self.layer_norm_eps = layer_norm_eps
- else:
- raise ValueError("First argument must be either a vocabulary size (int)"
- "or the path to a pretrained model config file (str)")
-
- @classmethod
- def from_dict(cls, json_object):
- """Constructs a `BertConfig` from a Python dictionary of parameters."""
- config = BertConfig(vocab_size_or_config_json_file=-1)
- for key, value in json_object.items():
- config.__dict__[key] = value
- return config
-
- @classmethod
- def from_json_file(cls, json_file):
- """Constructs a `BertConfig` from a json file of parameters."""
- with open(json_file, "r", encoding='utf-8') as reader:
- text = reader.read()
- return cls.from_dict(json.loads(text))
-
- def __repr__(self):
- return str(self.to_json_string())
-
- def to_dict(self):
- """Serializes this instance to a Python dictionary."""
- output = copy.deepcopy(self.__dict__)
- return output
-
- def to_json_string(self):
- """Serializes this instance to a JSON string."""
- return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
-
- def to_json_file(self, json_file_path):
- """ Save this instance to a json file."""
- with open(json_file_path, "w", encoding='utf-8') as writer:
- writer.write(self.to_json_string())
-
-
-def gelu(x):
- return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
-
-
-def swish(x):
- return x * torch.sigmoid(x)
-
-
-ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
-
-
-class BertLayerNorm(nn.Module):
- def __init__(self, hidden_size, eps=1e-12):
- """Construct a layernorm module in the TF style (epsilon inside the square root).
- """
- super(BertLayerNorm, self).__init__()
- self.weight = nn.Parameter(torch.ones(hidden_size))
- self.bias = nn.Parameter(torch.zeros(hidden_size))
- self.variance_epsilon = eps
-
- def forward(self, x):
- u = x.mean(-1, keepdim=True)
- s = (x - u).pow(2).mean(-1, keepdim=True)
- x = (x - u) / torch.sqrt(s + self.variance_epsilon)
- return self.weight * x + self.bias
-
-
-class BertEmbeddings(nn.Module):
- """Construct the embeddings from word, position and token_type embeddings.
- """
- def __init__(self, config):
- super(BertEmbeddings, self).__init__()
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
- self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
-
- # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
- # any TensorFlow checkpoint file
- self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
-
- def forward(self, input_ids, token_type_ids=None):
- seq_length = input_ids.size(1)
- position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
- position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
- if token_type_ids is None:
- token_type_ids = torch.zeros_like(input_ids)
-
- words_embeddings = self.word_embeddings(input_ids)
- position_embeddings = self.position_embeddings(position_ids)
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
-
- embeddings = words_embeddings + position_embeddings + token_type_embeddings
- embeddings = self.LayerNorm(embeddings)
- embeddings = self.dropout(embeddings)
- return embeddings
-
-
-class BertSelfAttention(nn.Module):
- def __init__(self, config):
- super(BertSelfAttention, self).__init__()
- if config.hidden_size % config.num_attention_heads != 0:
- raise ValueError(
- "The hidden size (%d) is not a multiple of the number of attention "
- "heads (%d)" % (config.hidden_size, config.num_attention_heads))
- self.num_attention_heads = config.num_attention_heads
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
- self.all_head_size = self.num_attention_heads * self.attention_head_size
-
- self.query = nn.Linear(config.hidden_size, self.all_head_size)
- self.key = nn.Linear(config.hidden_size, self.all_head_size)
- self.value = nn.Linear(config.hidden_size, self.all_head_size)
-
- self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
-
- def transpose_for_scores(self, x):
- new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
- x = x.view(*new_x_shape)
- return x.permute(0, 2, 1, 3)
-
- def forward(self, hidden_states, attention_mask):
- mixed_query_layer = self.query(hidden_states)
- mixed_key_layer = self.key(hidden_states)
- mixed_value_layer = self.value(hidden_states)
-
- query_layer = self.transpose_for_scores(mixed_query_layer)
- key_layer = self.transpose_for_scores(mixed_key_layer)
- value_layer = self.transpose_for_scores(mixed_value_layer)
-
- # Take the dot product between "query" and "key" to get the raw attention scores.
- attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
- attention_scores = attention_scores / math.sqrt(self.attention_head_size)
- # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
- attention_scores = attention_scores + attention_mask
-
- # Normalize the attention scores to probabilities.
- attention_probs = nn.Softmax(dim=-1)(attention_scores)
-
- # This is actually dropping out entire tokens to attend to, which might
- # seem a bit unusual, but is taken from the original Transformer paper.
- attention_probs = self.dropout(attention_probs)
-
- context_layer = torch.matmul(attention_probs, value_layer)
- context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
- context_layer = context_layer.view(*new_context_layer_shape)
- return context_layer
-
-
-class BertSelfOutput(nn.Module):
- def __init__(self, config):
- super(BertSelfOutput, self).__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
-
- def forward(self, hidden_states, input_tensor):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
-
-
-class BertAttention(nn.Module):
- def __init__(self, config):
- super(BertAttention, self).__init__()
- self.self = BertSelfAttention(config)
- self.output = BertSelfOutput(config)
-
- def forward(self, input_tensor, attention_mask):
- self_output = self.self(input_tensor, attention_mask)
- attention_output = self.output(self_output, input_tensor)
- return attention_output
-
-
-class BertIntermediate(nn.Module):
- def __init__(self, config):
- super(BertIntermediate, self).__init__()
- self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
- if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
- self.intermediate_act_fn = ACT2FN[config.hidden_act]
- else:
- self.intermediate_act_fn = config.hidden_act
-
- def forward(self, hidden_states):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.intermediate_act_fn(hidden_states)
- return hidden_states
-
-
-class BertOutput(nn.Module):
- def __init__(self, config):
- super(BertOutput, self).__init__()
- self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
- self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
-
- def forward(self, hidden_states, input_tensor):
- hidden_states = self.dense(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.LayerNorm(hidden_states + input_tensor)
- return hidden_states
-
-
-class BertLayer(nn.Module):
- def __init__(self, config):
- super(BertLayer, self).__init__()
- self.attention = BertAttention(config)
- self.intermediate = BertIntermediate(config)
- self.output = BertOutput(config)
-
- def forward(self, hidden_states, attention_mask):
- attention_output = self.attention(hidden_states, attention_mask)
- intermediate_output = self.intermediate(attention_output)
- layer_output = self.output(intermediate_output, attention_output)
- return layer_output
-
-
-class BertEncoder(nn.Module):
- def __init__(self, config):
- super(BertEncoder, self).__init__()
- layer = BertLayer(config)
- self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
-
- def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
- all_encoder_layers = []
- for layer_module in self.layer:
- hidden_states = layer_module(hidden_states, attention_mask)
- if output_all_encoded_layers:
- all_encoder_layers.append(hidden_states)
- if not output_all_encoded_layers:
- all_encoder_layers.append(hidden_states)
- return all_encoder_layers
-
-
-class BertPooler(nn.Module):
- def __init__(self, config):
- super(BertPooler, self).__init__()
- self.dense = nn.Linear(config.hidden_size, config.hidden_size)
- self.activation = nn.Tanh()
-
- def forward(self, hidden_states):
- # We "pool" the model by simply taking the hidden state corresponding
- # to the first token.
- first_token_tensor = hidden_states[:, 0]
- pooled_output = self.dense(first_token_tensor)
- pooled_output = self.activation(pooled_output)
- return pooled_output
-
-
-class BertModel(nn.Module):
- """BERT(Bidirectional Embedding Representations from Transformers).
-
- 如果你想使用预训练好的权重矩阵,请在以下网址下载.
- sources::
-
- 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
- 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
- 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
- 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
- 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
- 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
- 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
- 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin",
- 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
- 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
- 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
- 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
- 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin"
-
-
- 用预训练权重矩阵来建立BERT模型::
-
- model = BertModel.from_pretrained("path/to/weights/directory")
-
- 用随机初始化权重矩阵来建立BERT模型::
-
- model = BertModel()
-
- :param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小
- :param int hidden_size: 隐层大小,默认值为768,为BERT base的版本
- :param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本
- :param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本
- :param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本
- :param str hidden_act: FFN隐藏层激活函数,默认值为``gelu``
- :param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1
- :param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1
- :param int max_position_embeddings: 最大的序列长度,默认值为512,
- :param int type_vocab_size: 最大segment数量,默认值为2
- :param int initializer_range: 初始化权重范围,默认值为0.02
- """
-
- def __init__(self, config, *inputs, **kwargs):
- super(BertModel, self).__init__()
- if not isinstance(config, BertConfig):
- raise ValueError(
- "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
- "To create a model from a Google pretrained model use "
- "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
- self.__class__.__name__, self.__class__.__name__
- ))
- super(BertModel, self).__init__()
- self.config = config
- self.hidden_size = self.config.hidden_size
- self.embeddings = BertEmbeddings(config)
- self.encoder = BertEncoder(config)
- self.pooler = BertPooler(config)
- self.apply(self.init_bert_weights)
-
- def init_bert_weights(self, module):
- """ Initialize the weights.
- """
- if isinstance(module, (nn.Linear, nn.Embedding)):
- # Slightly different from the TF version which uses truncated_normal for initialization
- # cf https://github.com/pytorch/pytorch/pull/5617
- module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
- elif isinstance(module, BertLayerNorm):
- module.bias.data.zero_()
- module.weight.data.fill_(1.0)
- if isinstance(module, nn.Linear) and module.bias is not None:
- module.bias.data.zero_()
-
- def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
- if attention_mask is None:
- attention_mask = torch.ones_like(input_ids)
- if token_type_ids is None:
- token_type_ids = torch.zeros_like(input_ids)
-
- # We create a 3D attention mask from a 2D tensor mask.
- # Sizes are [batch_size, 1, 1, to_seq_length]
- # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
- # this attention mask is more simple than the triangular masking of causal attention
- # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
- extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
-
- # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
- # masked positions, this operation will create a tensor which is 0.0 for
- # positions we want to attend and -10000.0 for masked positions.
- # Since we are adding it to the raw scores before the softmax, this is
- # effectively the same as removing these entirely.
- extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
- extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
-
- embedding_output = self.embeddings(input_ids, token_type_ids)
- encoded_layers = self.encoder(embedding_output,
- extended_attention_mask,
- output_all_encoded_layers=output_all_encoded_layers)
- sequence_output = encoded_layers[-1]
- pooled_output = self.pooler(sequence_output)
- if not output_all_encoded_layers:
- encoded_layers = encoded_layers[-1]
- return encoded_layers, pooled_output
-
- @classmethod
- def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs):
- state_dict = kwargs.get('state_dict', None)
- kwargs.pop('state_dict', None)
- cache_dir = kwargs.get('cache_dir', None)
- kwargs.pop('cache_dir', None)
- from_tf = kwargs.get('from_tf', False)
- kwargs.pop('from_tf', None)
- # Load config
- config_file = os.path.join(pretrained_model_dir, CONFIG_FILE)
- config = BertConfig.from_json_file(config_file)
- # logger.info("Model config {}".format(config))
- # Instantiate model.
- model = cls(config, *inputs, **kwargs)
- if state_dict is None:
- files = glob.glob(os.path.join(pretrained_model_dir, '*.bin'))
- if len(files)==0:
- raise FileNotFoundError(f"There is no *.bin file in {pretrained_model_dir}")
- elif len(files)>1:
- raise FileExistsError(f"There are multiple *.bin files in {pretrained_model_dir}")
- weights_path = files[0]
- state_dict = torch.load(weights_path, map_location='cpu')
-
- old_keys = []
- new_keys = []
- for key in state_dict.keys():
- new_key = None
- if 'gamma' in key:
- new_key = key.replace('gamma', 'weight')
- if 'beta' in key:
- new_key = key.replace('beta', 'bias')
- if new_key:
- old_keys.append(key)
- new_keys.append(new_key)
- for old_key, new_key in zip(old_keys, new_keys):
- state_dict[new_key] = state_dict.pop(old_key)
-
- missing_keys = []
- unexpected_keys = []
- error_msgs = []
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
-
- def load(module, prefix=''):
- local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
- module._load_from_state_dict(
- state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + '.')
-
- load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
- if len(missing_keys) > 0:
- print("Weights of {} not initialized from pretrained model: {}".format(
- model.__class__.__name__, missing_keys))
- if len(unexpected_keys) > 0:
- print("Weights from pretrained model not used in {}: {}".format(
- model.__class__.__name__, unexpected_keys))
- return model
-
-
-def whitespace_tokenize(text):
- """Runs basic whitespace cleaning and splitting on a piece of text."""
- text = text.strip()
- if not text:
- return []
- tokens = text.split()
- return tokens
-
-
-class WordpieceTokenizer(object):
- """Runs WordPiece tokenization."""
-
- def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
- self.vocab = vocab
- self.unk_token = unk_token
- self.max_input_chars_per_word = max_input_chars_per_word
-
- def tokenize(self, text):
- """Tokenizes a piece of text into its word pieces.
-
- This uses a greedy longest-match-first algorithm to perform tokenization
- using the given vocabulary.
-
- For example:
- input = "unaffable"
- output = ["un", "##aff", "##able"]
-
- Args:
- text: A single token or whitespace separated tokens. This should have
- already been passed through `BasicTokenizer`.
-
- Returns:
- A list of wordpiece tokens.
- """
-
- output_tokens = []
- for token in whitespace_tokenize(text):
- chars = list(token)
- if len(chars) > self.max_input_chars_per_word:
- output_tokens.append(self.unk_token)
- continue
-
- is_bad = False
- start = 0
- sub_tokens = []
- while start < len(chars):
- end = len(chars)
- cur_substr = None
- while start < end:
- substr = "".join(chars[start:end])
- if start > 0:
- substr = "##" + substr
- if substr in self.vocab:
- cur_substr = substr
- break
- end -= 1
- if cur_substr is None:
- is_bad = True
- break
- sub_tokens.append(cur_substr)
- start = end
-
- if is_bad:
- output_tokens.append(self.unk_token)
- else:
- output_tokens.extend(sub_tokens)
- return output_tokens
-
-
-def load_vocab(vocab_file):
- """Loads a vocabulary file into a dictionary."""
- vocab = collections.OrderedDict()
- index = 0
- with open(vocab_file, "r", encoding="utf-8") as reader:
- while True:
- token = reader.readline()
- if not token:
- break
- token = token.strip()
- vocab[token] = index
- index += 1
- return vocab
-
-class BasicTokenizer(object):
- """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
-
- def __init__(self,
- do_lower_case=True,
- never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
- """Constructs a BasicTokenizer.
-
- Args:
- do_lower_case: Whether to lower case the input.
- """
- self.do_lower_case = do_lower_case
- self.never_split = never_split
-
- def tokenize(self, text):
- """Tokenizes a piece of text."""
- text = self._clean_text(text)
- # This was added on November 1st, 2018 for the multilingual and Chinese
- # models. This is also applied to the English models now, but it doesn't
- # matter since the English models were not trained on any Chinese data
- # and generally don't have any Chinese data in them (there are Chinese
- # characters in the vocabulary because Wikipedia does have some Chinese
- # words in the English Wikipedia.).
- text = self._tokenize_chinese_chars(text)
- orig_tokens = whitespace_tokenize(text)
- split_tokens = []
- for token in orig_tokens:
- if self.do_lower_case and token not in self.never_split:
- token = token.lower()
- token = self._run_strip_accents(token)
- split_tokens.extend(self._run_split_on_punc(token))
-
- output_tokens = whitespace_tokenize(" ".join(split_tokens))
- return output_tokens
-
- def _run_strip_accents(self, text):
- """Strips accents from a piece of text."""
- text = unicodedata.normalize("NFD", text)
- output = []
- for char in text:
- cat = unicodedata.category(char)
- if cat == "Mn":
- continue
- output.append(char)
- return "".join(output)
-
- def _run_split_on_punc(self, text):
- """Splits punctuation on a piece of text."""
- if text in self.never_split:
- return [text]
- chars = list(text)
- i = 0
- start_new_word = True
- output = []
- while i < len(chars):
- char = chars[i]
- if _is_punctuation(char):
- output.append([char])
- start_new_word = True
- else:
- if start_new_word:
- output.append([])
- start_new_word = False
- output[-1].append(char)
- i += 1
-
- return ["".join(x) for x in output]
-
- def _tokenize_chinese_chars(self, text):
- """Adds whitespace around any CJK character."""
- output = []
- for char in text:
- cp = ord(char)
- if self._is_chinese_char(cp):
- output.append(" ")
- output.append(char)
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
-
- def _is_chinese_char(self, cp):
- """Checks whether CP is the codepoint of a CJK character."""
- # This defines a "chinese character" as anything in the CJK Unicode block:
- # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
- #
- # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
- # despite its name. The modern Korean Hangul alphabet is a different block,
- # as is Japanese Hiragana and Katakana. Those alphabets are used to write
- # space-separated words, so they are not treated specially and handled
- # like the all of the other languages.
- if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
- (cp >= 0x3400 and cp <= 0x4DBF) or #
- (cp >= 0x20000 and cp <= 0x2A6DF) or #
- (cp >= 0x2A700 and cp <= 0x2B73F) or #
- (cp >= 0x2B740 and cp <= 0x2B81F) or #
- (cp >= 0x2B820 and cp <= 0x2CEAF) or
- (cp >= 0xF900 and cp <= 0xFAFF) or #
- (cp >= 0x2F800 and cp <= 0x2FA1F)): #
- return True
-
- return False
-
- def _clean_text(self, text):
- """Performs invalid character removal and whitespace cleanup on text."""
- output = []
- for char in text:
- cp = ord(char)
- if cp == 0 or cp == 0xfffd or _is_control(char):
- continue
- if _is_whitespace(char):
- output.append(" ")
- else:
- output.append(char)
- return "".join(output)
-
-
-def _is_whitespace(char):
- """Checks whether `chars` is a whitespace character."""
- # \t, \n, and \r are technically contorl characters but we treat them
- # as whitespace since they are generally considered as such.
- if char == " " or char == "\t" or char == "\n" or char == "\r":
- return True
- cat = unicodedata.category(char)
- if cat == "Zs":
- return True
- return False
-
-
-def _is_control(char):
- """Checks whether `chars` is a control character."""
- # These are technically control characters but we count them as whitespace
- # characters.
- if char == "\t" or char == "\n" or char == "\r":
- return False
- cat = unicodedata.category(char)
- if cat.startswith("C"):
- return True
- return False
-
-
-def _is_punctuation(char):
- """Checks whether `chars` is a punctuation character."""
- cp = ord(char)
- # We treat all non-letter/number ASCII as punctuation.
- # Characters such as "^", "$", and "`" are not in the Unicode
- # Punctuation class but we treat them as punctuation anyways, for
- # consistency.
- if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
- (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
- return True
- cat = unicodedata.category(char)
- if cat.startswith("P"):
- return True
- return False
-
-
-class BertTokenizer(object):
- """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
-
- def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
- never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
- """Constructs a BertTokenizer.
-
- Args:
- vocab_file: Path to a one-wordpiece-per-line vocabulary file
- do_lower_case: Whether to lower case the input
- Only has an effect when do_wordpiece_only=False
- do_basic_tokenize: Whether to do basic tokenization before wordpiece.
- max_len: An artificial maximum length to truncate tokenized sequences to;
- Effective maximum length is always the minimum of this
- value (if specified) and the underlying BERT model's
- sequence length.
- never_split: List of tokens which will never be split during tokenization.
- Only has an effect when do_wordpiece_only=False
- """
- if not os.path.isfile(vocab_file):
- raise ValueError(
- "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
- "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
- self.vocab = load_vocab(vocab_file)
- self.ids_to_tokens = collections.OrderedDict(
- [(ids, tok) for tok, ids in self.vocab.items()])
- self.do_basic_tokenize = do_basic_tokenize
- if do_basic_tokenize:
- self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
- never_split=never_split)
- self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
- self.max_len = max_len if max_len is not None else int(1e12)
-
- def _reinit_on_new_vocab(self, vocab):
- """
- 在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质
-
- :param vocab:
- :return:
- """
- self.vocab = vocab
- self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
-
- def tokenize(self, text):
- split_tokens = []
- if self.do_basic_tokenize:
- for token in self.basic_tokenizer.tokenize(text):
- for sub_token in self.wordpiece_tokenizer.tokenize(token):
- split_tokens.append(sub_token)
- else:
- split_tokens = self.wordpiece_tokenizer.tokenize(text)
- return split_tokens
-
- def convert_tokens_to_ids(self, tokens):
- """Converts a sequence of tokens into ids using the vocab."""
- ids = []
- for token in tokens:
- ids.append(self.vocab[token])
- if len(ids) > self.max_len:
- print(
- "Token indices sequence length is longer than the specified maximum "
- " sequence length for this BERT model ({} > {}). Running this"
- " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
- )
- return ids
-
- def convert_ids_to_tokens(self, ids):
- """Converts a sequence of ids in wordpiece tokens using the vocab."""
- tokens = []
- for i in ids:
- tokens.append(self.ids_to_tokens[i])
- return tokens
-
- def save_vocabulary(self, vocab_path):
- """Save the tokenizer vocabulary to a directory or file."""
- index = 0
- if os.path.isdir(vocab_path):
- vocab_file = os.path.join(vocab_path, VOCAB_NAME)
- else:
- vocab_file = vocab_path
- with open(vocab_file, "w", encoding="utf-8") as writer:
- for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
- if index != token_index:
- print("Saving vocabulary to {}: vocabulary indices are not consecutive."
- " Please check that the vocabulary is not corrupted!".format(vocab_file))
- index = token_index
- writer.write(token + u'\n')
- index += 1
- return vocab_file
-
- @classmethod
- def from_pretrained(cls, model_dir, *inputs, **kwargs):
- """
- 给定path,直接读取vocab.
-
- """
- pretrained_model_name_or_path = os.path.join(model_dir, VOCAB_NAME)
- print("loading vocabulary file {}".format(pretrained_model_name_or_path))
- max_len = 512
- kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
- # Instantiate tokenizer.
- tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs)
- return tokenizer
-
-VOCAB_NAME = 'vocab.txt'
-
-class _WordBertModel(nn.Module):
- def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False):
- super().__init__()
-
- self.tokenzier = BertTokenizer.from_pretrained(model_dir)
- self.encoder = BertModel.from_pretrained(model_dir)
- # 检查encoder_layer_number是否合理
- encoder_layer_number = len(self.encoder.encoder.layer)
- self.layers = list(map(int, layers.split(',')))
- for layer in self.layers:
- if layer<0:
- assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \
- f"a bert model with {encoder_layer_number} layers."
- else:
- assert layer 1 or weight_count > 1:
- raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.")
- elif config_count == 0 or weight_count == 0:
- raise Exception(f"No config file or weight file found in {model_dir}")
-
- config = json.load(open(os.path.join(model_dir, config_file), 'r'))
- self.weight_file = os.path.join(model_dir, weight_file)
- self.config = config
-
- OOV_TAG = ''
- PAD_TAG = ''
- BOS_TAG = ''
- EOS_TAG = ''
- BOW_TAG = ''
- EOW_TAG = ''
-
- # For the model trained with character-based word encoder.
- char_lexicon = {}
- with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi:
- for line in fpi:
- tokens = line.strip().split('\t')
- if len(tokens) == 1:
- tokens.insert(0, '\u3000')
- token, i = tokens
- char_lexicon[token] = int(i)
-
- # 做一些sanity check
- for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]:
- assert special_word in char_lexicon, f"{special_word} not found in char.dic."
-
- # 从vocab中构建char_vocab
- char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG)
- # 需要保证与在里面
- char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG])
-
- for word, index in vocab:
- char_vocab.add_word_lst(list(word))
-
- self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx
- # 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示)
- char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']),
- padding_idx=len(char_vocab))
-
- # 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict
- elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu')
-
- char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight']
-
- found_char_count = 0
- for char, index in char_vocab: # 调整character embedding
- if char in char_lexicon:
- index_in_pre = char_lexicon.get(char)
- found_char_count += 1
- else:
- index_in_pre = char_lexicon[OOV_TAG]
- char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre]
-
- print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.")
- # 生成words到chars的映射
- max_chars = config['char_cnn']['max_characters_per_token']
-
- self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars),
- fill_value=len(char_vocab),
- dtype=torch.long),
- requires_grad=False)
- for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]:
- if len(word) + 2 > max_chars:
- word = word[:max_chars - 2]
- if index == self._pad_index:
- continue
- elif word == BOS_TAG or word == EOS_TAG:
- char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(word)] + [
- char_vocab.to_index(EOW_TAG)]
- char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
- else:
- char_ids = [char_vocab.to_index(BOW_TAG)] + [char_vocab.to_index(c) for c in word] + [
- char_vocab.to_index(EOW_TAG)]
- char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids))
- self.words_to_chars_embedding[index] = torch.LongTensor(char_ids)
-
- self.char_vocab = char_vocab
-
- self.token_embedder = ConvTokenEmbedder(
- config, self.weight_file, None, char_emb_layer)
- elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight
- self.token_embedder.load_state_dict(elmo_model["char_cnn"])
-
- self.output_dim = config['lstm']['projection_dim']
-
- # lstm encoder
- self.encoder = ElmobiLm(config)
- self.encoder.load_state_dict(elmo_model["lstm"])
-
- if cache_word_reprs:
- if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用
- print("Start to generate cache word representations.")
- batch_size = 320
- # bos eos
- word_size = self.words_to_chars_embedding.size(0)
- num_batches = word_size // batch_size + \
- int(word_size % batch_size != 0)
-
- self.cached_word_embedding = nn.Embedding(word_size,
- config['lstm']['projection_dim'])
- with torch.no_grad():
- for i in range(num_batches):
- words = torch.arange(i * batch_size,
- min((i + 1) * batch_size, word_size)).long()
- chars = self.words_to_chars_embedding[words].unsqueeze(1) # batch_size x 1 x max_chars
- word_reprs = self.token_embedder(words.unsqueeze(1),
- chars).detach() # batch_size x 1 x config['encoder']['projection_dim']
- self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1)
-
- print("Finish generating cached word representations. Going to delete the character encoder.")
- del self.token_embedder, self.words_to_chars_embedding
- else:
- print("There is no need to cache word representations, since no character information is used.")
-
- def forward(self, words):
- """
-
- :param words: batch_size x max_len
- :return: num_layers x batch_size x max_len x hidden_size
- """
- # 扩展,
- batch_size, max_len = words.size()
- expanded_words = words.new_zeros(batch_size, max_len + 2) # 因为pad一定为0,
- seq_len = words.ne(self._pad_index).sum(dim=-1)
- expanded_words[:, 1:-1] = words
- expanded_words[:, 0].fill_(self.bos_index)
- expanded_words[torch.arange(batch_size).to(words), seq_len + 1] = self.eos_index
- seq_len = seq_len + 2
- zero_tensor = expanded_words.new_zeros(expanded_words.shape)
- mask = (expanded_words == zero_tensor).unsqueeze(-1)
- if hasattr(self, 'cached_word_embedding'):
- token_embedding = self.cached_word_embedding(expanded_words)
- else:
- if hasattr(self, 'words_to_chars_embedding'):
- chars = self.words_to_chars_embedding[expanded_words]
- else:
- chars = None
- token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim
-
- encoder_output = self.encoder(token_embedding, seq_len)
- if encoder_output.size(2) < max_len + 2:
- num_layers, _, output_len, hidden_size = encoder_output.size()
- dummy_tensor = encoder_output.new_zeros(num_layers, batch_size,
- max_len + 2 - output_len, hidden_size)
- encoder_output = torch.cat((encoder_output, dummy_tensor), 2)
- sz = encoder_output.size() # 2, batch_size, max_len, hidden_size
- token_embedding = token_embedding.masked_fill(mask, 0)
- token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3])
- encoder_output = torch.cat((token_embedding, encoder_output), dim=0)
-
- # 删除, . 这里没有精确地删除,但应该也不会影响最后的结果了。
- encoder_output = encoder_output[:, :, 1:-1]
- return encoder_output
diff --git a/fastNLP/modules/encoder/attention.py b/fastNLP/modules/encoder/attention.py
index 0a42d889..02bd078a 100644
--- a/fastNLP/modules/encoder/attention.py
+++ b/fastNLP/modules/encoder/attention.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"MultiHeadAttention"
]
@@ -8,8 +10,6 @@ import torch
import torch.nn.functional as F
from torch import nn
-from fastNLP.modules.dropout import TimestepDropout
-
from fastNLP.modules.utils import initial_parameter
@@ -18,7 +18,7 @@ class DotAttention(nn.Module):
.. todo::
补上文档
"""
-
+
def __init__(self, key_size, value_size, dropout=0.0):
super(DotAttention, self).__init__()
self.key_size = key_size
@@ -26,7 +26,7 @@ class DotAttention(nn.Module):
self.scale = math.sqrt(key_size)
self.drop = nn.Dropout(dropout)
self.softmax = nn.Softmax(dim=2)
-
+
def forward(self, Q, K, V, mask_out=None):
"""
@@ -45,7 +45,7 @@ class DotAttention(nn.Module):
class MultiHeadAttention(nn.Module):
"""
- 别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.attention.MultiHeadAttention`
+ 别名::class:`fastNLP.modules.MultiHeadAttention` :class:`fastNLP.modules.encoder.MultiHeadAttention`
:param input_size: int, 输入维度的大小。同时也是输出维度的大小。
:param key_size: int, 每个head的维度大小。
@@ -53,14 +53,14 @@ class MultiHeadAttention(nn.Module):
:param num_head: int,head的数量。
:param dropout: float。
"""
-
+
def __init__(self, input_size, key_size, value_size, num_head, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.input_size = input_size
self.key_size = key_size
self.value_size = value_size
self.num_head = num_head
-
+
in_size = key_size * num_head
self.q_in = nn.Linear(input_size, in_size)
self.k_in = nn.Linear(input_size, in_size)
@@ -69,14 +69,14 @@ class MultiHeadAttention(nn.Module):
self.attention = DotAttention(key_size=key_size, value_size=value_size, dropout=dropout)
self.out = nn.Linear(value_size * num_head, input_size)
self.reset_parameters()
-
+
def reset_parameters(self):
sqrt = math.sqrt
nn.init.normal_(self.q_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.k_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.key_size)))
nn.init.normal_(self.v_in.weight, mean=0, std=sqrt(2.0 / (self.input_size + self.value_size)))
nn.init.xavier_normal_(self.out.weight)
-
+
def forward(self, Q, K, V, atte_mask_out=None):
"""
@@ -92,7 +92,7 @@ class MultiHeadAttention(nn.Module):
q = self.q_in(Q).view(batch, sq, n_head, d_k)
k = self.k_in(K).view(batch, sk, n_head, d_k)
v = self.v_in(V).view(batch, sk, n_head, d_v)
-
+
# transpose q, k and v to do batch attention
q = q.permute(2, 0, 1, 3).contiguous().view(-1, sq, d_k)
k = k.permute(2, 0, 1, 3).contiguous().view(-1, sk, d_k)
@@ -100,7 +100,7 @@ class MultiHeadAttention(nn.Module):
if atte_mask_out is not None:
atte_mask_out = atte_mask_out.repeat(n_head, 1, 1)
atte = self.attention(q, k, v, atte_mask_out).view(n_head, batch, sq, d_v)
-
+
# concat all heads, do output linear
atte = atte.permute(1, 2, 0, 3).contiguous().view(batch, sq, -1)
output = self.out(atte)
@@ -124,11 +124,11 @@ class BiAttention(nn.Module):
\end{array}
"""
-
+
def __init__(self):
super(BiAttention, self).__init__()
self.inf = 10e12
-
+
def forward(self, in_x1, in_x2, x1_len, x2_len):
"""
:param torch.Tensor in_x1: [batch_size, x1_seq_len, hidden_size] 第一句的特征表示
@@ -139,36 +139,36 @@ class BiAttention(nn.Module):
torch.Tensor out_x2: [batch_size, x2_seq_len, hidden_size] 第一句attend到的特征表示
"""
-
+
assert in_x1.size()[0] == in_x2.size()[0]
assert in_x1.size()[2] == in_x2.size()[2]
# The batch size and hidden size must be equal.
assert in_x1.size()[1] == x1_len.size()[1] and in_x2.size()[1] == x2_len.size()[1]
# The seq len in in_x and x_len must be equal.
assert in_x1.size()[0] == x1_len.size()[0] and x1_len.size()[0] == x2_len.size()[0]
-
+
batch_size = in_x1.size()[0]
x1_max_len = in_x1.size()[1]
x2_max_len = in_x2.size()[1]
-
+
in_x2_t = torch.transpose(in_x2, 1, 2) # [batch_size, hidden_size, x2_seq_len]
-
+
attention_matrix = torch.bmm(in_x1, in_x2_t) # [batch_size, x1_seq_len, x2_seq_len]
-
+
a_mask = x1_len.le(0.5).float() * -self.inf # [batch_size, x1_seq_len]
a_mask = a_mask.view(batch_size, x1_max_len, -1)
a_mask = a_mask.expand(-1, -1, x2_max_len) # [batch_size, x1_seq_len, x2_seq_len]
b_mask = x2_len.le(0.5).float() * -self.inf
b_mask = b_mask.view(batch_size, -1, x2_max_len)
b_mask = b_mask.expand(-1, x1_max_len, -1) # [batch_size, x1_seq_len, x2_seq_len]
-
+
attention_a = F.softmax(attention_matrix + a_mask, dim=2) # [batch_size, x1_seq_len, x2_seq_len]
attention_b = F.softmax(attention_matrix + b_mask, dim=1) # [batch_size, x1_seq_len, x2_seq_len]
-
+
out_x1 = torch.bmm(attention_a, in_x2) # [batch_size, x1_seq_len, hidden_size]
attention_b_t = torch.transpose(attention_b, 1, 2)
out_x2 = torch.bmm(attention_b_t, in_x1) # [batch_size, x2_seq_len, hidden_size]
-
+
return out_x1, out_x2
@@ -182,10 +182,10 @@ class SelfAttention(nn.Module):
:param float drop: dropout概率,默认值为0.5
:param str initial_method: 初始化参数方法
"""
-
+
def __init__(self, input_size, attention_unit=300, attention_hops=10, drop=0.5, initial_method=None, ):
super(SelfAttention, self).__init__()
-
+
self.attention_hops = attention_hops
self.ws1 = nn.Linear(input_size, attention_unit, bias=False)
self.ws2 = nn.Linear(attention_unit, attention_hops, bias=False)
@@ -194,7 +194,7 @@ class SelfAttention(nn.Module):
self.drop = nn.Dropout(drop)
self.tanh = nn.Tanh()
initial_parameter(self, initial_method)
-
+
def _penalization(self, attention):
"""
compute the penalization term for attention module
@@ -208,7 +208,7 @@ class SelfAttention(nn.Module):
mat = torch.bmm(attention, attention_t) - self.I[:attention.size(0)]
ret = (torch.sum(torch.sum((mat ** 2), 2), 1).squeeze() + 1e-10) ** 0.5
return torch.sum(ret) / size[0]
-
+
def forward(self, input, input_origin):
"""
:param torch.Tensor input: [baz, senLen, h_dim] 要做attention的矩阵
@@ -218,14 +218,14 @@ class SelfAttention(nn.Module):
"""
input = input.contiguous()
size = input.size() # [bsz, len, nhid]
-
+
input_origin = input_origin.expand(self.attention_hops, -1, -1) # [hops,baz, len]
input_origin = input_origin.transpose(0, 1).contiguous() # [baz, hops,len]
-
+
y1 = self.tanh(self.ws1(self.drop(input))) # [baz,len,dim] -->[bsz,len, attention-unit]
attention = self.ws2(y1).transpose(1, 2).contiguous()
# [bsz,len, attention-unit]--> [bsz, len, hop]--> [baz,hop,len]
-
+
attention = attention + (-999999 * (input_origin == 0).float()) # remove the weight on padding token.
attention = F.softmax(attention, 2) # [baz ,hop, len]
return torch.bmm(attention, input), self._penalization(attention) # output1 --> [baz ,hop ,nhid]
diff --git a/fastNLP/modules/encoder/bert.py b/fastNLP/modules/encoder/bert.py
index 1819cc69..5026f48a 100644
--- a/fastNLP/modules/encoder/bert.py
+++ b/fastNLP/modules/encoder/bert.py
@@ -1,79 +1,925 @@
+"""undocumented
+这个页面的代码很大程度上参考(复制粘贴)了https://github.com/huggingface/pytorch-pretrained-BERT的代码, 如果你发现该代码对你
+ 有用,也请引用一下他们。
+"""
+__all__ = [
+ "BertModel"
+]
+
+import collections
+import copy
+import json
+import math
import os
-from torch import nn
+import unicodedata
+
import torch
-from ...io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
-from ._bert import _WordPieceBertModel, BertModel
+from torch import nn
+
+from ..utils import _get_file_name_base_on_postfix
+from ...core import logger
+CONFIG_FILE = 'bert_config.json'
+VOCAB_NAME = 'vocab.txt'
-class BertWordPieceEncoder(nn.Module):
+
+
+class BertConfig(object):
+ """Configuration class to store the configuration of a `BertModel`.
"""
- 读取bert模型,读取之后调用index_dataset方法在dataset中生成word_pieces这一列。
- :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为``en-base-uncased``
- :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
- :param bool requires_grad: 是否需要gradient。
+ def __init__(self,
+ vocab_size_or_config_json_file,
+ hidden_size=768,
+ num_hidden_layers=12,
+ num_attention_heads=12,
+ intermediate_size=3072,
+ hidden_act="gelu",
+ hidden_dropout_prob=0.1,
+ attention_probs_dropout_prob=0.1,
+ max_position_embeddings=512,
+ type_vocab_size=2,
+ initializer_range=0.02,
+ layer_norm_eps=1e-12):
+ """Constructs BertConfig.
+
+ Args:
+ vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `BertModel`.
+ hidden_size: Size of the encoder layers and the pooler layer.
+ num_hidden_layers: Number of hidden layers in the Transformer encoder.
+ num_attention_heads: Number of attention heads for each attention layer in
+ the Transformer encoder.
+ intermediate_size: The size of the "intermediate" (i.e., feed-forward)
+ layer in the Transformer encoder.
+ hidden_act: The non-linear activation function (function or string) in the
+ encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
+ hidden_dropout_prob: The dropout probabilitiy for all fully connected
+ layers in the embeddings, encoder, and pooler.
+ attention_probs_dropout_prob: The dropout ratio for the attention
+ probabilities.
+ max_position_embeddings: The maximum sequence length that this model might
+ ever be used with. Typically set this to something large just in case
+ (e.g., 512 or 1024 or 2048).
+ type_vocab_size: The vocabulary size of the `token_type_ids` passed into
+ `BertModel`.
+ initializer_range: The sttdev of the truncated_normal_initializer for
+ initializing all weight matrices.
+ layer_norm_eps: The epsilon used by LayerNorm.
+ """
+ if isinstance(vocab_size_or_config_json_file, str):
+ with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
+ json_config = json.loads(reader.read())
+ for key, value in json_config.items():
+ self.__dict__[key] = value
+ elif isinstance(vocab_size_or_config_json_file, int):
+ self.vocab_size = vocab_size_or_config_json_file
+ self.hidden_size = hidden_size
+ self.num_hidden_layers = num_hidden_layers
+ self.num_attention_heads = num_attention_heads
+ self.hidden_act = hidden_act
+ self.intermediate_size = intermediate_size
+ self.hidden_dropout_prob = hidden_dropout_prob
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
+ self.max_position_embeddings = max_position_embeddings
+ self.type_vocab_size = type_vocab_size
+ self.initializer_range = initializer_range
+ self.layer_norm_eps = layer_norm_eps
+ else:
+ raise ValueError("First argument must be either a vocabulary size (int)"
+ "or the path to a pretrained model config file (str)")
+
+ @classmethod
+ def from_dict(cls, json_object):
+ """Constructs a `BertConfig` from a Python dictionary of parameters."""
+ config = BertConfig(vocab_size_or_config_json_file=-1)
+ for key, value in json_object.items():
+ config.__dict__[key] = value
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `BertConfig` from a json file of parameters."""
+ with open(json_file, "r", encoding='utf-8') as reader:
+ text = reader.read()
+ return cls.from_dict(json.loads(text))
+
+ def __repr__(self):
+ return str(self.to_json_string())
+
+ def to_dict(self):
+ """Serializes this instance to a Python dictionary."""
+ output = copy.deepcopy(self.__dict__)
+ return output
+
+ def to_json_string(self):
+ """Serializes this instance to a JSON string."""
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
+
+ def to_json_file(self, json_file_path):
+ """ Save this instance to a json file."""
+ with open(json_file_path, "w", encoding='utf-8') as writer:
+ writer.write(self.to_json_string())
+
+
+def gelu(x):
+ return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
+
+
+def swish(x):
+ return x * torch.sigmoid(x)
+
+
+ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
+
+
+class BertLayerNorm(nn.Module):
+ def __init__(self, hidden_size, eps=1e-12):
+ """Construct a layernorm module in the TF style (epsilon inside the square root).
+ """
+ super(BertLayerNorm, self).__init__()
+ self.weight = nn.Parameter(torch.ones(hidden_size))
+ self.bias = nn.Parameter(torch.zeros(hidden_size))
+ self.variance_epsilon = eps
+
+ def forward(self, x):
+ u = x.mean(-1, keepdim=True)
+ s = (x - u).pow(2).mean(-1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.variance_epsilon)
+ return self.weight * x + self.bias
+
+
+class BertEmbeddings(nn.Module):
+ """Construct the embeddings from word, position and token_type embeddings.
"""
- def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1',
- requires_grad: bool=False):
- super().__init__()
- PRETRAIN_URL = _get_base_url('bert')
-
- if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR:
- model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
- model_url = PRETRAIN_URL + model_name
- model_dir = cached_path(model_url)
- # 检查是否存在
- elif os.path.isdir(model_dir_or_name):
- model_dir = model_dir_or_name
+
+ def __init__(self, config):
+ super(BertEmbeddings, self).__init__()
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
+
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
+ # any TensorFlow checkpoint file
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, input_ids, token_type_ids=None):
+ seq_length = input_ids.size(1)
+ position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
+ position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ words_embeddings = self.word_embeddings(input_ids)
+ position_embeddings = self.position_embeddings(position_ids)
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
+
+ embeddings = words_embeddings + position_embeddings + token_type_embeddings
+ embeddings = self.LayerNorm(embeddings)
+ embeddings = self.dropout(embeddings)
+ return embeddings
+
+
+class BertSelfAttention(nn.Module):
+ def __init__(self, config):
+ super(BertSelfAttention, self).__init__()
+ if config.hidden_size % config.num_attention_heads != 0:
+ raise ValueError(
+ "The hidden size (%d) is not a multiple of the number of attention "
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
+ self.num_attention_heads = config.num_attention_heads
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
+
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
+
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
+
+ def transpose_for_scores(self, x):
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
+ x = x.view(*new_x_shape)
+ return x.permute(0, 2, 1, 3)
+
+ def forward(self, hidden_states, attention_mask):
+ mixed_query_layer = self.query(hidden_states)
+ mixed_key_layer = self.key(hidden_states)
+ mixed_value_layer = self.value(hidden_states)
+
+ query_layer = self.transpose_for_scores(mixed_query_layer)
+ key_layer = self.transpose_for_scores(mixed_key_layer)
+ value_layer = self.transpose_for_scores(mixed_value_layer)
+
+ # Take the dot product between "query" and "key" to get the raw attention scores.
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
+ attention_scores = attention_scores + attention_mask
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
+
+ # This is actually dropping out entire tokens to attend to, which might
+ # seem a bit unusual, but is taken from the original Transformer paper.
+ attention_probs = self.dropout(attention_probs)
+
+ context_layer = torch.matmul(attention_probs, value_layer)
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
+ context_layer = context_layer.view(*new_context_layer_shape)
+ return context_layer
+
+
+class BertSelfOutput(nn.Module):
+ def __init__(self, config):
+ super(BertSelfOutput, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertAttention(nn.Module):
+ def __init__(self, config):
+ super(BertAttention, self).__init__()
+ self.self = BertSelfAttention(config)
+ self.output = BertSelfOutput(config)
+
+ def forward(self, input_tensor, attention_mask):
+ self_output = self.self(input_tensor, attention_mask)
+ attention_output = self.output(self_output, input_tensor)
+ return attention_output
+
+
+class BertIntermediate(nn.Module):
+ def __init__(self, config):
+ super(BertIntermediate, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
+ if isinstance(config.hidden_act, str):
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
else:
- raise ValueError(f"Cannot recognize {model_dir_or_name}.")
+ self.intermediate_act_fn = config.hidden_act
+
+ def forward(self, hidden_states):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.intermediate_act_fn(hidden_states)
+ return hidden_states
+
+
+class BertOutput(nn.Module):
+ def __init__(self, config):
+ super(BertOutput, self).__init__()
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
+ self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
+
+ def forward(self, hidden_states, input_tensor):
+ hidden_states = self.dense(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
+ return hidden_states
+
+
+class BertLayer(nn.Module):
+ def __init__(self, config):
+ super(BertLayer, self).__init__()
+ self.attention = BertAttention(config)
+ self.intermediate = BertIntermediate(config)
+ self.output = BertOutput(config)
+
+ def forward(self, hidden_states, attention_mask):
+ attention_output = self.attention(hidden_states, attention_mask)
+ intermediate_output = self.intermediate(attention_output)
+ layer_output = self.output(intermediate_output, attention_output)
+ return layer_output
+
+
+class BertEncoder(nn.Module):
+ def __init__(self, config):
+ super(BertEncoder, self).__init__()
+ layer = BertLayer(config)
+ self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(config.num_hidden_layers)])
+
+ def forward(self, hidden_states, attention_mask, output_all_encoded_layers=True):
+ all_encoder_layers = []
+ for layer_module in self.layer:
+ hidden_states = layer_module(hidden_states, attention_mask)
+ if output_all_encoded_layers:
+ all_encoder_layers.append(hidden_states)
+ if not output_all_encoded_layers:
+ all_encoder_layers.append(hidden_states)
+ return all_encoder_layers
+
+
+class BertPooler(nn.Module):
+ def __init__(self, config):
+ super(BertPooler, self).__init__()
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
+ self.activation = nn.Tanh()
+
+ def forward(self, hidden_states):
+ # We "pool" the model by simply taking the hidden state corresponding
+ # to the first token.
+ first_token_tensor = hidden_states[:, 0]
+ pooled_output = self.dense(first_token_tensor)
+ pooled_output = self.activation(pooled_output)
+ return pooled_output
+
+
+class BertModel(nn.Module):
+ """
+ 别名::class:`fastNLP.modules.BertModel` :class:`fastNLP.modules.encoder.BertModel`
+
+ BERT(Bidirectional Embedding Representations from Transformers).
+
+ 如果你想使用预训练好的权重矩阵,请在以下网址下载.
+ sources::
+
+ 'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
+ 'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
+ 'bert-base-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
+ 'bert-large-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
+ 'bert-base-multilingual-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
+ 'bert-base-multilingual-cased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
+ 'bert-base-chinese': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
+ 'bert-base-german-cased': "https://int-deepset-models-bert.s3.eu-central-1.amazonaws.com/pytorch/bert-base-german-cased-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
+ 'bert-large-uncased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
+ 'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin"
+
+
+ 用预训练权重矩阵来建立BERT模型::
+
+ model = BertModel.from_pretrained("path/to/weights/directory")
- self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers)
- self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size
- self.requires_grad = requires_grad
+ 用随机初始化权重矩阵来建立BERT模型::
- @property
- def requires_grad(self):
+ model = BertModel()
+
+ :param int vocab_size: 词表大小,默认值为30522,为BERT English uncase版本的词表大小
+ :param int hidden_size: 隐层大小,默认值为768,为BERT base的版本
+ :param int num_hidden_layers: 隐藏层数,默认值为12,为BERT base的版本
+ :param int num_attention_heads: 多头注意力头数,默认值为12,为BERT base的版本
+ :param int intermediate_size: FFN隐藏层大小,默认值是3072,为BERT base的版本
+ :param str hidden_act: FFN隐藏层激活函数,默认值为``gelu``
+ :param float hidden_dropout_prob: FFN隐藏层dropout,默认值为0.1
+ :param float attention_probs_dropout_prob: Attention层的dropout,默认值为0.1
+ :param int max_position_embeddings: 最大的序列长度,默认值为512,
+ :param int type_vocab_size: 最大segment数量,默认值为2
+ :param int initializer_range: 初始化权重范围,默认值为0.02
+ """
+
+ def __init__(self, config, *inputs, **kwargs):
+ super(BertModel, self).__init__()
+ if not isinstance(config, BertConfig):
+ raise ValueError(
+ "Parameter config in `{}(config)` should be an instance of class `BertConfig`. "
+ "To create a model from a Google pretrained model use "
+ "`model = {}.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
+ self.__class__.__name__, self.__class__.__name__
+ ))
+ super(BertModel, self).__init__()
+ self.config = config
+ self.hidden_size = self.config.hidden_size
+ self.embeddings = BertEmbeddings(config)
+ self.encoder = BertEncoder(config)
+ self.pooler = BertPooler(config)
+ self.apply(self.init_bert_weights)
+
+ def init_bert_weights(self, module):
+ """ Initialize the weights.
+ """
+ if isinstance(module, (nn.Linear, nn.Embedding)):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, BertLayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ if isinstance(module, nn.Linear) and module.bias is not None:
+ module.bias.data.zero_()
+
+ def forward(self, input_ids, token_type_ids=None, attention_mask=None, output_all_encoded_layers=True):
+ if attention_mask is None:
+ attention_mask = torch.ones_like(input_ids)
+ if token_type_ids is None:
+ token_type_ids = torch.zeros_like(input_ids)
+
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
+
+ embedding_output = self.embeddings(input_ids, token_type_ids)
+ encoded_layers = self.encoder(embedding_output,
+ extended_attention_mask,
+ output_all_encoded_layers=output_all_encoded_layers)
+ sequence_output = encoded_layers[-1]
+ pooled_output = self.pooler(sequence_output)
+ if not output_all_encoded_layers:
+ encoded_layers = encoded_layers[-1]
+ return encoded_layers, pooled_output
+
+ @classmethod
+ def from_pretrained(cls, pretrained_model_dir, *inputs, **kwargs):
+ state_dict = kwargs.get('state_dict', None)
+ kwargs.pop('state_dict', None)
+ kwargs.pop('cache_dir', None)
+ kwargs.pop('from_tf', None)
+ # Load config
+ config_file = _get_file_name_base_on_postfix(pretrained_model_dir, '.json')
+ config = BertConfig.from_json_file(config_file)
+ # logger.info("Model config {}".format(config))
+ # Instantiate model.
+ model = cls(config, *inputs, **kwargs)
+ if state_dict is None:
+ weights_path = _get_file_name_base_on_postfix(pretrained_model_dir, '.bin')
+ state_dict = torch.load(weights_path, map_location='cpu')
+
+ old_keys = []
+ new_keys = []
+ for key in state_dict.keys():
+ new_key = None
+ if 'gamma' in key:
+ new_key = key.replace('gamma', 'weight')
+ if 'beta' in key:
+ new_key = key.replace('beta', 'bias')
+ if new_key:
+ old_keys.append(key)
+ new_keys.append(new_key)
+ for old_key, new_key in zip(old_keys, new_keys):
+ state_dict[new_key] = state_dict.pop(old_key)
+
+ missing_keys = []
+ unexpected_keys = []
+ error_msgs = []
+ # copy state_dict so _load_from_state_dict can modify it
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ def load(module, prefix=''):
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
+ module._load_from_state_dict(
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(model, prefix='' if hasattr(model, 'bert') else 'bert.')
+ if len(missing_keys) > 0:
+ logger.warn("Weights of {} not initialized from pretrained model: {}".format(
+ model.__class__.__name__, missing_keys))
+ if len(unexpected_keys) > 0:
+ logger.warn("Weights from pretrained model not used in {}: {}".format(
+ model.__class__.__name__, unexpected_keys))
+ return model
+
+
+def whitespace_tokenize(text):
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
+ text = text.strip()
+ if not text:
+ return []
+ tokens = text.split()
+ return tokens
+
+
+class WordpieceTokenizer(object):
+ """Runs WordPiece tokenization."""
+
+ def __init__(self, vocab, unk_token="[UNK]", max_input_chars_per_word=100):
+ self.vocab = vocab
+ self.unk_token = unk_token
+ self.max_input_chars_per_word = max_input_chars_per_word
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text into its word pieces.
+
+ This uses a greedy longest-match-first algorithm to perform tokenization
+ using the given vocabulary.
+
+ For example:
+ input = "unaffable"
+ output = ["un", "##aff", "##able"]
+
+ Args:
+ text: A single token or whitespace separated tokens. This should have
+ already been passed through `BasicTokenizer`.
+
+ Returns:
+ A list of wordpiece tokens.
+ """
+
+ output_tokens = []
+ for token in whitespace_tokenize(text):
+ chars = list(token)
+ if len(chars) > self.max_input_chars_per_word:
+ output_tokens.append(self.unk_token)
+ continue
+
+ is_bad = False
+ start = 0
+ sub_tokens = []
+ while start < len(chars):
+ end = len(chars)
+ cur_substr = None
+ while start < end:
+ substr = "".join(chars[start:end])
+ if start > 0:
+ substr = "##" + substr
+ if substr in self.vocab:
+ cur_substr = substr
+ break
+ end -= 1
+ if cur_substr is None:
+ is_bad = True
+ break
+ sub_tokens.append(cur_substr)
+ start = end
+
+ if is_bad:
+ output_tokens.append(self.unk_token)
+ else:
+ output_tokens.extend(sub_tokens)
+ if len(output_tokens)==0: #防止里面全是空格或者回车符号
+ return [self.unk_token]
+ return output_tokens
+
+
+def load_vocab(vocab_file):
+ """Loads a vocabulary file into a dictionary."""
+ vocab = collections.OrderedDict()
+ index = 0
+ with open(vocab_file, "r", encoding="utf-8") as reader:
+ while True:
+ token = reader.readline()
+ if not token:
+ break
+ token = token.strip()
+ vocab[token] = index
+ index += 1
+ return vocab
+
+
+class BasicTokenizer(object):
+ """Runs basic tokenization (punctuation splitting, lower casing, etc.)."""
+
+ def __init__(self,
+ do_lower_case=True,
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+ """Constructs a BasicTokenizer.
+
+ Args:
+ do_lower_case: Whether to lower case the input.
+ """
+ self.do_lower_case = do_lower_case
+ self.never_split = never_split
+
+ def tokenize(self, text):
+ """Tokenizes a piece of text."""
+ text = self._clean_text(text)
+ # This was added on November 1st, 2018 for the multilingual and Chinese
+ # models. This is also applied to the English models now, but it doesn't
+ # matter since the English models were not trained on any Chinese data
+ # and generally don't have any Chinese data in them (there are Chinese
+ # characters in the vocabulary because Wikipedia does have some Chinese
+ # words in the English Wikipedia.).
+ text = self._tokenize_chinese_chars(text)
+ orig_tokens = whitespace_tokenize(text)
+ split_tokens = []
+ for token in orig_tokens:
+ if self.do_lower_case and token not in self.never_split:
+ token = token.lower()
+ token = self._run_strip_accents(token)
+ split_tokens.extend(self._run_split_on_punc(token))
+
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
+ return output_tokens
+
+ def _run_strip_accents(self, text):
+ """Strips accents from a piece of text."""
+ text = unicodedata.normalize("NFD", text)
+ output = []
+ for char in text:
+ cat = unicodedata.category(char)
+ if cat == "Mn":
+ continue
+ output.append(char)
+ return "".join(output)
+
+ def _run_split_on_punc(self, text):
+ """Splits punctuation on a piece of text."""
+ if text in self.never_split:
+ return [text]
+ chars = list(text)
+ i = 0
+ start_new_word = True
+ output = []
+ while i < len(chars):
+ char = chars[i]
+ if _is_punctuation(char):
+ output.append([char])
+ start_new_word = True
+ else:
+ if start_new_word:
+ output.append([])
+ start_new_word = False
+ output[-1].append(char)
+ i += 1
+
+ return ["".join(x) for x in output]
+
+ def _tokenize_chinese_chars(self, text):
+ """Adds whitespace around any CJK character."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if self._is_chinese_char(cp):
+ output.append(" ")
+ output.append(char)
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+ def _is_chinese_char(self, cp):
+ """Checks whether CP is the codepoint of a CJK character."""
+ # This defines a "chinese character" as anything in the CJK Unicode block:
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
+ #
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
+ # despite its name. The modern Korean Hangul alphabet is a different block,
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
+ # space-separated words, so they are not treated specially and handled
+ # like the all of the other languages.
+ if ((cp >= 0x4E00 and cp <= 0x9FFF) or #
+ (cp >= 0x3400 and cp <= 0x4DBF) or #
+ (cp >= 0x20000 and cp <= 0x2A6DF) or #
+ (cp >= 0x2A700 and cp <= 0x2B73F) or #
+ (cp >= 0x2B740 and cp <= 0x2B81F) or #
+ (cp >= 0x2B820 and cp <= 0x2CEAF) or
+ (cp >= 0xF900 and cp <= 0xFAFF) or #
+ (cp >= 0x2F800 and cp <= 0x2FA1F)): #
+ return True
+
+ return False
+
+ def _clean_text(self, text):
+ """Performs invalid character removal and whitespace cleanup on text."""
+ output = []
+ for char in text:
+ cp = ord(char)
+ if cp == 0 or cp == 0xfffd or _is_control(char):
+ continue
+ if _is_whitespace(char):
+ output.append(" ")
+ else:
+ output.append(char)
+ return "".join(output)
+
+
+def _is_whitespace(char):
+ """Checks whether `chars` is a whitespace character."""
+ # \t, \n, and \r are technically contorl characters but we treat them
+ # as whitespace since they are generally considered as such.
+ if char == " " or char == "\t" or char == "\n" or char == "\r":
+ return True
+ cat = unicodedata.category(char)
+ if cat == "Zs":
+ return True
+ return False
+
+
+def _is_control(char):
+ """Checks whether `chars` is a control character."""
+ # These are technically control characters but we count them as whitespace
+ # characters.
+ if char == "\t" or char == "\n" or char == "\r":
+ return False
+ cat = unicodedata.category(char)
+ if cat.startswith("C"):
+ return True
+ return False
+
+
+def _is_punctuation(char):
+ """Checks whether `chars` is a punctuation character."""
+ cp = ord(char)
+ # We treat all non-letter/number ASCII as punctuation.
+ # Characters such as "^", "$", and "`" are not in the Unicode
+ # Punctuation class but we treat them as punctuation anyways, for
+ # consistency.
+ if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
+ (cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
+ return True
+ cat = unicodedata.category(char)
+ if cat.startswith("P"):
+ return True
+ return False
+
+
+class BertTokenizer(object):
+ """Runs end-to-end tokenization: punctuation splitting + wordpiece"""
+
+ def __init__(self, vocab_file, do_lower_case=True, max_len=None, do_basic_tokenize=True,
+ never_split=("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]")):
+ """Constructs a BertTokenizer.
+
+ Args:
+ vocab_file: Path to a one-wordpiece-per-line vocabulary file
+ do_lower_case: Whether to lower case the input
+ Only has an effect when do_wordpiece_only=False
+ do_basic_tokenize: Whether to do basic tokenization before wordpiece.
+ max_len: An artificial maximum length to truncate tokenized sequences to;
+ Effective maximum length is always the minimum of this
+ value (if specified) and the underlying BERT model's
+ sequence length.
+ never_split: List of tokens which will never be split during tokenization.
+ Only has an effect when do_wordpiece_only=False
+ """
+ if not os.path.isfile(vocab_file):
+ raise ValueError(
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(vocab_file))
+ self.vocab = load_vocab(vocab_file)
+ self.ids_to_tokens = collections.OrderedDict(
+ [(ids, tok) for tok, ids in self.vocab.items()])
+ self.do_basic_tokenize = do_basic_tokenize
+ if do_basic_tokenize:
+ self.basic_tokenizer = BasicTokenizer(do_lower_case=do_lower_case,
+ never_split=never_split)
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+ self.max_len = max_len if max_len is not None else int(1e12)
+
+ def _reinit_on_new_vocab(self, vocab):
"""
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
+ 在load bert之后,可能会对vocab进行重新排列。重新排列之后调用这个函数重新初始化与vocab相关的性质
+
+ :param vocab:
:return:
"""
- requires_grads = set([param.requires_grad for name, param in self.named_parameters()])
- if len(requires_grads)==1:
- return requires_grads.pop()
+ self.vocab = vocab
+ self.wordpiece_tokenizer = WordpieceTokenizer(vocab=self.vocab)
+
+ def tokenize(self, text):
+ split_tokens = []
+ if self.do_basic_tokenize:
+ for token in self.basic_tokenizer.tokenize(text):
+ for sub_token in self.wordpiece_tokenizer.tokenize(token):
+ split_tokens.append(sub_token)
else:
- return None
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
+ return split_tokens
+
+ def convert_tokens_to_ids(self, tokens):
+ """Converts a sequence of tokens into ids using the vocab."""
+ ids = []
+ for token in tokens:
+ ids.append(self.vocab[token])
+ if len(ids) > self.max_len:
+ logger.warn(
+ "Token indices sequence length is longer than the specified maximum "
+ " sequence length for this BERT model ({} > {}). Running this"
+ " sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
+ )
+ return ids
+
+ def convert_ids_to_tokens(self, ids):
+ """Converts a sequence of ids in wordpiece tokens using the vocab."""
+ tokens = []
+ for i in ids:
+ tokens.append(self.ids_to_tokens[i])
+ return tokens
+
+ def save_vocabulary(self, vocab_path):
+ """Save the tokenizer vocabulary to a directory or file."""
+ index = 0
+ if os.path.isdir(vocab_path):
+ vocab_file = os.path.join(vocab_path, VOCAB_NAME)
+ else:
+ vocab_file = vocab_path
+ with open(vocab_file, "w", encoding="utf-8") as writer:
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
+ if index != token_index:
+ logger.warn("Saving vocabulary to {}: vocabulary indices are not consecutive."
+ " Please check that the vocabulary is not corrupted!".format(vocab_file))
+ index = token_index
+ writer.write(token + u'\n')
+ index += 1
+ return vocab_file
+
+ @classmethod
+ def from_pretrained(cls, model_dir, *inputs, **kwargs):
+ """
+ 给定path,直接读取vocab.
+
+ """
+ pretrained_model_name_or_path = _get_file_name_base_on_postfix(model_dir, '.txt')
+ logger.info("loading vocabulary file {}".format(pretrained_model_name_or_path))
+ max_len = 512
+ kwargs['max_len'] = min(kwargs.get('max_position_embeddings', int(1e12)), max_len)
+ # Instantiate tokenizer.
+ tokenizer = cls(pretrained_model_name_or_path, *inputs, **kwargs)
+ return tokenizer
- @requires_grad.setter
- def requires_grad(self, value):
- for name, param in self.named_parameters():
- param.requires_grad = value
+class _WordPieceBertModel(nn.Module):
+ """
+ 这个模块用于直接计算word_piece的结果.
- @property
- def embed_size(self):
- return self._embed_size
+ """
+
+ def __init__(self, model_dir: str, layers: str = '-1', pooled_cls:bool=False):
+ super().__init__()
- def index_datasets(self, *datasets, field_name):
+ self.tokenzier = BertTokenizer.from_pretrained(model_dir)
+ self.encoder = BertModel.from_pretrained(model_dir)
+ # 检查encoder_layer_number是否合理
+ encoder_layer_number = len(self.encoder.encoder.layer)
+ self.layers = list(map(int, layers.split(',')))
+ for layer in self.layers:
+ if layer < 0:
+ assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a bert model with {encoder_layer_number} layers."
+ else:
+ assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \
+ f"a bert model with {encoder_layer_number} layers."
+
+ self._cls_index = self.tokenzier.vocab['[CLS]']
+ self._sep_index = self.tokenzier.vocab['[SEP]']
+ self._wordpiece_unknown_index = self.tokenzier.vocab['[UNK]']
+ self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece
+ self.pooled_cls = pooled_cls
+
+ def index_dataset(self, *datasets, field_name, add_cls_sep=True):
"""
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。
:param datasets: DataSet对象
- :param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。
+ :param field_name: 基于哪一列index
:return:
"""
- self.model.index_dataset(*datasets, field_name=field_name)
+
+ def convert_words_to_word_pieces(words):
+ word_pieces = []
+ for word in words:
+ tokens = self.tokenzier.wordpiece_tokenizer.tokenize(word)
+ word_piece_ids = self.tokenzier.convert_tokens_to_ids(tokens)
+ word_pieces.extend(word_piece_ids)
+ if add_cls_sep:
+ if word_pieces[0] != self._cls_index:
+ word_pieces.insert(0, self._cls_index)
+ if word_pieces[-1] != self._sep_index:
+ word_pieces.insert(-1, self._sep_index)
+ return word_pieces
+
+ for index, dataset in enumerate(datasets):
+ try:
+ dataset.apply_field(convert_words_to_word_pieces, field_name=field_name, new_field_name='word_pieces',
+ is_input=True)
+ dataset.set_pad_val('word_pieces', self._wordpiece_pad_index)
+ except Exception as e:
+ logger.error(f"Exception happens when processing the {index} dataset.")
+ raise e
def forward(self, word_pieces, token_type_ids=None):
"""
- 计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。
- :param words: batch_size x max_len
- :param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
+ :param word_pieces: torch.LongTensor, batch_size x max_len
+ :param token_type_ids: torch.LongTensor, batch_size x max_len
+ :return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size
"""
- outputs = self.model(word_pieces, token_type_ids)
- outputs = torch.cat([*outputs], dim=-1)
+ batch_size, max_len = word_pieces.size()
+ attn_masks = word_pieces.ne(self._wordpiece_pad_index)
+ bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks,
+ output_all_encoded_layers=True)
+ # output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size
+ outputs = bert_outputs[0].new_zeros((len(self.layers), batch_size, max_len, bert_outputs[0].size(-1)))
+ for l_index, l in enumerate(self.layers):
+ bert_output = bert_outputs[l]
+ if l in (len(bert_outputs)-1, -1) and self.pooled_cls:
+ bert_output[:, 0] = pooled_cls
+ outputs[l_index] = bert_output
return outputs
diff --git a/fastNLP/modules/encoder/char_encoder.py b/fastNLP/modules/encoder/char_encoder.py
index 6ce63d1a..e40bd0dd 100644
--- a/fastNLP/modules/encoder/char_encoder.py
+++ b/fastNLP/modules/encoder/char_encoder.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"ConvolutionCharEncoder",
"LSTMCharEncoder"
@@ -11,7 +13,7 @@ from ..utils import initial_parameter
# from torch.nn.init import xavier_uniform
class ConvolutionCharEncoder(nn.Module):
"""
- 别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.ConvolutionCharEncoder`
+ 别名::class:`fastNLP.modules.ConvolutionCharEncoder` :class:`fastNLP.modules.encoder.ConvolutionCharEncoder`
char级别的卷积编码器.
@@ -21,15 +23,16 @@ class ConvolutionCharEncoder(nn.Module):
:param tuple kernels: 一个由int组成的tuple. tuple的长度是char级别卷积操作的数目, 第`i`个int表示第`i`个卷积操作的卷积核.
:param initial_method: 初始化参数的方式, 默认为`xavier normal`
"""
-
+
def __init__(self, char_emb_size=50, feature_maps=(40, 30, 30), kernels=(1, 3, 5), initial_method=None):
super(ConvolutionCharEncoder, self).__init__()
self.convs = nn.ModuleList([
- nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True, padding=(0, kernels[i]//2))
+ nn.Conv2d(1, feature_maps[i], kernel_size=(char_emb_size, kernels[i]), bias=True,
+ padding=(0, kernels[i] // 2))
for i in range(len(kernels))])
-
+
initial_parameter(self, initial_method)
-
+
def forward(self, x):
"""
:param torch.Tensor x: ``[batch_size * sent_length, word_length, char_emb_size]`` 输入字符的embedding
@@ -40,7 +43,7 @@ class ConvolutionCharEncoder(nn.Module):
x = x.transpose(2, 3)
# [batch_size*sent_length, channel, height, width]
return self._convolute(x).unsqueeze(2)
-
+
def _convolute(self, x):
feats = []
for conv in self.convs:
@@ -57,13 +60,13 @@ class ConvolutionCharEncoder(nn.Module):
class LSTMCharEncoder(nn.Module):
"""
- 别名::class:`fastNLP.modules.LSTMCharEncoder` :class:`fastNLP.modules.encoder.char_encoder.LSTMCharEncoder`
+ 别名::class:`fastNLP.modules.LSTMCharEncoder` :class:`fastNLP.modules.encoder.LSTMCharEncoder`
char级别基于LSTM的encoder.
"""
-
+
def __init__(self, char_emb_size=50, hidden_size=None, initial_method=None):
"""
:param int char_emb_size: char级别embedding的维度. Default: 50
@@ -73,14 +76,14 @@ class LSTMCharEncoder(nn.Module):
"""
super(LSTMCharEncoder, self).__init__()
self.hidden_size = char_emb_size if hidden_size is None else hidden_size
-
+
self.lstm = nn.LSTM(input_size=char_emb_size,
hidden_size=self.hidden_size,
num_layers=1,
bias=True,
batch_first=True)
initial_parameter(self, initial_method)
-
+
def forward(self, x):
"""
:param torch.Tensor x: ``[ n_batch*n_word, word_length, char_emb_size]`` 输入字符的embedding
@@ -91,6 +94,6 @@ class LSTMCharEncoder(nn.Module):
h0 = nn.init.orthogonal_(h0)
c0 = torch.empty(1, batch_size, self.hidden_size)
c0 = nn.init.orthogonal_(c0)
-
+
_, hidden = self.lstm(x, (h0, c0))
return hidden[0].squeeze().unsqueeze(2)
diff --git a/fastNLP/modules/encoder/conv_maxpool.py b/fastNLP/modules/encoder/conv_maxpool.py
index 68605c98..68415189 100644
--- a/fastNLP/modules/encoder/conv_maxpool.py
+++ b/fastNLP/modules/encoder/conv_maxpool.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"ConvMaxpool"
]
@@ -5,9 +7,10 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
+
class ConvMaxpool(nn.Module):
"""
- 别名::class:`fastNLP.modules.ConvMaxpool` :class:`fastNLP.modules.encoder.conv_maxpool.ConvMaxpool`
+ 别名::class:`fastNLP.modules.ConvMaxpool` :class:`fastNLP.modules.encoder.ConvMaxpool`
集合了Convolution和Max-Pooling于一体的层。给定一个batch_size x max_len x input_size的输入,返回batch_size x
sum(output_channels) 大小的matrix。在内部,是先使用CNN给输入做卷积,然后经过activation激活层,在通过在长度(max_len)
@@ -18,12 +21,12 @@ class ConvMaxpool(nn.Module):
:param int,tuple(int) kernel_sizes: 输出channel的kernel大小。
:param str activation: Convolution后的结果将通过该activation后再经过max-pooling。支持relu, sigmoid, tanh
"""
-
+
def __init__(self, in_channels, out_channels, kernel_sizes, activation="relu"):
super(ConvMaxpool, self).__init__()
for kernel_size in kernel_sizes:
- assert kernel_size%2==1, "kernel size has to be odd numbers."
+ assert kernel_size % 2 == 1, "kernel size has to be odd numbers."
# convolution
if isinstance(kernel_sizes, (list, tuple, int)):
@@ -36,22 +39,22 @@ class ConvMaxpool(nn.Module):
" of kernel_sizes."
else:
raise ValueError("The type of out_channels and kernel_sizes should be the same.")
-
+
self.convs = nn.ModuleList([nn.Conv1d(
in_channels=in_channels,
out_channels=oc,
kernel_size=ks,
stride=1,
- padding=ks//2,
+ padding=ks // 2,
dilation=1,
groups=1,
bias=None)
for oc, ks in zip(out_channels, kernel_sizes)])
-
+
else:
raise Exception(
'Incorrect kernel sizes: should be list, tuple or int')
-
+
# activation function
if activation == 'relu':
self.activation = F.relu
diff --git a/fastNLP/modules/encoder/embedding.py b/fastNLP/modules/encoder/embedding.py
deleted file mode 100644
index 050a423a..00000000
--- a/fastNLP/modules/encoder/embedding.py
+++ /dev/null
@@ -1,1083 +0,0 @@
-__all__ = [
- "Embedding",
- "StaticEmbedding",
- "ElmoEmbedding",
- "BertEmbedding",
- "StackEmbedding",
- "LSTMCharEmbedding",
- "CNNCharEmbedding",
-]
-import torch.nn as nn
-from ..utils import get_embeddings
-from .lstm import LSTM
-from ...core.vocabulary import Vocabulary
-from abc import abstractmethod
-import torch
-import numpy as np
-import torch.nn.functional as F
-import os
-from ._elmo import _ElmoModel
-from ...io.file_utils import cached_path, _get_base_url
-from ._bert import _WordBertModel
-from typing import List
-
-import warnings
-from ...core.dataset import DataSet
-from ...core.batch import DataSetIter
-from ...core.sampler import SequentialSampler
-from ...core.utils import _move_model_to_device, _get_model_device
-from ...io.file_utils import PRETRAINED_BERT_MODEL_DIR, PRETRAINED_ELMO_MODEL_DIR, PRETRAIN_STATIC_FILES
-
-
-class Embedding(nn.Module):
- """
- 别名::class:`fastNLP.modules.Embedding` :class:`fastNLP.modules.encoder.embedding.Embedding`
-
- Embedding组件. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度"""
-
- def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None):
- """
-
- :param tuple(int,int),torch.FloatTensor,nn.Embedding,numpy.ndarray init_embed: Embedding的大小(传入tuple(int, int),
- 第一个int为vocab_zie, 第二个int为embed_dim); 如果为Tensor, Embedding, ndarray等则直接使用该值初始化Embedding;
- :param float word_dropout: 按照一定概率随机将word设置为unk_index,这样可以使得unk这个token得到足够的训练, 且会对网络有
- 一定的regularize的作用。
- :param float dropout: 对Embedding的输出的dropout。
- :param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。
- """
- super(Embedding, self).__init__()
-
- self.embed = get_embeddings(init_embed)
-
- self.dropout = nn.Dropout(dropout)
- if not isinstance(self.embed, TokenEmbedding):
- self._embed_size = self.embed.weight.size(1)
- if word_dropout>0 and not isinstance(unk_index, int):
- raise ValueError("When drop word is set, you need to pass in the unk_index.")
- else:
- self._embed_size = self.embed.embed_size
- unk_index = self.embed.get_word_vocab().unknown_idx
- self.unk_index = unk_index
- self.word_dropout = word_dropout
-
- def forward(self, x):
- """
- :param torch.LongTensor x: [batch, seq_len]
- :return: torch.Tensor : [batch, seq_len, embed_dim]
- """
- if self.word_dropout>0 and self.training:
- mask = torch.ones_like(x).float() * self.word_dropout
- mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
- x = x.masked_fill(mask, self.unk_index)
- x = self.embed(x)
- return self.dropout(x)
-
- @property
- def num_embedding(self)->int:
- if isinstance(self.embed, nn.Embedding):
- return self.embed.weight.size(0)
- else:
- return self.embed.num_embedding
-
- def __len__(self):
- return len(self.embed)
-
- @property
- def embed_size(self) -> int:
- return self._embed_size
-
- @property
- def embedding_dim(self) -> int:
- return self._embed_size
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- if not isinstance(self.embed, TokenEmbedding):
- return self.embed.weight.requires_grad
- else:
- return self.embed.requires_grad
-
- @requires_grad.setter
- def requires_grad(self, value):
- if not isinstance(self.embed, TokenEmbedding):
- self.embed.weight.requires_grad = value
- else:
- self.embed.requires_grad = value
-
- @property
- def size(self):
- if isinstance(self.embed, TokenEmbedding):
- return self.embed.size
- else:
- return self.embed.weight.size()
-
-
-class TokenEmbedding(nn.Module):
- def __init__(self, vocab, word_dropout=0.0, dropout=0.0):
- super(TokenEmbedding, self).__init__()
- assert vocab.padding is not None, "Vocabulary must have a padding entry."
- self._word_vocab = vocab
- self._word_pad_index = vocab.padding_idx
- if word_dropout>0:
- assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word."
- self.word_dropout = word_dropout
- self._word_unk_index = vocab.unknown_idx
- self.dropout_layer = nn.Dropout(dropout)
-
- def drop_word(self, words):
- """
- 按照设定随机将words设置为unknown_index。
-
- :param torch.LongTensor words: batch_size x max_len
- :return:
- """
- if self.word_dropout > 0 and self.training:
- mask = torch.ones_like(words).float() * self.word_dropout
- mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1
- words = words.masked_fill(mask, self._word_unk_index)
- return words
-
- def dropout(self, words):
- """
- 对embedding后的word表示进行drop。
-
- :param torch.FloatTensor words: batch_size x max_len x embed_size
- :return:
- """
- return self.dropout_layer(words)
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- requires_grads = set([param.requires_grad for param in self.parameters()])
- if len(requires_grads) == 1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for param in self.parameters():
- param.requires_grad = value
-
- def __len__(self):
- return len(self._word_vocab)
-
- @property
- def embed_size(self) -> int:
- return self._embed_size
-
- @property
- def embedding_dim(self) -> int:
- return self._embed_size
-
- @property
- def num_embedding(self) -> int:
- """
- 这个值可能会大于实际的embedding矩阵的大小。
- :return:
- """
- return len(self._word_vocab)
-
- def get_word_vocab(self):
- """
- 返回embedding的词典。
-
- :return: Vocabulary
- """
- return self._word_vocab
-
- @property
- def size(self):
- return torch.Size(self.num_embedding, self._embed_size)
-
- @abstractmethod
- def forward(self, *input):
- raise NotImplementedError
-
-class StaticEmbedding(TokenEmbedding):
- """
- 别名::class:`fastNLP.modules.StaticEmbedding` :class:`fastNLP.modules.encoder.embedding.StaticEmbedding`
-
- StaticEmbedding组件. 给定embedding的名称,根据vocab从embedding中抽取相应的数据。该Embedding可以就按照正常的embedding使用了
-
- Example::
-
- >>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50')
-
-
- :param vocab: Vocabulary. 若该项为None则会读取所有的embedding。
- :param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding的文件名,第二种是传入embedding
- 的名称。目前支持的embedding包括{`en` 或者 `en-glove-840b-300` : glove.840B.300d, `en-glove-6b-50` : glove.6B.50d,
- `en-word2vec-300` : GoogleNews-vectors-negative300}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。
- :param bool requires_grad: 是否需要gradient. 默认为True
- :param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。
- :param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独
- 为大写的词语开辟一个vector表示,则将lower设置为False。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool normailize: 是否对vector进行normalize,使得每个vector的norm为1。
- """
- def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', requires_grad: bool=True, init_method=None,
- lower=False, dropout=0, word_dropout=0, normalize=False):
- super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- # 得到cache_path
- if model_dir_or_name.lower() in PRETRAIN_STATIC_FILES:
- PRETRAIN_URL = _get_base_url('static')
- model_name = PRETRAIN_STATIC_FILES[model_dir_or_name]
- model_url = PRETRAIN_URL + model_name
- model_path = cached_path(model_url)
- # 检查是否存在
- elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))):
- model_path = model_dir_or_name
- else:
- raise ValueError(f"Cannot recognize {model_dir_or_name}.")
-
- # 读取embedding
- if lower:
- lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown)
- for word, index in vocab:
- if not vocab._is_word_no_create_entry(word):
- lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的
- for word in vocab._no_create_word.keys(): # 不需要创建entry的
- if word in vocab:
- lowered_word = word.lower()
- if lowered_word not in lowered_vocab.word_count:
- lowered_vocab.add_word(lowered_word)
- lowered_vocab._no_create_word[lowered_word] += 1
- print(f"All word in vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered "
- f"words.")
- embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method,
- normalize=normalize)
- # 需要适配一下
- if not hasattr(self, 'words_to_words'):
- self.words_to_words = torch.arange(len(lowered_vocab, )).long()
- if lowered_vocab.unknown:
- unknown_idx = lowered_vocab.unknown_idx
- else:
- unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow
- words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
- requires_grad=False)
- for word, index in vocab:
- if word not in lowered_vocab:
- word = word.lower()
- if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了
- continue
- words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)]
- self.words_to_words = words_to_words
- else:
- embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method,
- normalize=normalize)
- self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1],
- padding_idx=vocab.padding_idx,
- max_norm=None, norm_type=2, scale_grad_by_freq=False,
- sparse=False, _weight=embedding)
- self._embed_size = self.embedding.weight.size(1)
- self.requires_grad = requires_grad
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- requires_grads = set([param.requires_grad for name, param in self.named_parameters()
- if 'words_to_words' not in name])
- if len(requires_grads) == 1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for name, param in self.named_parameters():
- if 'words_to_words' in name:
- continue
- param.requires_grad = value
-
- def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='', unknown='',
- normalize=True, error='ignore', init_method=None):
- """
- 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是
- word2vec(第一行只有两个元素)还是glove格式的数据。
-
- :param str embed_filepath: 预训练的embedding的路径。
- :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。
- 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。
- :param dtype: 读出的embedding的类型
- :param str padding: 词表中padding的token
- :param str unknown: 词表中unknown的token
- :param bool normalize: 是否将每个vector归一化到norm为1
- :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。
- 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。
- :param init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。默认使用torch.nn.init.zeros_
- :return torch.tensor: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。
- """
- assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported."
- if not os.path.exists(embed_filepath):
- raise FileNotFoundError("`{}` does not exist.".format(embed_filepath))
- with open(embed_filepath, 'r', encoding='utf-8') as f:
- line = f.readline().strip()
- parts = line.split()
- start_idx = 0
- if len(parts) == 2:
- dim = int(parts[1])
- start_idx += 1
- else:
- dim = len(parts) - 1
- f.seek(0)
- matrix = {}
- found_count = 0
- for idx, line in enumerate(f, start_idx):
- try:
- parts = line.strip().split()
- word = ''.join(parts[:-dim])
- nums = parts[-dim:]
- # 对齐unk与pad
- if word == padding and vocab.padding is not None:
- word = vocab.padding
- elif word == unknown and vocab.unknown is not None:
- word = vocab.unknown
- if word in vocab:
- index = vocab.to_index(word)
- matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim))
- found_count += 1
- except Exception as e:
- if error == 'ignore':
- warnings.warn("Error occurred at the {} line.".format(idx))
- else:
- print("Error occurred at the {} line.".format(idx))
- raise e
- print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab)))
- for word, index in vocab:
- if index not in matrix and not vocab._is_word_no_create_entry(word):
- if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化
- matrix[index] = matrix[vocab.unknown_idx]
- else:
- matrix[index] = None
-
- vectors = torch.zeros(len(matrix), dim)
- if init_method:
- init_method(vectors)
- else:
- nn.init.uniform_(vectors, -np.sqrt(3/dim), np.sqrt(3/dim))
-
- if vocab._no_create_word_length>0:
- if vocab.unknown is None: # 创建一个专门的unknown
- unknown_idx = len(matrix)
- vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous()
- else:
- unknown_idx = vocab.unknown_idx
- words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(),
- requires_grad=False)
- for order, (index, vec) in enumerate(matrix.items()):
- if vec is not None:
- vectors[order] = vec
- words_to_words[index] = order
- self.words_to_words = words_to_words
- else:
- for index, vec in matrix.items():
- if vec is not None:
- vectors[index] = vec
-
- if normalize:
- vectors /= (torch.norm(vectors, dim=1, keepdim=True) + 1e-12)
-
- return vectors
-
- def forward(self, words):
- """
- 传入words的index
-
- :param words: torch.LongTensor, [batch_size, max_len]
- :return: torch.FloatTensor, [batch_size, max_len, embed_size]
- """
- if hasattr(self, 'words_to_words'):
- words = self.words_to_words[words]
- words = self.drop_word(words)
- words = self.embedding(words)
- words = self.dropout(words)
- return words
-
-
-class ContextualEmbedding(TokenEmbedding):
- def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0):
- super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True):
- """
- 由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。
-
- :param datasets: DataSet对象
- :param batch_size: int, 生成cache的sentence表示时使用的batch的大小
- :param device: 参考 :class::fastNLP.Trainer 的device
- :param delete_weights: 似乎在生成了cache之后删除权重,在不需要finetune动态模型的情况下,删除权重会大量减少内存占用。
- :return:
- """
- for index, dataset in enumerate(datasets):
- try:
- assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed."
- assert 'words' in dataset.get_input_name(), "`words` field has to be set as input."
- except Exception as e:
- print(f"Exception happens at {index} dataset.")
- raise e
-
- sent_embeds = {}
- _move_model_to_device(self, device=device)
- device = _get_model_device(self)
- pad_index = self._word_vocab.padding_idx
- print("Start to calculate sentence representations.")
- with torch.no_grad():
- for index, dataset in enumerate(datasets):
- try:
- batch = DataSetIter(dataset, batch_size=batch_size, sampler=SequentialSampler())
- for batch_x, batch_y in batch:
- words = batch_x['words'].to(device)
- words_list = words.tolist()
- seq_len = words.ne(pad_index).sum(dim=-1)
- max_len = words.size(1)
- # 因为有些情况可能包含CLS, SEP, 从后面往前计算比较安全。
- seq_len_from_behind = (max_len - seq_len).tolist()
- word_embeds = self(words).detach().cpu().numpy()
- for b in range(words.size(0)):
- length = seq_len_from_behind[b]
- if length==0:
- sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b]
- else:
- sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length]
- except Exception as e:
- print(f"Exception happens at {index} dataset.")
- raise e
- print("Finish calculating sentence representations.")
- self.sent_embeds = sent_embeds
- if delete_weights:
- self._delete_model_weights()
-
- def _get_sent_reprs(self, words):
- """
- 获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None
-
- :param words: torch.LongTensor
- :return:
- """
- if hasattr(self, 'sent_embeds'):
- words_list = words.tolist()
- seq_len = words.ne(self._word_pad_index).sum(dim=-1)
- _embeds = []
- for b in range(len(words)):
- words_i = tuple(words_list[b][:seq_len[b]])
- embed = self.sent_embeds[words_i]
- _embeds.append(embed)
- max_sent_len = max(map(len, _embeds))
- embeds = words.new_zeros(len(_embeds), max_sent_len, self.embed_size, dtype=torch.float,
- device=words.device)
- for i, embed in enumerate(_embeds):
- embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device)
- return embeds
- return None
-
- @abstractmethod
- def _delete_model_weights(self):
- """删除计算表示的模型以节省资源"""
- raise NotImplementedError
-
- def remove_sentence_cache(self):
- """
- 删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。
-
- :return:
- """
- del self.sent_embeds
-
-
-class ElmoEmbedding(ContextualEmbedding):
- """
- 别名::class:`fastNLP.modules.ElmoEmbedding` :class:`fastNLP.modules.encoder.embedding.ElmoEmbedding`
-
- 使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。
- 我们提供的ELMo预训练模型来自 https://github.com/HIT-SCIR/ELMoForManyLangs
-
- Example::
-
- >>> embedding = ElmoEmbedding(vocab, model_dir_or_name='en', layers='2', requires_grad=True)
-
- :param vocab: 词表
- :param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo权重的文件名,第二种是传入ELMo版本的名称,
- 目前支持的ELMo包括{`en` : 英文版本的ELMo, `cn` : 中文版本的ELMo,}。第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载
- :param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果
- 按照这个顺序concat起来。默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致,
- 初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。)
- :param requires_grad: bool, 该层是否需要gradient, 默认为False.
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding,
- 并删除character encoder,之后将直接使用cache的embedding。默认为False。
- """
- def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', layers: str='2', requires_grad: bool=False,
- word_dropout=0.0, dropout=0.0, cache_word_reprs: bool=False):
- super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- # 根据model_dir_or_name检查是否存在并下载
- if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR:
- PRETRAIN_URL = _get_base_url('elmo')
- model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name]
- model_url = PRETRAIN_URL + model_name
- model_dir = cached_path(model_url)
- # 检查是否存在
- elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
- model_dir = model_dir_or_name
- else:
- raise ValueError(f"Cannot recognize {model_dir_or_name}.")
- self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs)
-
- if layers=='mix':
- self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers']+1),
- requires_grad=requires_grad)
- self.gamma = nn.Parameter(torch.ones(1), requires_grad=requires_grad)
- self._get_outputs = self._get_mixed_outputs
- self._embed_size = self.model.config['lstm']['projection_dim'] * 2
- else:
- layers = list(map(int, layers.split(',')))
- assert len(layers) > 0, "Must choose one output"
- for layer in layers:
- assert 0 <= layer <= 2, "Layer index should be in range [0, 2]."
- self.layers = layers
- self._get_outputs = self._get_layer_outputs
- self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2
-
- self.requires_grad = requires_grad
-
- def _get_mixed_outputs(self, outputs):
- # outputs: num_layers x batch_size x max_len x hidden_size
- # return: batch_size x max_len x hidden_size
- weights = F.softmax(self.layer_weights+1/len(outputs), dim=0).to(outputs)
- outputs = torch.einsum('l,lbij->bij', weights, outputs)
- return self.gamma.to(outputs)*outputs
-
- def set_mix_weights_requires_grad(self, flag=True):
- """
- 当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用
- 该方法没有用。
- :param bool flag: 混合不同层表示的结果是否可以训练。
- :return:
- """
- if hasattr(self, 'layer_weights'):
- self.layer_weights.requires_grad = flag
- self.gamma.requires_grad = flag
-
- def _get_layer_outputs(self, outputs):
- if len(self.layers) == 1:
- outputs = outputs[self.layers[0]]
- else:
- outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1)
-
- return outputs
-
- def forward(self, words: torch.LongTensor):
- """
- 计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的
- 被重复了一次,使得实际上layer=0的结果是[token_embedding;token_embedding], 而layer=1的结果是[forward_hiddens;
- backward_hiddens].
-
- :param words: batch_size x max_len
- :return: torch.FloatTensor. batch_size x max_len x (512*len(self.layers))
- """
- words = self.drop_word(words)
- outputs = self._get_sent_reprs(words)
- if outputs is not None:
- return self.dropout(outputs)
- outputs = self.model(words)
- outputs = self._get_outputs(outputs)
- return self.dropout(outputs)
-
- def _delete_model_weights(self):
- for name in ['layers', 'model', 'layer_weights', 'gamma']:
- if hasattr(self, name):
- delattr(self, name)
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
-
- :return:
- """
- requires_grads = set([param.requires_grad for name, param in self.named_parameters()
- if 'words_to_chars_embedding' not in name and 'words_to_words' not in name])
- if len(requires_grads) == 1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for name, param in self.named_parameters():
- if 'words_to_chars_embedding' in name or 'words_to_words' in name: # 这个不能加入到requires_grad中
- continue
- param.requires_grad = value
-
-
-class BertEmbedding(ContextualEmbedding):
- """
- 别名::class:`fastNLP.modules.BertEmbedding` :class:`fastNLP.modules.encoder.embedding.BertEmbedding`
-
- 使用BERT对words进行encode的Embedding。建议将输入的words长度限制在450以内,而不要使用512。这是由于预训练的bert模型长
- 度限制为512个token,而因为输入的word是未进行word piece分割的,在分割之后长度可能会超过最大长度限制。
-
- Example::
-
- >>> embedding = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1')
-
-
- :param fastNLP.Vocabulary vocab: 词表
- :param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased``.
- :param str layers:最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层
- :param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces
- 中计算得到它对应的表示。支持``last``, ``first``, ``avg``, ``max``。
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
- :param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样
- 会使得word embedding的结果比输入的结果长两个token。在使用 :class::StackEmbedding 可能会遇到问题。
- :param bool requires_grad: 是否需要gradient。
- """
- def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1',
- pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False,
- include_cls_sep: bool=False):
- super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- # 根据model_dir_or_name检查是否存在并下载
- if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR:
- PRETRAIN_URL = _get_base_url('bert')
- model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name]
- model_url = PRETRAIN_URL + model_name
- model_dir = cached_path(model_url)
- # 检查是否存在
- elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))):
- model_dir = model_dir_or_name
- else:
- raise ValueError(f"Cannot recognize {model_dir_or_name}.")
-
- self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers,
- pool_method=pool_method, include_cls_sep=include_cls_sep)
-
- self.requires_grad = requires_grad
- self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size
-
- def _delete_model_weights(self):
- del self.model
-
- def forward(self, words):
- """
- 计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要
- 删除这两个token的表示。
-
- :param torch.LongTensor words: [batch_size, max_len]
- :return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers))
- """
- words = self.drop_word(words)
- outputs = self._get_sent_reprs(words)
- if outputs is not None:
- return self.dropout(words)
- outputs = self.model(words)
- outputs = torch.cat([*outputs], dim=-1)
-
- return self.dropout(words)
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- requires_grads = set([param.requires_grad for name, param in self.named_parameters()
- if 'word_pieces_lengths' not in name])
- if len(requires_grads) == 1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for name, param in self.named_parameters():
- if 'word_pieces_lengths' in name: # 这个不能加入到requires_grad中
- continue
- param.requires_grad = value
-
-
-def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1):
- """
- 给定一个word的vocabulary生成character的vocabulary.
-
- :param vocab: 从vocab
- :param min_freq:
- :return:
- """
- char_vocab = Vocabulary(min_freq=min_freq)
- for word, index in vocab:
- if not vocab._is_word_no_create_entry(word):
- char_vocab.add_word_lst(list(word))
- return char_vocab
-
-
-class CNNCharEmbedding(TokenEmbedding):
- """
- 别名::class:`fastNLP.modules.CNNCharEmbedding` :class:`fastNLP.modules.encoder.embedding.CNNCharEmbedding`
-
- 使用CNN生成character embedding。CNN的结果为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout.
- 不同的kernel大小的fitler结果是concat起来的。
-
- Example::
-
- >>> cnn_char_embed = CNNCharEmbedding(vocab)
-
-
- :param vocab: 词表
- :param embed_size: 该word embedding的大小,默认值为50.
- :param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50.
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param float dropout: 以多大的概率drop
- :param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20].
- :param kernel_sizes: kernel的大小. 默认值为[5, 3, 1].
- :param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'.
- :param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数.
- :param min_char_freq: character的最少出现次数。默认值为2.
- """
- def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
- dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1),
- pool_method: str='max', activation='relu', min_char_freq: int=2):
- super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout)
-
- for kernel in kernel_sizes:
- assert kernel % 2 == 1, "Only odd kernel is allowed."
-
- assert pool_method in ('max', 'avg')
- self.dropout = nn.Dropout(dropout)
- self.pool_method = pool_method
- # activation function
- if isinstance(activation, str):
- if activation.lower() == 'relu':
- self.activation = F.relu
- elif activation.lower() == 'sigmoid':
- self.activation = F.sigmoid
- elif activation.lower() == 'tanh':
- self.activation = F.tanh
- elif activation is None:
- self.activation = lambda x: x
- elif callable(activation):
- self.activation = activation
- else:
- raise Exception(
- "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
-
- print("Start constructing character vocabulary.")
- # 建立char的词表
- self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
- self.char_pad_index = self.char_vocab.padding_idx
- print(f"In total, there are {len(self.char_vocab)} distinct characters.")
- # 对vocab进行index
- max_word_len = max(map(lambda x: len(x[0]), vocab))
- self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len),
- fill_value=self.char_pad_index, dtype=torch.long),
- requires_grad=False)
- self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
- for word, index in vocab:
- # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的也是同一个embed
- self.words_to_chars_embedding[index, :len(word)] = \
- torch.LongTensor([self.char_vocab.to_index(c) for c in word])
- self.word_lengths[index] = len(word)
- self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
-
- self.convs = nn.ModuleList([nn.Conv1d(
- char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2)
- for i in range(len(kernel_sizes))])
- self._embed_size = embed_size
- self.fc = nn.Linear(sum(filter_nums), embed_size)
- self.init_param()
-
- def forward(self, words):
- """
- 输入words的index后,生成对应的words的表示。
-
- :param words: [batch_size, max_len]
- :return: [batch_size, max_len, embed_size]
- """
- words = self.drop_word(words)
- batch_size, max_len = words.size()
- chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
- word_lengths = self.word_lengths[words] # batch_size x max_len
- max_word_len = word_lengths.max()
- chars = chars[:, :, :max_word_len]
- # 为1的地方为mask
- chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
- chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
- chars = self.dropout(chars)
- reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1)
- reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M
- conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1)
- for conv in self.convs]
- conv_chars = torch.cat(conv_chars, dim=-1).contiguous() # B x max_len x max_word_len x sum(filters)
- conv_chars = self.activation(conv_chars)
- if self.pool_method == 'max':
- conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
- chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters)
- else:
- conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
- chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
- chars = self.fc(chars)
- return self.dropout(chars)
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- params = []
- for name, param in self.named_parameters():
- if 'words_to_chars_embedding' not in name and 'word_lengths' not in name:
- params.append(param.requires_grad)
- requires_grads = set(params)
- if len(requires_grads) == 1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for name, param in self.named_parameters():
- if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
- continue
- param.requires_grad = value
-
- def init_param(self):
- for name, param in self.named_parameters():
- if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset
- continue
- if param.data.dim()>1:
- nn.init.xavier_uniform_(param, 1)
- else:
- nn.init.uniform_(param, -1, 1)
-
-class LSTMCharEmbedding(TokenEmbedding):
- """
- 别名::class:`fastNLP.modules.LSTMCharEmbedding` :class:`fastNLP.modules.encoder.embedding.LSTMCharEmbedding`
-
- 使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool
-
- Example::
-
- >>> lstm_char_embed = LSTMCharEmbedding(vocab)
-
- :param vocab: 词表
- :param embed_size: embedding的大小。默认值为50.
- :param char_emb_size: character的embedding的大小。默认值为50.
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。
- :param dropout: 以多大概率drop
- :param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50.
- :param pool_method: 支持'max', 'avg'
- :param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数.
- :param min_char_freq: character的最小出现次数。默认值为2.
- :param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。
- """
- def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0,
- dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2,
- bidirectional=True):
- super(LSTMCharEmbedding, self).__init__(vocab)
-
- assert hidden_size % 2 == 0, "Only even kernel is allowed."
-
- assert pool_method in ('max', 'avg')
- self.pool_method = pool_method
- self.dropout = nn.Dropout(dropout)
- # activation function
- if isinstance(activation, str):
- if activation.lower() == 'relu':
- self.activation = F.relu
- elif activation.lower() == 'sigmoid':
- self.activation = F.sigmoid
- elif activation.lower() == 'tanh':
- self.activation = F.tanh
- elif activation is None:
- self.activation = lambda x: x
- elif callable(activation):
- self.activation = activation
- else:
- raise Exception(
- "Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]")
-
- print("Start constructing character vocabulary.")
- # 建立char的词表
- self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq)
- self.char_pad_index = self.char_vocab.padding_idx
- print(f"In total, there are {len(self.char_vocab)} distinct characters.")
- # 对vocab进行index
- self.max_word_len = max(map(lambda x: len(x[0]), vocab))
- self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len),
- fill_value=self.char_pad_index, dtype=torch.long),
- requires_grad=False)
- self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False)
- for word, index in vocab:
- # if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否
- self.words_to_chars_embedding[index, :len(word)] = \
- torch.LongTensor([self.char_vocab.to_index(c) for c in word])
- self.word_lengths[index] = len(word)
- self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size)
-
- self.fc = nn.Linear(hidden_size, embed_size)
- hidden_size = hidden_size // 2 if bidirectional else hidden_size
-
- self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True)
- self._embed_size = embed_size
- self.bidirectional = bidirectional
-
- def forward(self, words):
- """
- 输入words的index后,生成对应的words的表示。
-
- :param words: [batch_size, max_len]
- :return: [batch_size, max_len, embed_size]
- """
- words = self.drop_word(words)
- batch_size, max_len = words.size()
- chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len
- word_lengths = self.word_lengths[words] # batch_size x max_len
- max_word_len = word_lengths.max()
- chars = chars[:, :, :max_word_len]
- # 为mask的地方为1
- chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了
- chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size
- chars = self.dropout(chars)
- reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1)
- char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len)
- lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1)
- # B x M x M x H
-
- lstm_chars = self.activation(lstm_chars)
- if self.pool_method == 'max':
- lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf'))
- chars, _ = torch.max(lstm_chars, dim=-2) # batch_size x max_len x H
- else:
- lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0)
- chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float()
-
- chars = self.fc(chars)
-
- return self.dropout(chars)
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- params = []
- for name, param in self.named_parameters():
- if 'words_to_chars_embedding' not in name and 'word_lengths' not in name:
- params.append(param)
- requires_grads = set(params)
- if len(requires_grads) == 1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for name, param in self.named_parameters():
- if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中
- continue
- param.requires_grad = value
-
-
-class StackEmbedding(TokenEmbedding):
- """
- 别名::class:`fastNLP.modules.StackEmbedding` :class:`fastNLP.modules.encoder.embedding.StackEmbedding`
-
- 支持将多个embedding集合成一个embedding。
-
- Example::
-
- >>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
- >>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True)
-
-
- :param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致
- :param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。不同embedidng会在相同的位置
- 被设置为unknown。如果这里设置了dropout,则组成的embedding就不要再设置dropout了。
- :param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。
-
- """
- def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0):
- vocabs = []
- for embed in embeds:
- if hasattr(embed, 'get_word_vocab'):
- vocabs.append(embed.get_word_vocab())
- _vocab = vocabs[0]
- for vocab in vocabs[1:]:
- assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary."
-
- super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout)
- assert isinstance(embeds, list)
- for embed in embeds:
- assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported."
- self.embeds = nn.ModuleList(embeds)
- self._embed_size = sum([embed.embed_size for embed in self.embeds])
-
- def append(self, embed: TokenEmbedding):
- """
- 添加一个embedding到结尾。
- :param embed:
- :return:
- """
- assert isinstance(embed, TokenEmbedding)
- self.embeds.append(embed)
-
- def pop(self):
- """
- 弹出最后一个embed
- :return:
- """
- return self.embeds.pop()
-
- @property
- def embed_size(self):
- return self._embed_size
-
- @property
- def requires_grad(self):
- """
- Embedding的参数是否允许优化。True: 所有参数运行优化; False: 所有参数不允许优化; None: 部分允许优化、部分不允许
- :return:
- """
- requires_grads = set([embed.requires_grad for embed in self.embeds()])
- if len(requires_grads)==1:
- return requires_grads.pop()
- else:
- return None
-
- @requires_grad.setter
- def requires_grad(self, value):
- for embed in self.embeds():
- embed.requires_grad = value
-
- def forward(self, words):
- """
- 得到多个embedding的结果,并把结果按照顺序concat起来。
-
- :param words: batch_size x max_len
- :return: 返回的shape和当前这个stack embedding中embedding的组成有关
- """
- outputs = []
- words = self.drop_word(words)
- for embed in self.embeds:
- outputs.append(embed(words))
- outputs = self.dropout(torch.cat(outputs, dim=-1))
- return outputs
-
diff --git a/fastNLP/modules/encoder/lstm.py b/fastNLP/modules/encoder/lstm.py
index 5e599a65..1f3eae6d 100644
--- a/fastNLP/modules/encoder/lstm.py
+++ b/fastNLP/modules/encoder/lstm.py
@@ -1,7 +1,8 @@
-"""
+"""undocumented
轻量封装的 Pytorch LSTM 模块.
可在 forward 时传入序列的长度, 自动对padding做合适的处理.
"""
+
__all__ = [
"LSTM"
]
@@ -10,13 +11,10 @@ import torch
import torch.nn as nn
import torch.nn.utils.rnn as rnn
-from ..utils import initial_parameter
-from torch import autograd
-
class LSTM(nn.Module):
"""
- 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.lstm.LSTM`
+ 别名::class:`fastNLP.modules.LSTM` :class:`fastNLP.modules.encoder.LSTM`
LSTM 模块, 轻量封装的Pytorch LSTM. 在提供seq_len的情况下,将自动使用pack_padded_sequence; 同时默认将forget gate的bias初始化
为1; 且可以应对DataParallel中LSTM的使用问题。
@@ -30,7 +28,7 @@ class LSTM(nn.Module):
:(batch, seq, feature). Default: ``False``
:param bias: 如果为 ``False``, 模型将不会使用bias. Default: ``True``
"""
-
+
def __init__(self, input_size, hidden_size=100, num_layers=1, dropout=0.0, batch_first=True,
bidirectional=False, bias=True):
super(LSTM, self).__init__()
diff --git a/fastNLP/modules/encoder/pooling.py b/fastNLP/modules/encoder/pooling.py
index 8337fe32..b1272284 100644
--- a/fastNLP/modules/encoder/pooling.py
+++ b/fastNLP/modules/encoder/pooling.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"MaxPool",
"MaxPoolWithMask",
@@ -10,7 +12,7 @@ import torch.nn as nn
class MaxPool(nn.Module):
"""
- 别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.pooling.MaxPool`
+ 别名::class:`fastNLP.modules.MaxPool` :class:`fastNLP.modules.encoder.MaxPool`
Max-pooling模块。
@@ -21,9 +23,9 @@ class MaxPool(nn.Module):
:param kernel_size: max pooling的窗口大小,默认为tensor最后k维,其中k为dimension
:param ceil_mode:
"""
-
+
def __init__(self, stride=None, padding=0, dilation=1, dimension=1, kernel_size=None, ceil_mode=False):
-
+
super(MaxPool, self).__init__()
assert (1 <= dimension) and (dimension <= 3)
self.dimension = dimension
@@ -32,7 +34,7 @@ class MaxPool(nn.Module):
self.dilation = dilation
self.kernel_size = kernel_size
self.ceil_mode = ceil_mode
-
+
def forward(self, x):
if self.dimension == 1:
pooling = nn.MaxPool1d(
@@ -59,15 +61,15 @@ class MaxPool(nn.Module):
class MaxPoolWithMask(nn.Module):
"""
- 别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.pooling.MaxPoolWithMask`
+ 别名::class:`fastNLP.modules.MaxPoolWithMask` :class:`fastNLP.modules.encoder.MaxPoolWithMask`
带mask矩阵的max pooling。在做max-pooling的时候不会考虑mask值为0的位置。
"""
-
+
def __init__(self):
super(MaxPoolWithMask, self).__init__()
self.inf = 10e12
-
+
def forward(self, tensor, mask, dim=1):
"""
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor
@@ -82,11 +84,11 @@ class MaxPoolWithMask(nn.Module):
class KMaxPool(nn.Module):
"""K max-pooling module."""
-
+
def __init__(self, k=1):
super(KMaxPool, self).__init__()
self.k = k
-
+
def forward(self, x):
"""
:param torch.Tensor x: [N, C, L] 初始tensor
@@ -99,16 +101,16 @@ class KMaxPool(nn.Module):
class AvgPool(nn.Module):
"""
- 别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.pooling.AvgPool`
+ 别名::class:`fastNLP.modules.AvgPool` :class:`fastNLP.modules.encoder.AvgPool`
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size]
"""
-
+
def __init__(self, stride=None, padding=0):
super(AvgPool, self).__init__()
self.stride = stride
self.padding = padding
-
+
def forward(self, x):
"""
:param torch.Tensor x: [N, C, L] 初始tensor
@@ -126,16 +128,16 @@ class AvgPool(nn.Module):
class AvgPoolWithMask(nn.Module):
"""
- 别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.pooling.AvgPoolWithMask`
+ 别名::class:`fastNLP.modules.AvgPoolWithMask` :class:`fastNLP.modules.encoder.AvgPoolWithMask`
给定形如[batch_size, max_len, hidden_size]的输入,在最后一维进行avg pooling. 输出为[batch_size, hidden_size], pooling
的时候只会考虑mask为1的位置
"""
-
+
def __init__(self):
super(AvgPoolWithMask, self).__init__()
self.inf = 10e12
-
+
def forward(self, tensor, mask, dim=1):
"""
:param torch.FloatTensor tensor: [batch_size, seq_len, channels] 初始tensor
diff --git a/fastNLP/modules/encoder/star_transformer.py b/fastNLP/modules/encoder/star_transformer.py
index 097fbebb..02d7a6a0 100644
--- a/fastNLP/modules/encoder/star_transformer.py
+++ b/fastNLP/modules/encoder/star_transformer.py
@@ -1,6 +1,7 @@
-"""
+"""undocumented
Star-Transformer 的encoder部分的 Pytorch 实现
"""
+
__all__ = [
"StarTransformer"
]
@@ -13,7 +14,7 @@ from torch.nn import functional as F
class StarTransformer(nn.Module):
"""
- 别名::class:`fastNLP.modules.StarTransformer` :class:`fastNLP.modules.encoder.star_transformer.StarTransformer`
+ 别名::class:`fastNLP.modules.StarTransformer` :class:`fastNLP.modules.encoder.StarTransformer`
Star-Transformer 的encoder部分。 输入3d的文本输入, 返回相同长度的文本编码
@@ -29,11 +30,11 @@ class StarTransformer(nn.Module):
模型会为输入序列加上position embedding。
若为`None`,忽略加上position embedding的步骤. Default: `None`
"""
-
+
def __init__(self, hidden_size, num_layers, num_head, head_dim, dropout=0.1, max_len=None):
super(StarTransformer, self).__init__()
self.iters = num_layers
-
+
self.norm = nn.ModuleList([nn.LayerNorm(hidden_size, eps=1e-6) for _ in range(self.iters)])
# self.emb_fc = nn.Conv2d(hidden_size, hidden_size, 1)
self.emb_drop = nn.Dropout(dropout)
@@ -43,12 +44,12 @@ class StarTransformer(nn.Module):
self.star_att = nn.ModuleList(
[_MSA2(hidden_size, nhead=num_head, head_dim=head_dim, dropout=0.0)
for _ in range(self.iters)])
-
+
if max_len is not None:
self.pos_emb = nn.Embedding(max_len, hidden_size)
else:
self.pos_emb = None
-
+
def forward(self, data, mask):
"""
:param FloatTensor data: [batch, length, hidden] 输入的序列
@@ -58,15 +59,15 @@ class StarTransformer(nn.Module):
[batch, hidden] 全局 relay 节点, 详见论文
"""
-
+
def norm_func(f, x):
# B, H, L, 1
return f(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
-
+
B, L, H = data.size()
mask = (mask == 0) # flip the mask for masked_fill_
smask = torch.cat([torch.zeros(B, 1, ).byte().to(mask), mask], 1)
-
+
embs = data.permute(0, 2, 1)[:, :, :, None] # B H L 1
if self.pos_emb and False:
P = self.pos_emb(torch.arange(L, dtype=torch.long, device=embs.device) \
@@ -80,13 +81,13 @@ class StarTransformer(nn.Module):
for i in range(self.iters):
ax = torch.cat([r_embs, relay.expand(B, H, 1, L)], 2)
nodes = F.leaky_relu(self.ring_att[i](norm_func(self.norm[i], nodes), ax=ax))
- #nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
+ # nodes = F.leaky_relu(self.ring_att[i](nodes, ax=ax))
relay = F.leaky_relu(self.star_att[i](relay, torch.cat([relay, nodes], 2), smask))
-
+
nodes = nodes.masked_fill_(ex_mask, 0)
-
+
nodes = nodes.view(B, H, L).permute(0, 2, 1)
-
+
return nodes, relay.view(B, H)
@@ -99,19 +100,19 @@ class _MSA1(nn.Module):
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
-
+
self.drop = nn.Dropout(dropout)
-
+
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
-
+
def forward(self, x, ax=None):
# x: B, H, L, 1, ax : B, H, X, L append features
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = x.shape
-
+
q, k, v = self.WQ(x), self.WK(x), self.WV(x) # x: (B,H,L,1)
-
+
if ax is not None:
aL = ax.shape[2]
ak = self.WK(ax).view(B, nhead, head_dim, aL, L)
@@ -124,12 +125,12 @@ class _MSA1(nn.Module):
if ax is not None:
k = torch.cat([k, ak], 3)
v = torch.cat([v, av], 3)
-
+
alphas = self.drop(F.softmax((q * k).sum(2, keepdim=True) / NP.sqrt(head_dim), 3)) # B N L 1 U
att = (alphas * v).sum(3).view(B, nhead * head_dim, L, 1)
-
+
ret = self.WO(att)
-
+
return ret
@@ -141,19 +142,19 @@ class _MSA2(nn.Module):
self.WK = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WV = nn.Conv2d(nhid, nhead * head_dim, 1)
self.WO = nn.Conv2d(nhead * head_dim, nhid, 1)
-
+
self.drop = nn.Dropout(dropout)
-
+
# print('NUM_HEAD', nhead, 'DIM_HEAD', head_dim)
self.nhid, self.nhead, self.head_dim, self.unfold_size = nhid, nhead, head_dim, 3
-
+
def forward(self, x, y, mask=None):
# x: B, H, 1, 1, 1 y: B H L 1
nhid, nhead, head_dim, unfold_size = self.nhid, self.nhead, self.head_dim, self.unfold_size
B, H, L, _ = y.shape
-
+
q, k, v = self.WQ(x), self.WK(y), self.WV(y)
-
+
q = q.view(B, nhead, 1, head_dim) # B, H, 1, 1 -> B, N, 1, h
k = k.view(B, nhead, head_dim, L) # B, H, L, 1 -> B, N, h, L
v = v.view(B, nhead, head_dim, L).permute(0, 1, 3, 2) # B, H, L, 1 -> B, N, L, h
diff --git a/fastNLP/modules/encoder/transformer.py b/fastNLP/modules/encoder/transformer.py
index d6bf2f1e..ce9172d5 100644
--- a/fastNLP/modules/encoder/transformer.py
+++ b/fastNLP/modules/encoder/transformer.py
@@ -1,3 +1,5 @@
+"""undocumented"""
+
__all__ = [
"TransformerEncoder"
]
@@ -9,7 +11,7 @@ from ..dropout import TimestepDropout
class TransformerEncoder(nn.Module):
"""
- 别名::class:`fastNLP.modules.TransformerEncoder` :class:`fastNLP.modules.encoder.transformer.TransformerEncoder`
+ 别名::class:`fastNLP.modules.TransformerEncoder` :class:`fastNLP.modules.encoder.TransformerEncoder`
transformer的encoder模块,不包含embedding层
@@ -22,7 +24,7 @@ class TransformerEncoder(nn.Module):
:param int num_head: head的数量。
:param float dropout: dropout概率. Default: 0.1
"""
-
+
class SubLayer(nn.Module):
def __init__(self, model_size, inner_size, key_size, value_size, num_head, dropout=0.1):
super(TransformerEncoder.SubLayer, self).__init__()
@@ -33,7 +35,7 @@ class TransformerEncoder(nn.Module):
nn.Linear(inner_size, model_size),
TimestepDropout(dropout), )
self.norm2 = nn.LayerNorm(model_size)
-
+
def forward(self, input, seq_mask=None, atte_mask_out=None):
"""
@@ -48,11 +50,11 @@ class TransformerEncoder(nn.Module):
output = self.norm2(output + norm_atte)
output *= seq_mask
return output
-
+
def __init__(self, num_layers, **kargs):
super(TransformerEncoder, self).__init__()
self.layers = nn.ModuleList([self.SubLayer(**kargs) for _ in range(num_layers)])
-
+
def forward(self, x, seq_mask=None):
"""
:param x: [batch, seq_len, model_size] 输入序列
diff --git a/fastNLP/modules/encoder/variational_rnn.py b/fastNLP/modules/encoder/variational_rnn.py
index 29b728e5..933555c8 100644
--- a/fastNLP/modules/encoder/variational_rnn.py
+++ b/fastNLP/modules/encoder/variational_rnn.py
@@ -1,6 +1,7 @@
-"""
+"""undocumented
Variational RNN 的 Pytorch 实现
"""
+
__all__ = [
"VarRNN",
"VarLSTM",
@@ -28,14 +29,14 @@ class VarRnnCellWrapper(nn.Module):
"""
Wrapper for normal RNN Cells, make it support variational dropout
"""
-
+
def __init__(self, cell, hidden_size, input_p, hidden_p):
super(VarRnnCellWrapper, self).__init__()
self.cell = cell
self.hidden_size = hidden_size
self.input_p = input_p
self.hidden_p = hidden_p
-
+
def forward(self, input_x, hidden, mask_x, mask_h, is_reversed=False):
"""
:param PackedSequence input_x: [seq_len, batch_size, input_size]
@@ -47,13 +48,13 @@ class VarRnnCellWrapper(nn.Module):
hidden: for LSTM, tuple of (h_n, c_n), [batch_size, hidden_size]
for other RNN, h_n, [batch_size, hidden_size]
"""
-
+
def get_hi(hi, h0, size):
h0_size = size - hi.size(0)
if h0_size > 0:
return torch.cat([hi, h0[:h0_size]], dim=0)
return hi[:size]
-
+
is_lstm = isinstance(hidden, tuple)
input, batch_sizes = input_x.data, input_x.batch_sizes
output = []
@@ -64,7 +65,7 @@ class VarRnnCellWrapper(nn.Module):
else:
batch_iter = batch_sizes
idx = 0
-
+
if is_lstm:
hn = (hidden[0].clone(), hidden[1].clone())
else:
@@ -91,7 +92,7 @@ class VarRnnCellWrapper(nn.Module):
hi = cell(input_i, hi)
hn[:size] = hi
output.append(hi)
-
+
if is_reversed:
output = list(reversed(output))
output = torch.cat(output, dim=0)
@@ -117,7 +118,7 @@ class VarRNNBase(nn.Module):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
-
+
def __init__(self, mode, Cell, input_size, hidden_size, num_layers=1,
bias=True, batch_first=False,
input_dropout=0, hidden_dropout=0, bidirectional=False):
@@ -141,7 +142,7 @@ class VarRNNBase(nn.Module):
cell, self.hidden_size, input_dropout, hidden_dropout))
initial_parameter(self)
self.is_lstm = (self.mode == "LSTM")
-
+
def _forward_one(self, n_layer, n_direction, input, hx, mask_x, mask_h):
is_lstm = self.is_lstm
idx = self.num_directions * n_layer + n_direction
@@ -150,7 +151,7 @@ class VarRNNBase(nn.Module):
output_x, hidden_x = cell(
input, hi, mask_x, mask_h, is_reversed=(n_direction == 1))
return output_x, hidden_x
-
+
def forward(self, x, hx=None):
"""
@@ -170,13 +171,13 @@ class VarRNNBase(nn.Module):
else:
max_batch_size = int(x.batch_sizes[0])
x, batch_sizes = x.data, x.batch_sizes
-
+
if hx is None:
hx = x.new_zeros(self.num_layers * self.num_directions,
max_batch_size, self.hidden_size, requires_grad=True)
if is_lstm:
hx = (hx, hx.new_zeros(hx.size(), requires_grad=True))
-
+
mask_x = x.new_ones((max_batch_size, self.input_size))
mask_out = x.new_ones(
(max_batch_size, self.hidden_size * self.num_directions))
@@ -185,7 +186,7 @@ class VarRNNBase(nn.Module):
training=self.training, inplace=True)
nn.functional.dropout(mask_out, p=self.hidden_dropout,
training=self.training, inplace=True)
-
+
hidden = x.new_zeros(
(self.num_layers * self.num_directions, max_batch_size, self.hidden_size))
if is_lstm:
@@ -207,22 +208,22 @@ class VarRNNBase(nn.Module):
else:
hidden[idx] = hidden_x
x = torch.cat(output_list, dim=-1)
-
+
if is_lstm:
hidden = (hidden, cellstate)
-
+
if is_packed:
output = PackedSequence(x, batch_sizes)
else:
x = PackedSequence(x, batch_sizes)
output, _ = pad_packed_sequence(x, batch_first=self.batch_first)
-
+
return output, hidden
class VarLSTM(VarRNNBase):
"""
- 别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.variational_rnn.VarLSTM`
+ 别名::class:`fastNLP.modules.VarLSTM` :class:`fastNLP.modules.encoder.VarLSTM`
Variational Dropout LSTM.
@@ -236,18 +237,18 @@ class VarLSTM(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的LSTM. Default: ``False``
"""
-
+
def __init__(self, *args, **kwargs):
super(VarLSTM, self).__init__(
mode="LSTM", Cell=nn.LSTMCell, *args, **kwargs)
-
+
def forward(self, x, hx=None):
return super(VarLSTM, self).forward(x, hx)
class VarRNN(VarRNNBase):
"""
- 别名::class:`fastNLP.modules.VarRNN` :class:`fastNLP.modules.encoder.variational_rnn.VarRNN`
+ 别名::class:`fastNLP.modules.VarRNN` :class:`fastNLP.modules.encoder.VarRNN`
Variational Dropout RNN.
@@ -261,18 +262,18 @@ class VarRNN(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的RNN. Default: ``False``
"""
-
+
def __init__(self, *args, **kwargs):
super(VarRNN, self).__init__(
mode="RNN", Cell=nn.RNNCell, *args, **kwargs)
-
+
def forward(self, x, hx=None):
return super(VarRNN, self).forward(x, hx)
class VarGRU(VarRNNBase):
"""
- 别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.variational_rnn.VarGRU`
+ 别名::class:`fastNLP.modules.VarGRU` :class:`fastNLP.modules.encoder.VarGRU`
Variational Dropout GRU.
@@ -286,10 +287,10 @@ class VarGRU(VarRNNBase):
:param hidden_dropout: 对每个隐状态的dropout概率. Default: 0
:param bidirectional: 若为 ``True``, 使用双向的GRU. Default: ``False``
"""
-
+
def __init__(self, *args, **kwargs):
super(VarGRU, self).__init__(
mode="GRU", Cell=nn.GRUCell, *args, **kwargs)
-
+
def forward(self, x, hx=None):
return super(VarGRU, self).forward(x, hx)
diff --git a/fastNLP/modules/utils.py b/fastNLP/modules/utils.py
index 3c6a3d27..09574782 100644
--- a/fastNLP/modules/utils.py
+++ b/fastNLP/modules/utils.py
@@ -1,6 +1,16 @@
+"""
+.. todo::
+ doc
+"""
+
+__all__ = [
+ "initial_parameter",
+ "summary"
+]
+
+import os
from functools import reduce
-import numpy as np
import torch
import torch.nn as nn
import torch.nn.init as init
@@ -40,7 +50,7 @@ def initial_parameter(net, initial_method=None):
init_method = init.uniform_
else:
init_method = init.xavier_normal_
-
+
def weights_init(m):
# classname = m.__class__.__name__
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d) or isinstance(m, nn.Conv3d): # for all the cnn
@@ -66,37 +76,10 @@ def initial_parameter(net, initial_method=None):
else:
init.normal_(w.data) # bias
# print("init else")
-
+
net.apply(weights_init)
-def get_embeddings(init_embed):
- """
- 根据输入的init_embed生成nn.Embedding对象。
-
- :param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入
- nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始
- 化; 传入orch.Tensor, 将使用传入的值作为Embedding初始化。
- :return nn.Embedding embeddings:
- """
- if isinstance(init_embed, tuple):
- res = nn.Embedding(
- num_embeddings=init_embed[0], embedding_dim=init_embed[1])
- nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)),
- b=np.sqrt(3/res.weight.data.size(1)))
- elif isinstance(init_embed, nn.Module):
- res = init_embed
- elif isinstance(init_embed, torch.Tensor):
- res = nn.Embedding.from_pretrained(init_embed, freeze=False)
- elif isinstance(init_embed, np.ndarray):
- init_embed = torch.tensor(init_embed, dtype=torch.float32)
- res = nn.Embedding.from_pretrained(init_embed, freeze=False)
- else:
- raise TypeError(
- 'invalid init_embed type: {}'.format((type(init_embed))))
- return res
-
-
def summary(model: nn.Module):
"""
得到模型的总参数量
@@ -106,11 +89,11 @@ def summary(model: nn.Module):
"""
train = []
nontrain = []
-
+
def layer_summary(module: nn.Module):
def count_size(sizes):
- return reduce(lambda x, y: x*y, sizes)
-
+ return reduce(lambda x, y: x * y, sizes)
+
for p in module.parameters(recurse=False):
if p.requires_grad:
train.append(count_size(p.shape))
@@ -118,7 +101,7 @@ def summary(model: nn.Module):
nontrain.append(count_size(p.shape))
for subm in module.children():
layer_summary(subm)
-
+
layer_summary(model)
total_train = sum(train)
total_nontrain = sum(nontrain)
@@ -128,7 +111,7 @@ def summary(model: nn.Module):
strings.append('Trainable params: {:,}'.format(total_train))
strings.append('Non-trainable params: {:,}'.format(total_nontrain))
max_len = len(max(strings, key=len))
- bar = '-'*(max_len + 3)
+ bar = '-' * (max_len + 3)
strings = [bar] + strings + [bar]
print('\n'.join(strings))
return total, total_train, total_nontrain
@@ -139,10 +122,25 @@ def get_dropout_mask(drop_p: float, tensor: torch.Tensor):
根据tensor的形状,生成一个mask
:param drop_p: float, 以多大的概率置为0。
- :param tensor:torch.Tensor
+ :param tensor: torch.Tensor
:return: torch.FloatTensor. 与tensor一样的shape
"""
mask_x = torch.ones_like(tensor)
nn.functional.dropout(mask_x, p=drop_p,
training=False, inplace=True)
- return mask_x
\ No newline at end of file
+ return mask_x
+
+
+def _get_file_name_base_on_postfix(dir_path, postfix):
+ """
+ 在dir_path中寻找后缀为postfix的文件.
+ :param dir_path: str, 文件夹
+ :param postfix: 形如".bin", ".json"等
+ :return: str,文件的路径
+ """
+ files = list(filter(lambda filename: filename.endswith(postfix), os.listdir(os.path.join(dir_path))))
+ if len(files) == 0:
+ raise FileNotFoundError(f"There is no file endswith *{postfix} file in {dir_path}")
+ elif len(files) > 1:
+ raise FileExistsError(f"There are multiple *{postfix} files in {dir_path}")
+ return os.path.join(dir_path, files[0])
diff --git a/reproduction/Star_transformer/train.py b/reproduction/Star_transformer/train.py
index f1e5c2f9..d8e2576b 100644
--- a/reproduction/Star_transformer/train.py
+++ b/reproduction/Star_transformer/train.py
@@ -1,7 +1,7 @@
-from util import get_argparser, set_gpu, set_rng_seeds, add_model_args
+from reproduction.Star_transformer.util import get_argparser, set_gpu, set_rng_seeds, add_model_args
seed = set_rng_seeds(15360)
print('RNG SEED {}'.format(seed))
-from datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN
+from reproduction.Star_transformer.datasets import load_seqtag, load_sst, load_snli, EmbedLoader, MAX_LEN
import torch.nn as nn
import torch
import numpy as np
diff --git a/reproduction/Summarization/Baseline/data/dataloader.py b/reproduction/Summarization/Baseline/data/dataloader.py
index 47cd0856..dcb294b0 100644
--- a/reproduction/Summarization/Baseline/data/dataloader.py
+++ b/reproduction/Summarization/Baseline/data/dataloader.py
@@ -1,188 +1,188 @@
-import pickle
-import numpy as np
-
-from fastNLP.core.vocabulary import Vocabulary
-from fastNLP.io.base_loader import DataBundle
-from fastNLP.io.dataset_loader import JsonLoader
-from fastNLP.core.const import Const
-
-from tools.logger import *
-
-WORD_PAD = "[PAD]"
-WORD_UNK = "[UNK]"
-DOMAIN_UNK = "X"
-TAG_UNK = "X"
-
-
-class SummarizationLoader(JsonLoader):
- """
- 读取summarization数据集,读取的DataSet包含fields::
-
- text: list(str),document
- summary: list(str), summary
- text_wd: list(list(str)),tokenized document
- summary_wd: list(list(str)), tokenized summary
- labels: list(int),
- flatten_label: list(int), 0 or 1, flatten labels
- domain: str, optional
- tag: list(str), optional
-
- 数据来源: CNN_DailyMail Newsroom DUC
- """
-
- def __init__(self):
- super(SummarizationLoader, self).__init__()
-
- def _load(self, path):
- ds = super(SummarizationLoader, self)._load(path)
-
- def _lower_text(text_list):
- return [text.lower() for text in text_list]
-
- def _split_list(text_list):
- return [text.split() for text in text_list]
-
- def _convert_label(label, sent_len):
- np_label = np.zeros(sent_len, dtype=int)
- if label != []:
- np_label[np.array(label)] = 1
- return np_label.tolist()
-
- ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
- ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
- ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
- ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
- ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
-
- return ds
-
- def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
- """
- :param paths: dict path for each dataset
- :param vocab_size: int max_size for vocab
- :param vocab_path: str vocab path
- :param sent_max_len: int max token number of the sentence
- :param doc_max_timesteps: int max sentence number of the document
- :param domain: bool build vocab for publication, use 'X' for unknown
- :param tag: bool build vocab for tag, use 'X' for unknown
- :param load_vocab_file: bool build vocab (False) or load vocab (True)
- :return: DataBundle
- datasets: dict keys correspond to the paths dict
- vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
- embeddings: optional
- """
-
- def _pad_sent(text_wd):
- pad_text_wd = []
- for sent_wd in text_wd:
- if len(sent_wd) < sent_max_len:
- pad_num = sent_max_len - len(sent_wd)
- sent_wd.extend([WORD_PAD] * pad_num)
- else:
- sent_wd = sent_wd[:sent_max_len]
- pad_text_wd.append(sent_wd)
- return pad_text_wd
-
- def _token_mask(text_wd):
- token_mask_list = []
- for sent_wd in text_wd:
- token_num = len(sent_wd)
- if token_num < sent_max_len:
- mask = [1] * token_num + [0] * (sent_max_len - token_num)
- else:
- mask = [1] * sent_max_len
- token_mask_list.append(mask)
- return token_mask_list
-
- def _pad_label(label):
- text_len = len(label)
- if text_len < doc_max_timesteps:
- pad_label = label + [0] * (doc_max_timesteps - text_len)
- else:
- pad_label = label[:doc_max_timesteps]
- return pad_label
-
- def _pad_doc(text_wd):
- text_len = len(text_wd)
- if text_len < doc_max_timesteps:
- padding = [WORD_PAD] * sent_max_len
- pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
- else:
- pad_text = text_wd[:doc_max_timesteps]
- return pad_text
-
- def _sent_mask(text_wd):
- text_len = len(text_wd)
- if text_len < doc_max_timesteps:
- sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
- else:
- sent_mask = [1] * doc_max_timesteps
- return sent_mask
-
-
- datasets = {}
- train_ds = None
- for key, value in paths.items():
- ds = self.load(value)
- # pad sent
- ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
- ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
- # pad document
- ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
- ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
- ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
-
- # rename field
- ds.rename_field("pad_text", Const.INPUT)
- ds.rename_field("seq_len", Const.INPUT_LEN)
- ds.rename_field("pad_label", Const.TARGET)
-
- # set input and target
- ds.set_input(Const.INPUT, Const.INPUT_LEN)
- ds.set_target(Const.TARGET, Const.INPUT_LEN)
-
- datasets[key] = ds
- if "train" in key:
- train_ds = datasets[key]
-
- vocab_dict = {}
- if load_vocab_file == False:
- logger.info("[INFO] Build new vocab from training dataset!")
- if train_ds == None:
- raise ValueError("Lack train file to build vocabulary!")
-
- vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
- vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
- vocab_dict["vocab"] = vocabs
- else:
- logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
- word_list = []
- with open(vocab_path, 'r', encoding='utf8') as vocab_f:
- cnt = 2 # pad and unk
- for line in vocab_f:
- pieces = line.split("\t")
- word_list.append(pieces[0])
- cnt += 1
- if cnt > vocab_size:
- break
- vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
- vocabs.add_word_lst(word_list)
- vocabs.build_vocab()
- vocab_dict["vocab"] = vocabs
-
- if domain == True:
- domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
- domaindict.from_dataset(train_ds, field_name="publication")
- vocab_dict["domain"] = domaindict
- if tag == True:
- tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
- tagdict.from_dataset(train_ds, field_name="tag")
- vocab_dict["tag"] = tagdict
-
- for ds in datasets.values():
- vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
-
- return DataBundle(vocabs=vocab_dict, datasets=datasets)
-
-
-
+import pickle
+import numpy as np
+
+from fastNLP.core.vocabulary import Vocabulary
+from fastNLP.io.data_bundle import DataBundle
+from fastNLP.io.dataset_loader import JsonLoader
+from fastNLP.core.const import Const
+
+from tools.logger import *
+
+WORD_PAD = "[PAD]"
+WORD_UNK = "[UNK]"
+DOMAIN_UNK = "X"
+TAG_UNK = "X"
+
+
+class SummarizationLoader(JsonLoader):
+ """
+ 读取summarization数据集,读取的DataSet包含fields::
+
+ text: list(str),document
+ summary: list(str), summary
+ text_wd: list(list(str)),tokenized document
+ summary_wd: list(list(str)), tokenized summary
+ labels: list(int),
+ flatten_label: list(int), 0 or 1, flatten labels
+ domain: str, optional
+ tag: list(str), optional
+
+ 数据来源: CNN_DailyMail Newsroom DUC
+ """
+
+ def __init__(self):
+ super(SummarizationLoader, self).__init__()
+
+ def _load(self, path):
+ ds = super(SummarizationLoader, self)._load(path)
+
+ def _lower_text(text_list):
+ return [text.lower() for text in text_list]
+
+ def _split_list(text_list):
+ return [text.split() for text in text_list]
+
+ def _convert_label(label, sent_len):
+ np_label = np.zeros(sent_len, dtype=int)
+ if label != []:
+ np_label[np.array(label)] = 1
+ return np_label.tolist()
+
+ ds.apply(lambda x: _lower_text(x['text']), new_field_name='text')
+ ds.apply(lambda x: _lower_text(x['summary']), new_field_name='summary')
+ ds.apply(lambda x:_split_list(x['text']), new_field_name='text_wd')
+ ds.apply(lambda x:_split_list(x['summary']), new_field_name='summary_wd')
+ ds.apply(lambda x:_convert_label(x["label"], len(x["text"])), new_field_name="flatten_label")
+
+ return ds
+
+ def process(self, paths, vocab_size, vocab_path, sent_max_len, doc_max_timesteps, domain=False, tag=False, load_vocab_file=True):
+ """
+ :param paths: dict path for each dataset
+ :param vocab_size: int max_size for vocab
+ :param vocab_path: str vocab path
+ :param sent_max_len: int max token number of the sentence
+ :param doc_max_timesteps: int max sentence number of the document
+ :param domain: bool build vocab for publication, use 'X' for unknown
+ :param tag: bool build vocab for tag, use 'X' for unknown
+ :param load_vocab_file: bool build vocab (False) or load vocab (True)
+ :return: DataBundle
+ datasets: dict keys correspond to the paths dict
+ vocabs: dict key: vocab(if "train" in paths), domain(if domain=True), tag(if tag=True)
+ embeddings: optional
+ """
+
+ def _pad_sent(text_wd):
+ pad_text_wd = []
+ for sent_wd in text_wd:
+ if len(sent_wd) < sent_max_len:
+ pad_num = sent_max_len - len(sent_wd)
+ sent_wd.extend([WORD_PAD] * pad_num)
+ else:
+ sent_wd = sent_wd[:sent_max_len]
+ pad_text_wd.append(sent_wd)
+ return pad_text_wd
+
+ def _token_mask(text_wd):
+ token_mask_list = []
+ for sent_wd in text_wd:
+ token_num = len(sent_wd)
+ if token_num < sent_max_len:
+ mask = [1] * token_num + [0] * (sent_max_len - token_num)
+ else:
+ mask = [1] * sent_max_len
+ token_mask_list.append(mask)
+ return token_mask_list
+
+ def _pad_label(label):
+ text_len = len(label)
+ if text_len < doc_max_timesteps:
+ pad_label = label + [0] * (doc_max_timesteps - text_len)
+ else:
+ pad_label = label[:doc_max_timesteps]
+ return pad_label
+
+ def _pad_doc(text_wd):
+ text_len = len(text_wd)
+ if text_len < doc_max_timesteps:
+ padding = [WORD_PAD] * sent_max_len
+ pad_text = text_wd + [padding] * (doc_max_timesteps - text_len)
+ else:
+ pad_text = text_wd[:doc_max_timesteps]
+ return pad_text
+
+ def _sent_mask(text_wd):
+ text_len = len(text_wd)
+ if text_len < doc_max_timesteps:
+ sent_mask = [1] * text_len + [0] * (doc_max_timesteps - text_len)
+ else:
+ sent_mask = [1] * doc_max_timesteps
+ return sent_mask
+
+
+ datasets = {}
+ train_ds = None
+ for key, value in paths.items():
+ ds = self.load(value)
+ # pad sent
+ ds.apply(lambda x:_pad_sent(x["text_wd"]), new_field_name="pad_text_wd")
+ ds.apply(lambda x:_token_mask(x["text_wd"]), new_field_name="pad_token_mask")
+ # pad document
+ ds.apply(lambda x:_pad_doc(x["pad_text_wd"]), new_field_name="pad_text")
+ ds.apply(lambda x:_sent_mask(x["pad_text_wd"]), new_field_name="seq_len")
+ ds.apply(lambda x:_pad_label(x["flatten_label"]), new_field_name="pad_label")
+
+ # rename field
+ ds.rename_field("pad_text", Const.INPUT)
+ ds.rename_field("seq_len", Const.INPUT_LEN)
+ ds.rename_field("pad_label", Const.TARGET)
+
+ # set input and target
+ ds.set_input(Const.INPUT, Const.INPUT_LEN)
+ ds.set_target(Const.TARGET, Const.INPUT_LEN)
+
+ datasets[key] = ds
+ if "train" in key:
+ train_ds = datasets[key]
+
+ vocab_dict = {}
+ if load_vocab_file == False:
+ logger.info("[INFO] Build new vocab from training dataset!")
+ if train_ds == None:
+ raise ValueError("Lack train file to build vocabulary!")
+
+ vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
+ vocabs.from_dataset(train_ds, field_name=["text_wd","summary_wd"])
+ vocab_dict["vocab"] = vocabs
+ else:
+ logger.info("[INFO] Load existing vocab from %s!" % vocab_path)
+ word_list = []
+ with open(vocab_path, 'r', encoding='utf8') as vocab_f:
+ cnt = 2 # pad and unk
+ for line in vocab_f:
+ pieces = line.split("\t")
+ word_list.append(pieces[0])
+ cnt += 1
+ if cnt > vocab_size:
+ break
+ vocabs = Vocabulary(max_size=vocab_size, padding=WORD_PAD, unknown=WORD_UNK)
+ vocabs.add_word_lst(word_list)
+ vocabs.build_vocab()
+ vocab_dict["vocab"] = vocabs
+
+ if domain == True:
+ domaindict = Vocabulary(padding=None, unknown=DOMAIN_UNK)
+ domaindict.from_dataset(train_ds, field_name="publication")
+ vocab_dict["domain"] = domaindict
+ if tag == True:
+ tagdict = Vocabulary(padding=None, unknown=TAG_UNK)
+ tagdict.from_dataset(train_ds, field_name="tag")
+ vocab_dict["tag"] = tagdict
+
+ for ds in datasets.values():
+ vocab_dict["vocab"].index_dataset(ds, field_name=Const.INPUT, new_field_name=Const.INPUT)
+
+ return DataBundle(vocabs=vocab_dict, datasets=datasets)
+
+
+
diff --git a/reproduction/Summarization/BertSum/dataloader.py b/reproduction/Summarization/BertSum/dataloader.py
index c5201261..6af797e4 100644
--- a/reproduction/Summarization/BertSum/dataloader.py
+++ b/reproduction/Summarization/BertSum/dataloader.py
@@ -3,7 +3,7 @@ from datetime import timedelta
from fastNLP.io.dataset_loader import JsonLoader
from fastNLP.modules.encoder._bert import BertTokenizer
-from fastNLP.io.base_loader import DataBundle
+from fastNLP.io.data_bundle import DataBundle
from fastNLP.core.const import Const
class BertData(JsonLoader):
diff --git a/reproduction/Summarization/BertSum/model.py b/reproduction/Summarization/BertSum/model.py
index 655ad16e..1ee821fc 100644
--- a/reproduction/Summarization/BertSum/model.py
+++ b/reproduction/Summarization/BertSum/model.py
@@ -2,7 +2,7 @@ import torch
from torch import nn
from torch.nn import init
-from fastNLP.modules.encoder._bert import BertModel
+from fastNLP.modules.encoder.bert import BertModel
class Classifier(nn.Module):
diff --git a/reproduction/coreference_resolution/data_load/cr_loader.py b/reproduction/coreference_resolution/data_load/cr_loader.py
index a424b0d1..5ed73473 100644
--- a/reproduction/coreference_resolution/data_load/cr_loader.py
+++ b/reproduction/coreference_resolution/data_load/cr_loader.py
@@ -1,7 +1,7 @@
from fastNLP.io.dataset_loader import JsonLoader,DataSet,Instance
from fastNLP.io.file_reader import _read_json
from fastNLP.core.vocabulary import Vocabulary
-from fastNLP.io.base_loader import DataBundle
+from fastNLP.io.data_bundle import DataBundle
from reproduction.coreference_resolution.model.config import Config
import reproduction.coreference_resolution.model.preprocess as preprocess
diff --git a/reproduction/joint_cws_parse/readme.md b/reproduction/joint_cws_parse/README.md
similarity index 100%
rename from reproduction/joint_cws_parse/readme.md
rename to reproduction/joint_cws_parse/README.md
diff --git a/reproduction/joint_cws_parse/data/data_loader.py b/reproduction/joint_cws_parse/data/data_loader.py
index 3e6fec4b..4df46b04 100644
--- a/reproduction/joint_cws_parse/data/data_loader.py
+++ b/reproduction/joint_cws_parse/data/data_loader.py
@@ -1,6 +1,6 @@
-from fastNLP.io.base_loader import DataSetLoader, DataBundle
+from fastNLP.io.data_bundle import DataSetLoader, DataBundle
from fastNLP.io.data_loader import ConllLoader
import numpy as np
diff --git a/reproduction/joint_cws_parse/models/CharParser.py b/reproduction/joint_cws_parse/models/CharParser.py
index 1ed5ea2d..7d89cacb 100644
--- a/reproduction/joint_cws_parse/models/CharParser.py
+++ b/reproduction/joint_cws_parse/models/CharParser.py
@@ -12,7 +12,7 @@ from torch.nn import functional as F
from fastNLP.modules.dropout import TimestepDropout
from fastNLP.modules.encoder.variational_rnn import VarLSTM
from fastNLP import seq_len_to_mask
-from fastNLP.modules import Embedding
+from fastNLP.embeddings import Embedding
def drop_input_independent(word_embeddings, dropout_emb):
@@ -224,11 +224,11 @@ class CharBiaffineParser(BiaffineParser):
batch_size, seq_len, _ = arc_pred.shape
flip_mask = (mask == 0)
- _arc_pred = arc_pred.clone()
- _arc_pred.masked_fill_(flip_mask.unsqueeze(1), -float('inf'))
+ # _arc_pred = arc_pred.clone()
+ _arc_pred = arc_pred.masked_fill(flip_mask.unsqueeze(1), -float('inf'))
- arc_true[:, 0].fill_(-1)
- label_true[:, 0].fill_(-1)
+ arc_true.data[:, 0].fill_(-1)
+ label_true.data[:, 0].fill_(-1)
arc_nll = F.cross_entropy(_arc_pred.view(-1, seq_len), arc_true.view(-1), ignore_index=-1)
label_nll = F.cross_entropy(label_pred.view(-1, label_pred.size(-1)), label_true.view(-1), ignore_index=-1)
diff --git a/reproduction/joint_cws_parse/train.py b/reproduction/joint_cws_parse/train.py
index 2f8b0d04..ed4b07f0 100644
--- a/reproduction/joint_cws_parse/train.py
+++ b/reproduction/joint_cws_parse/train.py
@@ -2,18 +2,19 @@ import sys
sys.path.append('../..')
from reproduction.joint_cws_parse.data.data_loader import CTBxJointLoader
-from fastNLP.modules.encoder.embedding import StaticEmbedding
+from fastNLP.embeddings.static_embedding import StaticEmbedding
from torch import nn
from functools import partial
from reproduction.joint_cws_parse.models.CharParser import CharParser
from reproduction.joint_cws_parse.models.metrics import SegAppCharParseF1Metric, CWSMetric
-from fastNLP import cache_results, BucketSampler, Trainer
+from fastNLP import BucketSampler, Trainer
from torch import optim
-from reproduction.joint_cws_parse.models.callbacks import DevCallback, OptimizerCallback
-from torch.optim.lr_scheduler import LambdaLR, StepLR
+from reproduction.joint_cws_parse.models.callbacks import DevCallback
+from torch.optim.lr_scheduler import StepLR
from fastNLP import Tester
from fastNLP import GradientClipCallback, LRScheduler
import os
+from fastNLP import cache_results
def set_random_seed(random_seed=666):
import random, numpy, torch
@@ -39,43 +40,42 @@ label_mlp_size = 100
batch_size = 32
update_every = 4
n_epochs = 100
-data_folder = '' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件
-vector_folder = '' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt
+data_name = 'new_ctb7'
####################################################
+data_folder = f'/remote-home/hyan01/exps/JointCwsPosParser/data/{data_name}/output' # 填写在数据所在文件夹, 文件夹下应该有train, dev, test等三个文件
+vector_folder = '/remote-home/hyan01/exps/CWS/pretrain/vectors' # 预训练的vector,下面应该包含三个文件: 1grams_t3_m50_corpus.txt, 2grams_t3_m50_corpus.txt, 3grams_t3_m50_corpus.txt
set_random_seed(1234)
device = 0
-# @cache_results('caches/{}.pkl'.format(data_name))
-# def get_data():
-data = CTBxJointLoader().process(data_folder)
-
-char_labels_vocab = data.vocabs['char_labels']
-
-pre_chars_vocab = data.vocabs['pre_chars']
-pre_bigrams_vocab = data.vocabs['pre_bigrams']
-pre_trigrams_vocab = data.vocabs['pre_trigrams']
-
-chars_vocab = data.vocabs['chars']
-bigrams_vocab = data.vocabs['bigrams']
-trigrams_vocab = data.vocabs['trigrams']
-
-pre_chars_embed = StaticEmbedding(pre_chars_vocab,
- model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'),
- init_method=uniform_init, normalize=False)
-pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data/pre_chars_embed.embedding.weight.data.std()
-pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab,
- model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'),
- init_method=uniform_init, normalize=False)
-pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data/pre_bigrams_embed.embedding.weight.data.std()
-pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab,
- model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'),
- init_method=uniform_init, normalize=False)
-pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data/pre_trigrams_embed.embedding.weight.data.std()
-
- # return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data
-
-# chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data()
+@cache_results('caches/{}.pkl'.format(data_name))
+def get_data():
+ data = CTBxJointLoader().process(data_folder)
+ char_labels_vocab = data.vocabs['char_labels']
+
+ pre_chars_vocab = data.vocabs['pre_chars']
+ pre_bigrams_vocab = data.vocabs['pre_bigrams']
+ pre_trigrams_vocab = data.vocabs['pre_trigrams']
+
+ chars_vocab = data.vocabs['chars']
+ bigrams_vocab = data.vocabs['bigrams']
+ trigrams_vocab = data.vocabs['trigrams']
+ pre_chars_embed = StaticEmbedding(pre_chars_vocab,
+ model_dir_or_name=os.path.join(vector_folder, '1grams_t3_m50_corpus.txt'),
+ init_method=uniform_init, normalize=False)
+ pre_chars_embed.embedding.weight.data = pre_chars_embed.embedding.weight.data / pre_chars_embed.embedding.weight.data.std()
+ pre_bigrams_embed = StaticEmbedding(pre_bigrams_vocab,
+ model_dir_or_name=os.path.join(vector_folder, '2grams_t3_m50_corpus.txt'),
+ init_method=uniform_init, normalize=False)
+ pre_bigrams_embed.embedding.weight.data = pre_bigrams_embed.embedding.weight.data / pre_bigrams_embed.embedding.weight.data.std()
+ pre_trigrams_embed = StaticEmbedding(pre_trigrams_vocab,
+ model_dir_or_name=os.path.join(vector_folder, '3grams_t3_m50_corpus.txt'),
+ init_method=uniform_init, normalize=False)
+ pre_trigrams_embed.embedding.weight.data = pre_trigrams_embed.embedding.weight.data / pre_trigrams_embed.embedding.weight.data.std()
+
+ return chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data
+
+chars_vocab, bigrams_vocab, trigrams_vocab, char_labels_vocab, pre_chars_embed, pre_bigrams_embed, pre_trigrams_embed, data = get_data()
print(data)
model = CharParser(char_vocab_size=len(chars_vocab),
@@ -104,11 +104,24 @@ optimizer = optim.Adam([param for param in model.parameters() if param.requires_
sampler = BucketSampler(seq_len_field_name='seq_lens')
callbacks = []
+
+from fastNLP.core.callback import Callback
+from torch.optim.lr_scheduler import LambdaLR
+class SchedulerCallback(Callback):
+ def __init__(self, scheduler):
+ super().__init__()
+ self.scheduler = scheduler
+
+ def on_backward_end(self):
+ if self.step % self.update_every==0:
+ self.scheduler.step()
+
+scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
# scheduler = LambdaLR(optimizer, lr_lambda=lambda step:(0.75)**(step//5000))
-scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
-# optim_callback = OptimizerCallback(optimizer, scheduler, update_every)
+# scheduler = StepLR(optimizer, step_size=18, gamma=0.75)
+scheduler_callback = SchedulerCallback(scheduler)
# callbacks.append(optim_callback)
-scheduler_callback = LRScheduler(scheduler)
+# scheduler_callback = LRScheduler(scheduler)
callbacks.append(scheduler_callback)
callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))
@@ -119,6 +132,6 @@ callbacks.append(dev_callback)
trainer = Trainer(data.datasets['train'], model, loss=None, metrics=metrics, n_epochs=n_epochs, batch_size=batch_size, print_every=3,
validate_every=-1, dev_data=data.datasets['dev'], save_path=None, optimizer=optimizer,
- check_code_level=0, metric_key='u_f1', sampler=sampler, prefetch=True, use_tqdm=True,
+ check_code_level=0, metric_key='u_f1', sampler=sampler, num_workers=2, use_tqdm=True,
device=device, callbacks=callbacks, update_every=update_every)
trainer.train()
\ No newline at end of file
diff --git a/reproduction/Biaffine_parser/cfg.cfg b/reproduction/legacy/Biaffine_parser/cfg.cfg
similarity index 100%
rename from reproduction/Biaffine_parser/cfg.cfg
rename to reproduction/legacy/Biaffine_parser/cfg.cfg
diff --git a/reproduction/Biaffine_parser/infer.py b/reproduction/legacy/Biaffine_parser/infer.py
similarity index 100%
rename from reproduction/Biaffine_parser/infer.py
rename to reproduction/legacy/Biaffine_parser/infer.py
diff --git a/reproduction/Biaffine_parser/main.py b/reproduction/legacy/Biaffine_parser/main.py
similarity index 100%
rename from reproduction/Biaffine_parser/main.py
rename to reproduction/legacy/Biaffine_parser/main.py
diff --git a/reproduction/Biaffine_parser/run.py b/reproduction/legacy/Biaffine_parser/run.py
similarity index 100%
rename from reproduction/Biaffine_parser/run.py
rename to reproduction/legacy/Biaffine_parser/run.py
diff --git a/reproduction/Biaffine_parser/util.py b/reproduction/legacy/Biaffine_parser/util.py
similarity index 100%
rename from reproduction/Biaffine_parser/util.py
rename to reproduction/legacy/Biaffine_parser/util.py
diff --git a/reproduction/Chinese_word_segmentation/__init__.py b/reproduction/legacy/Chinese_word_segmentation/__init__.py
similarity index 100%
rename from reproduction/Chinese_word_segmentation/__init__.py
rename to reproduction/legacy/Chinese_word_segmentation/__init__.py
diff --git a/reproduction/Chinese_word_segmentation/cws.cfg b/reproduction/legacy/Chinese_word_segmentation/cws.cfg
similarity index 100%
rename from reproduction/Chinese_word_segmentation/cws.cfg
rename to reproduction/legacy/Chinese_word_segmentation/cws.cfg
diff --git a/reproduction/Chinese_word_segmentation/cws_io/__init__.py b/reproduction/legacy/Chinese_word_segmentation/cws_io/__init__.py
similarity index 100%
rename from reproduction/Chinese_word_segmentation/cws_io/__init__.py
rename to reproduction/legacy/Chinese_word_segmentation/cws_io/__init__.py
diff --git a/reproduction/Chinese_word_segmentation/cws_io/cws_reader.py b/reproduction/legacy/Chinese_word_segmentation/cws_io/cws_reader.py
similarity index 100%
rename from reproduction/Chinese_word_segmentation/cws_io/cws_reader.py
rename to reproduction/legacy/Chinese_word_segmentation/cws_io/cws_reader.py
diff --git a/reproduction/Chinese_word_segmentation/models/__init__.py b/reproduction/legacy/Chinese_word_segmentation/models/__init__.py
similarity index 100%
rename from reproduction/Chinese_word_segmentation/models/__init__.py
rename to reproduction/legacy/Chinese_word_segmentation/models/__init__.py
diff --git a/reproduction/Chinese_word_segmentation/models/cws_model.py b/reproduction/legacy/Chinese_word_segmentation/models/cws_model.py
similarity index 98%
rename from reproduction/Chinese_word_segmentation/models/cws_model.py
rename to reproduction/legacy/Chinese_word_segmentation/models/cws_model.py
index b41ad87d..0d10d2e5 100644
--- a/reproduction/Chinese_word_segmentation/models/cws_model.py
+++ b/reproduction/legacy/Chinese_word_segmentation/models/cws_model.py
@@ -4,7 +4,7 @@ from torch import nn
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.decoder.mlp import MLP
-from reproduction.Chinese_word_segmentation.utils import seq_lens_to_mask
+from reproduction.legacy.Chinese_word_segmentation.utils import seq_lens_to_mask
class CWSBiLSTMEncoder(BaseModel):
diff --git a/reproduction/Chinese_word_segmentation/models/cws_transformer.py b/reproduction/legacy/Chinese_word_segmentation/models/cws_transformer.py
similarity index 97%
rename from reproduction/Chinese_word_segmentation/models/cws_transformer.py
rename to reproduction/legacy/Chinese_word_segmentation/models/cws_transformer.py
index e8ae5ecc..ae8a5a7f 100644
--- a/reproduction/Chinese_word_segmentation/models/cws_transformer.py
+++ b/reproduction/legacy/Chinese_word_segmentation/models/cws_transformer.py
@@ -9,7 +9,7 @@
from torch import nn
import torch
# from fastNLP.modules.encoder.transformer import TransformerEncoder
-from reproduction.Chinese_word_segmentation.models.transformer import TransformerEncoder
+from reproduction.legacy.Chinese_word_segmentation.models import TransformerEncoder
from fastNLP.modules.decoder.crf import ConditionalRandomField,seq_len_to_byte_mask
from fastNLP.modules.decoder.crf import allowed_transitions
@@ -79,7 +79,7 @@ class TransformerCWS(nn.Module):
return {'pred': probs, 'seq_lens':seq_lens}
-from reproduction.Chinese_word_segmentation.models.dilated_transformer import TransformerDilateEncoder
+from reproduction.legacy.Chinese_word_segmentation.models import TransformerDilateEncoder
class TransformerDilatedCWS(nn.Module):
def __init__(self, vocab_num, embed_dim=100, bigram_vocab_num=None, bigram_embed_dim=100, num_bigram_per_char=None,
diff --git a/reproduction/Chinese_word_segmentation/process/__init__.py b/reproduction/legacy/Chinese_word_segmentation/process/__init__.py
similarity index 100%
rename from reproduction/Chinese_word_segmentation/process/__init__.py
rename to reproduction/legacy/Chinese_word_segmentation/process/__init__.py
diff --git a/reproduction/Chinese_word_segmentation/process/cws_processor.py b/reproduction/legacy/Chinese_word_segmentation/process/cws_processor.py
similarity index 99%
rename from reproduction/Chinese_word_segmentation/process/cws_processor.py
rename to reproduction/legacy/Chinese_word_segmentation/process/cws_processor.py
index 614d9ef5..1f64bed2 100644
--- a/reproduction/Chinese_word_segmentation/process/cws_processor.py
+++ b/reproduction/legacy/Chinese_word_segmentation/process/cws_processor.py
@@ -4,7 +4,7 @@ import re
from fastNLP.api.processor import Processor
from fastNLP.core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
-from reproduction.Chinese_word_segmentation.process.span_converter import SpanConverter
+from reproduction.legacy.Chinese_word_segmentation.process.span_converter import SpanConverter
_SPECIAL_TAG_PATTERN = '<[a-zA-Z]+>'
diff --git a/reproduction/Chinese_word_segmentation/process/span_converter.py b/reproduction/legacy/Chinese_word_segmentation/process/span_converter.py
similarity index 100%
rename from reproduction/Chinese_word_segmentation/process/span_converter.py
rename to reproduction/legacy/Chinese_word_segmentation/process/span_converter.py
diff --git a/reproduction/Chinese_word_segmentation/utils.py b/reproduction/legacy/Chinese_word_segmentation/utils.py
similarity index 100%
rename from reproduction/Chinese_word_segmentation/utils.py
rename to reproduction/legacy/Chinese_word_segmentation/utils.py
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/README.md b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/README.md
similarity index 94%
rename from reproduction/LSTM+self_attention_sentiment_analysis/README.md
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/README.md
index 2dff7caa..dfb337ec 100644
--- a/reproduction/LSTM+self_attention_sentiment_analysis/README.md
+++ b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/README.md
@@ -1,5 +1,7 @@
# Prototype
+这是一个很旧版本的reproduction,待修改
+
## Word2Idx.py
A mapping model between words and indexes
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/Word2Idx.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/Word2Idx.py
similarity index 100%
rename from reproduction/LSTM+self_attention_sentiment_analysis/Word2Idx.py
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/Word2Idx.py
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/config.cfg b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/config.cfg
similarity index 100%
rename from reproduction/LSTM+self_attention_sentiment_analysis/config.cfg
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/config.cfg
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/dataloader.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/dataloader.py
similarity index 100%
rename from reproduction/LSTM+self_attention_sentiment_analysis/dataloader.py
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/dataloader.py
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/example.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/example.py
similarity index 100%
rename from reproduction/LSTM+self_attention_sentiment_analysis/example.py
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/example.py
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/main.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/main.py
similarity index 90%
rename from reproduction/LSTM+self_attention_sentiment_analysis/main.py
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/main.py
index 871dc476..05077530 100644
--- a/reproduction/LSTM+self_attention_sentiment_analysis/main.py
+++ b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/main.py
@@ -1,6 +1,9 @@
+# 这是一个很旧版本的代码
+
+"""
import torch.nn.functional as F
-from fastNLP.core.trainer import ClassificationTrainer
+from fastNLP.core.trainer import Trainer
from fastNLP.core.utils import ClassPreprocess as Preprocess
from fastNLP.io.config_io import ConfigLoader
from fastNLP.io.config_io import ConfigSection
@@ -8,7 +11,7 @@ from fastNLP.io.dataset_loader import DummyClassificationReader as Dataset_loade
from fastNLP.models.base_model import BaseModel
from fastNLP.modules.aggregator.self_attention import SelfAttention
from fastNLP.modules.decoder.mlp import MLP
-from fastNLP.modules.encoder.embedding import Embedding as Embedding
+from fastNLP.embeddings.embedding import Embedding as Embedding
from fastNLP.modules.encoder.lstm import LSTM
train_data_path = 'small_train_data.txt'
@@ -61,12 +64,13 @@ class SELF_ATTENTION_YELP_CLASSIFICATION(BaseModel):
train_args = ConfigSection()
ConfigLoader("good path").load_config('config.cfg',{"train": train_args})
-train_args['vocab'] = len(word2index)
+# train_args['vocab'] = len(word2index)
-trainer = ClassificationTrainer(**train_args.data)
+trainer = Trainer(**train_args.data)
# for k in train_args.__dict__.keys():
# print(k, train_args[k])
model = SELF_ATTENTION_YELP_CLASSIFICATION(train_args)
-trainer.train(model,train_data , dev_data)
+trainer.train()
+"""
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/predict.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/predict.py
similarity index 100%
rename from reproduction/LSTM+self_attention_sentiment_analysis/predict.py
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/predict.py
diff --git a/reproduction/LSTM+self_attention_sentiment_analysis/prepare.py b/reproduction/legacy/LSTM+self_attention_sentiment_analysis/prepare.py
similarity index 100%
rename from reproduction/LSTM+self_attention_sentiment_analysis/prepare.py
rename to reproduction/legacy/LSTM+self_attention_sentiment_analysis/prepare.py
diff --git a/reproduction/POS_tagging/pos_processor.py b/reproduction/legacy/POS_tagging/pos_processor.py
similarity index 100%
rename from reproduction/POS_tagging/pos_processor.py
rename to reproduction/legacy/POS_tagging/pos_processor.py
diff --git a/reproduction/POS_tagging/pos_reader.py b/reproduction/legacy/POS_tagging/pos_reader.py
similarity index 100%
rename from reproduction/POS_tagging/pos_reader.py
rename to reproduction/legacy/POS_tagging/pos_reader.py
diff --git a/reproduction/POS_tagging/pos_tag.cfg b/reproduction/legacy/POS_tagging/pos_tag.cfg
similarity index 100%
rename from reproduction/POS_tagging/pos_tag.cfg
rename to reproduction/legacy/POS_tagging/pos_tag.cfg
diff --git a/reproduction/POS_tagging/train_pos_tag.py b/reproduction/legacy/POS_tagging/train_pos_tag.py
similarity index 100%
rename from reproduction/POS_tagging/train_pos_tag.py
rename to reproduction/legacy/POS_tagging/train_pos_tag.py
diff --git a/reproduction/POS_tagging/utils.py b/reproduction/legacy/POS_tagging/utils.py
similarity index 100%
rename from reproduction/POS_tagging/utils.py
rename to reproduction/legacy/POS_tagging/utils.py
diff --git a/reproduction/matching/data/MatchingDataLoader.py b/reproduction/matching/data/MatchingDataLoader.py
deleted file mode 100644
index 67fa4c8d..00000000
--- a/reproduction/matching/data/MatchingDataLoader.py
+++ /dev/null
@@ -1,431 +0,0 @@
-
-import os
-
-from typing import Union, Dict
-
-from fastNLP.core.const import Const
-from fastNLP.core.vocabulary import Vocabulary
-from fastNLP.io.base_loader import DataBundle, DataSetLoader
-from fastNLP.io.dataset_loader import JsonLoader, CSVLoader
-from fastNLP.io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR
-from fastNLP.modules.encoder._bert import BertTokenizer
-
-
-class MatchingLoader(DataSetLoader):
- """
- 别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.dataset_loader.MatchingLoader`
-
- 读取Matching任务的数据集
-
- :param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名
- """
-
- def __init__(self, paths: dict=None):
- self.paths = paths
-
- def _load(self, path):
- """
- :param str path: 待读取数据集的路径名
- :return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子
- 的原始字符串文本,第三个为标签
- """
- raise NotImplementedError
-
- def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None,
- to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None,
- cut_text: int = None, get_index=True, auto_pad_length: int=None,
- auto_pad_token: str='', set_input: Union[list, str, bool]=True,
- set_target: Union[list, str, bool] = True, concat: Union[str, list, bool]=None, ) -> DataBundle:
- """
- :param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹,
- 则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和
- 对应的全路径文件名。
- :param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义
- 这个数据集的名字,如果不定义则默认为train。
- :param bool to_lower: 是否将文本自动转为小写。默认值为False。
- :param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` :
- 提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和
- attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len
- :param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径
- :param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。
- :param bool get_index: 是否需要根据词表将文本转为index
- :param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad
- :param str auto_pad_token: 自动pad的内容
- :param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False
- 则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input,
- 于此同时其他field不会被设置为input。默认值为True。
- :param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。
- :param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个。
- 如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果
- 传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]'].
- :return:
- """
- if isinstance(set_input, str):
- set_input = [set_input]
- if isinstance(set_target, str):
- set_target = [set_target]
- if isinstance(set_input, bool):
- auto_set_input = set_input
- else:
- auto_set_input = False
- if isinstance(set_target, bool):
- auto_set_target = set_target
- else:
- auto_set_target = False
- if isinstance(paths, str):
- if os.path.isdir(paths):
- path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()}
- else:
- path = {dataset_name if dataset_name is not None else 'train': paths}
- else:
- path = paths
-
- data_info = DataBundle()
- for data_name in path.keys():
- data_info.datasets[data_name] = self._load(path[data_name])
-
- for data_name, data_set in data_info.datasets.items():
- if auto_set_input:
- data_set.set_input(Const.INPUTS(0), Const.INPUTS(1))
- if auto_set_target:
- if Const.TARGET in data_set.get_field_names():
- data_set.set_target(Const.TARGET)
-
- if to_lower:
- for data_name, data_set in data_info.datasets.items():
- data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0),
- is_input=auto_set_input)
- data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1),
- is_input=auto_set_input)
-
- if bert_tokenizer is not None:
- if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR:
- PRETRAIN_URL = _get_base_url('bert')
- model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer]
- model_url = PRETRAIN_URL + model_name
- model_dir = cached_path(model_url)
- # 检查是否存在
- elif os.path.isdir(bert_tokenizer):
- model_dir = bert_tokenizer
- else:
- raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.")
-
- words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]')
- with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f:
- lines = f.readlines()
- lines = [line.strip() for line in lines]
- words_vocab.add_word_lst(lines)
- words_vocab.build_vocab()
-
- tokenizer = BertTokenizer.from_pretrained(model_dir)
-
- for data_name, data_set in data_info.datasets.items():
- for fields in data_set.get_field_names():
- if Const.INPUT in fields:
- data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields,
- is_input=auto_set_input)
-
- if isinstance(concat, bool):
- concat = 'default' if concat else None
- if concat is not None:
- if isinstance(concat, str):
- CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'],
- 'default': ['', '', '', '']}
- if concat.lower() in CONCAT_MAP:
- concat = CONCAT_MAP[concat]
- else:
- concat = 4 * [concat]
- assert len(concat) == 4, \
- f'Please choose a list with 4 symbols which at the beginning of first sentence ' \
- f'the end of first sentence, the begin of second sentence, and the end of second' \
- f'sentence. Your input is {concat}'
-
- for data_name, data_set in data_info.datasets.items():
- data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] +
- x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT)
- data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT,
- is_input=auto_set_input)
-
- if seq_len_type is not None:
- if seq_len_type == 'seq_len': #
- for data_name, data_set in data_info.datasets.items():
- for fields in data_set.get_field_names():
- if Const.INPUT in fields:
- data_set.apply(lambda x: len(x[fields]),
- new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
- is_input=auto_set_input)
- elif seq_len_type == 'mask':
- for data_name, data_set in data_info.datasets.items():
- for fields in data_set.get_field_names():
- if Const.INPUT in fields:
- data_set.apply(lambda x: [1] * len(x[fields]),
- new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN),
- is_input=auto_set_input)
- elif seq_len_type == 'bert':
- for data_name, data_set in data_info.datasets.items():
- if Const.INPUT not in data_set.get_field_names():
- raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: '
- f'got {data_set.get_field_names()}')
- data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1),
- new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input)
- data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]),
- new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input)
-
- if auto_pad_length is not None:
- cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length)
-
- if cut_text is not None:
- for data_name, data_set in data_info.datasets.items():
- for fields in data_set.get_field_names():
- if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')):
- data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields,
- is_input=auto_set_input)
-
- data_set_list = [d for n, d in data_info.datasets.items()]
- assert len(data_set_list) > 0, f'There are NO data sets in data info!'
-
- if bert_tokenizer is None:
- words_vocab = Vocabulary(padding=auto_pad_token)
- words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
- field_name=[n for n in data_set_list[0].get_field_names()
- if (Const.INPUT in n)],
- no_create_entry_dataset=[d for n, d in data_info.datasets.items()
- if 'train' not in n])
- target_vocab = Vocabulary(padding=None, unknown=None)
- target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n],
- field_name=Const.TARGET)
- data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab}
-
- if get_index:
- for data_name, data_set in data_info.datasets.items():
- for fields in data_set.get_field_names():
- if Const.INPUT in fields:
- data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields,
- is_input=auto_set_input)
-
- if Const.TARGET in data_set.get_field_names():
- data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET,
- is_input=auto_set_input, is_target=auto_set_target)
-
- if auto_pad_length is not None:
- if seq_len_type == 'seq_len':
- raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, '
- f'so the seq_len_type cannot be `{seq_len_type}`!')
- for data_name, data_set in data_info.datasets.items():
- for fields in data_set.get_field_names():
- if Const.INPUT in fields:
- data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] *
- (auto_pad_length - len(x[fields])), new_field_name=fields,
- is_input=auto_set_input)
- elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'):
- data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])),
- new_field_name=fields, is_input=auto_set_input)
-
- for data_name, data_set in data_info.datasets.items():
- if isinstance(set_input, list):
- data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()])
- if isinstance(set_target, list):
- data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()])
-
- return data_info
-
-
-class SNLILoader(MatchingLoader, JsonLoader):
- """
- 别名::class:`fastNLP.io.SNLILoader` :class:`fastNLP.io.dataset_loader.SNLILoader`
-
- 读取SNLI数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源: https://nlp.stanford.edu/projects/snli/snli_1.0.zip
- """
-
- def __init__(self, paths: dict=None):
- fields = {
- 'sentence1_binary_parse': Const.INPUTS(0),
- 'sentence2_binary_parse': Const.INPUTS(1),
- 'gold_label': Const.TARGET,
- }
- paths = paths if paths is not None else {
- 'train': 'snli_1.0_train.jsonl',
- 'dev': 'snli_1.0_dev.jsonl',
- 'test': 'snli_1.0_test.jsonl'}
- MatchingLoader.__init__(self, paths=paths)
- JsonLoader.__init__(self, fields=fields)
-
- def _load(self, path):
- ds = JsonLoader._load(self, path)
-
- parentheses_table = str.maketrans({'(': None, ')': None})
-
- ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
- new_field_name=Const.INPUTS(0))
- ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
- new_field_name=Const.INPUTS(1))
- ds.drop(lambda x: x[Const.TARGET] == '-')
- return ds
-
-
-class RTELoader(MatchingLoader, CSVLoader):
- """
- 别名::class:`fastNLP.io.RTELoader` :class:`fastNLP.io.dataset_loader.RTELoader`
-
- 读取RTE数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源:
- """
-
- def __init__(self, paths: dict=None):
- paths = paths if paths is not None else {
- 'train': 'train.tsv',
- 'dev': 'dev.tsv',
- 'test': 'test.tsv' # test set has not label
- }
- MatchingLoader.__init__(self, paths=paths)
- self.fields = {
- 'sentence1': Const.INPUTS(0),
- 'sentence2': Const.INPUTS(1),
- 'label': Const.TARGET,
- }
- CSVLoader.__init__(self, sep='\t')
-
- def _load(self, path):
- ds = CSVLoader._load(self, path)
-
- for k, v in self.fields.items():
- if v in ds.get_field_names():
- ds.rename_field(k, v)
- for fields in ds.get_all_fields():
- if Const.INPUT in fields:
- ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
-
- return ds
-
-
-class QNLILoader(MatchingLoader, CSVLoader):
- """
- 别名::class:`fastNLP.io.QNLILoader` :class:`fastNLP.io.dataset_loader.QNLILoader`
-
- 读取QNLI数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源:
- """
-
- def __init__(self, paths: dict=None):
- paths = paths if paths is not None else {
- 'train': 'train.tsv',
- 'dev': 'dev.tsv',
- 'test': 'test.tsv' # test set has not label
- }
- MatchingLoader.__init__(self, paths=paths)
- self.fields = {
- 'question': Const.INPUTS(0),
- 'sentence': Const.INPUTS(1),
- 'label': Const.TARGET,
- }
- CSVLoader.__init__(self, sep='\t')
-
- def _load(self, path):
- ds = CSVLoader._load(self, path)
-
- for k, v in self.fields.items():
- if v in ds.get_field_names():
- ds.rename_field(k, v)
- for fields in ds.get_all_fields():
- if Const.INPUT in fields:
- ds.apply(lambda x: x[fields].strip().split(), new_field_name=fields)
-
- return ds
-
-
-class MNLILoader(MatchingLoader, CSVLoader):
- """
- 别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.dataset_loader.MNLILoader`
-
- 读取MNLI数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源:
- """
-
- def __init__(self, paths: dict=None):
- paths = paths if paths is not None else {
- 'train': 'train.tsv',
- 'dev_matched': 'dev_matched.tsv',
- 'dev_mismatched': 'dev_mismatched.tsv',
- 'test_matched': 'test_matched.tsv',
- 'test_mismatched': 'test_mismatched.tsv',
- # 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt',
- # 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt',
-
- # test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle)
- }
- MatchingLoader.__init__(self, paths=paths)
- CSVLoader.__init__(self, sep='\t')
- self.fields = {
- 'sentence1_binary_parse': Const.INPUTS(0),
- 'sentence2_binary_parse': Const.INPUTS(1),
- 'gold_label': Const.TARGET,
- }
-
- def _load(self, path):
- ds = CSVLoader._load(self, path)
-
- for k, v in self.fields.items():
- if k in ds.get_field_names():
- ds.rename_field(k, v)
-
- if Const.TARGET in ds.get_field_names():
- if ds[0][Const.TARGET] == 'hidden':
- ds.delete_field(Const.TARGET)
-
- parentheses_table = str.maketrans({'(': None, ')': None})
-
- ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(),
- new_field_name=Const.INPUTS(0))
- ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(),
- new_field_name=Const.INPUTS(1))
- if Const.TARGET in ds.get_field_names():
- ds.drop(lambda x: x[Const.TARGET] == '-')
- return ds
-
-
-class QuoraLoader(MatchingLoader, CSVLoader):
- """
- 别名::class:`fastNLP.io.QuoraLoader` :class:`fastNLP.io.dataset_loader.QuoraLoader`
-
- 读取MNLI数据集,读取的DataSet包含fields::
-
- words1: list(str),第一句文本, premise
- words2: list(str), 第二句文本, hypothesis
- target: str, 真实标签
-
- 数据来源:
- """
-
- def __init__(self, paths: dict=None):
- paths = paths if paths is not None else {
- 'train': 'train.tsv',
- 'dev': 'dev.tsv',
- 'test': 'test.tsv',
- }
- MatchingLoader.__init__(self, paths=paths)
- CSVLoader.__init__(self, sep='\t', headers=(Const.TARGET, Const.INPUTS(0), Const.INPUTS(1), 'pairID'))
-
- def _load(self, path):
- ds = CSVLoader._load(self, path)
- return ds
diff --git a/reproduction/matching/matching_bert.py b/reproduction/matching/matching_bert.py
index 75112d5a..323d81a3 100644
--- a/reproduction/matching/matching_bert.py
+++ b/reproduction/matching/matching_bert.py
@@ -2,10 +2,13 @@ import random
import numpy as np
import torch
-from fastNLP.core import Trainer, Tester, AccuracyMetric, Const, Adam
+from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
+from fastNLP.core.callback import WarmupCallback, EvaluateCallback
+from fastNLP.core.optimizer import AdamW
+from fastNLP.embeddings import BertEmbedding
+from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, MNLIBertPipe,\
+ QNLIBertPipe, QuoraBertPipe
-from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \
- MNLILoader, QNLILoader, QuoraLoader
from reproduction.matching.model.bert import BertForNLI
@@ -13,16 +16,22 @@ from reproduction.matching.model.bert import BertForNLI
class BERTConfig:
task = 'snli'
+
batch_size_per_gpu = 6
n_epochs = 6
lr = 2e-5
- seq_len_type = 'bert'
+ warm_up_rate = 0.1
seed = 42
+ save_path = None # 模型存储的位置,None表示不存储模型。
+
train_dataset_name = 'train'
dev_dataset_name = 'dev'
test_dataset_name = 'test'
- save_path = None # 模型存储的位置,None表示不存储模型。
- bert_dir = 'path/to/bert/dir' # 预训练BERT参数文件的文件夹
+
+ to_lower = True # 忽略大小写
+ tokenizer = 'spacy' # 使用spacy进行分词
+
+ bert_model_dir_or_name = 'bert-base-uncased'
arg = BERTConfig()
@@ -38,58 +47,52 @@ if n_gpu > 0:
# load data set
if arg.task == 'snli':
- data_info = SNLILoader().process(
- paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type,
- bert_tokenizer=arg.bert_dir, cut_text=512,
- get_index=True, concat='bert',
- )
+ data_bundle = SNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'rte':
- data_info = RTELoader().process(
- paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type,
- bert_tokenizer=arg.bert_dir, cut_text=512,
- get_index=True, concat='bert',
- )
+ data_bundle = RTEBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'qnli':
- data_info = QNLILoader().process(
- paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
- bert_tokenizer=arg.bert_dir, cut_text=512,
- get_index=True, concat='bert',
- )
+ data_bundle = QNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'mnli':
- data_info = MNLILoader().process(
- paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type,
- bert_tokenizer=arg.bert_dir, cut_text=512,
- get_index=True, concat='bert',
- )
+ data_bundle = MNLIBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'quora':
- data_info = QuoraLoader().process(
- paths='path/to/quora/data', to_lower=True, seq_len_type=arg.seq_len_type,
- bert_tokenizer=arg.bert_dir, cut_text=512,
- get_index=True, concat='bert',
- )
+ data_bundle = QuoraBertPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')
+print(data_bundle) # print details in data_bundle
+
+# load embedding
+embed = BertEmbedding(data_bundle.vocabs[Const.INPUT], model_dir_or_name=arg.bert_model_dir_or_name)
+
# define model
-model = BertForNLI(class_num=len(data_info.vocabs[Const.TARGET]), bert_dir=arg.bert_dir)
+model = BertForNLI(embed, class_num=len(data_bundle.vocabs[Const.TARGET]))
+
+# define optimizer and callback
+optimizer = AdamW(lr=arg.lr, params=model.parameters())
+callbacks = [WarmupCallback(warmup=arg.warm_up_rate, schedule='linear'), ]
+
+if arg.task in ['snli']:
+ callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name]))
+ # evaluate test set in every epoch if task is snli.
# define trainer
-trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model,
- optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
+trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model,
+ optimizer=optimizer,
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
- dev_data=data_info.datasets[arg.dev_dataset_name],
+ dev_data=data_bundle.datasets[arg.dev_dataset_name],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1,
- save_path=arg.save_path)
+ save_path=arg.save_path,
+ callbacks=callbacks)
# train model
trainer.train(load_best_model=True)
# define tester
tester = Tester(
- data=data_info.datasets[arg.test_dataset_name],
+ data=data_bundle.datasets[arg.test_dataset_name],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
diff --git a/reproduction/matching/matching_cntn.py b/reproduction/matching/matching_cntn.py
index d813164d..9be716ba 100644
--- a/reproduction/matching/matching_cntn.py
+++ b/reproduction/matching/matching_cntn.py
@@ -1,11 +1,10 @@
import argparse
import torch
-import os
-from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
-from fastNLP.modules.encoder.embedding import StaticEmbedding
+from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const, CrossEntropyLoss
+from fastNLP.embeddings import StaticEmbedding
+from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe
-from reproduction.matching.data.MatchingDataLoader import QNLILoader, RTELoader, SNLILoader, MNLILoader
from reproduction.matching.model.cntn import CNTNModel
# define hyper-parameters
@@ -14,14 +13,12 @@ argument.add_argument('--embedding', choices=['glove', 'word2vec'], default='glo
argument.add_argument('--batch-size-per-gpu', type=int, default=256)
argument.add_argument('--n-epochs', type=int, default=200)
argument.add_argument('--lr', type=float, default=1e-5)
-argument.add_argument('--seq-len-type', choices=['mask', 'seq_len'], default='mask')
argument.add_argument('--save-dir', type=str, default=None)
argument.add_argument('--cntn-depth', type=int, default=1)
argument.add_argument('--cntn-ns', type=int, default=200)
argument.add_argument('--cntn-k-top', type=int, default=10)
argument.add_argument('--cntn-r', type=int, default=5)
argument.add_argument('--dataset', choices=['qnli', 'rte', 'snli', 'mnli'], default='qnli')
-argument.add_argument('--max-len', type=int, default=50)
arg = argument.parse_args()
# dataset dict
@@ -46,30 +43,25 @@ else:
num_labels = 3
# load data set
-if arg.dataset == 'qnli':
- data_info = QNLILoader().process(
- paths='path/to/qnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
- get_index=True, concat=False, auto_pad_length=arg.max_len)
+if arg.dataset == 'snli':
+ data_bundle = SNLIPipe(lower=True, tokenizer='raw').process_from_file()
elif arg.dataset == 'rte':
- data_info = RTELoader().process(
- paths='path/to/rte/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
- get_index=True, concat=False, auto_pad_length=arg.max_len)
-elif arg.dataset == 'snli':
- data_info = SNLILoader().process(
- paths='path/to/snli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
- get_index=True, concat=False, auto_pad_length=arg.max_len)
+ data_bundle = RTEPipe(lower=True, tokenizer='raw').process_from_file()
+elif arg.dataset == 'qnli':
+ data_bundle = QNLIPipe(lower=True, tokenizer='raw').process_from_file()
elif arg.dataset == 'mnli':
- data_info = MNLILoader().process(
- paths='path/to/mnli/data', to_lower=True, seq_len_type=arg.seq_len_type, bert_tokenizer=None,
- get_index=True, concat=False, auto_pad_length=arg.max_len)
+ data_bundle = MNLIPipe(lower=True, tokenizer='raw').process_from_file()
else:
- raise ValueError(f'now we only support [qnli,rte,snli,mnli] dataset for cntn model!')
+ raise RuntimeError(f'NOT support {arg.task} task yet!')
+
+print(data_bundle) # print details in data_bundle
# load embedding
if arg.embedding == 'word2vec':
- embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-word2vec-300', requires_grad=True)
+ embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-word2vec-300',
+ requires_grad=True)
elif arg.embedding == 'glove':
- embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], model_dir_or_name='en-glove-840b-300',
+ embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
requires_grad=True)
else:
raise ValueError(f'now we only support word2vec or glove embedding for cntn model!')
@@ -80,11 +72,12 @@ model = CNTNModel(embedding, ns=arg.cntn_ns, k_top=arg.cntn_k_top, num_labels=nu
print(model)
# define trainer
-trainer = Trainer(train_data=data_info.datasets['train'], model=model,
+trainer = Trainer(train_data=data_bundle.datasets['train'], model=model,
optimizer=Adam(lr=arg.lr, model_params=model.parameters()),
+ loss=CrossEntropyLoss(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
- dev_data=data_info.datasets[dev_dict[arg.dataset]],
+ dev_data=data_bundle.datasets[dev_dict[arg.dataset]],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1)
@@ -94,7 +87,7 @@ trainer.train(load_best_model=True)
# define tester
tester = Tester(
- data=data_info.datasets[test_dict[arg.dataset]],
+ data=data_bundle.datasets[test_dict[arg.dataset]],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
diff --git a/reproduction/matching/matching_esim.py b/reproduction/matching/matching_esim.py
index d878608f..9d50c0fb 100644
--- a/reproduction/matching/matching_esim.py
+++ b/reproduction/matching/matching_esim.py
@@ -6,30 +6,33 @@ from torch.optim import Adamax
from torch.optim.lr_scheduler import StepLR
from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
-from fastNLP.core.callback import GradientClipCallback, LRScheduler
-from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding
-
-from reproduction.matching.data.MatchingDataLoader import SNLILoader, RTELoader, \
- MNLILoader, QNLILoader, QuoraLoader
-from reproduction.matching.model.esim import ESIMModel
+from fastNLP.core.callback import GradientClipCallback, LRScheduler, EvaluateCallback
+from fastNLP.core.losses import CrossEntropyLoss
+from fastNLP.embeddings import StaticEmbedding
+from fastNLP.embeddings import ElmoEmbedding
+from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
+from fastNLP.models.snli import ESIM
# define hyper-parameters
class ESIMConfig:
task = 'snli'
+
embedding = 'glove'
+
batch_size_per_gpu = 196
n_epochs = 30
lr = 2e-3
- seq_len_type = 'seq_len'
- # seq_len表示在process的时候用len(words)来表示长度信息;
- # mask表示用0/1掩码矩阵来表示长度信息;
seed = 42
+ save_path = None # 模型存储的位置,None表示不存储模型。
+
train_dataset_name = 'train'
dev_dataset_name = 'dev'
test_dataset_name = 'test'
- save_path = None # 模型存储的位置,None表示不存储模型。
+
+ to_lower = True # 忽略大小写
+ tokenizer = 'spacy' # 使用spacy进行分词
arg = ESIMConfig()
@@ -45,43 +48,32 @@ if n_gpu > 0:
# load data set
if arg.task == 'snli':
- data_info = SNLILoader().process(
- paths='path/to/snli/data', to_lower=False, seq_len_type=arg.seq_len_type,
- get_index=True, concat=False,
- )
+ data_bundle = SNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'rte':
- data_info = RTELoader().process(
- paths='path/to/rte/data', to_lower=False, seq_len_type=arg.seq_len_type,
- get_index=True, concat=False,
- )
+ data_bundle = RTEPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'qnli':
- data_info = QNLILoader().process(
- paths='path/to/qnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
- get_index=True, concat=False,
- )
+ data_bundle = QNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'mnli':
- data_info = MNLILoader().process(
- paths='path/to/mnli/data', to_lower=False, seq_len_type=arg.seq_len_type,
- get_index=True, concat=False,
- )
+ data_bundle = MNLIPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
elif arg.task == 'quora':
- data_info = QuoraLoader().process(
- paths='path/to/quora/data', to_lower=False, seq_len_type=arg.seq_len_type,
- get_index=True, concat=False,
- )
+ data_bundle = QuoraPipe(lower=arg.to_lower, tokenizer=arg.tokenizer).process_from_file()
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')
+print(data_bundle) # print details in data_bundle
+
# load embedding
if arg.embedding == 'elmo':
- embedding = ElmoEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True)
+ embedding = ElmoEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-medium',
+ requires_grad=True)
elif arg.embedding == 'glove':
- embedding = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=True, normalize=False)
+ embedding = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], model_dir_or_name='en-glove-840b-300d',
+ requires_grad=True, normalize=False)
else:
raise RuntimeError(f'NOT support {arg.embedding} embedding yet!')
# define model
-model = ESIMModel(embedding, num_labels=len(data_info.vocabs[Const.TARGET]))
+model = ESIM(embedding, num_labels=len(data_bundle.vocabs[Const.TARGET]))
# define optimizer and callback
optimizer = Adamax(lr=arg.lr, params=model.parameters())
@@ -92,23 +84,29 @@ callbacks = [
LRScheduler(scheduler),
]
+if arg.task in ['snli']:
+ callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.test_dataset_name]))
+ # evaluate test set in every epoch if task is snli.
+
# define trainer
-trainer = Trainer(train_data=data_info.datasets[arg.train_dataset_name], model=model,
+trainer = Trainer(train_data=data_bundle.datasets[arg.train_dataset_name], model=model,
optimizer=optimizer,
+ loss=CrossEntropyLoss(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
n_epochs=arg.n_epochs, print_every=-1,
- dev_data=data_info.datasets[arg.dev_dataset_name],
+ dev_data=data_bundle.datasets[arg.dev_dataset_name],
metrics=AccuracyMetric(), metric_key='acc',
device=[i for i in range(torch.cuda.device_count())],
check_code_level=-1,
- save_path=arg.save_path)
+ save_path=arg.save_path,
+ callbacks=callbacks)
# train model
trainer.train(load_best_model=True)
# define tester
tester = Tester(
- data=data_info.datasets[arg.test_dataset_name],
+ data=data_bundle.datasets[arg.test_dataset_name],
model=model,
metrics=AccuracyMetric(),
batch_size=torch.cuda.device_count() * arg.batch_size_per_gpu,
diff --git a/reproduction/matching/matching_mwan.py b/reproduction/matching/matching_mwan.py
index e96ee0c9..026ea7b4 100644
--- a/reproduction/matching/matching_mwan.py
+++ b/reproduction/matching/matching_mwan.py
@@ -1,23 +1,16 @@
-import sys
-
-import os
import random
import numpy as np
import torch
-from torch.optim import Adadelta, SGD
+from torch.optim import Adadelta
from torch.optim.lr_scheduler import StepLR
-from tqdm import tqdm
-
from fastNLP import CrossEntropyLoss
-from fastNLP import cache_results
-from fastNLP.core import Trainer, Tester, Adam, AccuracyMetric, Const
-from fastNLP.core.predictor import Predictor
-from fastNLP.core.callback import GradientClipCallback, LRScheduler, FitlogCallback
-from fastNLP.modules.encoder.embedding import ElmoEmbedding, StaticEmbedding
+from fastNLP.core import Trainer, Tester, AccuracyMetric, Const
+from fastNLP.core.callback import LRScheduler, EvaluateCallback
+from fastNLP.embeddings import StaticEmbedding
-from fastNLP.io.data_loader import MNLILoader, QNLILoader, QuoraLoader, SNLILoader, RTELoader
+from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, MNLIPipe, QNLIPipe, QuoraPipe
from reproduction.matching.model.mwan import MwanModel
import fitlog
@@ -52,47 +45,25 @@ for k in arg.__dict__:
# load data set
if arg.task == 'snli':
- @cache_results(f'snli_mwan.pkl')
- def read_snli():
- data_info = SNLILoader().process(
- paths='path/to/snli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
- get_index=True, concat=False, extra_split=['/','%','-'],
- )
- return data_info
- data_info = read_snli()
+ data_bundle = SNLIPipe(lower=True, tokenizer='spacy').process_from_file()
elif arg.task == 'rte':
- @cache_results(f'rte_mwan.pkl')
- def read_rte():
- data_info = RTELoader().process(
- paths='path/to/rte/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
- get_index=True, concat=False, extra_split=['/','%','-'],
- )
- return data_info
- data_info = read_rte()
+ data_bundle = RTEPipe(lower=True, tokenizer='spacy').process_from_file()
elif arg.task == 'qnli':
- data_info = QNLILoader().process(
- paths='path/to/qnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
- get_index=True, concat=False , cut_text=512, extra_split=['/','%','-'],
- )
+ data_bundle = QNLIPipe(lower=True, tokenizer='spacy').process_from_file()
elif arg.task == 'mnli':
- @cache_results(f'mnli_v0.9_mwan.pkl')
- def read_mnli():
- data_info = MNLILoader().process(
- paths='path/to/mnli/data', to_lower=True, seq_len_type=None, bert_tokenizer=None,
- get_index=True, concat=False, extra_split=['/','%','-'],
- )
- return data_info
- data_info = read_mnli()
+ data_bundle = MNLIPipe(lower=True, tokenizer='spacy').process_from_file()
+elif arg.task == 'quora':
+ data_bundle = QuoraPipe(lower=True, tokenizer='spacy').process_from_file()
else:
raise RuntimeError(f'NOT support {arg.task} task yet!')
-print(data_info)
-print(len(data_info.vocabs['words']))
+print(data_bundle)
+print(len(data_bundle.vocabs[Const.INPUTS(0)]))
model = MwanModel(
- num_class = len(data_info.vocabs[Const.TARGET]),
- EmbLayer = StaticEmbedding(data_info.vocabs[Const.INPUT], requires_grad=False, normalize=False),
+ num_class = len(data_bundle.vocabs[Const.TARGET]),
+ EmbLayer = StaticEmbedding(data_bundle.vocabs[Const.INPUTS(0)], requires_grad=False, normalize=False),
ElmoLayer = None,
args_of_imm = {
"input_size" : 300 ,
@@ -111,21 +82,20 @@ callbacks = [
]
if arg.task in ['snli']:
- callbacks.append(FitlogCallback(data_info.datasets[arg.testset_name], verbose=1))
+ callbacks.append(EvaluateCallback(data=data_bundle.datasets[arg.testset_name]))
elif arg.task == 'mnli':
- callbacks.append(FitlogCallback({'dev_matched': data_info.datasets['dev_matched'],
- 'dev_mismatched': data_info.datasets['dev_mismatched']},
- verbose=1))
+ callbacks.append(EvaluateCallback(data={'dev_matched': data_bundle.datasets['dev_matched'],
+ 'dev_mismatched': data_bundle.datasets['dev_mismatched']},))
trainer = Trainer(
- train_data = data_info.datasets['train'],
+ train_data = data_bundle.datasets['train'],
model = model,
optimizer = optimizer,
num_workers = 0,
batch_size = arg.batch_size,
n_epochs = arg.n_epochs,
print_every = -1,
- dev_data = data_info.datasets[arg.devset_name],
+ dev_data = data_bundle.datasets[arg.devset_name],
metrics = AccuracyMetric(pred = "pred" , target = "target"),
metric_key = 'acc',
device = [i for i in range(torch.cuda.device_count())],
@@ -136,7 +106,7 @@ trainer = Trainer(
trainer.train(load_best_model=True)
tester = Tester(
- data=data_info.datasets[arg.testset_name],
+ data=data_bundle.datasets[arg.testset_name],
model=model,
metrics=AccuracyMetric(),
batch_size=arg.batch_size,
diff --git a/reproduction/matching/model/bert.py b/reproduction/matching/model/bert.py
index 9b3a78b2..73a0c533 100644
--- a/reproduction/matching/model/bert.py
+++ b/reproduction/matching/model/bert.py
@@ -3,39 +3,28 @@ import torch
import torch.nn as nn
from fastNLP.core.const import Const
-from fastNLP.models import BaseModel
-from fastNLP.modules.encoder.bert import BertModel
+from fastNLP.models.base_model import BaseModel
+from fastNLP.embeddings import BertEmbedding
class BertForNLI(BaseModel):
- # TODO: still in progress
- def __init__(self, class_num=3, bert_dir=None):
+ def __init__(self, bert_embed: BertEmbedding, class_num=3):
super(BertForNLI, self).__init__()
- if bert_dir is not None:
- self.bert = BertModel.from_pretrained(bert_dir)
- else:
- self.bert = BertModel()
- hidden_size = self.bert.pooler.dense._parameters['bias'].size(-1)
- self.classifier = nn.Linear(hidden_size, class_num)
-
- def forward(self, words, seq_len1, seq_len2, target=None):
+ self.embed = bert_embed
+ self.classifier = nn.Linear(self.embed.embedding_dim, class_num)
+
+ def forward(self, words):
"""
:param torch.Tensor words: [batch_size, seq_len] input_ids
- :param torch.Tensor seq_len1: [batch_size, seq_len] token_type_ids
- :param torch.Tensor seq_len2: [batch_size, seq_len] attention_mask
- :param torch.Tensor target: [batch]
:return:
"""
- _, pooled_output = self.bert(words, seq_len1, seq_len2)
- logits = self.classifier(pooled_output)
+ hidden = self.embed(words)
+ logits = self.classifier(hidden)
- if target is not None:
- loss_func = torch.nn.CrossEntropyLoss()
- loss = loss_func(logits, target)
- return {Const.OUTPUT: logits, Const.LOSS: loss}
return {Const.OUTPUT: logits}
- def predict(self, words, seq_len1, seq_len2, target=None):
- return self.forward(words, seq_len1, seq_len2)
+ def predict(self, words):
+ logits = self.forward(words)[Const.OUTPUT]
+ return {Const.OUTPUT: logits.argmax(dim=-1)}
diff --git a/reproduction/matching/model/cntn.py b/reproduction/matching/model/cntn.py
index 0b4803fa..cfa5e5a8 100644
--- a/reproduction/matching/model/cntn.py
+++ b/reproduction/matching/model/cntn.py
@@ -3,10 +3,8 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
-from torch.nn import CrossEntropyLoss
-
-from fastNLP.models import BaseModel
-from fastNLP.modules.encoder.embedding import TokenEmbedding
+from fastNLP.models.base_model import BaseModel
+from fastNLP.embeddings import TokenEmbedding
from fastNLP.core.const import Const
@@ -83,13 +81,12 @@ class CNTNModel(BaseModel):
self.weight_V = nn.Linear(2 * ns, r)
self.weight_u = nn.Sequential(nn.Dropout(p=dropout_rate), nn.Linear(r, num_labels))
- def forward(self, words1, words2, seq_len1, seq_len2, target=None):
+ def forward(self, words1, words2, seq_len1, seq_len2):
"""
:param words1: [batch, seq_len, emb_size] Question.
:param words2: [batch, seq_len, emb_size] Answer.
:param seq_len1: [batch]
:param seq_len2: [batch]
- :param target: [batch] Glod labels.
:return:
"""
in_q = self.embedding(words1)
@@ -109,12 +106,7 @@ class CNTNModel(BaseModel):
in_a = self.fc_q(in_a.view(in_a.size(0), -1))
score = torch.tanh(self.weight_u(self.weight_M(in_q, in_a) + self.weight_V(torch.cat((in_q, in_a), -1))))
- if target is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(score, target)
- return {Const.LOSS: loss, Const.OUTPUT: score}
- else:
- return {Const.OUTPUT: score}
+ return {Const.OUTPUT: score}
- def predict(self, **kwargs):
- return self.forward(**kwargs)
+ def predict(self, words1, words2, seq_len1, seq_len2):
+ return self.forward(words1, words2, seq_len1, seq_len2)
diff --git a/reproduction/matching/model/esim.py b/reproduction/matching/model/esim.py
index 187e565d..d704e2f8 100644
--- a/reproduction/matching/model/esim.py
+++ b/reproduction/matching/model/esim.py
@@ -2,11 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
-from torch.nn import CrossEntropyLoss
-
-from fastNLP.models import BaseModel
-from fastNLP.modules.encoder.embedding import TokenEmbedding
-from fastNLP.modules.encoder.lstm import LSTM
+from fastNLP.models.base_model import BaseModel
+from fastNLP.embeddings import TokenEmbedding
from fastNLP.core.const import Const
from fastNLP.core.utils import seq_len_to_mask
@@ -43,13 +40,12 @@ class ESIMModel(BaseModel):
nn.init.xavier_uniform_(self.classifier[1].weight.data)
nn.init.xavier_uniform_(self.classifier[4].weight.data)
- def forward(self, words1, words2, seq_len1, seq_len2, target=None):
+ def forward(self, words1, words2, seq_len1, seq_len2):
"""
:param words1: [batch, seq_len]
:param words2: [batch, seq_len]
:param seq_len1: [batch]
:param seq_len2: [batch]
- :param target:
:return:
"""
mask1 = seq_len_to_mask(seq_len1, words1.size(1))
@@ -83,16 +79,10 @@ class ESIMModel(BaseModel):
logits = torch.tanh(self.classifier(out))
# logits = self.classifier(out)
- if target is not None:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(logits, target)
-
- return {Const.LOSS: loss, Const.OUTPUT: logits}
- else:
- return {Const.OUTPUT: logits}
+ return {Const.OUTPUT: logits}
- def predict(self, **kwargs):
- pred = self.forward(**kwargs)[Const.OUTPUT].argmax(-1)
+ def predict(self, words1, words2, seq_len1, seq_len2):
+ pred = self.forward(words1, words2, seq_len1, seq_len2)[Const.OUTPUT].argmax(-1)
return {Const.OUTPUT: pred}
# input [batch_size, len , hidden]
diff --git a/reproduction/matching/test/test_snlidataloader.py b/reproduction/matching/test/test_snlidataloader.py
deleted file mode 100644
index 60b3ad59..00000000
--- a/reproduction/matching/test/test_snlidataloader.py
+++ /dev/null
@@ -1,10 +0,0 @@
-import unittest
-from ..data import MatchingDataLoader
-from fastNLP.core.vocabulary import Vocabulary
-
-
-class TestCWSDataLoader(unittest.TestCase):
- def test_case1(self):
- snli_loader = MatchingDataLoader()
- # TODO: still in progress
-
diff --git a/reproduction/seqence_labelling/chinese_ner/readme.md b/reproduction/seqence_labelling/chinese_ner/readme.md
new file mode 100644
index 00000000..3a9d37d8
--- /dev/null
+++ b/reproduction/seqence_labelling/chinese_ner/readme.md
@@ -0,0 +1,30 @@
+使用以下中文NERPipe自动下载的统计数据
+
+| MsraNERPipe | # of sents | # of tokens |
+| ----------- | ---------- | ----------- |
+| train | 41747 | 1954374 |
+| dev | 4617 | 215505 |
+| test | 4365 | 172601 |
+| total | 50729 | 2342480 |
+这里报道的统计数据,与[https://arxiv.org/pdf/1805.02023.pdf]()报道的一致
+
+
+
+| WeiboNERPipe | # of sents | # of tokens |
+| ------------ | ---------- | ----------- |
+| train | 1350 | 73778 |
+| dev | 270 | 14509 |
+| test | 270 | 14842 |
+| total | 1890 | 1890 |
+这里报道的统计数据与[https://www.cs.cmu.edu/~ark/EMNLP-2015/proceedings/EMNLP/pdf/EMNLP064.pdf]()一致
+
+
+
+
+| PeopleDailyPipe | # of sents | # of tokens |
+| --------------- | ---------- | ----------- |
+| train | 50658 | 2169879 |
+| dev | 4631 | 172601 |
+| test | 68 | 2270 |
+| total | 55357 | 2344750 |
+这里使用的数据与[https://arxiv.org/pdf/1906.08101.pdf]()的数据是一致的
diff --git a/reproduction/seqence_labelling/chinese_ner/train_bert.py b/reproduction/seqence_labelling/chinese_ner/train_bert.py
new file mode 100644
index 00000000..b12c8f75
--- /dev/null
+++ b/reproduction/seqence_labelling/chinese_ner/train_bert.py
@@ -0,0 +1,81 @@
+
+
+"""
+使用Bert进行中文命名实体识别
+
+"""
+
+import sys
+
+sys.path.append('../../../')
+
+from torch import nn
+
+from fastNLP.embeddings import BertEmbedding, Embedding
+from fastNLP import Trainer, Const
+from fastNLP import BucketSampler, SpanFPreRecMetric, GradientClipCallback
+from fastNLP.modules import MLP
+from fastNLP.core.callback import WarmupCallback
+from fastNLP import CrossEntropyLoss
+from fastNLP.core.optimizer import AdamW
+from fastNLP.io import MsraNERPipe, MsraNERLoader, WeiboNERPipe
+
+from fastNLP import cache_results
+
+encoding_type = 'bio'
+
+@cache_results('caches/weibo.pkl', _refresh=False)
+def get_data():
+ # data_dir = MsraNERLoader().download(dev_ratio=0)
+ # data = MsraNERPipe(encoding_type=encoding_type, target_pad_val=-100).process_from_file(data_dir)
+ data = WeiboNERPipe(encoding_type=encoding_type).process_from_file()
+ return data
+data = get_data()
+print(data)
+
+class BertCNNER(nn.Module):
+ def __init__(self, embed, tag_size):
+ super().__init__()
+ self.embedding = embed
+ self.tag_size = tag_size
+ self.mlp = MLP(size_layer=[self.embedding.embedding_dim, tag_size])
+
+ def forward(self, chars):
+ # batch_size, max_len = words.size()
+ chars = self.embedding(chars)
+ outputs = self.mlp(chars)
+
+ return {Const.OUTPUT: outputs}
+
+ def predict(self, chars):
+ # batch_size, max_len = words.size()
+ chars = self.embedding(chars)
+ outputs = self.mlp(chars)
+
+ return {Const.OUTPUT: outputs}
+
+embed = BertEmbedding(data.get_vocab(Const.CHAR_INPUT), model_dir_or_name='cn-wwm-ext',
+ pool_method='first', requires_grad=True, layers='11', include_cls_sep=False, dropout=0.5)
+
+callbacks = [
+ GradientClipCallback(clip_type='norm', clip_value=1),
+ WarmupCallback(warmup=0.1, schedule='linear')
+ ]
+
+model = BertCNNER(embed, len(data.vocabs[Const.TARGET]))
+optimizer = AdamW(model.parameters(), lr=3e-5)
+
+for name, dataset in data.datasets.items():
+ original_len = len(dataset)
+ dataset.drop(lambda x:x['seq_len']>256, inplace=True)
+ clipped_len = len(dataset)
+ print("Delete {} instances in {}.".format(original_len-clipped_len, name))
+
+trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, sampler=BucketSampler(),
+ device=0, dev_data=data.datasets['test'], batch_size=6,
+ metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type),
+ loss=CrossEntropyLoss(reduction='sum'),
+ callbacks=callbacks, num_workers=2, n_epochs=5,
+ check_code_level=0, update_every=3)
+trainer.train()
+
diff --git a/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py b/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py
new file mode 100644
index 00000000..58b32265
--- /dev/null
+++ b/reproduction/seqence_labelling/chinese_ner/train_cn_ner.py
@@ -0,0 +1,136 @@
+import sys
+sys.path.append('../../..')
+
+from fastNLP.embeddings import StaticEmbedding
+
+from torch import nn
+import torch
+from fastNLP.embeddings.utils import get_embeddings
+from fastNLP.modules import LSTM
+from fastNLP.modules import ConditionalRandomField
+from fastNLP.modules import allowed_transitions
+import torch.nn.functional as F
+from fastNLP import seq_len_to_mask
+from fastNLP.core.const import Const as C
+from fastNLP import SpanFPreRecMetric, Trainer
+from fastNLP import cache_results, Vocabulary
+from fastNLP.io.pipe.utils import _add_chars_field, _indexize
+
+from fastNLP.io.pipe import Pipe
+from fastNLP.core.utils import iob2bioes, iob2
+from fastNLP.io import MsraNERLoader, WeiboNERLoader
+
+class ChineseNERPipe(Pipe):
+ def __init__(self, encoding_type: str = 'bio', target_pad_val=0, bigram=False):
+ if encoding_type == 'bio':
+ self.convert_tag = iob2
+ else:
+ self.convert_tag = lambda words: iob2bioes(iob2(words))
+ self.target_pad_val = int(target_pad_val)
+ self.bigram = bigram
+
+ def process(self, data_bundle):
+ data_bundle.copy_field(C.RAW_CHAR, C.CHAR_INPUT)
+ input_fields = [C.TARGET, C.CHAR_INPUT, C.INPUT_LEN]
+ target_fields = [C.TARGET, C.INPUT_LEN]
+ if self.bigram:
+ for dataset in data_bundle.datasets.values():
+ dataset.apply_field(lambda chars:[c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])],
+ field_name=C.CHAR_INPUT, new_field_name='bigrams')
+ bigram_vocab = Vocabulary()
+ bigram_vocab.from_dataset(data_bundle.get_dataset('train'),field_name='bigrams',
+ no_create_entry_dataset=[ds for name, ds in data_bundle.datasets.items() if name!='train'])
+ bigram_vocab.index_dataset(*data_bundle.datasets.values(), field_name='bigrams')
+ data_bundle.set_vocab(bigram_vocab, field_name='bigrams')
+ input_fields.append('bigrams')
+
+ _add_chars_field(data_bundle, lower=False)
+
+ # index
+ _indexize(data_bundle, input_field_names=C.CHAR_INPUT, target_field_names=C.TARGET)
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.set_pad_val(C.TARGET, self.target_pad_val)
+ dataset.add_seq_len(C.CHAR_INPUT)
+
+ data_bundle.set_input(*input_fields)
+ data_bundle.set_target(*target_fields)
+
+ return data_bundle
+
+
+class CNBiLSTMCRFNER(nn.Module):
+ def __init__(self, char_embed, num_classes, bigram_embed=None, trigram_embed=None, num_layers=1, hidden_size=100,
+ dropout=0.5, target_vocab=None, encoding_type=None):
+ super().__init__()
+
+ self.char_embed = get_embeddings(char_embed)
+ embed_size = self.char_embed.embedding_dim
+ if bigram_embed:
+ self.bigram_embed = get_embeddings(bigram_embed)
+ embed_size += self.bigram_embed.embedding_dim
+ if trigram_embed:
+ self.trigram_ebmbed = get_embeddings(trigram_embed)
+ embed_size += self.bigram_embed.embedding_dim
+
+ if num_layers>1:
+ self.lstm = LSTM(embed_size, num_layers=num_layers, hidden_size=hidden_size//2, bidirectional=True,
+ batch_first=True, dropout=dropout)
+ else:
+ self.lstm = LSTM(embed_size, num_layers=num_layers, hidden_size=hidden_size//2, bidirectional=True,
+ batch_first=True)
+
+ self.dropout = nn.Dropout(dropout)
+ self.fc = nn.Linear(hidden_size, num_classes)
+
+ trans = None
+ if target_vocab is not None and encoding_type is not None:
+ trans = allowed_transitions(target_vocab.idx2word, encoding_type=encoding_type, include_start_end=True)
+
+ self.crf = ConditionalRandomField(num_classes, include_start_end_trans=True, allowed_transitions=trans)
+
+ def _forward(self, chars, bigrams=None, trigrams=None, seq_len=None, target=None):
+ chars = self.char_embed(chars)
+ if hasattr(self, 'bigram_embed'):
+ bigrams = self.bigram_embed(bigrams)
+ chars = torch.cat((chars, bigrams), dim=-1)
+ if hasattr(self, 'trigram_embed'):
+ trigrams = self.trigram_embed(trigrams)
+ chars = torch.cat((chars, trigrams), dim=-1)
+ feats, _ = self.lstm(chars, seq_len=seq_len)
+ feats = self.fc(feats)
+ feats = self.dropout(feats)
+ logits = F.log_softmax(feats, dim=-1)
+ mask = seq_len_to_mask(seq_len)
+ if target is None:
+ pred, _ = self.crf.viterbi_decode(logits, mask)
+ return {C.OUTPUT: pred}
+ else:
+ loss = self.crf(logits, target, mask).mean()
+ return {C.LOSS:loss}
+
+ def forward(self, chars, target, bigrams=None, trigrams=None, seq_len=None):
+ return self._forward(chars, bigrams, trigrams, seq_len, target)
+
+ def predict(self, chars, seq_len=None, bigrams=None, trigrams=None):
+ return self._forward(chars, bigrams, trigrams, seq_len)
+
+# data_bundle = pickle.load(open('caches/msra.pkl', 'rb'))
+@cache_results('caches/weibo-lstm.pkl', _refresh=False)
+def get_data():
+ data_bundle = WeiboNERLoader().load()
+ data_bundle = ChineseNERPipe(encoding_type='bioes', bigram=True).process(data_bundle)
+ char_embed = StaticEmbedding(data_bundle.get_vocab(C.CHAR_INPUT), model_dir_or_name='cn-fasttext')
+ bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), embedding_dim=100, min_freq=3)
+ return data_bundle, char_embed, bigram_embed
+data_bundle, char_embed, bigram_embed = get_data()
+# data_bundle = get_data()
+print(data_bundle)
+
+# exit(0)
+model = CNBiLSTMCRFNER(char_embed, num_classes=len(data_bundle.vocabs['target']), bigram_embed=bigram_embed)
+
+Trainer(data_bundle.datasets['train'], model, batch_size=20,
+ metrics=SpanFPreRecMetric(data_bundle.vocabs['target'], encoding_type='bioes'),
+ num_workers=2, dev_data=data_bundle. datasets['dev'], device=0).train()
+
diff --git a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py b/reproduction/seqence_labelling/cws/data/CWSDataLoader.py
deleted file mode 100644
index 3c82d814..00000000
--- a/reproduction/seqence_labelling/cws/data/CWSDataLoader.py
+++ /dev/null
@@ -1,249 +0,0 @@
-
-from fastNLP.io.embed_loader import EmbeddingOption, EmbedLoader
-from fastNLP.core.vocabulary import VocabularyOption
-from fastNLP.io.base_loader import DataSetLoader, DataBundle
-from typing import Union, Dict, List, Iterator
-from fastNLP import DataSet
-from fastNLP import Instance
-from fastNLP import Vocabulary
-from fastNLP import Const
-from reproduction.utils import check_dataloader_paths
-from functools import partial
-
-class SigHanLoader(DataSetLoader):
- """
- 任务相关的说明可以在这里找到http://sighan.cs.uchicago.edu/
- 支持的数据格式为,一行一句,不同的word用空格隔开。如下例
-
- 共同 创造 美好 的 新 世纪 —— 二○○一年 新年
- 女士 们 , 先生 们 , 同志 们 , 朋友 们 :
-
- 读取sighan中的数据集,返回的DataSet将包含以下的内容fields:
- raw_chars: list(str), 每个元素是一个汉字
- chars: list(str), 每个元素是一个index(汉字对应的index)
- target: list(int), 根据不同的encoding_type会有不同的变化
-
- :param target_type: target的类型,当前支持以下的两种: "bmes", "shift_relay"
- """
-
- def __init__(self, target_type:str):
- super().__init__()
-
- if target_type.lower() not in ('bmes', 'shift_relay'):
- raise ValueError("target_type only supports 'bmes', 'shift_relay'.")
-
- self.target_type = target_type
- if target_type=='bmes':
- self._word_len_to_target = self._word_len_to_bems
- elif target_type=='shift_relay':
- self._word_len_to_target = self._word_lens_to_relay
-
- @staticmethod
- def _word_lens_to_relay(word_lens: Iterator[int]):
- """
- [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长);
- :param word_lens:
- :return: {'target': , 'end_seg_mask':, 'start_seg_mask':}
- """
- tags = []
- end_seg_mask = []
- start_seg_mask = []
- for word_len in word_lens:
- tags.extend([idx for idx in range(word_len - 1, -1, -1)])
- end_seg_mask.extend([0] * (word_len - 1) + [1])
- start_seg_mask.extend([1] + [0] * (word_len - 1))
- return {'target': tags, 'end_seg_mask': end_seg_mask, 'start_seg_mask': start_seg_mask}
-
- @staticmethod
- def _word_len_to_bems(word_lens:Iterator[int])->Dict[str, List[str]]:
- """
-
- :param word_lens: 每个word的长度
- :return:
- """
- tags = []
- for word_len in word_lens:
- if word_len==1:
- tags.append('S')
- else:
- tags.append('B')
- for _ in range(word_len-2):
- tags.append('M')
- tags.append('E')
- return {'target':tags}
-
- @staticmethod
- def _gen_bigram(chars:List[str])->List[str]:
- """
-
- :param chars:
- :return:
- """
- return [c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])]
-
- def load(self, path:str, bigram:bool=False)->DataSet:
- """
- :param path: str
- :param bigram: 是否使用bigram feature
- :return:
- """
- dataset = DataSet()
- with open(path, 'r', encoding='utf-8') as f:
- for line in f:
- line = line.strip()
- if not line: # 去掉空行
- continue
- parts = line.split()
- word_lens = map(len, parts)
- chars = list(''.join(parts))
- tags = self._word_len_to_target(word_lens)
- assert len(chars)==len(tags['target'])
- dataset.append(Instance(raw_chars=chars, **tags, seq_len=len(chars)))
- if len(dataset)==0:
- raise RuntimeError(f"{path} has no valid data.")
- if bigram:
- dataset.apply_field(self._gen_bigram, field_name='raw_chars', new_field_name='bigrams')
- return dataset
-
- def process(self, paths: Union[str, Dict[str, str]], char_vocab_opt:VocabularyOption=None,
- char_embed_opt:EmbeddingOption=None, bigram_vocab_opt:VocabularyOption=None,
- bigram_embed_opt:EmbeddingOption=None, L:int=4):
- """
- 支持的数据格式为一行一个sample,并且用空格隔开不同的词语。例如
-
- Option::
-
- 共同 创造 美好 的 新 世纪 —— 二○○一年 新年 贺词
- ( 二○○○年 十二月 三十一日 ) ( 附 图片 1 张 )
- 女士 们 , 先生 们 , 同志 们 , 朋友 们 :
-
- paths支持两种格式,第一种是str,第二种是Dict[str, str].
-
- Option::
-
- # 1. str类型
- # 1.1 传入具体的文件路径
- data = SigHanLoader('bmes').process('/path/to/cws/data.txt') # 将读取data.txt的内容
- # 包含以下的内容data.vocabs['chars']:Vocabulary对象,
- # data.vocabs['target']: Vocabulary对象,根据encoding_type可能会没有该值
- # data.embeddings['chars']: Embedding对象. 只有提供了预训练的词向量的路径才有该项
- # data.datasets['train']: DataSet对象
- # 包含的field有:
- # raw_chars: list[str], 每个元素是一个汉字
- # chars: list[int], 每个元素是汉字对应的index
- # target: list[int], 根据encoding_type有对应的变化
- # 1.2 传入一个目录, 里面必须包含train.txt文件
- data = SigHanLoader('bmes').process('path/to/cws/') #将尝试在该目录下读取 train.txt, test.txt以及dev.txt
- # 包含以下的内容data.vocabs['chars']: Vocabulary对象
- # data.vocabs['target']:Vocabulary对象
- # data.embeddings['chars']: 仅在提供了预训练embedding路径的情况下,为Embedding对象;
- # data.datasets['train']: DataSet对象
- # 包含的field有:
- # raw_chars: list[str], 每个元素是一个汉字
- # chars: list[int], 每个元素是汉字对应的index
- # target: list[int], 根据encoding_type有对应的变化
- # data.datasets['dev']: DataSet对象,如果文件夹下包含了dev.txt;内容与data.datasets['train']一样
-
- # 2. dict类型, key是文件的名称,value是对应的读取路径. 必须包含'train'这个key
- paths = {'train': '/path/to/train/train.txt', 'test':'/path/to/test/test.txt', 'dev':'/path/to/dev/dev.txt'}
- data = SigHanLoader(paths).process(paths)
- # 结果与传入目录时是一致的,但是可以传入多个数据集。data.datasets中的key将与这里传入的一致
-
- :param paths: 支持传入目录,文件路径,以及dict。
- :param char_vocab_opt: 用于构建chars的vocabulary参数,默认为min_freq=2
- :param char_embed_opt: 用于读取chars的Embedding的参数,默认不读取pretrained的embedding
- :param bigram_vocab_opt: 用于构建bigram的vocabulary参数,默认不使用bigram, 仅在指定该参数的情况下会带有bigrams这个field。
- 为List[int], 每个instance长度与chars一样, abcde的bigram为ab bc cd de e
- :param bigram_embed_opt: 用于读取预训练bigram的参数,仅在传入bigram_vocab_opt有效
- :param L: 当target_type为shift_relay时传入的segment长度
- :return:
- """
- # 推荐大家使用这个check_data_loader_paths进行paths的验证
- paths = check_dataloader_paths(paths)
- datasets = {}
- data = DataBundle()
- bigram = bigram_vocab_opt is not None
- for name, path in paths.items():
- dataset = self.load(path, bigram=bigram)
- datasets[name] = dataset
- input_fields = []
- target_fields = []
- # 创建vocab
- char_vocab = Vocabulary(min_freq=2) if char_vocab_opt is None else Vocabulary(**char_vocab_opt)
- char_vocab.from_dataset(datasets['train'], field_name='raw_chars')
- char_vocab.index_dataset(*datasets.values(), field_name='raw_chars', new_field_name='chars')
- data.vocabs[Const.CHAR_INPUT] = char_vocab
- input_fields.extend([Const.CHAR_INPUT, Const.INPUT_LEN, Const.TARGET])
- target_fields.append(Const.TARGET)
- # 创建target
- if self.target_type == 'bmes':
- target_vocab = Vocabulary(unknown=None, padding=None)
- target_vocab.add_word_lst(['B']*4+['M']*3+['E']*2+['S'])
- target_vocab.index_dataset(*datasets.values(), field_name='target')
- data.vocabs[Const.TARGET] = target_vocab
- if char_embed_opt is not None:
- char_embed = EmbedLoader.load_with_vocab(**char_embed_opt, vocab=char_vocab)
- data.embeddings['chars'] = char_embed
- if bigram:
- bigram_vocab = Vocabulary(**bigram_vocab_opt)
- bigram_vocab.from_dataset(datasets['train'], field_name='bigrams')
- bigram_vocab.index_dataset(*datasets.values(), field_name='bigrams')
- data.vocabs['bigrams'] = bigram_vocab
- if bigram_embed_opt is not None:
- bigram_embed = EmbedLoader.load_with_vocab(**bigram_embed_opt, vocab=bigram_vocab)
- data.embeddings['bigrams'] = bigram_embed
- input_fields.append('bigrams')
- if self.target_type == 'shift_relay':
- func = partial(self._clip_target, L=L)
- for name, dataset in datasets.items():
- res = dataset.apply_field(func, field_name='target')
- relay_target = [res_i[0] for res_i in res]
- relay_mask = [res_i[1] for res_i in res]
- dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False)
- dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False)
- if self.target_type == 'shift_relay':
- input_fields.extend(['end_seg_mask'])
- target_fields.append('start_seg_mask')
- # 将dataset加入DataInfo
- for name, dataset in datasets.items():
- dataset.set_input(*input_fields)
- dataset.set_target(*target_fields)
- data.datasets[name] = dataset
-
- return data
-
- @staticmethod
- def _clip_target(target:List[int], L:int):
- """
-
- 只有在target_type为shift_relay的使用
- :param target: List[int]
- :param L:
- :return:
- """
- relay_target_i = []
- tmp = []
- for j in range(len(target) - 1):
- tmp.append(target[j])
- if target[j] > target[j + 1]:
- pass
- else:
- relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
- tmp = []
- # 处理未结束的部分
- if len(tmp) == 0:
- relay_target_i.append(0)
- else:
- tmp.append(target[-1])
- relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
- relay_mask_i = []
- j = 0
- while j < len(target):
- seg_len = target[j] + 1
- if target[j] < L:
- relay_mask_i.extend([0] * (seg_len))
- else:
- relay_mask_i.extend([1] * (seg_len - L) + [0] * L)
- j = seg_len + j
- return relay_target_i, relay_mask_i
-
diff --git a/reproduction/seqence_labelling/cws/data/cws_shift_pipe.py b/reproduction/seqence_labelling/cws/data/cws_shift_pipe.py
new file mode 100644
index 00000000..0ae4064d
--- /dev/null
+++ b/reproduction/seqence_labelling/cws/data/cws_shift_pipe.py
@@ -0,0 +1,202 @@
+from fastNLP.io.pipe import Pipe
+from fastNLP.io import DataBundle
+from fastNLP.io.loader import CWSLoader
+from fastNLP import Const
+from itertools import chain
+from fastNLP.io.pipe.utils import _indexize
+from functools import partial
+from fastNLP.io.pipe.cws import _find_and_replace_alpha_spans, _find_and_replace_digit_spans
+
+
+def _word_lens_to_relay(word_lens):
+ """
+ [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长);
+ :param word_lens:
+ :return:
+ """
+ tags = []
+ for word_len in word_lens:
+ tags.extend([idx for idx in range(word_len - 1, -1, -1)])
+ return tags
+
+def _word_lens_to_end_seg_mask(word_lens):
+ """
+ [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长);
+ :param word_lens:
+ :return:
+ """
+ end_seg_mask = []
+ for word_len in word_lens:
+ end_seg_mask.extend([0] * (word_len - 1) + [1])
+ return end_seg_mask
+
+def _word_lens_to_start_seg_mask(word_lens):
+ """
+ [1, 2, 3, ..] 转换为[0, 1, 0, 2, 1, 0,](start指示seg有多长);
+ :param word_lens:
+ :return:
+ """
+ start_seg_mask = []
+ for word_len in word_lens:
+ start_seg_mask.extend([1] + [0] * (word_len - 1))
+ return start_seg_mask
+
+
+class CWSShiftRelayPipe(Pipe):
+ """
+
+ :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None
+ :param int L: ShiftRelay模型的超参数
+ :param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。
+ :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]
+ :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]
+ """
+ def __init__(self, dataset_name=None, L=5, replace_num_alpha=True, bigrams=True):
+ self.dataset_name = dataset_name
+ self.bigrams = bigrams
+ self.replace_num_alpha = replace_num_alpha
+ self.L = L
+
+ def _tokenize(self, data_bundle):
+ """
+ 将data_bundle中的'chars'列切分成一个一个的word.
+ 例如输入是"共同 创造 美好.."->[[共, 同], [创, 造], [...], ]
+
+ :param data_bundle:
+ :return:
+ """
+ def split_word_into_chars(raw_chars):
+ words = raw_chars.split()
+ chars = []
+ for word in words:
+ char = []
+ subchar = []
+ for c in word:
+ if c=='<':
+ subchar.append(c)
+ continue
+ if c=='>' and subchar[0]=='<':
+ char.append(''.join(subchar))
+ subchar = []
+ if subchar:
+ subchar.append(c)
+ else:
+ char.append(c)
+ char.extend(subchar)
+ chars.append(char)
+ return chars
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(split_word_into_chars, field_name=Const.CHAR_INPUT,
+ new_field_name=Const.CHAR_INPUT)
+ return data_bundle
+
+ def process(self, data_bundle: DataBundle) -> DataBundle:
+ """
+ 可以处理的DataSet需要包含raw_words列
+
+ .. csv-table::
+ :header: "raw_words"
+
+ "上海 浦东 开发 与 法制 建设 同步"
+ "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )"
+ "..."
+
+ :param data_bundle:
+ :return:
+ """
+ data_bundle.copy_field(Const.RAW_WORD, Const.CHAR_INPUT)
+
+ if self.replace_num_alpha:
+ data_bundle.apply_field(_find_and_replace_alpha_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
+ data_bundle.apply_field(_find_and_replace_digit_spans, Const.CHAR_INPUT, Const.CHAR_INPUT)
+
+ self._tokenize(data_bundle)
+ input_field_names = [Const.CHAR_INPUT]
+ target_field_names = []
+
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(lambda chars:_word_lens_to_relay(map(len, chars)), field_name=Const.CHAR_INPUT,
+ new_field_name=Const.TARGET)
+ dataset.apply_field(lambda chars:_word_lens_to_start_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT,
+ new_field_name='start_seg_mask')
+ dataset.apply_field(lambda chars:_word_lens_to_end_seg_mask(map(len, chars)), field_name=Const.CHAR_INPUT,
+ new_field_name='end_seg_mask')
+ dataset.apply_field(lambda chars:list(chain(*chars)), field_name=Const.CHAR_INPUT,
+ new_field_name=Const.CHAR_INPUT)
+ target_field_names.append('start_seg_mask')
+ input_field_names.append('end_seg_mask')
+ if self.bigrams:
+ for name, dataset in data_bundle.datasets.items():
+ dataset.apply_field(lambda chars: [c1+c2 for c1, c2 in zip(chars, chars[1:]+[''])],
+ field_name=Const.CHAR_INPUT, new_field_name='bigrams')
+ input_field_names.append('bigrams')
+
+ _indexize(data_bundle, ['chars', 'bigrams'], [])
+
+ func = partial(_clip_target, L=self.L)
+ for name, dataset in data_bundle.datasets.items():
+ res = dataset.apply_field(func, field_name='target')
+ relay_target = [res_i[0] for res_i in res]
+ relay_mask = [res_i[1] for res_i in res]
+ dataset.add_field('relay_target', relay_target, is_input=True, is_target=False, ignore_type=False)
+ dataset.add_field('relay_mask', relay_mask, is_input=True, is_target=False, ignore_type=False)
+ input_field_names.append('relay_target')
+ input_field_names.append('relay_mask')
+
+ input_fields = [Const.TARGET, Const.INPUT_LEN] + input_field_names
+ target_fields = [Const.TARGET, Const.INPUT_LEN] + target_field_names
+ for name, dataset in data_bundle.datasets.items():
+ dataset.add_seq_len(Const.CHAR_INPUT)
+
+ data_bundle.set_input(*input_fields)
+ data_bundle.set_target(*target_fields)
+
+ return data_bundle
+
+ def process_from_file(self, paths=None) -> DataBundle:
+ """
+
+ :param str paths:
+ :return:
+ """
+ if self.dataset_name is None and paths is None:
+ raise RuntimeError("You have to set `paths` when calling process_from_file() or `dataset_name `when initialization.")
+ if self.dataset_name is not None and paths is not None:
+ raise RuntimeError("You cannot specify `paths` and `dataset_name` simultaneously")
+ data_bundle = CWSLoader(self.dataset_name).load(paths)
+ return self.process(data_bundle)
+
+def _clip_target(target, L:int):
+ """
+
+ 只有在target_type为shift_relay的使用
+ :param target: List[int]
+ :param L:
+ :return:
+ """
+ relay_target_i = []
+ tmp = []
+ for j in range(len(target) - 1):
+ tmp.append(target[j])
+ if target[j] > target[j + 1]:
+ pass
+ else:
+ relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
+ tmp = []
+ # 处理未结束的部分
+ if len(tmp) == 0:
+ relay_target_i.append(0)
+ else:
+ tmp.append(target[-1])
+ relay_target_i.extend([L - 1 if t >= L else t for t in tmp[::-1]])
+ relay_mask_i = []
+ j = 0
+ while j < len(target):
+ seg_len = target[j] + 1
+ if target[j] < L:
+ relay_mask_i.extend([0] * (seg_len))
+ else:
+ relay_mask_i.extend([1] * (seg_len - L) + [0] * L)
+ j = seg_len + j
+ return relay_target_i, relay_mask_i
diff --git a/reproduction/seqence_labelling/cws/model/bilstm_crf_cws.py b/reproduction/seqence_labelling/cws/model/bilstm_crf_cws.py
new file mode 100644
index 00000000..4f87a81c
--- /dev/null
+++ b/reproduction/seqence_labelling/cws/model/bilstm_crf_cws.py
@@ -0,0 +1,60 @@
+
+import torch
+from fastNLP.modules import LSTM
+from fastNLP.modules import allowed_transitions, ConditionalRandomField
+from fastNLP import seq_len_to_mask
+from torch import nn
+from fastNLP import Const
+import torch.nn.functional as F
+
+class BiLSTMCRF(nn.Module):
+ def __init__(self, char_embed, hidden_size, num_layers, target_vocab=None, bigram_embed=None, trigram_embed=None,
+ dropout=0.5):
+ super().__init__()
+
+ embed_size = char_embed.embed_size
+ self.char_embed = char_embed
+ if bigram_embed:
+ embed_size += bigram_embed.embed_size
+ self.bigram_embed = bigram_embed
+ if trigram_embed:
+ embed_size += trigram_embed.embed_size
+ self.trigram_embed = trigram_embed
+
+ self.lstm = LSTM(embed_size, hidden_size=hidden_size//2, bidirectional=True, batch_first=True,
+ num_layers=num_layers)
+ self.dropout = nn.Dropout(p=dropout)
+ self.fc = nn.Linear(hidden_size, len(target_vocab))
+
+ transitions = None
+ if target_vocab:
+ transitions = allowed_transitions(target_vocab, include_start_end=True, encoding_type='bmes')
+
+ self.crf = ConditionalRandomField(num_tags=len(target_vocab), allowed_transitions=transitions)
+
+ def _forward(self, chars, bigrams, trigrams, seq_len, target=None):
+ chars = self.char_embed(chars)
+ if bigrams is not None:
+ bigrams = self.bigram_embed(bigrams)
+ chars = torch.cat([chars, bigrams], dim=-1)
+ if trigrams is not None:
+ trigrams = self.trigram_embed(trigrams)
+ chars = torch.cat([chars, trigrams], dim=-1)
+
+ output, _ = self.lstm(chars, seq_len)
+ output = self.dropout(output)
+ output = self.fc(output)
+ output = F.log_softmax(output, dim=-1)
+ mask = seq_len_to_mask(seq_len)
+ if target is None:
+ pred, _ = self.crf.viterbi_decode(output, mask)
+ return {Const.OUTPUT:pred}
+ else:
+ loss = self.crf.forward(output, tags=target, mask=mask)
+ return {Const.LOSS:loss}
+
+ def forward(self, chars, seq_len, target, bigrams=None, trigrams=None):
+ return self._forward(chars, bigrams, trigrams, seq_len, target)
+
+ def predict(self, chars, seq_len, bigrams=None, trigrams=None):
+ return self._forward(chars, bigrams, trigrams, seq_len)
\ No newline at end of file
diff --git a/reproduction/seqence_labelling/cws/model/model.py b/reproduction/seqence_labelling/cws/model/bilstm_shift_relay.py
similarity index 75%
rename from reproduction/seqence_labelling/cws/model/model.py
rename to reproduction/seqence_labelling/cws/model/bilstm_shift_relay.py
index bdd9002d..4ce1cc51 100644
--- a/reproduction/seqence_labelling/cws/model/model.py
+++ b/reproduction/seqence_labelling/cws/model/bilstm_shift_relay.py
@@ -1,7 +1,5 @@
from torch import nn
import torch
-from fastNLP.modules import Embedding
-import numpy as np
from reproduction.seqence_labelling.cws.model.module import FeatureFunMax, SemiCRFShiftRelay
from fastNLP.modules import LSTM
@@ -21,25 +19,21 @@ class ShiftRelayCWSModel(nn.Module):
:param num_bigram_per_char: 每个character对应的bigram的数量
:param drop_p: Dropout的大小
"""
- def __init__(self, char_embed:Embedding, bigram_embed:Embedding, hidden_size:int=400, num_layers:int=1,
- L:int=6, num_bigram_per_char:int=1, drop_p:float=0.2):
+ def __init__(self, char_embed, bigram_embed, hidden_size:int=400, num_layers:int=1, L:int=6, drop_p:float=0.2):
super().__init__()
- self.char_embedding = Embedding(char_embed, dropout=drop_p)
- self._pretrained_embed = False
- if isinstance(char_embed, np.ndarray):
- self._pretrained_embed = True
- self.bigram_embedding = Embedding(bigram_embed, dropout=drop_p)
- self.lstm = LSTM(100 * (num_bigram_per_char + 1), hidden_size // 2, num_layers=num_layers, bidirectional=True,
+ self.char_embedding = char_embed
+ self.bigram_embedding = bigram_embed
+ self.lstm = LSTM(char_embed.embed_size+bigram_embed.embed_size, hidden_size // 2, num_layers=num_layers,
+ bidirectional=True,
batch_first=True)
self.feature_fn = FeatureFunMax(hidden_size, L)
self.semi_crf_relay = SemiCRFShiftRelay(L)
self.feat_drop = nn.Dropout(drop_p)
self.reset_param()
- # self.feature_fn.reset_parameters()
def reset_param(self):
for name, param in self.named_parameters():
- if 'embedding' in name and self._pretrained_embed:
+ if 'embedding' in name:
continue
if 'bias_hh' in name:
nn.init.constant_(param, 0)
@@ -51,10 +45,8 @@ class ShiftRelayCWSModel(nn.Module):
nn.init.xavier_uniform_(param)
def get_feats(self, chars, bigrams, seq_len):
- batch_size, max_len = chars.size()
chars = self.char_embedding(chars)
bigrams = self.bigram_embedding(bigrams)
- bigrams = bigrams.view(bigrams.size(0), max_len, -1)
chars = torch.cat([chars, bigrams], dim=-1)
feats, _ = self.lstm(chars, seq_len)
feats = self.feat_drop(feats)
diff --git a/reproduction/seqence_labelling/cws/readme.md b/reproduction/seqence_labelling/cws/readme.md
new file mode 100644
index 00000000..a25bb0ed
--- /dev/null
+++ b/reproduction/seqence_labelling/cws/readme.md
@@ -0,0 +1,32 @@
+四个数据集的统计信息,最原始的数据可以从[http://sighan.cs.uchicago.edu/bakeoff2005/]()下载。
+
+| pku | # of sents | # of tokens |
+| ----- | ---------- | ----------- |
+| train | 17173 | 1650222 |
+| dev | 1881 | 176226 |
+| test | 1944 | 172733 |
+| total | 20998 | 1999181 |
+
+
+| cityu | # of sents | # of tokens |
+| ----- | ---------- | ----------- |
+| train | 47696 | 2164907 |
+| dev | 5323 | 238447 |
+| test | 1492 | 67690 |
+| total | 54511 | 2471044 |
+
+
+| msra | # of sents | # of tokens |
+| ----- | ---------- | ----------- |
+| train | 78242 | 3644550 |
+| dev | 8676 | 405919 |
+| test | 3985 | 184355 |
+| total | 90903 | 4234824 |
+
+
+| as | # of sents | # of tokens |
+| ----- | ---------- | ----------- |
+| train | 638273 | 7536586 |
+| dev | 70680 | 831464 |
+| test | 14429 | 197681 |
+| total | 723382 | 8565731 |
diff --git a/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py b/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py
deleted file mode 100644
index f4260849..00000000
--- a/reproduction/seqence_labelling/cws/test/test_CWSDataLoader.py
+++ /dev/null
@@ -1,17 +0,0 @@
-
-
-import unittest
-from ..data.CWSDataLoader import SigHanLoader
-from fastNLP.core.vocabulary import VocabularyOption
-
-
-class TestCWSDataLoader(unittest.TestCase):
- def test_case1(self):
- cws_loader = SigHanLoader(target_type='bmes')
- data = cws_loader.process('pku_demo.txt')
- print(data.datasets)
-
- def test_calse2(self):
- cws_loader = SigHanLoader(target_type='bmes')
- data = cws_loader.process('pku_demo.txt', bigram_vocab_opt=VocabularyOption())
- print(data.datasets)
\ No newline at end of file
diff --git a/reproduction/seqence_labelling/cws/train_bilstm_crf.py b/reproduction/seqence_labelling/cws/train_bilstm_crf.py
new file mode 100644
index 00000000..b9a77249
--- /dev/null
+++ b/reproduction/seqence_labelling/cws/train_bilstm_crf.py
@@ -0,0 +1,52 @@
+import sys
+sys.path.append('../../..')
+
+from fastNLP.io.pipe.cws import CWSPipe
+from reproduction.seqence_labelling.cws.model.bilstm_crf_cws import BiLSTMCRF
+from fastNLP import Trainer, cache_results
+from fastNLP.embeddings import StaticEmbedding
+from fastNLP import EvaluateCallback, BucketSampler, SpanFPreRecMetric, GradientClipCallback
+from torch.optim import Adagrad
+
+###########hyper
+dataname = 'pku'
+hidden_size = 400
+num_layers = 1
+lr = 0.05
+###########hyper
+
+
+@cache_results('{}.pkl'.format(dataname), _refresh=False)
+def get_data():
+ data_bundle = CWSPipe(dataset_name=dataname, bigrams=True, trigrams=False).process_from_file()
+ char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.33, word_dropout=0.01,
+ model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt')
+ bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.33,min_freq=3, word_dropout=0.01,
+ model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt')
+ return data_bundle, char_embed, bigram_embed
+
+data_bundle, char_embed, bigram_embed = get_data()
+print(data_bundle)
+
+model = BiLSTMCRF(char_embed, hidden_size, num_layers, target_vocab=data_bundle.get_vocab('target'), bigram_embed=bigram_embed,
+ trigram_embed=None, dropout=0.3)
+model.cuda()
+
+callbacks = []
+callbacks.append(EvaluateCallback(data_bundle.get_dataset('test')))
+callbacks.append(GradientClipCallback(clip_type='value', clip_value=5))
+optimizer = Adagrad(model.parameters(), lr=lr)
+
+metrics = []
+metric1 = SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'), encoding_type='bmes')
+metrics.append(metric1)
+
+trainer = Trainer(data_bundle.get_dataset('train'), model, optimizer=optimizer, loss=None,
+ batch_size=128, sampler=BucketSampler(), update_every=1,
+ num_workers=1, n_epochs=10, print_every=5,
+ dev_data=data_bundle.get_dataset('dev'),
+ metrics=metrics,
+ metric_key=None,
+ validate_every=-1, save_path=None, use_tqdm=True, device=0,
+ callbacks=callbacks, check_code_level=0, dev_batch_size=128)
+trainer.train()
diff --git a/reproduction/seqence_labelling/cws/train_shift_relay.py b/reproduction/seqence_labelling/cws/train_shift_relay.py
index 55576575..322f42bb 100644
--- a/reproduction/seqence_labelling/cws/train_shift_relay.py
+++ b/reproduction/seqence_labelling/cws/train_shift_relay.py
@@ -1,64 +1,53 @@
-import os
+import sys
+sys.path.append('../../..')
from fastNLP import cache_results
-from reproduction.seqence_labelling.cws.data.CWSDataLoader import SigHanLoader
-from reproduction.seqence_labelling.cws.model.model import ShiftRelayCWSModel
-from fastNLP.io.embed_loader import EmbeddingOption
-from fastNLP.core.vocabulary import VocabularyOption
+from reproduction.seqence_labelling.cws.data.cws_shift_pipe import CWSShiftRelayPipe
+from reproduction.seqence_labelling.cws.model.bilstm_shift_relay import ShiftRelayCWSModel
from fastNLP import Trainer
from torch.optim import Adam
from fastNLP import BucketSampler
from fastNLP import GradientClipCallback
from reproduction.seqence_labelling.cws.model.metric import RelayMetric
-
-
-# 借助一下fastNLP的自动缓存机制,但是只能缓存4G以下的结果
-@cache_results(None)
-def prepare_data():
- data = SigHanLoader(target_type='shift_relay').process(file_dir, char_embed_opt=char_embed_opt,
- bigram_vocab_opt=bigram_vocab_opt,
- bigram_embed_opt=bigram_embed_opt,
- L=L)
- return data
+from fastNLP.embeddings import StaticEmbedding
+from fastNLP import EvaluateCallback
#########hyper
L = 4
hidden_size = 200
num_layers = 1
drop_p = 0.2
-lr = 0.02
-
+lr = 0.008
+data_name = 'pku'
#########hyper
device = 0
-# !!!!这里千万不要放完全路径,因为这样会暴露你们在服务器上的用户名,比较危险。所以一定要使用相对路径,最好把数据放到
-# 你们的reproduction路径下,然后设置.gitignore
-file_dir = '/path/to/'
-char_embed_path = '/pretrain/vectors/1grams_t3_m50_corpus.txt'
-bigram_embed_path = '/pretrain/vectors/2grams_t3_m50_corpus.txt'
-bigram_vocab_opt = VocabularyOption(min_freq=3)
-char_embed_opt = EmbeddingOption(embed_filepath=char_embed_path)
-bigram_embed_opt = EmbeddingOption(embed_filepath=bigram_embed_path)
-
-data_name = os.path.basename(file_dir)
cache_fp = 'caches/{}.pkl'.format(data_name)
+@cache_results(_cache_fp=cache_fp, _refresh=True) # 将结果缓存到cache_fp中,这样下次运行就直接读取,而不需要再次运行
+def prepare_data():
+ data_bundle = CWSShiftRelayPipe(dataset_name=data_name, L=L).process_from_file()
+ # 预训练的character embedding和bigram embedding
+ char_embed = StaticEmbedding(data_bundle.get_vocab('chars'), dropout=0.5, word_dropout=0.01,
+ model_dir_or_name='~/exps/CWS/pretrain/vectors/1grams_t3_m50_corpus.txt')
+ bigram_embed = StaticEmbedding(data_bundle.get_vocab('bigrams'), dropout=0.5, min_freq=3, word_dropout=0.01,
+ model_dir_or_name='~/exps/CWS/pretrain/vectors/2grams_t3_m50_corpus.txt')
-data = prepare_data(_cache_fp=cache_fp, _refresh=True)
+ return data_bundle, char_embed, bigram_embed
-model = ShiftRelayCWSModel(char_embed=data.embeddings['chars'], bigram_embed=data.embeddings['bigrams'],
- hidden_size=hidden_size, num_layers=num_layers,
- L=L, num_bigram_per_char=1, drop_p=drop_p)
+data, char_embed, bigram_embed = prepare_data()
-sampler = BucketSampler(batch_size=32)
+model = ShiftRelayCWSModel(char_embed=char_embed, bigram_embed=bigram_embed,
+ hidden_size=hidden_size, num_layers=num_layers, drop_p=drop_p, L=L)
+
+sampler = BucketSampler()
optimizer = Adam(model.parameters(), lr=lr)
-clipper = GradientClipCallback(clip_value=5, clip_type='value')
-callbacks = [clipper]
-# if pretrain:
-# fixer = FixEmbedding([model.char_embedding, model.bigram_embedding], fix_until=fix_until)
-# callbacks.append(fixer)
-trainer = Trainer(data.datasets['train'], model, optimizer=optimizer, loss=None, batch_size=32, sampler=sampler,
- update_every=5, n_epochs=3, print_every=5, dev_data=data.datasets['dev'], metrics=RelayMetric(),
+clipper = GradientClipCallback(clip_value=5, clip_type='value') # 截断太大的梯度
+evaluator = EvaluateCallback(data.get_dataset('test')) # 额外测试在test集上的效果
+callbacks = [clipper, evaluator]
+
+trainer = Trainer(data.get_dataset('train'), model, optimizer=optimizer, loss=None, batch_size=128, sampler=sampler,
+ update_every=1, n_epochs=10, print_every=5, dev_data=data.get_dataset('dev'), metrics=RelayMetric(),
metric_key='f', validate_every=-1, save_path=None, use_tqdm=True, device=device, callbacks=callbacks,
- check_code_level=0)
+ check_code_level=0, num_workers=1)
trainer.train()
\ No newline at end of file
diff --git a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py b/reproduction/seqence_labelling/ner/data/Conll2003Loader.py
deleted file mode 100644
index 1aeddcf8..00000000
--- a/reproduction/seqence_labelling/ner/data/Conll2003Loader.py
+++ /dev/null
@@ -1,93 +0,0 @@
-
-from fastNLP.core.vocabulary import VocabularyOption
-from fastNLP.io.base_loader import DataSetLoader, DataBundle
-from typing import Union, Dict
-from fastNLP import Vocabulary
-from fastNLP import Const
-from reproduction.utils import check_dataloader_paths
-
-from fastNLP.io import ConllLoader
-from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
-
-
-class Conll2003DataLoader(DataSetLoader):
- def __init__(self, task:str='ner', encoding_type:str='bioes'):
- """
- 加载Conll2003格式的英语语料,该数据集的信息可以在https://www.clips.uantwerpen.be/conll2003/ner/找到。当task为pos
- 时,返回的DataSet中target取值于第2列; 当task为chunk时,返回的DataSet中target取值于第3列;当task为ner时,返回
- 的DataSet中target取值于第4列。所有"-DOCSTART- -X- O O"将被忽略,这会导致数据的数量少于很多文献报道的值,但
- 鉴于"-DOCSTART- -X- O O"只是用于文档分割的符号,并不应该作为预测对象,所以我们忽略了数据中的-DOCTSTART-开头的行
- ner与chunk任务读取后的数据的target将为encoding_type类型。pos任务读取后就是pos列的数据。
-
- :param task: 指定需要标注任务。可选ner, pos, chunk
- """
- assert task in ('ner', 'pos', 'chunk')
- index = {'ner':3, 'pos':1, 'chunk':2}[task]
- self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index])
- self._tag_converters = []
- if task in ('ner', 'chunk'):
- self._tag_converters = [iob2]
- if encoding_type == 'bioes':
- self._tag_converters.append(iob2bioes)
-
- def load(self, path: str):
- dataset = self._loader.load(path)
- def convert_tag_schema(tags):
- for converter in self._tag_converters:
- tags = converter(tags)
- return tags
- if self._tag_converters:
- dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET)
- return dataset
-
- def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None, lower:bool=False):
- """
- 读取并处理数据。数据中的'-DOCSTART-'开头的行会被忽略
-
- :param paths:
- :param word_vocab_opt: vocabulary的初始化值
- :param lower: 是否将所有字母转为小写。
- :return:
- """
- # 读取数据
- paths = check_dataloader_paths(paths)
- data = DataBundle()
- input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
- target_fields = [Const.TARGET, Const.INPUT_LEN]
- for name, path in paths.items():
- dataset = self.load(path)
- dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
- if lower:
- dataset.words.lower()
- data.datasets[name] = dataset
-
- # 对construct vocab
- word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
- word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
- no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
- word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
- data.vocabs[Const.INPUT] = word_vocab
-
- # cap words
- cap_word_vocab = Vocabulary()
- cap_word_vocab.from_dataset(data.datasets['train'], field_name='raw_words',
- no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
- cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
- input_fields.append('cap_words')
- data.vocabs['cap_words'] = cap_word_vocab
-
- # 对target建vocab
- target_vocab = Vocabulary(unknown=None, padding=None)
- target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
- target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
- data.vocabs[Const.TARGET] = target_vocab
-
- for name, dataset in data.datasets.items():
- dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
- dataset.set_input(*input_fields)
- dataset.set_target(*target_fields)
-
- return data
-
-if __name__ == '__main__':
- pass
\ No newline at end of file
diff --git a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py b/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
deleted file mode 100644
index a6070f39..00000000
--- a/reproduction/seqence_labelling/ner/data/OntoNoteLoader.py
+++ /dev/null
@@ -1,152 +0,0 @@
-from fastNLP.core.vocabulary import VocabularyOption
-from fastNLP.io.base_loader import DataSetLoader, DataBundle
-from typing import Union, Dict
-from fastNLP import DataSet
-from fastNLP import Vocabulary
-from fastNLP import Const
-from reproduction.utils import check_dataloader_paths
-
-from fastNLP.io import ConllLoader
-from reproduction.seqence_labelling.ner.data.utils import iob2bioes, iob2
-
-class OntoNoteNERDataLoader(DataSetLoader):
- """
- 用于读取处理为Conll格式后的OntoNote数据。将OntoNote数据处理为conll格式的过程可以参考https://github.com/yhcc/OntoNotes-5.0-NER。
-
- """
- def __init__(self, encoding_type:str='bioes'):
- assert encoding_type in ('bioes', 'bio')
- self.encoding_type = encoding_type
- if encoding_type=='bioes':
- self.encoding_method = iob2bioes
- else:
- self.encoding_method = iob2
-
- def load(self, path:str)->DataSet:
- """
- 给定一个文件路径,读取数据。返回的DataSet包含以下的field
- raw_words: List[str]
- target: List[str]
-
- :param path:
- :return:
- """
- dataset = ConllLoader(headers=['raw_words', 'target'], indexes=[3, 10]).load(path)
- def convert_to_bio(tags):
- bio_tags = []
- flag = None
- for tag in tags:
- label = tag.strip("()*")
- if '(' in tag:
- bio_label = 'B-' + label
- flag = label
- elif flag:
- bio_label = 'I-' + flag
- else:
- bio_label = 'O'
- if ')' in tag:
- flag = None
- bio_tags.append(bio_label)
- return self.encoding_method(bio_tags)
-
- def convert_word(words):
- converted_words = []
- for word in words:
- word = word.replace('/.', '.') # 有些结尾的.是/.形式的
- if not word.startswith('-'):
- converted_words.append(word)
- continue
- # 以下是由于这些符号被转义了,再转回来
- tfrs = {'-LRB-':'(',
- '-RRB-': ')',
- '-LSB-': '[',
- '-RSB-': ']',
- '-LCB-': '{',
- '-RCB-': '}'
- }
- if word in tfrs:
- converted_words.append(tfrs[word])
- else:
- converted_words.append(word)
- return converted_words
-
- dataset.apply_field(convert_word, field_name='raw_words', new_field_name='raw_words')
- dataset.apply_field(convert_to_bio, field_name='target', new_field_name='target')
-
- return dataset
-
- def process(self, paths: Union[str, Dict[str, str]], word_vocab_opt:VocabularyOption=None,
- lower:bool=True)->DataBundle:
- """
- 读取并处理数据。返回的DataInfo包含以下的内容
- vocabs:
- word: Vocabulary
- target: Vocabulary
- datasets:
- train: DataSet
- words: List[int], 被设置为input
- target: int. label,被同时设置为input和target
- seq_len: int. 句子的长度,被同时设置为input和target
- raw_words: List[str]
- xxx(根据传入的paths可能有所变化)
-
- :param paths:
- :param word_vocab_opt: vocabulary的初始化值
- :param lower: 是否使用小写
- :return:
- """
- paths = check_dataloader_paths(paths)
- data = DataBundle()
- input_fields = [Const.TARGET, Const.INPUT, Const.INPUT_LEN]
- target_fields = [Const.TARGET, Const.INPUT_LEN]
- for name, path in paths.items():
- dataset = self.load(path)
- dataset.apply_field(lambda words: words, field_name='raw_words', new_field_name=Const.INPUT)
- if lower:
- dataset.words.lower()
- data.datasets[name] = dataset
-
- # 对construct vocab
- word_vocab = Vocabulary(min_freq=2) if word_vocab_opt is None else Vocabulary(**word_vocab_opt)
- word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT,
- no_create_entry_dataset=[dataset for name, dataset in data.datasets.items() if name!='train'])
- word_vocab.index_dataset(*data.datasets.values(), field_name=Const.INPUT, new_field_name=Const.INPUT)
- data.vocabs[Const.INPUT] = word_vocab
-
- # cap words
- cap_word_vocab = Vocabulary()
- cap_word_vocab.from_dataset(*data.datasets.values(), field_name='raw_words')
- cap_word_vocab.index_dataset(*data.datasets.values(), field_name='raw_words', new_field_name='cap_words')
- input_fields.append('cap_words')
- data.vocabs['cap_words'] = cap_word_vocab
-
- # 对target建vocab
- target_vocab = Vocabulary(unknown=None, padding=None)
- target_vocab.from_dataset(*data.datasets.values(), field_name=Const.TARGET)
- target_vocab.index_dataset(*data.datasets.values(), field_name=Const.TARGET)
- data.vocabs[Const.TARGET] = target_vocab
-
- for name, dataset in data.datasets.items():
- dataset.add_seq_len(Const.INPUT, new_field_name=Const.INPUT_LEN)
- dataset.set_input(*input_fields)
- dataset.set_target(*target_fields)
-
- return data
-
-
-if __name__ == '__main__':
- loader = OntoNoteNERDataLoader()
- dataset = loader.load('/hdd/fudanNLP/fastNLP/others/data/v4/english/test.txt')
- print(dataset.target.value_count())
- print(dataset[:4])
-
-
-"""
-train 115812 2200752
-development 15680 304684
-test 12217 230111
-
-train 92403 1901772
-valid 13606 279180
-test 10258 204135
-"""
\ No newline at end of file
diff --git a/reproduction/seqence_labelling/ner/data/utils.py b/reproduction/seqence_labelling/ner/data/utils.py
deleted file mode 100644
index 8f7af792..00000000
--- a/reproduction/seqence_labelling/ner/data/utils.py
+++ /dev/null
@@ -1,49 +0,0 @@
-from typing import List
-
-def iob2(tags:List[str])->List[str]:
- """
- 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。
-
- :param tags: 需要转换的tags
- """
- for i, tag in enumerate(tags):
- if tag == "O":
- continue
- split = tag.split("-")
- if len(split) != 2 or split[0] not in ["I", "B"]:
- raise TypeError("The encoding schema is not a valid IOB type.")
- if split[0] == "B":
- continue
- elif i == 0 or tags[i - 1] == "O": # conversion IOB1 to IOB2
- tags[i] = "B" + tag[1:]
- elif tags[i - 1][1:] == tag[1:]:
- continue
- else: # conversion IOB1 to IOB2
- tags[i] = "B" + tag[1:]
- return tags
-
-def iob2bioes(tags:List[str])->List[str]:
- """
- 将iob的tag转换为bmeso编码
- :param tags:
- :return:
- """
- new_tags = []
- for i, tag in enumerate(tags):
- if tag == 'O':
- new_tags.append(tag)
- else:
- split = tag.split('-')[0]
- if split == 'B':
- if i+1!=len(tags) and tags[i+1].split('-')[0] == 'I':
- new_tags.append(tag)
- else:
- new_tags.append(tag.replace('B-', 'S-'))
- elif split == 'I':
- if i + 1400)
+data_bundle.datasets['dev'].drop(lambda x:len(x['words'])>400)
+data_bundle.datasets['test'].drop(lambda x:len(x['words'])>400)
+bert_embed = BertEmbedding(data_bundle.vocabs['words'], requires_grad=False,
+ model_dir_or_name="en-base-uncased")
+model = BiLSTMSentiment(bert_embed, len(data_bundle.vocabs['target']))
+
+Trainer(data_bundle.datasets['train'], model, optimizer=None, loss=CrossEntropyLoss(), device=0,
+ batch_size=10, dev_data=data_bundle.datasets['dev'], metrics=AccuracyMetric()).train()
+
+# 在测试集上测试一下效果
+Tester(data_bundle.datasets['test'], model, batch_size=32, metrics=AccuracyMetric()).test()
\ No newline at end of file
diff --git a/reproduction/text_classification/train_char_cnn.py b/reproduction/text_classification/train_char_cnn.py
index 050527fe..6b56608a 100644
--- a/reproduction/text_classification/train_char_cnn.py
+++ b/reproduction/text_classification/train_char_cnn.py
@@ -7,29 +7,29 @@ import sys
sys.path.append('../..')
from fastNLP.core.const import Const as C
import torch.nn as nn
-from data.yelpLoader import yelpLoader
-from data.sstLoader import sst2Loader
-from data.IMDBLoader import IMDBLoader
+from fastNLP.io.data_loader import YelpLoader
+from fastNLP.io.pipe.classification import YelpFullPipe,YelpPolarityPipe,SST2Pipe,IMDBPipe
+#from data.sstLoader import sst2Loader
from model.char_cnn import CharacterLevelCNN
-from fastNLP.core.vocabulary import Vocabulary
-from fastNLP.models.cnn_text_classification import CNNText
-from fastNLP.modules.encoder.embedding import CNNCharEmbedding,StaticEmbedding,StackEmbedding,LSTMCharEmbedding
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.core.trainer import Trainer
from torch.optim import SGD
from torch.autograd import Variable
import torch
-from fastNLP import BucketSampler
+from torch.optim.lr_scheduler import LambdaLR
+from fastNLP.core import LRScheduler
+
##hyper
#todo 这里加入fastnlp的记录
class Config():
+ #seed=7777
model_dir_or_name="en-base-uncased"
embedding_grad= False,
bert_embedding_larers= '4,-2,-1'
train_epoch= 50
num_classes=2
- task= "IMDB"
+ task= "yelp_p"
#yelp_p
datapath = {"train": "/remote-home/ygwang/yelp_polarity/train.csv",
"test": "/remote-home/ygwang/yelp_polarity/test.csv"}
@@ -46,6 +46,9 @@ class Config():
number_of_characters=69
extra_characters=''
max_length=1014
+ weight_decay = 1e-5
+ to_lower=True
+ tokenizer = 'spacy' # 使用spacy进行分词
char_cnn_config={
"alphabet": {
@@ -104,16 +107,42 @@ class Config():
}
ops=Config
+# set_rng_seeds(ops.seed)
+# print('RNG SEED: {}'.format(ops.seed))
+
##1.task相关信息:利用dataloader载入dataInfo
-#dataloader=sst2Loader()
+#dataloader=SST2Loader()
#dataloader=IMDBLoader()
-dataloader=yelpLoader(fine_grained=True)
-datainfo=dataloader.process(ops.datapath,char_level_op=True)
+# dataloader=YelpLoader(fine_grained=True)
+# datainfo=dataloader.process(ops.datapath,char_level_op=True,split_dev_op=False)
char_vocab=ops.char_cnn_config["alphabet"]["en"]["lower"]["alphabet"]
ops.number_of_characters=len(char_vocab)
ops.embedding_dim=ops.number_of_characters
+# load data set
+if ops.task == 'yelp_p':
+ data_bundle = YelpPolarityPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
+elif ops.task == 'yelp_f':
+ data_bundle = YelpFullPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
+elif ops.task == 'imdb':
+ data_bundle = IMDBPipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
+elif ops.task == 'sst-2':
+ data_bundle = SST2Pipe(lower=ops.to_lower, tokenizer=ops.tokenizer).process_from_file()
+else:
+ raise RuntimeError(f'NOT support {ops.task} task yet!')
+
+
+def wordtochar(words):
+ chars = []
+ for word in words:
+ #word = word.lower()
+ for char in word:
+ chars.append(char)
+ chars.append('')
+ chars.pop()
+ return chars
+
#chartoindex
def chartoindex(chars):
max_seq_len=ops.max_length
@@ -133,13 +162,14 @@ def chartoindex(chars):
char_index_list=[zero_index]*max_seq_len
return char_index_list
-for dataset in datainfo.datasets.values():
+for dataset in data_bundle.datasets.values():
+ dataset.apply_field(wordtochar, field_name="raw_words", new_field_name='chars')
dataset.apply_field(chartoindex,field_name='chars',new_field_name='chars')
-datainfo.datasets['train'].set_input('chars')
-datainfo.datasets['test'].set_input('chars')
-datainfo.datasets['train'].set_target('target')
-datainfo.datasets['test'].set_target('target')
+data_bundle.datasets['train'].set_input('chars')
+data_bundle.datasets['test'].set_input('chars')
+data_bundle.datasets['train'].set_target('target')
+data_bundle.datasets['test'].set_target('target')
##2. 定义/组装模型,这里可以随意,就如果是fastNLP封装好的,类似CNNText就直接用初始化调用就好了,这里只是给出一个伪框架表示占位,在这里建立符合fastNLP输入输出规范的model
class ModelFactory(nn.Module):
@@ -162,7 +192,7 @@ class ModelFactory(nn.Module):
## 2.或直接复用fastNLP的模型
#vocab=datainfo.vocabs['words']
-vocab_label=datainfo.vocabs['target']
+vocab_label=data_bundle.vocabs['target']
'''
# emded_char=CNNCharEmbedding(vocab)
# embed_word = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True)
@@ -186,13 +216,21 @@ model=CharacterLevelCNN(ops,embedding)
## 3. 声明loss,metric,optimizer
loss=CrossEntropyLoss
metric=AccuracyMetric
-optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
+#optimizer= SGD([param for param in model.parameters() if param.requires_grad==True], lr=ops.lr)
+optimizer = SGD([param for param in model.parameters() if param.requires_grad == True],
+ lr=ops.lr, momentum=0.9, weight_decay=ops.weight_decay)
+callbacks = []
+# callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
+callbacks.append(
+ LRScheduler(LambdaLR(optimizer, lambda epoch: ops.lr if epoch <
+ ops.train_epoch * 0.8 else ops.lr * 0.1))
+)
## 4.定义train方法
def train(model,datainfo,loss,metrics,optimizer,num_epochs=100):
- trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),
- metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=0, check_code_level=-1,
- n_epochs=num_epochs)
+ trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss(target='target'),batch_size=ops.batch_size,
+ metrics=[metrics(target='target')], dev_data=datainfo.datasets['test'], device=[0,1,2], check_code_level=-1,
+ n_epochs=num_epochs,callbacks=callbacks)
print(trainer.train())
@@ -201,5 +239,5 @@ if __name__=="__main__":
#print(vocab_label)
#print(datainfo.datasets["train"])
- train(model,datainfo,loss,metric,optimizer,num_epochs=ops.train_epoch)
+ train(model,data_bundle,loss,metric,optimizer,num_epochs=ops.train_epoch)
\ No newline at end of file
diff --git a/reproduction/text_classification/train_dpcnn.py b/reproduction/text_classification/train_dpcnn.py
index 70570970..f3f4e231 100644
--- a/reproduction/text_classification/train_dpcnn.py
+++ b/reproduction/text_classification/train_dpcnn.py
@@ -3,25 +3,26 @@
import torch.cuda
from fastNLP.core.utils import cache_results
from torch.optim import SGD
-from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
+from torch.optim.lr_scheduler import CosineAnnealingLR
from fastNLP.core.trainer import Trainer
from fastNLP import CrossEntropyLoss, AccuracyMetric
-from fastNLP.modules.encoder.embedding import StaticEmbedding, CNNCharEmbedding, StackEmbedding
+from fastNLP.embeddings import StaticEmbedding
from reproduction.text_classification.model.dpcnn import DPCNN
-from data.yelpLoader import yelpLoader
+from fastNLP.io.data_loader import YelpLoader
from fastNLP.core.sampler import BucketSampler
-import torch.nn as nn
-from fastNLP.core import LRScheduler, Callback
+from fastNLP.core import LRScheduler
from fastNLP.core.const import Const as C
from fastNLP.core.vocabulary import VocabularyOption
+from fastNLP.core.dist_trainer import DistTrainer
from utils.util_init import set_rng_seeds
+from fastNLP import logger
import os
-os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
-os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
+# os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
+# os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
-
-
# hyper
+logger.add_file('log', 'INFO')
+print(logger.handlers)
class Config():
seed = 12345
@@ -46,11 +47,11 @@ class Config():
self.datapath = {k: os.path.join(self.datadir, v)
for k, v in self.datafile.items()}
-
ops = Config()
set_rng_seeds(ops.seed)
-print('RNG SEED: {}'.format(ops.seed))
+# print('RNG SEED: {}'.format(ops.seed))
+logger.info('RNG SEED %d'%ops.seed)
# 1.task相关信息:利用dataloader载入dataInfo
@@ -59,33 +60,35 @@ print('RNG SEED: {}'.format(ops.seed))
@cache_results(ops.model_dir_or_name+'-data-cache')
def load_data():
- datainfo = yelpLoader(fine_grained=True, lower=True).process(
+ datainfo = YelpLoader(fine_grained=True, lower=True).process(
paths=ops.datapath, train_ds=['train'], src_vocab_op=ops.src_vocab_op)
for ds in datainfo.datasets.values():
ds.apply_field(len, C.INPUT, C.INPUT_LEN)
ds.set_input(C.INPUT, C.INPUT_LEN)
ds.set_target(C.TARGET)
- embedding = StaticEmbedding(
- datainfo.vocabs['words'], model_dir_or_name='en-glove-840b-300', requires_grad=ops.embedding_grad,
- normalize=False
- )
- return datainfo, embedding
+ return datainfo
-datainfo, embedding = load_data()
+
+datainfo = load_data()
+embedding = StaticEmbedding(
+ datainfo.vocabs['words'], model_dir_or_name='en-glove-6b-100d', requires_grad=ops.embedding_grad,
+ normalize=False)
embedding.embedding.weight.data /= embedding.embedding.weight.data.std()
-print(embedding.embedding.weight.mean(), embedding.embedding.weight.std())
+print(embedding.embedding.weight.data.mean(), embedding.embedding.weight.data.std())
# 2.或直接复用fastNLP的模型
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
-
-print(datainfo)
-print(datainfo.datasets['train'][0])
+datainfo.datasets['train'] = datainfo.datasets['train'][:1000]
+datainfo.datasets['test'] = datainfo.datasets['test'][:1000]
+# print(datainfo)
+# print(datainfo.datasets['train'][0])
+logger.info(datainfo)
model = DPCNN(init_embed=embedding, num_cls=len(datainfo.vocabs[C.TARGET]),
embed_dropout=ops.embed_dropout, cls_dropout=ops.cls_dropout)
-print(model)
+# print(model)
# 3. 声明loss,metric,optimizer
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
@@ -107,16 +110,21 @@ callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
-print(device)
+# print(device)
+logger.info(device)
# 4.定义train方法
trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
sampler=BucketSampler(num_buckets=50, batch_size=ops.batch_size),
- metrics=[metric],
+ metrics=[metric], use_tqdm=False, save_path='save',
dev_data=datainfo.datasets['test'], device=device,
check_code_level=-1, batch_size=ops.batch_size, callbacks=callbacks,
n_epochs=ops.train_epoch, num_workers=4)
-
+# trainer = DistTrainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss,
+# metrics=[metric],
+# dev_data=datainfo.datasets['test'], device='cuda',
+# batch_size_per_gpu=ops.batch_size, callbacks_all=callbacks,
+# n_epochs=ops.train_epoch, num_workers=4)
if __name__ == "__main__":
diff --git a/reproduction/text_classification/train_lstm.py b/reproduction/text_classification/train_lstm.py
index 4ecc61a1..40f77061 100644
--- a/reproduction/text_classification/train_lstm.py
+++ b/reproduction/text_classification/train_lstm.py
@@ -3,20 +3,13 @@ import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
-
-import torch.nn as nn
-
-from data.IMDBLoader import IMDBLoader
-from fastNLP.modules.encoder.embedding import StaticEmbedding
+from fastNLP.io.data_loader import IMDBLoader
+from fastNLP.embeddings import StaticEmbedding
from model.lstm import BiLSTMSentiment
-from fastNLP.core.const import Const as C
from fastNLP import CrossEntropyLoss, AccuracyMetric
-from fastNLP import Trainer, Tester
+from fastNLP import Trainer
from torch.optim import Adam
-from fastNLP.io.model_io import ModelLoader, ModelSaver
-
-import argparse
class Config():
diff --git a/reproduction/text_classification/train_lstm_att.py b/reproduction/text_classification/train_lstm_att.py
index a6f0dd03..1052f606 100644
--- a/reproduction/text_classification/train_lstm_att.py
+++ b/reproduction/text_classification/train_lstm_att.py
@@ -3,20 +3,13 @@ import os
os.environ['FASTNLP_BASE_URL'] = 'http://10.141.222.118:8888/file/download/'
os.environ['FASTNLP_CACHE_DIR'] = '/remote-home/hyan01/fastnlp_caches'
-
-import torch.nn as nn
-
-from data.IMDBLoader import IMDBLoader
-from fastNLP.modules.encoder.embedding import StaticEmbedding
+from fastNLP.io.data_loader import IMDBLoader
+from fastNLP.embeddings import StaticEmbedding
from model.lstm_self_attention import BiLSTM_SELF_ATTENTION
-from fastNLP.core.const import Const as C
from fastNLP import CrossEntropyLoss, AccuracyMetric
-from fastNLP import Trainer, Tester
+from fastNLP import Trainer
from torch.optim import Adam
-from fastNLP.io.model_io import ModelLoader, ModelSaver
-
-import argparse
class Config():
diff --git a/test/core/test_dataset.py b/test/core/test_dataset.py
index 0228f207..9c05c334 100644
--- a/test/core/test_dataset.py
+++ b/test/core/test_dataset.py
@@ -1,4 +1,5 @@
import os
+import sys
import unittest
from fastNLP import DataSet
@@ -79,6 +80,16 @@ class TestDataSetMethods(unittest.TestCase):
self.assertFalse("x" in dd.field_arrays)
self.assertTrue("y" in dd.field_arrays)
+ def test_delete_instance(self):
+ dd = DataSet()
+ old_length = 2
+ dd.add_field("x", [[1, 2, 3]] * old_length)
+ dd.add_field("y", [[1, 2, 3, 4]] * old_length)
+ dd.delete_instance(0)
+ self.assertEqual(len(dd), old_length-1)
+ dd.delete_instance(0)
+ self.assertEqual(len(dd), old_length-2)
+
def test_getitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ins_1, ins_0 = ds[0], ds[1]
diff --git a/test/core/test_dist_trainer.py b/test/core/test_dist_trainer.py
new file mode 100644
index 00000000..c6879634
--- /dev/null
+++ b/test/core/test_dist_trainer.py
@@ -0,0 +1,167 @@
+import unittest
+
+import numpy as np
+import torch.cuda
+from fastNLP import DataSet
+from fastNLP import Instance
+from fastNLP import CrossEntropyLoss, BCELoss
+from fastNLP import SGD
+from fastNLP.core.dist_trainer import DistTrainer, get_local_rank
+from fastNLP.models.base_model import NaiveClassifier
+import shutil
+import os
+import subprocess
+from argparse import ArgumentParser
+from fastNLP.core.callback import EchoCallback
+from fastNLP import AccuracyMetric
+
+def prepare_fake_dataset():
+ mean = np.array([-3, -3])
+ cov = np.array([[1, 0], [0, 1]])
+ class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
+
+ mean = np.array([3, 3])
+ cov = np.array([[1, 0], [0, 1]])
+ class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
+
+ data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=0) for item in class_A] +
+ [Instance(x=[float(item[0]), float(item[1])], y=1) for item in class_B])
+ return data_set
+
+def prepare_fake_dataset2(*args, size=100):
+ ys = np.random.randint(4, size=100, dtype=np.int64)
+ data = {'y': ys}
+ for arg in args:
+ data[arg] = np.random.randn(size, 5)
+ return DataSet(data=data)
+
+def set_rng_seed(seed):
+ np.random.seed(seed)
+
+def prepare_env():
+ def prepare_fake_dataset():
+ mean = np.array([-3, -3])
+ cov = np.array([[1, 0], [0, 1]])
+ class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
+
+ mean = np.array([3, 3])
+ cov = np.array([[1, 0], [0, 1]])
+ class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
+
+ data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
+ [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
+ return data_set
+
+ data_set = prepare_fake_dataset()
+ data_set.set_input("x")
+ data_set.set_target("y")
+ model = NaiveClassifier(2, 1)
+ return data_set, model
+
+class TestDistTrainer(unittest.TestCase):
+ save_path = './save_cp'
+
+ def run1(self):
+ # test distributed training
+ print('local rank', get_local_rank())
+ set_rng_seed(100)
+ data_set = prepare_fake_dataset()
+ data_set.set_input("x", flag=True)
+ data_set.set_target("y", flag=True)
+
+ model = NaiveClassifier(2, 2)
+
+ trainer = DistTrainer(
+ model=model, train_data=data_set, optimizer=SGD(lr=0.1),
+ loss=CrossEntropyLoss(pred="predict", target="y"),
+ batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path,
+ )
+ trainer.train()
+ """
+ # 应该正确运行
+ """
+ if trainer.is_master and os.path.exists(self.save_path):
+ shutil.rmtree(self.save_path)
+
+ def run2(self):
+ # test fp16 with distributed training
+ print('local rank', get_local_rank())
+ set_rng_seed(100)
+ data_set = prepare_fake_dataset()
+ data_set.set_input("x", flag=True)
+ data_set.set_target("y", flag=True)
+
+ model = NaiveClassifier(2, 2)
+
+ trainer = DistTrainer(
+ model=model, train_data=data_set, optimizer=SGD(lr=0.1),
+ loss=CrossEntropyLoss(pred="predict", target="y"),
+ batch_size_per_gpu=8, n_epochs=3, print_every=50, save_path=self.save_path,
+ fp16='O1'
+ )
+ trainer.train()
+ """
+ # 应该正确运行
+ """
+ if trainer.is_master and os.path.exists(self.save_path):
+ shutil.rmtree(self.save_path)
+
+ def run3(self):
+ set_rng_seed(100)
+ data_set, model = prepare_env()
+ trainer = DistTrainer(
+ data_set, model, optimizer=None,
+ loss=BCELoss(pred="predict", target="y"),
+ n_epochs=3, print_every=50,
+ callbacks_all=[EchoCallback('callbacks_all')],
+ callbacks_master=[EchoCallback('callbacks_master')]
+ )
+ trainer.train()
+
+ def run4(self):
+ set_rng_seed(100)
+ data_set, model = prepare_env()
+
+ train_set, dev_set = data_set.split(0.3)
+
+ model = NaiveClassifier(2, 1)
+
+ trainer = DistTrainer(
+ train_set, model, optimizer=SGD(lr=0.1),
+ loss=BCELoss(pred="predict", target="y"),
+ batch_size_per_gpu=32, n_epochs=3, print_every=50, dev_data=dev_set,
+ metrics=AccuracyMetric(pred="predict", target="y"), validate_every=-1, save_path=None,
+ )
+ trainer.train()
+ """
+ # 应该正确运行
+ """
+
+ def run_dist(self, run_id):
+ if torch.cuda.is_available():
+ ngpu = min(2, torch.cuda.device_count())
+ path = __file__
+ cmd = ['python', '-m', 'torch.distributed.launch',
+ '--nproc_per_node', str(ngpu), path, '--test', str(run_id)]
+ print(' '.join(cmd))
+ subprocess.check_call(cmd)
+
+ def test_normal_run(self):
+ self.run_dist(1)
+
+ def no_test_fp16(self):
+ self.run_dist(2)
+
+ def test_callback(self):
+ self.run_dist(3)
+
+ def test_dev_data(self):
+ self.run_dist(4)
+
+if __name__ == '__main__':
+ runner = TestDistTrainer()
+ parser = ArgumentParser()
+ parser.add_argument('--test', type=int)
+ args, _ = parser.parse_known_args()
+ if args.test and hasattr(runner, 'run%s'%args.test):
+ getattr(runner, 'run%s'%args.test)()
diff --git a/test/core/test_field.py b/test/core/test_field.py
index e9053f37..c46e2de2 100644
--- a/test/core/test_field.py
+++ b/test/core/test_field.py
@@ -170,22 +170,22 @@ class TestFieldArray(unittest.TestCase):
def test_append(self):
with self.assertRaises(Exception):
- fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
+ fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False)
fa.append(0)
with self.assertRaises(Exception):
- fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
+ fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True, use_1st_ins_infer_dim_type=False)
fa.append([1, 2, 3, 4, 5])
with self.assertRaises(Exception):
- fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
+ fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False)
fa.append([])
with self.assertRaises(Exception):
- fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
+ fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True, use_1st_ins_infer_dim_type=False)
fa.append(["str", 0, 0, 0, 1.89])
- fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True)
+ fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]], is_input=True, use_1st_ins_infer_dim_type=False)
fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
self.assertEqual(len(fa), 3)
self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
diff --git a/test/core/test_metrics.py b/test/core/test_metrics.py
index 9c8a586c..236066d6 100644
--- a/test/core/test_metrics.py
+++ b/test/core/test_metrics.py
@@ -7,7 +7,7 @@ from fastNLP import AccuracyMetric
from fastNLP.core.metrics import _pred_topk, _accuracy_topk
from fastNLP.core.vocabulary import Vocabulary
from collections import Counter
-from fastNLP.core.metrics import SpanFPreRecMetric
+from fastNLP.core.metrics import SpanFPreRecMetric, ExtractiveQAMetric
def _generate_tags(encoding_type, number_labels=4):
@@ -347,3 +347,46 @@ class TestUsefulFunctions(unittest.TestCase):
_ = _pred_topk(np.random.randint(0, 3, size=(10, 1)))
# 跑通即可
+
+
+class TestExtractiveQAMetric(unittest.TestCase):
+
+ def test_cast_1(self):
+ qa_prediction = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169,
+ -0.3782, 0.8240],
+ [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, -1.1563,
+ -0.3562, -1.4116],
+ [-1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858,
+ -2.0023, 0.0075],
+ [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186,
+ 0.3832, -0.1540],
+ [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120,
+ -1.3508, -0.9513],
+ [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919,
+ -0.0842, -0.4294]],
+
+ [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490,
+ -1.4138, -0.8853],
+ [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845,
+ -1.0726, 0.0364],
+ [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264,
+ -0.8836, -0.9320],
+ [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044,
+ -1.6857, 1.1571],
+ [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125,
+ 3.5837, 1.0184],
+ [1.6495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943,
+ -0.9025, 0.0864]]])
+ qa_prediction = qa_prediction.permute(1, 2, 0)
+ pred1, pred2 = qa_prediction.split(1, dim=-1)
+ pred1 = pred1.squeeze(-1)
+ pred2 = pred2.squeeze(-1)
+ target1 = torch.LongTensor([3, 0, 2, 4, 4, 0])
+ target2 = torch.LongTensor([4, 1, 6, 8, 7, 1])
+ metric = ExtractiveQAMetric()
+ metric.evaluate(pred1, pred2, target1, target2)
+ result = metric.get_metric()
+ truth = {'EM': 62.5, 'f_1': 72.5, 'noAns-f_1': 50.0, 'noAns-EM': 50.0, 'hasAns-f_1': 95.0, 'hasAns-EM': 75.0}
+ for k, v in truth.items():
+ self.assertTrue(k in result)
+ self.assertEqual(v, result[k])
diff --git a/test/data_for_tests/sample_mnli.tsv b/test/data_for_tests/sample_mnli.tsv
new file mode 100644
index 00000000..9a30b95b
--- /dev/null
+++ b/test/data_for_tests/sample_mnli.tsv
@@ -0,0 +1,12 @@
+index promptID pairID genre sentence1_binary_parse sentence2_binary_parse sentence1_parse sentence2_parse sentence1 sentence2 label1 label2 label3 label4 label5 gold_label
+0 63735 63735n slate ( ( The ( new rights ) ) ( are ( nice enough ) ) ) ( Everyone ( really ( likes ( the ( newest benefits ) ) ) ) ) (ROOT (S (NP (DT The) (JJ new) (NNS rights)) (VP (VBP are) (ADJP (JJ nice) (RB enough))))) (ROOT (S (NP (NN Everyone)) (VP (ADVP (RB really)) (VBZ likes) (NP (DT the) (JJS newest) (NNS benefits))))) The new rights are nice enough Everyone really likes the newest benefits neutral entailment neutral neutral neutral neutral
+1 91383 91383c government ( ( This site ) ( ( includes ( ( ( ( a list ) ( of ( all ( award winners ) ) ) ) and ) ( ( a ( searchable database ) ) ( of ( Government ( Executive articles ) ) ) ) ) ) . ) ) ( ( ( The ( Government ( Executive articles ) ) ) ( housed ( on ( the website ) ) ) ) ( ( ( are not ) ( able ( to ( be searched ) ) ) ) . ) ) (ROOT (S (NP (DT This) (NN site)) (VP (VBZ includes) (NP (NP (NP (DT a) (NN list)) (PP (IN of) (NP (DT all) (NN award) (NNS winners)))) (CC and) (NP (NP (DT a) (JJ searchable) (NN database)) (PP (IN of) (NP (NNP Government) (NNP Executive) (NNS articles)))))) (. .))) (ROOT (S (NP (NP (DT The) (NNP Government) (NNP Executive) (NNS articles)) (VP (VBN housed) (PP (IN on) (NP (DT the) (NN website))))) (VP (VBP are) (RB not) (ADJP (JJ able) (S (VP (TO to) (VP (VB be) (ADJP (JJ searched))))))) (. .))) This site includes a list of all award winners and a searchable database of Government Executive articles. The Government Executive articles housed on the website are not able to be searched. contradiction contradiction contradiction contradiction contradiction contradiction
+2 755 755e telephone ( ( ( ( uh ( i ( ( do n't ) ( know ( ( i i ) ( have ( ( mixed emotions ) ( about ( him ( ( uh sometimes ) ( i ( like him ) ) ) ) ) ) ) ) ) ) ) ) but ) ( ( at ( the ( same times ) ) ) ( i ( love ( to ( see somebody ) ) ) ) ) ) ( beat him ) ) ( I ( ( ( ( ( ( like him ) ( for ( the ( most part ) ) ) ) , ) but ) ( ( would still ) ( enjoy ( seeing ( someone ( beat him ) ) ) ) ) ) . ) ) (ROOT (SINV (S (S (INTJ (UH uh)) (NP (FW i)) (VP (VBP do) (RB n't) (VP (VB know) (NP (NP (FW i) (FW i)) (SBAR (S (VP (VBP have) (VP (VBN mixed) (NP (NNS emotions)) (PP (IN about) (S (NP (PRP him)) (VP (VBG uh) (ADVP (RB sometimes)) (NP (NP (FW i)) (PP (IN like) (NP (PRP him))))))))))))))) (CC but) (S (PP (IN at) (NP (DT the) (JJ same) (NNS times))) (NP (FW i)) (VP (VBP love) (S (VP (TO to) (VP (VB see) (NP (NN somebody)))))))) (VP (VBD beat)) (NP (PRP him)))) (ROOT (S (NP (PRP I)) (VP (VP (VBP like) (NP (PRP him)) (PP (IN for) (NP (DT the) (JJS most) (NN part)))) (, ,) (CC but) (VP (MD would) (ADVP (RB still)) (VP (VB enjoy) (S (VP (VBG seeing) (S (NP (NN someone)) (VP (VB beat) (NP (PRP him))))))))) (. .))) uh i don't know i i have mixed emotions about him uh sometimes i like him but at the same times i love to see somebody beat him I like him for the most part, but would still enjoy seeing someone beat him. entailment entailment entailment entailment entailment entailment
+3 78013 78013c telephone ( yeah ( ( i i ) ( think ( ( my ( favorite restaurant ) ) ( ( is always ) ( been ( ( the ( one closest ) ) ( you ( ( know ( the closest ) ) ( ( as long ) ( as ( it ( 's ( it ( meets ( ( the ( minimum criteria ) ) ( you ( know ( of ( good food ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ( ( My ( favorite restaurants ) ) ( ( ( ( are always ) ( ( ( ( ( at least ) a ) hundred ) miles ) away ) ) ( from ( my house ) ) ) . ) ) (ROOT (S (VP (VB yeah) (NP (NP (FW i) (FW i)) (SBAR (S (VP (VBP think) (SBAR (S (NP (PRP$ my) (JJ favorite) (NN restaurant)) (VP (VBZ is) (ADVP (RB always)) (VP (VBN been) (NP (NP (DT the) (CD one) (JJS closest)) (SBAR (S (NP (PRP you)) (VP (VBP know) (NP (DT the) (JJS closest)) (ADVP (ADVP (RB as) (RB long)) (SBAR (IN as) (S (NP (PRP it)) (VP (VBZ 's) (SBAR (S (NP (PRP it)) (VP (VBZ meets) (NP (NP (DT the) (JJ minimum) (NNS criteria)) (SBAR (S (NP (PRP you)) (VP (VBP know) (PP (IN of) (NP (JJ good) (NN food))))))))))))))))))))))))))))) (ROOT (S (NP (PRP$ My) (JJ favorite) (NNS restaurants)) (VP (VBP are) (ADVP (RB always)) (ADVP (NP (QP (IN at) (JJS least) (DT a) (CD hundred)) (NNS miles)) (RB away)) (PP (IN from) (NP (PRP$ my) (NN house)))) (. .))) yeah i i think my favorite restaurant is always been the one closest you know the closest as long as it's it meets the minimum criteria you know of good food My favorite restaurants are always at least a hundred miles away from my house. contradiction contradiction contradiction contradiction contradiction contradiction
+4 96377 96377c telephone ( i ( ( do n't ) ( know ( um ( do ( you ( do ( ( a lot ) ( of camping ) ) ) ) ) ) ) ) ) ( I ( ( know exactly ) . ) ) (ROOT (S (NP (FW i)) (VP (VBP do) (RB n't) (VP (VB know) (SBAR (S (NP (NN um)) (VP (VBP do) (SBAR (S (NP (PRP you)) (VP (VBP do) (NP (NP (DT a) (NN lot)) (PP (IN of) (NP (NN camping)))))))))))))) (ROOT (S (NP (PRP I)) (VP (VBP know) (ADVP (RB exactly))) (. .))) i don't know um do you do a lot of camping I know exactly. contradiction contradiction contradiction contradiction contradiction contradiction
+5 139749 139749c telephone ( well ( that ( would ( be ( ( a help ) ( i ( wish ( they ( would ( do ( that ( ( ( here ( we ( have ( got ( so ( ( little ( landfill space ) ) ( left ( that ( we ( 're ( going ( to ( ( run out ) ( before ( ( the end ) ( of ( this decade ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) and ) ( it ( ( 's really ) ( going ( to be ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ) ( We ( ( have ( plenty ( of ( space ( in ( the landfill ) ) ) ) ) ) . ) ) (ROOT (FRAG (ADVP (RB well)) (SBAR (WHNP (WDT that)) (S (VP (MD would) (VP (VB be) (NP (NP (DT a) (NN help)) (SBAR (S (NP (FW i)) (VP (VBP wish) (SBAR (S (NP (PRP they)) (VP (MD would) (VP (VB do) (SBAR (IN that) (S (S (ADVP (RB here)) (NP (PRP we)) (VP (VBP have) (VP (VBN got) (SBAR (IN so) (S (NP (JJ little) (NN landfill) (NN space)) (VP (VBD left) (SBAR (IN that) (S (NP (PRP we)) (VP (VBP 're) (VP (VBG going) (S (VP (TO to) (VP (VB run) (PRT (RP out)) (PP (IN before) (NP (NP (DT the) (NN end)) (PP (IN of) (NP (DT this) (NN decade)))))))))))))))))) (CC and) (S (NP (PRP it)) (VP (VBZ 's) (ADVP (RB really)) (VP (VBG going) (S (VP (TO to) (VP (VB be))))))))))))))))))))))) (ROOT (S (NP (PRP We)) (VP (VBP have) (NP (NP (RB plenty)) (PP (IN of) (NP (NP (NN space)) (PP (IN in) (NP (DT the) (NN landfill))))))) (. .))) well that would be a help i wish they would do that here we have got so little landfill space left that we're going to run out before the end of this decade and it's really going to be We have plenty of space in the landfill. contradiction contradiction contradiction contradiction contradiction contradiction
+6 101415 101415c telephone ( yeah ( ( ( i know ) and ) ( i ( did ( that ( ( ( all ( through college ) ) and ) ( it ( worked too ) ) ) ) ) ) ) ) ( I ( ( ( did ( that all ) ) ( through college ) ) ( but ( it ( never worked ) ) ) ) ) (ROOT (S (VP (VB yeah) (S (S (NP (FW i)) (VP (VBP know))) (CC and) (S (NP (FW i)) (VP (VBD did) (SBAR (IN that) (S (S (NP (DT all)) (PP (IN through) (NP (NN college)))) (CC and) (S (NP (PRP it)) (VP (VBD worked) (ADVP (RB too)))))))))))) (ROOT (S (NP (PRP I)) (VP (VBD did) (ADVP (IN that) (DT all)) (PP (IN through) (NP (NN college))) (SBAR (CC but) (S (NP (PRP it)) (ADVP (RB never)) (VP (VBD worked))))))) yeah i know and i did that all through college and it worked too I did that all through college but it never worked contradiction contradiction contradiction contradiction contradiction contradiction
+7 93958 93958n travel ( ( ( ( ( Calcutta ( seems ( to ( be ( ( the ( only ( other ( production center ) ) ) ) ( ( having ( any pretensions ) ) ( to ( ( artistic creativity ) ( at all ) ) ) ) ) ) ) ) ) , ) but ) ( ironically ( you ( ( 're actually ) ( ( more ( likely ( to ( see ( ( the works ) ( of ( ( ( Satyajit Ray ) or ) ( ( Mrinal Sen ) ( shown ( in ( Europe ( or ( North America ) ) ) ) ) ) ) ) ) ) ) ) ) ( than ( in ( India itself ) ) ) ) ) ) ) ) . ) ( ( Most ( of ( ( Mrinal ( Sen 's ) ) work ) ) ) ( ( can ( be ( found ( in ( European collections ) ) ) ) ) . ) ) (ROOT (S (S (NP (NNP Calcutta)) (VP (VBZ seems) (S (VP (TO to) (VP (VB be) (NP (NP (DT the) (JJ only) (JJ other) (NN production) (NN center)) (VP (VBG having) (NP (DT any) (NNS pretensions)) (PP (TO to) (NP (NP (JJ artistic) (NN creativity)) (ADVP (IN at) (DT all))))))))))) (, ,) (CC but) (S (ADVP (RB ironically)) (NP (PRP you)) (VP (VBP 're) (ADVP (RB actually)) (ADJP (ADJP (RBR more) (JJ likely) (S (VP (TO to) (VP (VB see) (NP (NP (DT the) (NNS works)) (PP (IN of) (NP (NP (NNP Satyajit) (NNP Ray)) (CC or) (NP (NP (NNP Mrinal) (NNP Sen)) (VP (VBN shown) (PP (IN in) (NP (NNP Europe) (CC or) (NNP North) (NNP America)))))))))))) (ADVP (IN than) (PP (IN in) (S (VP (VBG India) (NP (PRP itself))))))))) (. .))) (ROOT (S (NP (NP (JJS Most)) (PP (IN of) (NP (NP (NNP Mrinal) (NNP Sen) (POS 's)) (NN work)))) (VP (MD can) (VP (VB be) (VP (VBN found) (PP (IN in) (NP (JJ European) (NNS collections)))))) (. .))) Calcutta seems to be the only other production center having any pretensions to artistic creativity at all, but ironically you're actually more likely to see the works of Satyajit Ray or Mrinal Sen shown in Europe or North America than in India itself. Most of Mrinal Sen's work can be found in European collections. neutral neutral entailment neutral neutral neutral
+8 12567 12567c slate ( ( If ( ( that investor ) ( were ( willing ( to ( pay ( extra ( for ( ( the security ) ( of ( limited downside ) ) ) ) ) ) ) ) ) ) ) ( , ( she ( ( could ( ( buy ( put options ) ) ( with ( ( a ( strike price ) ) ( of ( ( ( $ 98 ) , ) ( which ( would ( ( ( lock ( in ( ( her profit ) ( on ( ( the shares ) ( at ( $ 18 ) ) ) ) ) ) ) , ) ( less ( whatever ( ( the options ) cost ) ) ) ) ) ) ) ) ) ) ) ) . ) ) ) ) ( ( THe ( strike price ) ) ( ( could ( be ( $ 8 ) ) ) . ) ) (ROOT (S (SBAR (IN If) (S (NP (DT that) (NN investor)) (VP (VBD were) (ADJP (JJ willing) (S (VP (TO to) (VP (VB pay) (NP (NP (JJ extra)) (PP (IN for) (NP (NP (DT the) (NN security)) (PP (IN of) (NP (JJ limited) (NN downside))))))))))))) (, ,) (NP (PRP she)) (VP (MD could) (VP (VB buy) (NP (NN put) (NNS options)) (PP (IN with) (NP (NP (DT a) (NN strike) (NN price)) (PP (IN of) (NP (NP ($ $) (CD 98)) (, ,) (SBAR (WHNP (WDT which)) (S (VP (MD would) (VP (VB lock) (PP (IN in) (NP (NP (PRP$ her) (NN profit)) (PP (IN on) (NP (NP (DT the) (NNS shares)) (PP (IN at) (NP ($ $) (CD 18))))))) (, ,) (ADVP (ADVP (RBR less)) (SBAR (WHNP (WDT whatever)) (S (NP (DT the) (NNS options)) (VP (VBD cost))))))))))))))) (. .))) (ROOT (S (NP (NNP THe) (NN strike) (NN price)) (VP (MD could) (VP (VB be) (NP ($ $) (CD 8)))) (. .))) If that investor were willing to pay extra for the security of limited downside, she could buy put options with a strike price of $98, which would lock in her profit on the shares at $18, less whatever the options cost. THe strike price could be $8. contradiction contradiction contradiction contradiction contradiction contradiction
+9 117487 117487n slate ( ( 3 -RRB- ) ( ( Dare ( you ( ( ( rise ( to ( ( ( ( the occasion ) , ) ( like Raskolnikov ) ) , ) ) ) and ) ( reject ( ( the ( petty rules ) ) ( that ( govern ( lesser men ) ) ) ) ) ) ) ) ? ) ) ( ( ( Would you ) ( ( ( rise up ) and ) ( defeaat ( ( all ( evil lords ) ) ( in ( the town ) ) ) ) ) ) ? ) (ROOT (S (LST (LS 3) (-RRB- -RRB-)) (VP (VB Dare) (S (NP (PRP you)) (VP (VP (VB rise) (PP (TO to) (NP (NP (DT the) (NN occasion)) (, ,) (PP (IN like) (NP (NNP Raskolnikov))) (, ,)))) (CC and) (VP (VB reject) (NP (NP (DT the) (JJ petty) (NNS rules)) (SBAR (WHNP (WDT that)) (S (VP (VBP govern) (NP (JJR lesser) (NNS men)))))))))) (. ?))) (ROOT (SQ (MD Would) (NP (PRP you)) (VP (VP (VB rise) (PRT (RP up))) (CC and) (VP (VB defeaat) (NP (NP (DT all) (JJ evil) (NNS lords)) (PP (IN in) (NP (DT the) (NN town)))))) (. ?))) 3) Dare you rise to the occasion, like Raskolnikov, and reject the petty rules that govern lesser men? Would you rise up and defeaat all evil lords in the town? neutral neutral neutral neutral neutral neutral
+10 9616 9616c travel ( ( The ( ( most important ) directions ) ) ( ( ( are ( simply ( ( up and ) up ) ) ) ( ( ( ( ( ( ( ( leads eventually ) ( to ( the cathedral ) ) ) and ) ( fortress ( commanding ( the hilltop ) ) ) ) , ) and ) down ) ( inevitably ( ( leads ( to ( one ( of ( three gates ) ) ) ) ) ( through ( ( the wall ) ( to ( the ( new town ) ) ) ) ) ) ) ) ) . ) ) ( Go ( ( downwards ( to ( one ( of ( ( ( the gates ) , ) ( ( all ( of which ) ) ( will ( ( lead you ) ( into ( the cathedral ) ) ) ) ) ) ) ) ) ) . ) ) (ROOT (S (NP (DT The) (ADJP (RBS most) (JJ important)) (NNS directions)) (VP (VBP are) (PRN (ADVP (RB simply)) (ADVP (RB up) (CC and) (RB up))) (VP (VP (VBZ leads) (ADVP (RB eventually)) (PP (TO to) (NP (DT the) (NN cathedral)))) (CC and) (VP (VBZ fortress) (NP (JJ commanding) (DT the) (NN hilltop))) (, ,) (CC and) (ADVP (RB down)) (VP (ADVP (RB inevitably)) (VBZ leads) (PP (TO to) (NP (NP (CD one)) (PP (IN of) (NP (CD three) (NNS gates))))) (PP (IN through) (NP (NP (DT the) (NN wall)) (PP (TO to) (NP (DT the) (JJ new) (NN town)))))))) (. .))) (ROOT (S (NP (NNP Go)) (VP (VBZ downwards) (PP (TO to) (NP (NP (CD one)) (PP (IN of) (NP (NP (DT the) (NNS gates)) (, ,) (SBAR (WHNP (DT all) (WHPP (IN of) (WHNP (WDT which)))) (S (VP (MD will) (VP (VB lead) (NP (PRP you)) (PP (IN into) (NP (DT the) (NN cathedral)))))))))))) (. .))) The most important directions are simply up and up leads eventually to the cathedral and fortress commanding the hilltop, and down inevitably leads to one of three gates through the wall to the new town. Go downwards to one of the gates, all of which will lead you into the cathedral. contradiction contradiction entailment contradiction contradiction contradiction
diff --git a/reproduction/seqence_labelling/cws/test/__init__.py b/test/embeddings/__init__.py
similarity index 100%
rename from reproduction/seqence_labelling/cws/test/__init__.py
rename to test/embeddings/__init__.py
diff --git a/test/embeddings/test_bert_embedding.py b/test/embeddings/test_bert_embedding.py
new file mode 100644
index 00000000..da81c8c9
--- /dev/null
+++ b/test/embeddings/test_bert_embedding.py
@@ -0,0 +1,21 @@
+import unittest
+from fastNLP import Vocabulary
+from fastNLP.embeddings import BertEmbedding
+import torch
+import os
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestDownload(unittest.TestCase):
+ def test_download(self):
+ # import os
+ vocab = Vocabulary().add_word_lst("This is a test .".split())
+ embed = BertEmbedding(vocab, model_dir_or_name='en')
+ words = torch.LongTensor([[2, 3, 4, 0]])
+ print(embed(words).size())
+
+ def test_word_drop(self):
+ vocab = Vocabulary().add_word_lst("This is a test .".split())
+ embed = BertEmbedding(vocab, model_dir_or_name='en', dropout=0.1, word_dropout=0.2)
+ for i in range(10):
+ words = torch.LongTensor([[2, 3, 4, 0]])
+ print(embed(words).size())
\ No newline at end of file
diff --git a/test/embeddings/test_char_embedding.py b/test/embeddings/test_char_embedding.py
new file mode 100644
index 00000000..ceafe4f5
--- /dev/null
+++ b/test/embeddings/test_char_embedding.py
@@ -0,0 +1,26 @@
+import unittest
+
+import torch
+
+from fastNLP import Vocabulary, DataSet, Instance
+from fastNLP.embeddings.char_embedding import LSTMCharEmbedding, CNNCharEmbedding
+
+
+class TestCharEmbed(unittest.TestCase):
+ def test_case_1(self):
+ ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])])
+ vocab = Vocabulary().from_dataset(ds, field_name='words')
+ self.assertEqual(len(vocab), 5)
+ embed = LSTMCharEmbedding(vocab, embed_size=60)
+ x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
+ y = embed(x)
+ self.assertEqual(tuple(y.size()), (2, 3, 60))
+
+ def test_case_2(self):
+ ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['Jack'])])
+ vocab = Vocabulary().from_dataset(ds, field_name='words')
+ self.assertEqual(len(vocab), 5)
+ embed = CNNCharEmbedding(vocab, embed_size=60)
+ x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
+ y = embed(x)
+ self.assertEqual(tuple(y.size()), (2, 3, 60))
diff --git a/test/embeddings/test_elmo_embedding.py b/test/embeddings/test_elmo_embedding.py
new file mode 100644
index 00000000..a087f0a4
--- /dev/null
+++ b/test/embeddings/test_elmo_embedding.py
@@ -0,0 +1,21 @@
+
+import unittest
+from fastNLP import Vocabulary
+from fastNLP.embeddings import ElmoEmbedding
+import torch
+import os
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestDownload(unittest.TestCase):
+ def test_download_small(self):
+ # import os
+ vocab = Vocabulary().add_word_lst("This is a test .".split())
+ elmo_embed = ElmoEmbedding(vocab, model_dir_or_name='en-small')
+ words = torch.LongTensor([[0, 1, 2]])
+ print(elmo_embed(words).size())
+
+
+# 首先保证所有权重可以加载;上传权重;验证可以下载
+
+
+
diff --git a/test/embeddings/test_stack_embeddings.py b/test/embeddings/test_stack_embeddings.py
new file mode 100644
index 00000000..2eb0b414
--- /dev/null
+++ b/test/embeddings/test_stack_embeddings.py
@@ -0,0 +1,20 @@
+import unittest
+
+import torch
+
+from fastNLP import Vocabulary, DataSet, Instance
+from fastNLP.embeddings import LSTMCharEmbedding, CNNCharEmbedding, StackEmbedding
+
+
+class TestCharEmbed(unittest.TestCase):
+ def test_case_1(self):
+ ds = DataSet([Instance(words=['hello', 'world']), Instance(words=['hello', 'Jack'])])
+ vocab = Vocabulary().from_dataset(ds, field_name='words')
+ self.assertEqual(len(vocab), 5)
+ cnn_embed = CNNCharEmbedding(vocab, embed_size=60)
+ lstm_embed = LSTMCharEmbedding(vocab, embed_size=70)
+ embed = StackEmbedding([cnn_embed, lstm_embed])
+ x = torch.LongTensor([[2, 1, 0], [4, 3, 4]])
+ y = embed(x)
+ self.assertEqual(tuple(y.size()), (2, 3, 130))
+
diff --git a/test/embeddings/test_static_embedding.py b/test/embeddings/test_static_embedding.py
new file mode 100644
index 00000000..c17daa0a
--- /dev/null
+++ b/test/embeddings/test_static_embedding.py
@@ -0,0 +1,140 @@
+import unittest
+
+from fastNLP.embeddings import StaticEmbedding
+from fastNLP import Vocabulary
+import torch
+import os
+
+
+class TestLoad(unittest.TestCase):
+ def test_norm1(self):
+ # 测试只对可以找到的norm
+ vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
+ embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt',
+ only_norm_found_vector=True)
+ self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
+ self.assertNotEqual(torch.norm(embed(torch.LongTensor([[4]]))).item(), 1)
+
+ def test_norm2(self):
+ # 测试对所有都norm
+ vocab = Vocabulary().add_word_lst(['the', 'a', 'notinfile'])
+ embed = StaticEmbedding(vocab, model_dir_or_name='test/data_for_tests/glove.6B.50d_test.txt',
+ normalize=True)
+ self.assertEqual(round(torch.norm(embed(torch.LongTensor([[2]]))).item(), 4), 1)
+ self.assertEqual(round(torch.norm(embed(torch.LongTensor([[4]]))).item(), 4), 1)
+
+ def test_dropword(self):
+ # 测试是否可以通过drop word
+ vocab = Vocabulary().add_word_lst([chr(i) for i in range(1, 200)])
+ embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=10, dropout=0.1, word_dropout=0.4)
+ for i in range(10):
+ length = torch.randint(1, 50, (1,)).item()
+ batch = torch.randint(1, 4, (1,)).item()
+ words = torch.randint(1, 200, (batch, length)).long()
+ embed(words)
+
+class TestRandomSameEntry(unittest.TestCase):
+ def test_same_vector(self):
+ vocab = Vocabulary().add_word_lst(["The", "the", "THE", 'a', "A"])
+ embed = StaticEmbedding(vocab, model_dir_or_name=None, embedding_dim=5, lower=True)
+ words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'a', 'A']]])
+ words = embed(words)
+ embed_0 = words[0, 0]
+ for i in range(1, 3):
+ assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
+ embed_0 = words[0, 3]
+ for i in range(3, 5):
+ assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
+
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_same_vector2(self):
+ vocab = Vocabulary().add_word_lst(["The", 'a', 'b', "the", "THE", "B", 'a', "A"])
+ embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
+ lower=True)
+ words = torch.LongTensor([[vocab.to_index(word) for word in ["The", "the", "THE", 'b', "B", 'a', 'A']]])
+ words = embed(words)
+ embed_0 = words[0, 0]
+ for i in range(1, 3):
+ assert torch.sum(embed_0==words[0, i]).eq(len(embed_0))
+ embed_0 = words[0, 3]
+ for i in range(3, 5):
+ assert torch.sum(embed_0 == words[0, i]).eq(len(embed_0))
+
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_same_vector3(self):
+ # 验证lower
+ word_lst = ["The", "the"]
+ no_create_word_lst = ['of', 'Of', 'With', 'with']
+ vocab = Vocabulary().add_word_lst(word_lst)
+ vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
+ embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
+ lower=True)
+ words = torch.LongTensor([[vocab.to_index(word) for word in word_lst+no_create_word_lst]])
+ words = embed(words)
+
+ lowered_word_lst = [word.lower() for word in word_lst]
+ lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
+ lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
+ lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
+ lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
+ lower=False)
+ lowered_words = torch.LongTensor([[lowered_vocab.to_index(word) for word in lowered_word_lst+lowered_no_create_word_lst]])
+ lowered_words = lowered_embed(lowered_words)
+
+ all_words = word_lst + no_create_word_lst
+
+ for idx, (word_i, word_j) in enumerate(zip(words[0], lowered_words[0])):
+ with self.subTest(idx=idx, word=all_words[idx]):
+ assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
+
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_same_vector4(self):
+ # 验证在有min_freq下的lower
+ word_lst = ["The", "the", "the", "The", "a", "A"]
+ no_create_word_lst = ['of', 'Of', "Of", "of", 'With', 'with']
+ all_words = word_lst[:-2] + no_create_word_lst[:-2]
+ vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
+ vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
+ embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
+ lower=True)
+ words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
+ words = embed(words)
+
+ lowered_word_lst = [word.lower() for word in word_lst]
+ lowered_no_create_word_lst = [word.lower() for word in no_create_word_lst]
+ lowered_vocab = Vocabulary().add_word_lst(lowered_word_lst)
+ lowered_vocab.add_word_lst(lowered_no_create_word_lst, no_create_entry=True)
+ lowered_embed = StaticEmbedding(lowered_vocab, model_dir_or_name='en-glove-6B-100d',
+ lower=False)
+ lowered_words = torch.LongTensor([[lowered_vocab.to_index(word.lower()) for word in all_words]])
+ lowered_words = lowered_embed(lowered_words)
+
+ for idx in range(len(all_words)):
+ word_i, word_j = words[0, idx], lowered_words[0, idx]
+ with self.subTest(idx=idx, word=all_words[idx]):
+ assert torch.sum(word_i == word_j).eq(lowered_embed.embed_size)
+
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_same_vector5(self):
+ # 检查通过使用min_freq后的word是否内容一致
+ word_lst = ["they", "the", "they", "the", 'he', 'he', "a", "A"]
+ no_create_word_lst = ['of', "of", "she", "she", 'With', 'with']
+ all_words = word_lst[:-2] + no_create_word_lst[:-2]
+ vocab = Vocabulary().add_word_lst(word_lst)
+ vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
+ embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6B-100d',
+ lower=False, min_freq=2)
+ words = torch.LongTensor([[vocab.to_index(word) for word in all_words]])
+ words = embed(words)
+
+ min_freq_vocab = Vocabulary(min_freq=2).add_word_lst(word_lst)
+ min_freq_vocab.add_word_lst(no_create_word_lst, no_create_entry=True)
+ min_freq_embed = StaticEmbedding(min_freq_vocab, model_dir_or_name='en-glove-6B-100d',
+ lower=False)
+ min_freq_words = torch.LongTensor([[min_freq_vocab.to_index(word.lower()) for word in all_words]])
+ min_freq_words = min_freq_embed(min_freq_words)
+
+ for idx in range(len(all_words)):
+ word_i, word_j = words[0, idx], min_freq_words[0, idx]
+ with self.subTest(idx=idx, word=all_words[idx]):
+ assert torch.sum(word_i == word_j).eq(min_freq_embed.embed_size)
\ No newline at end of file
diff --git a/test/io/loader/test_classification_loader.py b/test/io/loader/test_classification_loader.py
new file mode 100644
index 00000000..28f08921
--- /dev/null
+++ b/test/io/loader/test_classification_loader.py
@@ -0,0 +1,19 @@
+
+import unittest
+from fastNLP.io.loader.classification import YelpFullLoader
+from fastNLP.io.loader.classification import YelpPolarityLoader
+from fastNLP.io.loader.classification import IMDBLoader
+from fastNLP.io.loader.classification import SST2Loader
+from fastNLP.io.loader.classification import SSTLoader
+import os
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestDownload(unittest.TestCase):
+ def test_download(self):
+ for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
+ loader().download()
+
+ def test_load(self):
+ for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader]:
+ data_bundle = loader().load()
+ print(data_bundle)
diff --git a/test/io/loader/test_conll_loader.py b/test/io/loader/test_conll_loader.py
new file mode 100644
index 00000000..e44b8a2a
--- /dev/null
+++ b/test/io/loader/test_conll_loader.py
@@ -0,0 +1,21 @@
+
+import unittest
+import os
+from fastNLP.io.loader.conll import MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader
+
+class MSRANERTest(unittest.TestCase):
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_download(self):
+ MsraNERLoader().download(re_download=False)
+ data_bundle = MsraNERLoader().load()
+ print(data_bundle)
+
+class PeopleDailyTest(unittest.TestCase):
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_download(self):
+ PeopleDailyNERLoader().download()
+
+class WeiboNERTest(unittest.TestCase):
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_download(self):
+ WeiboNERLoader().download()
\ No newline at end of file
diff --git a/test/io/loader/test_cws_loader.py b/test/io/loader/test_cws_loader.py
new file mode 100644
index 00000000..6ad607c3
--- /dev/null
+++ b/test/io/loader/test_cws_loader.py
@@ -0,0 +1,13 @@
+import unittest
+import os
+from fastNLP.io.loader import CWSLoader
+
+
+class CWSLoaderTest(unittest.TestCase):
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_download(self):
+ dataset_names = ['pku', 'cityu', 'as', 'msra']
+ for dataset_name in dataset_names:
+ with self.subTest(dataset_name=dataset_name):
+ data_bundle = CWSLoader(dataset_name=dataset_name).load()
+ print(data_bundle)
\ No newline at end of file
diff --git a/test/io/loader/test_matching_loader.py b/test/io/loader/test_matching_loader.py
new file mode 100644
index 00000000..5c1a91f1
--- /dev/null
+++ b/test/io/loader/test_matching_loader.py
@@ -0,0 +1,22 @@
+
+import unittest
+from fastNLP.io.loader.matching import RTELoader
+from fastNLP.io.loader.matching import QNLILoader
+from fastNLP.io.loader.matching import SNLILoader
+from fastNLP.io.loader.matching import QuoraLoader
+from fastNLP.io.loader.matching import MNLILoader
+import os
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestDownload(unittest.TestCase):
+ def test_download(self):
+ for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
+ loader().download()
+ with self.assertRaises(Exception):
+ QuoraLoader().load()
+
+ def test_load(self):
+ for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
+ data_bundle = loader().load()
+ print(data_bundle)
+
diff --git a/test/io/pipe/test_classification.py b/test/io/pipe/test_classification.py
new file mode 100644
index 00000000..39dc71e0
--- /dev/null
+++ b/test/io/pipe/test_classification.py
@@ -0,0 +1,13 @@
+import unittest
+import os
+
+from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestPipe(unittest.TestCase):
+ def test_process_from_file(self):
+ for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
+ with self.subTest(pipe=pipe):
+ print(pipe)
+ data_bundle = pipe(tokenizer='raw').process_from_file()
+ print(data_bundle)
diff --git a/test/io/pipe/test_conll.py b/test/io/pipe/test_conll.py
new file mode 100644
index 00000000..e8879d71
--- /dev/null
+++ b/test/io/pipe/test_conll.py
@@ -0,0 +1,12 @@
+import unittest
+import os
+from fastNLP.io import MsraNERPipe, PeopleDailyPipe, WeiboNERPipe
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestPipe(unittest.TestCase):
+ def test_process_from_file(self):
+ for pipe in [MsraNERPipe, PeopleDailyPipe, WeiboNERPipe]:
+ with self.subTest(pipe=pipe):
+ print(pipe)
+ data_bundle = pipe().process_from_file()
+ print(data_bundle)
\ No newline at end of file
diff --git a/test/io/pipe/test_cws.py b/test/io/pipe/test_cws.py
new file mode 100644
index 00000000..2fc57ae2
--- /dev/null
+++ b/test/io/pipe/test_cws.py
@@ -0,0 +1,13 @@
+
+import unittest
+import os
+from fastNLP.io.pipe.cws import CWSPipe
+
+class CWSPipeTest(unittest.TestCase):
+ @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+ def test_process_from_file(self):
+ dataset_names = ['pku', 'cityu', 'as', 'msra']
+ for dataset_name in dataset_names:
+ with self.subTest(dataset_name=dataset_name):
+ data_bundle = CWSPipe(dataset_name=dataset_name).process_from_file()
+ print(data_bundle)
\ No newline at end of file
diff --git a/test/io/pipe/test_matching.py b/test/io/pipe/test_matching.py
new file mode 100644
index 00000000..c057bb0c
--- /dev/null
+++ b/test/io/pipe/test_matching.py
@@ -0,0 +1,26 @@
+
+import unittest
+import os
+
+from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe
+from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe
+
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestPipe(unittest.TestCase):
+ def test_process_from_file(self):
+ for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
+ with self.subTest(pipe=pipe):
+ print(pipe)
+ data_bundle = pipe(tokenizer='raw').process_from_file()
+ print(data_bundle)
+
+
+@unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
+class TestBertPipe(unittest.TestCase):
+ def test_process_from_file(self):
+ for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
+ with self.subTest(pipe=pipe):
+ print(pipe)
+ data_bundle = pipe(tokenizer='raw').process_from_file()
+ print(data_bundle)
diff --git a/test/io/test_data_loader.py b/test/io/test_data_loader.py
new file mode 100644
index 00000000..5b1bb749
--- /dev/null
+++ b/test/io/test_data_loader.py
@@ -0,0 +1,15 @@
+import unittest
+
+from fastNLP.core.const import Const
+from fastNLP.io.data_loader import MNLILoader
+
+
+class TestDataLoader(unittest.TestCase):
+
+ def test_mnli_loader(self):
+ ds = MNLILoader().process('test/data_for_tests/sample_mnli.tsv',
+ to_lower=True, get_index=True, seq_len_type='mask')
+ self.assertTrue('train' in ds.datasets)
+ self.assertTrue(len(ds.datasets) == 1)
+ self.assertTrue(len(ds.datasets['train']) == 11)
+ self.assertTrue(isinstance(ds.datasets['train'][0][Const.INPUT_LENS(0)], list))
diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py
index 492545f6..6fb8e4f7 100644
--- a/test/io/test_dataset_loader.py
+++ b/test/io/test_dataset_loader.py
@@ -61,17 +61,17 @@ class TestDatasetLoader(unittest.TestCase):
print(info.datasets)
os.remove(train), os.remove(test)
- def test_import(self):
- import fastNLP
- from fastNLP.io import SNLILoader
- ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
- get_index=True, seq_len_type='seq_len', extra_split=['-'])
- assert 'train' in ds.datasets
- assert len(ds.datasets) == 1
- assert len(ds.datasets['train']) == 3
-
- ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
- get_index=True, seq_len_type='seq_len')
- assert 'train' in ds.datasets
- assert len(ds.datasets) == 1
- assert len(ds.datasets['train']) == 3
+ # def test_import(self):
+ # import fastNLP
+ # from fastNLP.io import SNLILoader
+ # ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
+ # get_index=True, seq_len_type='seq_len', extra_split=['-'])
+ # assert 'train' in ds.datasets
+ # assert len(ds.datasets) == 1
+ # assert len(ds.datasets['train']) == 3
+ #
+ # ds = SNLILoader().process('test/data_for_tests/sample_snli.jsonl', to_lower=True,
+ # get_index=True, seq_len_type='seq_len')
+ # assert 'train' in ds.datasets
+ # assert len(ds.datasets) == 1
+ # assert len(ds.datasets['train']) == 3
diff --git a/test/io/test_embed_loader.py b/test/io/test_embed_loader.py
index ff8ecfcf..bbfe8858 100644
--- a/test/io/test_embed_loader.py
+++ b/test/io/test_embed_loader.py
@@ -16,7 +16,7 @@ class TestEmbedLoader(unittest.TestCase):
self.assertEqual(g_m.shape, (4, 50))
w_m = EmbedLoader.load_with_vocab(word2vec, vocab, normalize=True)
self.assertEqual(w_m.shape, (4, 50))
- self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4)
+ self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 4, delta=1e-4)
def test_load_without_vocab(self):
words = ['the', 'of', 'in', 'a', 'to', 'and']
@@ -28,13 +28,13 @@ class TestEmbedLoader(unittest.TestCase):
self.assertIn(word, vocab)
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True)
self.assertEqual(w_m.shape, (8, 50))
- self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8)
+ self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 8, delta=1e-4)
for word in words:
self.assertIn(word, vocab)
# no unk
w_m, vocab = EmbedLoader.load_without_vocab(word2vec, normalize=True, unknown=None)
self.assertEqual(w_m.shape, (7, 50))
- self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7)
+ self.assertAlmostEqual(np.linalg.norm(w_m, axis=1).sum(), 7, delta=1e-4)
for word in words:
self.assertIn(word, vocab)
diff --git a/test/models/test_bert.py b/test/models/test_bert.py
index 38a16f9b..40b98c81 100644
--- a/test/models/test_bert.py
+++ b/test/models/test_bert.py
@@ -2,28 +2,34 @@ import unittest
import torch
-from fastNLP.models.bert import *
+from fastNLP.models.bert import BertForSequenceClassification, BertForQuestionAnswering, \
+ BertForTokenClassification, BertForMultipleChoice
class TestBert(unittest.TestCase):
def test_bert_1(self):
from fastNLP.core.const import Const
- from fastNLP.modules.encoder._bert import BertConfig
+ from fastNLP.modules.encoder.bert import BertConfig
model = BertForSequenceClassification(2, BertConfig(32000))
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
- token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
- pred = model(input_ids, token_type_ids, input_mask)
+ pred = model(input_ids, input_mask)
+ self.assertTrue(isinstance(pred, dict))
+ self.assertTrue(Const.OUTPUT in pred)
+ self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
+
+ input_mask = torch.LongTensor([3, 2])
+ pred = model(input_ids, input_mask)
self.assertTrue(isinstance(pred, dict))
self.assertTrue(Const.OUTPUT in pred)
self.assertEqual(tuple(pred[Const.OUTPUT].shape), (2, 2))
def test_bert_2(self):
from fastNLP.core.const import Const
- from fastNLP.modules.encoder._bert import BertConfig
+ from fastNLP.modules.encoder.bert import BertConfig
model = BertForMultipleChoice(2, BertConfig(32000))
@@ -38,7 +44,7 @@ class TestBert(unittest.TestCase):
def test_bert_3(self):
from fastNLP.core.const import Const
- from fastNLP.modules.encoder._bert import BertConfig
+ from fastNLP.modules.encoder.bert import BertConfig
model = BertForTokenClassification(7, BertConfig(32000))
@@ -53,7 +59,7 @@ class TestBert(unittest.TestCase):
def test_bert_4(self):
from fastNLP.core.const import Const
- from fastNLP.modules.encoder._bert import BertConfig
+ from fastNLP.modules.encoder.bert import BertConfig
model = BertForQuestionAnswering(BertConfig(32000))
diff --git a/test/models/test_snli.py b/test/models/test_snli.py
new file mode 100644
index 00000000..7a588a4c
--- /dev/null
+++ b/test/models/test_snli.py
@@ -0,0 +1,9 @@
+import unittest
+from .model_runner import *
+from fastNLP.models.snli import ESIM
+
+
+class TestSNLIModel(unittest.TestCase):
+ def test_snli(self):
+ model = ESIM((VOCAB_SIZE, 10), num_labels=NUM_CLS, dropout_rate=0)
+ RUNNER.run_model_with_task(NLI, model)
diff --git a/reproduction/seqence_labelling/ner/test/__init__.py b/test/modules/__init__.py
similarity index 100%
rename from reproduction/seqence_labelling/ner/test/__init__.py
rename to test/modules/__init__.py
diff --git a/test/modules/decoder/__init__.py b/test/modules/decoder/__init__.py
new file mode 100644
index 00000000..e69de29b
diff --git a/test/modules/encoder/test_bert.py b/test/modules/decoder/test_bert.py
similarity index 91%
rename from test/modules/encoder/test_bert.py
rename to test/modules/decoder/test_bert.py
index 2a799478..0fcf01e4 100644
--- a/test/modules/encoder/test_bert.py
+++ b/test/modules/decoder/test_bert.py
@@ -8,7 +8,7 @@ from fastNLP.models.bert import BertModel
class TestBert(unittest.TestCase):
def test_bert_1(self):
- from fastNLP.modules.encoder._bert import BertConfig
+ from fastNLP.modules.encoder.bert import BertConfig
config = BertConfig(32000)
model = BertModel(config)
diff --git a/test/test_tutorials.py b/test/test_tutorials.py
index 6f4a8347..3ec0e381 100644
--- a/test/test_tutorials.py
+++ b/test/test_tutorials.py
@@ -5,14 +5,13 @@ from fastNLP import Instance
from fastNLP import Vocabulary
from fastNLP.core.losses import CrossEntropyLoss
from fastNLP.core.metrics import AccuracyMetric
-
+from fastNLP.io.loader import CSVLoader
class TestTutorial(unittest.TestCase):
def test_fastnlp_10min_tutorial(self):
# 从csv读取数据到DataSet
sample_path = "test/data_for_tests/tutorial_sample_dataset.csv"
- dataset = DataSet.read_csv(sample_path, headers=('raw_sentence', 'label'),
- sep='\t')
+ dataset = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(sample_path)
print(len(dataset))
print(dataset[0])
print(dataset[-3])
@@ -110,7 +109,7 @@ class TestTutorial(unittest.TestCase):
def test_fastnlp_1min_tutorial(self):
# tutorials/fastnlp_1min_tutorial.ipynb
data_path = "test/data_for_tests/tutorial_sample_dataset.csv"
- ds = DataSet.read_csv(data_path, headers=('raw_sentence', 'label'), sep='\t')
+ ds = CSVLoader(headers=['raw_sentence', 'label'], sep=' ')._load(data_path)
print(ds[1])
# 将所有数字转为小写