@@ -8,7 +8,7 @@ install: | |||
- pip install pytest-cov | |||
# command to run tests | |||
script: | |||
- pytest --cov=./ test/ | |||
- pytest --cov=fastNLP test/ | |||
after_success: | |||
- bash <(curl -s https://codecov.io/bash) |
@@ -6,11 +6,12 @@ | |||
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) | |||
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) | |||
fastNLP 是一款轻量级的 NLP 处理套件。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner)、POS-Tagging等)、中文分词、[文本分类](reproduction/text_classification)、[Matching](reproduction/matching)、[指代消解](reproduction/coreference_resolution)、[摘要](reproduction/Summarization)等任务; 也可以使用它构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||
fastNLP 是一款轻量级的 NLP 工具包。你既可以使用它快速地完成一个序列标注([NER](reproduction/seqence_labelling/ner)、POS-Tagging等)、中文分词、[文本分类](reproduction/text_classification)、[Matching](reproduction/matching)、[指代消解](reproduction/coreference_resolution)、[摘要](reproduction/Summarization)等任务; 也可以使用它快速构建许多复杂的网络模型,进行科研。它具有如下的特性: | |||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的DataSet Loader,省去预处理代码; | |||
- 统一的Tabular式数据容器,让数据预处理过程简洁明了。内置多种数据集的Loader和Pipe,省去预处理代码; | |||
- 多种训练、测试组件,例如训练器Trainer;测试器Tester;以及各种评测metrics等等; | |||
- 各种方便的NLP工具,例如预处理embedding加载(包括ELMo和BERT); 中间数据cache等; | |||
- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载 | |||
- 详尽的中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)以供查阅; | |||
- 提供诸多高级模块,例如Variational LSTM, Transformer, CRF等; | |||
- 在序列标注、中文分词、文本分类、Matching、指代消解、摘要等任务上封装了各种模型可供直接使用,详细内容见 [reproduction](reproduction) 部分; | |||
@@ -36,7 +37,7 @@ pip install fastNLP | |||
python -m spacy download en | |||
``` | |||
目前使用pip安装fastNLP的版本是0.4.1,有较多功能仍未更新,最新内容以master分支为准。 | |||
目前使用pypi安装fastNLP的版本是0.4.1,有较多功能仍未更新,最新内容以master分支为准。 | |||
fastNLP0.5.0版本将在近期推出,请密切关注。 | |||
@@ -44,7 +45,7 @@ fastNLP0.5.0版本将在近期推出,请密切关注。 | |||
- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html) | |||
- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html) | |||
- [2. 使用DataSetLoader加载数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_load_dataset.html) | |||
- [2. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_load_dataset.html) | |||
- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) | |||
- [4. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_loss_optimizer.html) | |||
- [5. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_datasetiter.html) | |||
@@ -118,7 +119,7 @@ fastNLP的大致工作流程如上图所示,而项目结构如下: | |||
</tr> | |||
<tr> | |||
<td><b> fastNLP.io </b></td> | |||
<td> 实现了读写功能,包括数据读入,模型读写等 </td> | |||
<td> 实现了读写功能,包括数据读入与预处理,模型读写,自动下载等 </td> | |||
</tr> | |||
</table> | |||
@@ -14,13 +14,13 @@ help: | |||
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) | |||
apidoc: | |||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) | |||
$(SPHINXAPIDOC) -efM -o source ../$(SPHINXPROJ) && python3 format.py | |||
server: | |||
cd build/html && python -m http.server | |||
dev: | |||
rm -rf build/html && make html && make server | |||
rm -rf build && make html && make server | |||
.PHONY: help Makefile | |||
@@ -0,0 +1,142 @@ | |||
import inspect | |||
import os | |||
import sys | |||
def _colored_string(string: str, color: str or int) -> str: | |||
"""在终端中显示一串有颜色的文字 | |||
:param string: 在终端中显示的文字 | |||
:param color: 文字的颜色 | |||
:return: | |||
""" | |||
if isinstance(color, str): | |||
color = { | |||
"black": 30, "Black": 30, "BLACK": 30, | |||
"red": 31, "Red": 31, "RED": 31, | |||
"green": 32, "Green": 32, "GREEN": 32, | |||
"yellow": 33, "Yellow": 33, "YELLOW": 33, | |||
"blue": 34, "Blue": 34, "BLUE": 34, | |||
"purple": 35, "Purple": 35, "PURPLE": 35, | |||
"cyan": 36, "Cyan": 36, "CYAN": 36, | |||
"white": 37, "White": 37, "WHITE": 37 | |||
}[color] | |||
return "\033[%dm%s\033[0m" % (color, string) | |||
def gr(string, flag): | |||
if flag: | |||
return _colored_string(string, "green") | |||
else: | |||
return _colored_string(string, "red") | |||
def find_all_modules(): | |||
modules = {} | |||
children = {} | |||
to_doc = set() | |||
root = '../fastNLP' | |||
for path, dirs, files in os.walk(root): | |||
for file in files: | |||
if file.endswith('.py'): | |||
name = ".".join(path.split('/')[1:]) | |||
if file.split('.')[0] != "__init__": | |||
name = name + '.' + file.split('.')[0] | |||
__import__(name) | |||
m = sys.modules[name] | |||
modules[name] = m | |||
try: | |||
m.__all__ | |||
except: | |||
print(name, "__all__ missing") | |||
continue | |||
if m.__doc__ is None: | |||
print(name, "__doc__ missing") | |||
continue | |||
if "undocumented" not in m.__doc__: | |||
to_doc.add(name) | |||
for module in to_doc: | |||
t = ".".join(module.split('.')[:-1]) | |||
if t in to_doc: | |||
if t not in children: | |||
children[t] = set() | |||
children[t].add(module) | |||
for m in children: | |||
children[m] = sorted(children[m]) | |||
return modules, to_doc, children | |||
def create_rst_file(modules, name, children): | |||
m = modules[name] | |||
with open("./source/" + name + ".rst", "w") as fout: | |||
t = "=" * len(name) | |||
fout.write(name + "\n") | |||
fout.write(t + "\n") | |||
fout.write("\n") | |||
fout.write(".. automodule:: " + name + "\n") | |||
if name != "fastNLP.core" and len(m.__all__) > 0: | |||
fout.write(" :members: " + ", ".join(m.__all__) + "\n") | |||
short = name[len("fastNLP."):] | |||
if not (short.startswith('models') or short.startswith('modules') or short.startswith('embeddings')): | |||
fout.write(" :inherited-members:\n") | |||
fout.write("\n") | |||
if name in children: | |||
fout.write("子模块\n------\n\n.. toctree::\n :maxdepth: 1\n\n") | |||
for module in children[name]: | |||
fout.write(" " + module + "\n") | |||
def check_file(m, name): | |||
names = name.split('.') | |||
test_name = "test." + ".".join(names[1:-1]) + ".test_" + names[-1] | |||
try: | |||
__import__(test_name) | |||
tm = sys.modules[test_name] | |||
except ModuleNotFoundError: | |||
tm = None | |||
tested = tm is not None | |||
funcs = {} | |||
classes = {} | |||
for item, obj in inspect.getmembers(m): | |||
if inspect.isclass(obj) and obj.__module__ == name and not obj.__name__.startswith('_'): | |||
this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm), {}) | |||
for i in dir(obj): | |||
func = getattr(obj, i) | |||
if inspect.isfunction(func) and not i.startswith('_'): | |||
this[2][i] = (func.__doc__ is not None, False) | |||
classes[obj.__name__] = this | |||
if inspect.isfunction(obj) and obj.__module__ == name and not obj.__name__.startswith('_'): | |||
this = (obj.__doc__ is not None, tested and obj.__name__ in dir(tm)) # docs | |||
funcs[obj.__name__] = this | |||
return funcs, classes | |||
def check_files(modules, out=sys.stdout): | |||
for name in sorted(modules.keys()): | |||
print(name, file=out) | |||
funcs, classes = check_file(modules[name], name) | |||
for f in funcs: | |||
print("%-30s \t %s \t %s" % (f, gr("文档", funcs[f][0]), gr("测试", funcs[f][1])), file=out) | |||
for c in classes: | |||
print("%-30s \t %s \t %s" % (c, gr("文档", classes[c][0]), gr("测试", classes[c][1])), file=out) | |||
methods = classes[c][2] | |||
for f in methods: | |||
print(" %-28s \t %s" % (f, gr("文档", methods[f][0])), file=out) | |||
print(file=out) | |||
def main(): | |||
sys.path.append("..") | |||
print(_colored_string('Getting modules...', "Blue")) | |||
modules, to_doc, children = find_all_modules() | |||
print(_colored_string('Done!', "Green")) | |||
print(_colored_string('Creating rst files...', "Blue")) | |||
for name in to_doc: | |||
create_rst_file(modules, name, children) | |||
print(_colored_string('Done!', "Green")) | |||
print(_colored_string('Checking all files...', "Blue")) | |||
check_files(modules) | |||
print(_colored_string('Done!', "Green")) | |||
if __name__ == "__main__": | |||
main() |
@@ -48,12 +48,14 @@ extensions = [ | |||
autodoc_default_options = { | |||
'member-order': 'bysource', | |||
'special-members': '__init__', | |||
'undoc-members': True, | |||
'undoc-members': False, | |||
} | |||
autoclass_content = "class" | |||
# Add any paths that contain templates here, relative to this directory. | |||
templates_path = ['_templates'] | |||
# template_bridge | |||
# The suffix(es) of source filenames. | |||
# You can specify multiple suffix as a list of string: | |||
# | |||
@@ -113,7 +115,7 @@ html_static_path = ['_static'] | |||
# -- Options for HTMLHelp output --------------------------------------------- | |||
# Output file base name for HTML help builder. | |||
htmlhelp_basename = 'fastNLPdoc' | |||
htmlhelp_basename = 'fastNLP doc' | |||
# -- Options for LaTeX output ------------------------------------------------ | |||
@@ -166,10 +168,12 @@ texinfo_documents = [ | |||
# -- Extension configuration ------------------------------------------------- | |||
def maybe_skip_member(app, what, name, obj, skip, options): | |||
if name.startswith("_"): | |||
return True | |||
if obj.__doc__ is None: | |||
return True | |||
if name == "__init__": | |||
return False | |||
if name.startswith("_"): | |||
return True | |||
return False | |||
@@ -2,6 +2,6 @@ fastNLP.core.batch | |||
================== | |||
.. automodule:: fastNLP.core.batch | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: BatchIter, DataSetIter, TorchLoaderIter | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.callback | |||
===================== | |||
.. automodule:: fastNLP.core.callback | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, TesterCallback, CallbackException, EarlyStopError | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.const | |||
================== | |||
.. automodule:: fastNLP.core.const | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Const | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.dataset | |||
==================== | |||
.. automodule:: fastNLP.core.dataset | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: DataSet | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.field | |||
================== | |||
.. automodule:: fastNLP.core.field | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Padder, AutoPadder, EngChar2DPadder | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.instance | |||
===================== | |||
.. automodule:: fastNLP.core.instance | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Instance | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.losses | |||
=================== | |||
.. automodule:: fastNLP.core.losses | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: LossBase, LossFunc, LossInForward, CrossEntropyLoss, BCELoss, L1Loss, NLLLoss | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.metrics | |||
==================== | |||
.. automodule:: fastNLP.core.metrics | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: MetricBase, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.optimizer | |||
====================== | |||
.. automodule:: fastNLP.core.optimizer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Optimizer, SGD, Adam, AdamW | |||
:inherited-members: | |||
@@ -2,12 +2,9 @@ fastNLP.core | |||
============ | |||
.. automodule:: fastNLP.core | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
子模块 | |||
---------- | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
@@ -2,6 +2,6 @@ fastNLP.core.sampler | |||
==================== | |||
.. automodule:: fastNLP.core.sampler | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Sampler, BucketSampler, SequentialSampler, RandomSampler | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.tester | |||
=================== | |||
.. automodule:: fastNLP.core.tester | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Tester | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.trainer | |||
==================== | |||
.. automodule:: fastNLP.core.trainer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Trainer | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.utils | |||
================== | |||
.. automodule:: fastNLP.core.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: cache_results, seq_len_to_mask, get_seq_len | |||
:inherited-members: | |||
@@ -2,6 +2,6 @@ fastNLP.core.vocabulary | |||
======================= | |||
.. automodule:: fastNLP.core.vocabulary | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Vocabulary, VocabularyOption | |||
:inherited-members: | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.embeddings.bert\_embedding | |||
================================== | |||
fastNLP.embeddings.bert_embedding | |||
================================= | |||
.. automodule:: fastNLP.embeddings.bert_embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: BertEmbedding, BertWordPieceEncoder | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.embeddings.char\_embedding | |||
================================== | |||
fastNLP.embeddings.char_embedding | |||
================================= | |||
.. automodule:: fastNLP.embeddings.char_embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: CNNCharEmbedding, LSTMCharEmbedding | |||
@@ -0,0 +1,6 @@ | |||
fastNLP.embeddings.contextual_embedding | |||
======================================= | |||
.. automodule:: fastNLP.embeddings.contextual_embedding | |||
:members: ContextualEmbedding | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.embeddings.elmo\_embedding | |||
================================== | |||
fastNLP.embeddings.elmo_embedding | |||
================================= | |||
.. automodule:: fastNLP.embeddings.elmo_embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: ElmoEmbedding | |||
@@ -2,6 +2,5 @@ fastNLP.embeddings.embedding | |||
============================ | |||
.. automodule:: fastNLP.embeddings.embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Embedding, TokenEmbedding | |||
@@ -2,18 +2,17 @@ fastNLP.embeddings | |||
================== | |||
.. automodule:: fastNLP.embeddings | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Embedding, TokenEmbedding, StaticEmbedding, ElmoEmbedding, BertEmbedding, BertWordPieceEncoder, StackEmbedding, LSTMCharEmbedding, CNNCharEmbedding, get_embeddings | |||
子模块 | |||
---------- | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.embeddings.bert_embedding | |||
fastNLP.embeddings.char_embedding | |||
fastNLP.embeddings.contextual_embedding | |||
fastNLP.embeddings.elmo_embedding | |||
fastNLP.embeddings.embedding | |||
fastNLP.embeddings.stack_embedding | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.embeddings.stack\_embedding | |||
=================================== | |||
fastNLP.embeddings.stack_embedding | |||
================================== | |||
.. automodule:: fastNLP.embeddings.stack_embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: StackEmbedding | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.embeddings.static\_embedding | |||
==================================== | |||
fastNLP.embeddings.static_embedding | |||
=================================== | |||
.. automodule:: fastNLP.embeddings.static_embedding | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: StaticEmbedding | |||
@@ -2,6 +2,5 @@ fastNLP.embeddings.utils | |||
======================== | |||
.. automodule:: fastNLP.embeddings.utils | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: get_embeddings | |||
@@ -1,7 +0,0 @@ | |||
fastNLP.io.base\_loader | |||
======================= | |||
.. automodule:: fastNLP.io.base_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -0,0 +1,7 @@ | |||
fastNLP.io.data_bundle | |||
====================== | |||
.. automodule:: fastNLP.io.data_bundle | |||
:members: DataBundle | |||
:inherited-members: | |||
@@ -1,7 +0,0 @@ | |||
fastNLP.io.data\_loader | |||
========================== | |||
.. automodule:: fastNLP.io.data_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,7 +0,0 @@ | |||
fastNLP.io.dataset\_loader | |||
========================== | |||
.. automodule:: fastNLP.io.dataset_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: |
@@ -1,7 +1,7 @@ | |||
fastNLP.io.embed\_loader | |||
======================== | |||
fastNLP.io.embed_loader | |||
======================= | |||
.. automodule:: fastNLP.io.embed_loader | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: EmbedLoader, EmbeddingOption | |||
:inherited-members: | |||
@@ -0,0 +1,7 @@ | |||
fastNLP.io.file_utils | |||
===================== | |||
.. automodule:: fastNLP.io.file_utils | |||
:members: cached_path, get_filepath, get_cache_path, split_filename_suffix, get_from_cache | |||
:inherited-members: | |||
@@ -0,0 +1,7 @@ | |||
fastNLP.io.loader | |||
================= | |||
.. automodule:: fastNLP.io.loader | |||
:members: Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, PeopleDailyNERLoader, WeiboNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader | |||
:inherited-members: | |||
@@ -1,7 +1,7 @@ | |||
fastNLP.io.model\_io | |||
==================== | |||
fastNLP.io.model_io | |||
=================== | |||
.. automodule:: fastNLP.io.model_io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: ModelLoader, ModelSaver | |||
:inherited-members: | |||
@@ -0,0 +1,7 @@ | |||
fastNLP.io.pipe | |||
=============== | |||
.. automodule:: fastNLP.io.pipe | |||
:members: Pipe, CWSPipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, WeiboNERPipe, PeopleDailyPipe, Conll2003Pipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe | |||
:inherited-members: | |||
@@ -2,18 +2,19 @@ fastNLP.io | |||
========== | |||
.. automodule:: fastNLP.io | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: DataBundle, EmbedLoader, Loader, YelpLoader, YelpFullLoader, YelpPolarityLoader, IMDBLoader, SSTLoader, SST2Loader, ChnSentiCorpLoader, ConllLoader, Conll2003Loader, Conll2003NERLoader, OntoNotesNERLoader, CTBLoader, MsraNERLoader, WeiboNERLoader, PeopleDailyNERLoader, CSVLoader, JsonLoader, CWSLoader, MNLILoader, QuoraLoader, SNLILoader, QNLILoader, RTELoader, Pipe, YelpFullPipe, YelpPolarityPipe, SSTPipe, SST2Pipe, IMDBPipe, ChnSentiCorpPipe, Conll2003Pipe, Conll2003NERPipe, OntoNotesNERPipe, MsraNERPipe, PeopleDailyPipe, WeiboNERPipe, CWSPipe, MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, ModelLoader, ModelSaver | |||
:inherited-members: | |||
子模块 | |||
---------- | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.io.base_loader | |||
fastNLP.io.data_bundle | |||
fastNLP.io.embed_loader | |||
fastNLP.io.dataset_loader | |||
fastNLP.io.data_loader | |||
fastNLP.io.file_utils | |||
fastNLP.io.loader | |||
fastNLP.io.model_io | |||
fastNLP.io.pipe | |||
fastNLP.io.utils |
@@ -0,0 +1,7 @@ | |||
fastNLP.io.utils | |||
================ | |||
.. automodule:: fastNLP.io.utils | |||
:members: check_loader_paths | |||
:inherited-members: | |||
@@ -0,0 +1,6 @@ | |||
fastNLP.models.bert | |||
=================== | |||
.. automodule:: fastNLP.models.bert | |||
:members: BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.models.biaffine\_parser | |||
=============================== | |||
fastNLP.models.biaffine_parser | |||
============================== | |||
.. automodule:: fastNLP.models.biaffine_parser | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: BiaffineParser, GraphParser | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.models.cnn\_text\_classification | |||
======================================== | |||
fastNLP.models.cnn_text_classification | |||
====================================== | |||
.. automodule:: fastNLP.models.cnn_text_classification | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: CNNText | |||
@@ -2,16 +2,15 @@ fastNLP.models | |||
============== | |||
.. automodule:: fastNLP.models | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: CNNText, SeqLabeling, AdvSeqLabel, ESIM, StarTransEnc, STSeqLabel, STNLICls, STSeqCls, BiaffineParser, GraphParser, BertForSequenceClassification, BertForSentenceMatching, BertForMultipleChoice, BertForTokenClassification, BertForQuestionAnswering | |||
子模块 | |||
---------- | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
fastNLP.models.bert | |||
fastNLP.models.biaffine_parser | |||
fastNLP.models.cnn_text_classification | |||
fastNLP.models.sequence_labeling | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.models.sequence\_labeling | |||
================================= | |||
fastNLP.models.sequence_labeling | |||
================================ | |||
.. automodule:: fastNLP.models.sequence_labeling | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: SeqLabeling, AdvSeqLabel, BiLSTMCRF | |||
@@ -2,6 +2,5 @@ fastNLP.models.snli | |||
=================== | |||
.. automodule:: fastNLP.models.snli | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: ESIM | |||
@@ -1,7 +1,6 @@ | |||
fastNLP.models.star\_transformer | |||
================================ | |||
fastNLP.models.star_transformer | |||
=============================== | |||
.. automodule:: fastNLP.models.star_transformer | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: StarTransEnc, STNLICls, STSeqCls, STSeqLabel | |||
@@ -2,7 +2,5 @@ fastNLP.modules.decoder | |||
======================= | |||
.. automodule:: fastNLP.modules.decoder | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: MLP, ConditionalRandomField, viterbi_decode, allowed_transitions | |||
@@ -2,6 +2,5 @@ fastNLP.modules.encoder | |||
======================= | |||
.. automodule:: fastNLP.modules.encoder | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention | |||
@@ -2,16 +2,14 @@ fastNLP.modules | |||
=============== | |||
.. automodule:: fastNLP.modules | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: ConvolutionCharEncoder, LSTMCharEncoder, ConvMaxpool, LSTM, StarTransformer, TransformerEncoder, VarRNN, VarLSTM, VarGRU, MaxPool, MaxPoolWithMask, AvgPool, AvgPoolWithMask, MultiHeadAttention, MLP, ConditionalRandomField, viterbi_decode, allowed_transitions, TimestepDropout | |||
子模块 | |||
----------- | |||
------ | |||
.. toctree:: | |||
:titlesonly: | |||
:maxdepth: 1 | |||
fastNLP.modules.decoder | |||
fastNLP.modules.encoder | |||
fastNLP.modules.encoder | |||
fastNLP.modules.utils |
@@ -0,0 +1,6 @@ | |||
fastNLP.modules.utils | |||
===================== | |||
.. automodule:: fastNLP.modules.utils | |||
:members: initial_parameter, summary | |||
@@ -1,13 +1,12 @@ | |||
API 文档 | |||
=============== | |||
fastNLP | |||
======= | |||
.. automodule:: fastNLP | |||
:members: | |||
:undoc-members: | |||
:show-inheritance: | |||
:members: Instance, FieldArray, DataSetIter, BatchIter, TorchLoaderIter, Vocabulary, DataSet, Const, Trainer, Tester, Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC, LRFinder, Padder, AutoPadder, EngChar2DPadder, AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric, Optimizer, SGD, Adam, AdamW, Sampler, SequentialSampler, BucketSampler, RandomSampler, LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward, cache_results, logger | |||
:inherited-members: | |||
内部模块 | |||
----------- | |||
子模块 | |||
------ | |||
.. toctree:: | |||
:maxdepth: 1 | |||
@@ -2,7 +2,6 @@ fastNLP | |||
======= | |||
.. toctree:: | |||
:titlesonly: | |||
:maxdepth: 4 | |||
fastNLP |
@@ -1,57 +1,53 @@ | |||
================================= | |||
使用DataSetLoader加载数据集 | |||
================================= | |||
======================================= | |||
使用Loader和Pipe加载并处理数据集 | |||
======================================= | |||
这一部分是一个关于如何加载数据集的教程 | |||
教程目录: | |||
- `Part I: 数据集容器`_ | |||
- `Part II: 数据集的使用方式`_ | |||
- `Part III: 不同数据类型的DataSetLoader`_ | |||
- `Part IV: DataSetLoader举例`_ | |||
- `Part V: fastNLP封装好的数据集加载器`_ | |||
- `Part I: 数据集容器DataBundle`_ | |||
- `Part II: 加载数据集的基类Loader`_ | |||
- `Part III: 不同格式类型的基础Loader`_ | |||
- `Part IV: 使用Pipe对数据集进行预处理`_ | |||
- `Part V: fastNLP封装好的Loader和Pipe`_ | |||
---------------------------- | |||
Part I: 数据集容器 | |||
---------------------------- | |||
------------------------------------ | |||
Part I: 数据集容器DataBundle | |||
------------------------------------ | |||
在fastNLP中,我们使用 :class:`~fastNLP.io.base_loader.DataBundle` 来存储数据集信息。 | |||
:class:`~fastNLP.io.base_loader.DataBundle` 类包含了两个重要内容: `datasets` 和 `vocabs` 。 | |||
在fastNLP中,我们使用 :class:`~fastNLP.io.data_bundle.DataBundle` 来存储数据集信息。 | |||
:class:`~fastNLP.io.data_bundle.DataBundle` 类包含了两个重要内容: `datasets` 和 `vocabs` 。 | |||
`datasets` 是一个 `key` 为数据集名称(如 `train` , `dev` ,和 `test` 等), `value` 为 :class:`~fastNLP.DataSet` 的字典。 | |||
`vocabs` 是一个 `key` 为词表名称(如 :attr:`fastNLP.Const.INPUT` 表示输入文本的词表名称, :attr:`fastNLP.Const.TARGET` 表示目标 | |||
的真实标签词表的名称,等等), `value` 为词表内容( :class:`~fastNLP.Vocabulary` )的字典。 | |||
---------------------------- | |||
Part II: 数据集的使用方式 | |||
---------------------------- | |||
------------------------------------- | |||
Part II: 加载数据集的基类Loader | |||
------------------------------------- | |||
在fastNLP中,我们采用 :class:`~fastNLP.io.base_loader.DataSetLoader` 来作为加载数据集的基类。 | |||
:class:`~fastNLP.io.base_loader.DataSetLoader` 定义了各种DataSetLoader所需的API接口,开发者应该继承它实现各种的DataSetLoader。 | |||
在各种数据集的DataSetLoader当中,至少应该编写如下内容: | |||
在fastNLP中,我们采用 :class:`~fastNLP.io.loader.Loader` 来作为加载数据集的基类。 | |||
:class:`~fastNLP.io.loader.Loader` 定义了各种Loader所需的API接口,开发者应该继承它实现各种的Loader。 | |||
在各种数据集的Loader当中,至少应该编写如下内容: | |||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的 :class:`~fastNLP.io.DataBundle` | |||
- _load 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` | |||
- load 函数:从文件或者文件夹中读取数据并组装成 :class:`~fastNLP.io.data_bundle.DataBundle` | |||
**\*process函数中可以调用load函数或_load函数** | |||
DataSetLoader的_load或者load函数返回的 :class:`~fastNLP.DataSet` 当中,内容为数据集的文本信息,process函数返回的 | |||
:class:`~fastNLP.io.DataBundle` 当中, `datasets` 的内容为已经index好的、可以直接被 :class:`~fastNLP.Trainer` | |||
接受的内容。 | |||
Loader的load函数返回的 :class:`~fastNLP.io.data_bundle.DataBundle` 里面包含了数据集的原始数据。 | |||
-------------------------------------------------------- | |||
Part III: 不同数据类型的DataSetLoader | |||
Part III: 不同格式类型的基础Loader | |||
-------------------------------------------------------- | |||
:class:`~fastNLP.io.dataset_loader.CSVLoader` | |||
:class:`~fastNLP.io.loader.CSVLoader` | |||
读取CSV类型的数据集文件。例子如下: | |||
.. code-block:: python | |||
from fastNLP.io.loader import CSVLoader | |||
data_set_loader = CSVLoader( | |||
headers=('words', 'target'), sep='\t' | |||
) | |||
@@ -67,17 +63,18 @@ Part III: 不同数据类型的DataSetLoader | |||
The performances are an absolute joy . 4 | |||
:class:`~fastNLP.io.dataset_loader.JsonLoader` | |||
:class:`~fastNLP.io.loader.JsonLoader` | |||
读取Json类型的数据集文件,数据必须按行存储,每行是一个包含各类属性的Json对象。例子如下: | |||
.. code-block:: python | |||
data_set_loader = JsonLoader( | |||
from fastNLP.io.loader import JsonLoader | |||
oader = JsonLoader( | |||
fields={'sentence1': 'words1', 'sentence2': 'words2', 'gold_label': 'target'} | |||
) | |||
# 表示将Json对象中'sentence1'、'sentence2'和'gold_label'对应的值赋给'words1'、'words2'、'target'这三个fields | |||
data_set = data_set_loader._load('path/to/your/file') | |||
data_set = loader._load('path/to/your/file') | |||
数据集内容样例如下 :: | |||
@@ -86,139 +83,68 @@ Part III: 不同数据类型的DataSetLoader | |||
{"annotator_labels": ["entailment"], "captionID": "3416050480.jpg#4", "gold_label": "entailment", "pairID": "3416050480.jpg#4r1e", "sentence1": "A person on a horse jumps over a broken down airplane.", "sentence1_binary_parse": "( ( ( A person ) ( on ( a horse ) ) ) ( ( jumps ( over ( a ( broken ( down airplane ) ) ) ) ) . ) )", "sentence1_parse": "(ROOT (S (NP (NP (DT A) (NN person)) (PP (IN on) (NP (DT a) (NN horse)))) (VP (VBZ jumps) (PP (IN over) (NP (DT a) (JJ broken) (JJ down) (NN airplane)))) (. .)))", "sentence2": "A person is outdoors, on a horse.", "sentence2_binary_parse": "( ( A person ) ( ( ( ( is outdoors ) , ) ( on ( a horse ) ) ) . ) )", "sentence2_parse": "(ROOT (S (NP (DT A) (NN person)) (VP (VBZ is) (ADVP (RB outdoors)) (, ,) (PP (IN on) (NP (DT a) (NN horse)))) (. .)))"} | |||
------------------------------------------ | |||
Part IV: DataSetLoader举例 | |||
Part IV: 使用Pipe对数据集进行预处理 | |||
------------------------------------------ | |||
以Matching任务为例子: | |||
:class:`~fastNLP.io.data_loader.MatchingLoader` | |||
我们在fastNLP当中封装了一个Matching任务数据集的数据加载类: :class:`~fastNLP.io.data_loader.MatchingLoader` . | |||
在MatchingLoader类当中我们封装了一个对数据集中的文本内容进行进一步的预处理的函数: | |||
:meth:`~fastNLP.io.data_loader.MatchingLoader.process` | |||
这个函数具有各种预处理option,如: | |||
- 是否将文本转成全小写 | |||
- 是否需要序列长度信息,需要什么类型的序列长度信息 | |||
- 是否需要用BertTokenizer来获取序列的WordPiece信息 | |||
- 等等 | |||
在fastNLP中,我们采用 :class:`~fastNLP.io.pipe.Pipe` 来作为加载数据集的基类。 | |||
:class:`~fastNLP.io.pipe.Pipe` 定义了各种Pipe所需的API接口,开发者应该继承它实现各种的Pipe。 | |||
在各种数据集的Pipe当中,至少应该编写如下内容: | |||
具体内容参见 :meth:`fastNLP.io.MatchingLoader.process` 。 | |||
- process 函数:对输入的 :class:`~fastNLP.io.data_bundle.DataBundle` 进行处理(如构建词表、 | |||
将dataset的文本内容转成index等等),然后返回该 :class:`~fastNLP.io.data_bundle.DataBundle` | |||
- process_from_file 函数:输入数据集所在文件夹,读取内容并组装成 :class:`~fastNLP.io.data_bundle.DataBundle` , | |||
然后调用相对应的process函数对数据进行预处理 | |||
:class:`~fastNLP.io.data_loader.SNLILoader` | |||
一个关于SNLI数据集的DataSetLoader。SNLI数据集来自 | |||
`SNLI Data Set <https://nlp.stanford.edu/projects/snli/snli_1.0.zip>`_ . | |||
以SNLI数据集为例,写一个自定义Pipe的例子如下: | |||
在 :class:`~fastNLP.io.data_loader.SNLILoader` 的 :meth:`~fastNLP.io.data_loader.SNLILoader._load` | |||
函数中,我们用以下代码将数据集内容从文本文件读入内存: | |||
.. code-block:: python | |||
.. code-block:: python | |||
from fastNLP.io.loader import SNLILoader | |||
from fastNLP.io.pipe import MatchingPipe | |||
data = SNLILoader().process( | |||
paths='path/to/snli/data', to_lower=False, seq_len_type='seq_len', | |||
get_index=True, concat=False, | |||
) | |||
print(data) | |||
class MySNLIPipe(MatchingPipe): | |||
输出的内容是:: | |||
def process(self, data_bundle): | |||
data_bundle = super(MySNLIPipe, self).process(data_bundle) | |||
# MatchingPipe类里封装了一个关于matching任务的process函数,可以直接继承使用 | |||
# 如果有需要进行额外的预处理操作可以在这里加入您的代码 | |||
return data_bundle | |||
In total 3 datasets: | |||
train has 549367 instances. | |||
dev has 9842 instances. | |||
test has 9824 instances. | |||
In total 2 vocabs: | |||
words has 43154 entries. | |||
target has 3 entries. | |||
def process_from_file(self, paths=None): | |||
data_bundle = SNLILoader().load(paths) # 使用SNLILoader读取原始数据集 | |||
# SNLILoader的load函数中,paths如果为None则会自动下载 | |||
return self.process(data_bundle) # 调用相对应的process函数对data_bundle进行处理 | |||
调用Pipe示例: | |||
这里的data是一个 :class:`~fastNLP.io.base_loader.DataBundle` ,取 ``datasets`` 字典里的内容即可直接传入 | |||
:class:`~fastNLP.Trainer` 或者 :class:`~fastNLP.Tester` 进行训练或者测试。 | |||
.. code-block:: python | |||
:class:`~fastNLP.io.data_loader.IMDBLoader` | |||
以IMDB数据集为例,在 :class:`~fastNLP.io.data_loader.IMDBLoader` 的 :meth:`~fastNLP.io.data_loader.IMDBLoader._load` | |||
函数中,我们用以下代码将数据集内容从文本文件读入内存: | |||
from fastNLP.io.pipe import SNLIBertPipe | |||
data_bundle = SNLIBertPipe(lower=True, tokenizer=arg.tokenizer).process_from_file() | |||
print(data_bundle) | |||
.. code-block:: python | |||
输出的内容是:: | |||
data = IMDBLoader().process( | |||
paths={'train': 'path/to/train/file', 'test': 'path/to/test/file'} | |||
) | |||
print(data) | |||
In total 3 datasets: | |||
train has 549367 instances. | |||
dev has 9842 instances. | |||
test has 9824 instances. | |||
In total 2 vocabs: | |||
words has 34184 entries. | |||
target has 3 entries. | |||
输出的内容是:: | |||
In total 3 datasets: | |||
train has 22500 instances. | |||
test has 25000 instances. | |||
dev has 2500 instances. | |||
In total 2 vocabs: | |||
words has 82846 entries. | |||
target has 2 entries. | |||
这里的将原来的train集按9:1的比例分成了训练集和验证集。 | |||
这里表示一共有3个数据集和2个词表。其中: | |||
- 3个数据集分别为train、dev、test数据集,分别有549367、9842、9824个instance | |||
- 2个词表分别为words词表与target词表。其中words词表为句子文本所构建的词表,一共有34184个单词; | |||
target词表为目标标签所构建的词表,一共有3种标签。(注:如果有多个输入,则句子文本所构建的词表将 | |||
会被命名为words1以对应相对应的列名) | |||
------------------------------------------ | |||
Part V: fastNLP封装好的数据集加载器 | |||
Part V: fastNLP封装好的Loader和Pipe | |||
------------------------------------------ | |||
fastNLP封装好的数据集加载器可以适用于多种类型的任务: | |||
- `文本分类任务`_ | |||
- `序列标注任务`_ | |||
- `Matching任务`_ | |||
文本分类任务 | |||
------------------- | |||
========================== ================================================================== | |||
数据集名称 数据集加载器 | |||
-------------------------- ------------------------------------------------------------------ | |||
IMDb :class:`~fastNLP.io.data_loader.IMDBLoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
SST :class:`~fastNLP.io.data_loader.SSTLoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
SST-2 :class:`~fastNLP.io.data_loader.SST2Loader` | |||
-------------------------- ------------------------------------------------------------------ | |||
Yelp Polarity :class:`~fastNLP.io.data_loader.YelpLoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
Yelp Full :class:`~fastNLP.io.data_loader.YelpLoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
MTL16 :class:`~fastNLP.io.data_loader.MTL16Loader` | |||
========================== ================================================================== | |||
序列标注任务 | |||
------------------- | |||
========================== ================================================================== | |||
数据集名称 数据集加载器 | |||
-------------------------- ------------------------------------------------------------------ | |||
Conll :class:`~fastNLP.io.data_loader.ConllLoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
Conll2003 :class:`~fastNLP.io.data_loader.Conll2003Loader` | |||
-------------------------- ------------------------------------------------------------------ | |||
人民日报数据集 :class:`~fastNLP.io.data_loader.PeopleDailyCorpusLoader` | |||
========================== ================================================================== | |||
Matching任务 | |||
------------------- | |||
========================== ================================================================== | |||
数据集名称 数据集加载器 | |||
-------------------------- ------------------------------------------------------------------ | |||
SNLI :class:`~fastNLP.io.data_loader.SNLILoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
MultiNLI :class:`~fastNLP.io.data_loader.MNLILoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
QNLI :class:`~fastNLP.io.data_loader.QNLILoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
RTE :class:`~fastNLP.io.data_loader.RTELoader` | |||
-------------------------- ------------------------------------------------------------------ | |||
Quora Pair Dataset :class:`~fastNLP.io.data_loader.QuoraLoader` | |||
========================== ================================================================== | |||
fastNLP封装了多种任务/数据集的Loader和Pipe并提供自动下载功能,具体参见文档 | |||
`fastNLP可加载的embedding与数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_ | |||
@@ -12,6 +12,7 @@ | |||
- `Part IV: 使用预训练的Contextual Embedding(ELMo & BERT)`_ | |||
- `Part V: 使用character-level的embedding`_ | |||
- `Part VI: 叠加使用多个embedding`_ | |||
- `Part VII: fastNLP支持的预训练Embedding`_ | |||
@@ -35,12 +36,14 @@ Part II: 使用随机初始化的embedding | |||
.. code-block:: python | |||
from fastNLP import Embedding | |||
embed = Embedding(10000, 50) | |||
也可以传入一个初始化的参数矩阵: | |||
.. code-block:: python | |||
from fastNLP import Embedding | |||
embed = Embedding(init_embed) | |||
其中的init_embed可以是torch.FloatTensor、torch.nn.Embedding或者numpy.ndarray。 | |||
@@ -59,6 +62,7 @@ Embedding,例子如下: | |||
.. code-block:: python | |||
from fastNLP import StaticEmbedding | |||
embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径,也可以是embedding模型的名称: | |||
@@ -67,34 +71,13 @@ vocab为根据数据集构建的词表,model_dir_or_name可以是一个路径 | |||
和word2vec类型的权重文件都支持) | |||
2 如果传入的是模型名称,那么fastNLP将会根据名称查找embedding模型,如果在cache目录下找到模型则会 | |||
自动加载;如果找不到则会自动下载。可以通过环境变量 ``FASTNLP_CACHE_DIR`` 来自定义cache目录,如:: | |||
自动加载;如果找不到则会自动下载到cache目录。默认的cache目录为 `~/.fastNLP` 文件夹。可以通过环境 | |||
变量 ``FASTNLP_CACHE_DIR`` 来自定义cache目录,如:: | |||
$ FASTNLP_CACHE_DIR=~/fastnlp_cache_dir python your_python_file.py | |||
这个命令表示fastNLP将会在 `~/fastnlp_cache_dir` 这个目录下寻找模型,找不到则会自动将模型下载到这个目录 | |||
目前支持的静态embedding模型有: | |||
========================== ================================ | |||
模型名称 模型 | |||
-------------------------- -------------------------------- | |||
en glove.840B.300d | |||
-------------------------- -------------------------------- | |||
en-glove-840d-300 glove.840B.300d | |||
-------------------------- -------------------------------- | |||
en-glove-6b-50 glove.6B.50d | |||
-------------------------- -------------------------------- | |||
en-word2vec-300 谷歌word2vec 300维 | |||
-------------------------- -------------------------------- | |||
en-fasttext 英文fasttext 300维 | |||
-------------------------- -------------------------------- | |||
cn 腾讯中文词向量 200维 | |||
-------------------------- -------------------------------- | |||
cn-fasttext 中文fasttext 300维 | |||
========================== ================================ | |||
----------------------------------------------------------- | |||
Part IV: 使用预训练的Contextual Embedding(ELMo & BERT) | |||
----------------------------------------------------------- | |||
@@ -106,62 +89,20 @@ Part IV: 使用预训练的Contextual Embedding(ELMo & BERT) | |||
.. code-block:: python | |||
from fastNLP import ElmoEmbedding | |||
embed = ElmoEmbedding(vocab, model_dir_or_name='small', requires_grad=False) | |||
目前支持的ElmoEmbedding模型有: | |||
========================== ================================ | |||
模型名称 模型 | |||
-------------------------- -------------------------------- | |||
small allennlp ELMo的small | |||
-------------------------- -------------------------------- | |||
medium allennlp ELMo的medium | |||
-------------------------- -------------------------------- | |||
original allennlp ELMo的original | |||
-------------------------- -------------------------------- | |||
5.5b-original allennlp ELMo的5.5B original | |||
========================== ================================ | |||
BERT-embedding的使用方法如下: | |||
.. code-block:: python | |||
from fastNLP import BertEmbedding | |||
embed = BertEmbedding( | |||
vocab, model_dir_or_name='en-base-cased', requires_grad=False, layers='4,-2,-1' | |||
) | |||
其中layers变量表示需要取哪几层的encode结果。 | |||
目前支持的BertEmbedding模型有: | |||
========================== ==================================== | |||
模型名称 模型 | |||
-------------------------- ------------------------------------ | |||
en bert-base-cased | |||
-------------------------- ------------------------------------ | |||
en-base-uncased bert-base-uncased | |||
-------------------------- ------------------------------------ | |||
en-base-cased bert-base-cased | |||
-------------------------- ------------------------------------ | |||
en-large-uncased bert-large-uncased | |||
-------------------------- ------------------------------------ | |||
en-large-cased bert-large-cased | |||
-------------------------- ------------------------------------ | |||
-------------------------- ------------------------------------ | |||
en-large-cased-wwm bert-large-cased-whole-word-mask | |||
-------------------------- ------------------------------------ | |||
en-large-uncased-wwm bert-large-uncased-whole-word-mask | |||
-------------------------- ------------------------------------ | |||
en-base-cased-mrpc bert-base-cased-finetuned-mrpc | |||
-------------------------- ------------------------------------ | |||
-------------------------- ------------------------------------ | |||
multilingual bert-base-multilingual-cased | |||
-------------------------- ------------------------------------ | |||
multilingual-base-uncased bert-base-multilingual-uncased | |||
-------------------------- ------------------------------------ | |||
multilingual-base-cased bert-base-multilingual-cased | |||
========================== ==================================== | |||
----------------------------------------------------- | |||
Part V: 使用character-level的embedding | |||
----------------------------------------------------- | |||
@@ -173,6 +114,7 @@ CNNCharEmbedding的使用例子如下: | |||
.. code-block:: python | |||
from fastNLP import CNNCharEmbedding | |||
embed = CNNCharEmbedding(vocab, embed_size=100, char_emb_size=50) | |||
这表示这个CNNCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。 | |||
@@ -181,12 +123,12 @@ CNNCharEmbedding的使用例子如下: | |||
.. code-block:: python | |||
from fastNLP import LSTMCharEmbedding | |||
embed = LSTMCharEmbedding(vocab, embed_size=100, char_emb_size=50) | |||
这表示这个LSTMCharEmbedding当中character的embedding维度大小为50,返回的embedding结果维度大小为100。 | |||
----------------------------------------------------- | |||
Part VI: 叠加使用多个embedding | |||
----------------------------------------------------- | |||
@@ -197,6 +139,7 @@ Part VI: 叠加使用多个embedding | |||
.. code-block:: python | |||
from fastNLP import StaticEmbedding, StackEmbedding | |||
embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) | |||
@@ -208,7 +151,17 @@ StackEmbedding会把多个embedding的结果拼接起来,如上面例子的sta | |||
.. code-block:: python | |||
from fastNLP import StaticEmbedding, StackEmbedding, ElmoEmbedding | |||
elmo_embedding = ElmoEmbedding(vocab, model_dir_or_name='medium', layers='0,1,2', requires_grad=False) | |||
glove_embedding = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
stack_embed = StackEmbedding([elmo_embedding, glove_embedding]) | |||
------------------------------------------ | |||
Part VII: fastNLP支持的预训练Embedding | |||
------------------------------------------ | |||
fastNLP支持多种预训练Embedding并提供自动下载功能,具体参见文档 | |||
`fastNLP可加载的embedding与数据集 <https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0>`_ | |||
@@ -1,4 +1,4 @@ | |||
============================================================================== | |||
============================================================================== | |||
动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试 | |||
============================================================================== | |||
@@ -19,7 +19,9 @@ | |||
loader = SSTLoader() | |||
#这里的all.txt是下载好数据后train.txt、dev.txt、test.txt的组合 | |||
dataset = loader.load("./trainDevTestTrees_PTB/trees/all.txt") | |||
#loader.load(path)会首先判断path是否为none,若是则自动从网站下载数据,若不是则读入数据并返回databundle | |||
databundle_ = loader.load("./trainDevTestTrees_PTB/trees/all.txt") | |||
dataset = databundle_.datasets['train'] | |||
print(dataset[0]) | |||
输出数据如下:: | |||
@@ -31,6 +33,7 @@ | |||
数据处理 | |||
可以使用事先定义的 :class:`~fastNLP.io.SSTPipe` 类对数据进行基本预处理,这里我们手动进行处理。 | |||
我们使用 :class:`~fastNLP.DataSet` 类的 :meth:`~fastNLP.DataSet.apply` 方法将 ``target`` :mod:`~fastNLP.core.field` 转化为整数。 | |||
.. code-block:: python | |||
@@ -158,6 +161,7 @@ Vocabulary 的使用 | |||
损失函数 | |||
训练模型需要提供一个损失函数 | |||
,fastNLP中提供了直接可以导入使用的四种loss,分别为: | |||
* :class:`~fastNLP.CrossEntropyLoss`:包装了torch.nn.functional.cross_entropy()函数,返回交叉熵损失(可以运用于多分类场景) | |||
* :class:`~fastNLP.BCELoss`:包装了torch.nn.functional.binary_cross_entropy()函数,返回二分类的交叉熵 | |||
* :class:`~fastNLP.L1Loss`:包装了torch.nn.functional.l1_loss()函数,返回L1 损失 | |||
@@ -209,7 +213,7 @@ Vocabulary 的使用 | |||
#使用CNNText的时候第一个参数输入一个tuple,作为模型定义embedding的参数 | |||
#还可以传入 kernel_nums, kernel_sizes, padding, dropout的自定义值 | |||
model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=3, padding=2, dropout=0.1) | |||
model_cnn = CNNText((len(vocab),EMBED_DIM), num_classes=3, dropout=0.1) | |||
#如果在定义trainer的时候没有传入optimizer参数,模型默认的优化器为torch.optim.Adam且learning rate为lr=4e-3 | |||
#这里只使用了optimizer_1作为优化器输入,感兴趣可以尝试optimizer_2或者其他优化器作为输入 | |||
@@ -20,7 +20,9 @@ | |||
loader = SSTLoader() | |||
#这里的all.txt是下载好数据后train.txt、dev.txt、test.txt的组合 | |||
dataset = loader.load("./trainDevTestTrees_PTB/trees/all.txt") | |||
#loader.load(path)会首先判断path是否为none,若是则自动从网站下载数据,若不是则读入数据并返回databundle | |||
databundle_ = loader.load("./trainDevTestTrees_PTB/trees/all.txt") | |||
dataset = databundle_.datasets['train'] | |||
print(dataset[0]) | |||
输出数据如下:: | |||
@@ -32,6 +34,7 @@ | |||
数据处理 | |||
可以使用事先定义的 :class:`~fastNLP.io.SSTPipe` 类对数据进行基本预处理,这里我们手动进行处理。 | |||
我们使用 :class:`~fastNLP.DataSet` 类的 :meth:`~fastNLP.DataSet.apply` 方法将 ``target`` :mod:`~fastNLP.core.field` 转化为整数。 | |||
.. code-block:: python | |||
@@ -192,7 +195,7 @@ sampler | |||
import time | |||
embed_dim = 100 | |||
model = CNNText((len(vocab),embed_dim), num_classes=3, padding=2, dropout=0.1) | |||
model = CNNText((len(vocab),embed_dim), num_classes=3, dropout=0.1) | |||
def train(epoch, data, devdata): | |||
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |||
@@ -3,64 +3,52 @@ | |||
===================== | |||
这一部分的内容主要展示如何使用fastNLP 实现序列标注任务。你可以使用fastNLP的各个组件快捷,方便地完成序列标注任务,达到出色的效果。 | |||
在阅读这篇Tutorial前,希望你已经熟悉了fastNLP的基础使用,包括基本数据结构以及数据预处理,embedding的嵌入等,希望你对之前的教程有更进一步的掌握。 | |||
我们将对CoNLL-03的英文数据集进行处理,展示如何完成命名实体标注任务整个训练的过程。 | |||
在阅读这篇Tutorial前,希望你已经熟悉了fastNLP的基础使用,尤其是数据的载入以及模型的构建,通过这个小任务的能让你进一步熟悉fastNLP的使用。 | |||
我们将对基于Weibo的中文社交数据集进行处理,展示如何完成命名实体标注任务的整个过程。 | |||
载入数据 | |||
=================================== | |||
fastNLP可以方便地载入各种类型的数据。同时,针对常见的数据集,我们已经预先实现了载入方法,其中包含CoNLL-03数据集。 | |||
fastNLP的数据载入主要是由Loader与Pipe两个基类衔接完成的。通过Loader可以方便地载入各种类型的数据。同时,针对常见的数据集,我们已经预先实现了载入方法,其中包含weibo数据集。 | |||
在设计dataloader时,以DataSetLoader为基类,可以改写并应用于其他数据集的载入。 | |||
.. code-block:: python | |||
class Conll2003DataLoader(DataSetLoader): | |||
def __init__(self, task:str='ner', encoding_type:str='bioes'): | |||
assert task in ('ner', 'pos', 'chunk') | |||
index = {'ner':3, 'pos':1, 'chunk':2}[task] | |||
#ConllLoader是fastNLP内置的类 | |||
self._loader = ConllLoader(headers=['raw_words', 'target'], indexes=[0, index]) | |||
self._tag_converters = None | |||
if task in ('ner', 'chunk'): | |||
#iob和iob2bioes会对tag进行统一,标准化 | |||
self._tag_converters = [iob2] | |||
if encoding_type == 'bioes': | |||
self._tag_converters.append(iob2bioes) | |||
def load(self, path: str): | |||
dataset = self._loader.load(path) | |||
def convert_tag_schema(tags): | |||
for converter in self._tag_converters: | |||
tags = converter(tags) | |||
return tags | |||
if self._tag_converters: | |||
#使用apply实现convert_tag_schema函数,实际上也支持匿名函数 | |||
dataset.apply_field(convert_tag_schema, field_name=Const.TARGET, new_field_name=Const.TARGET) | |||
return dataset | |||
输出数据格式如: | |||
{'raw_words': ['on', 'Friday', ':'] type=list, | |||
'target': ['O', 'O', 'O'] type=list}, | |||
from fastNLP.io import WeiboNERLoader | |||
data_bundle = WeiboNERLoader().load() | |||
载入后的数据如 :: | |||
{'dev': DataSet( | |||
{{'raw_chars': ['用', '最', '大', '努', '力', '去', '做''人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', ' | |||
'target': ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O',, 'O', 'O', 'O', 'O', 'O', 'O'] type=list})} | |||
{'test': DataSet( | |||
{{'raw_chars': ['感', '恩', '大', '回', '馈'] type=list, 'target': ['O', 'O', 'O', 'O', 'O'] type=list})} | |||
{'train': DataSet( | |||
{'raw_chars': ['国', '安', '老', '球', '迷'] type=list, 'target': ['B-ORG.NAM', 'I-ORG.NAM', 'B-PER.NOM', 'I-PER.NOM', 'I-PER.NOM'] type=list})} | |||
数据处理 | |||
---------------------------- | |||
我们进一步处理数据。将数据和词表封装在 :class:`~fastNLP.DataBundle` 类中。data是DataBundle的实例。 | |||
我们输入模型的数据包括char embedding,以及word embedding。在数据处理部分,我们尝试完成词表的构建。 | |||
使用fastNLP中的Vocabulary类来构建词表。 | |||
我们进一步处理数据。通过Pipe基类处理Loader载入的数据。 如果你还有印象,应该还能想起,实现自定义数据集的Pipe时,至少要编写process 函数或者process_from_file 函数。前者接受 :class:`~fastNLP.DataBundle` 类的数据,并返回该 :class:`~fastNLP.DataBundle` 。后者接收数据集所在文件夹为参数,读取并处理为 :class:`~fastNLP.DataBundle` 后,通过process 函数处理数据。 | |||
这里我们已经实现通过Loader载入数据,并已返回 :class:`~fastNLP.DataBundle` 类的数据。我们编写process 函数以处理Loader载入后的数据。 | |||
.. code-block:: python | |||
word_vocab = Vocabulary(min_freq=2) | |||
word_vocab.from_dataset(data.datasets['train'], field_name=Const.INPUT) | |||
word_vocab.index_dataset(*data.datasets.values(),field_name=Const.INPUT, new_field_name=Const.INPUT) | |||
from fastNLP.io import ChineseNERPipe | |||
data_bundle = ChineseNERPipe(encoding_type='bioes', bigram=True).process(data_bundle) | |||
处理后的data对象内部为: | |||
载入后的数据如下 :: | |||
dataset | |||
vocabs | |||
dataset保存了train和test中的数据,并保存为dataset类型 | |||
vocab保存了words,raw-words以及target的词表。 | |||
{'raw_chars': ['用', '最', '大', '努', '力', '去', '做', '值', '得', '的', '事', '人', '生', '。', '哈', '哈', '哈', '哈', '哈', '哈', '我', '在'] type=list, | |||
'target': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] type=list, | |||
'chars': [97, 71, 34, 422, 104, 72, 144, 628, 66, 3, 158, 2, 9, 647, 485, 196, 2,19] type=list, | |||
'bigrams': [5948, 1950, 34840, 98, 8413, 3961, 34841, 631, 34842, 407, 462, 45, 3 1959, 1619, 3, 3, 3, 3, 3, 2663, 29, 90] type=list, | |||
'seq_len': 30 type=int} | |||
模型构建 | |||
-------------------------------- | |||
@@ -69,27 +57,23 @@ fastNLP可以方便地载入各种类型的数据。同时,针对常见的数 | |||
模型的训练 | |||
首先实例化模型,导入所需的char embedding以及word embedding。Embedding的载入可以参考教程。 | |||
也可以查看 :mod:`~fastNLP.modules.encoder.embedding` 使用所需的embedding 载入方法。 | |||
fastNLP将模型的训练过程封装在了 :class:`~fastnlp.trainer` 类中。 | |||
也可以查看 :mod:`~fastNLP.embedding` 使用所需的embedding 载入方法。 | |||
fastNLP将模型的训练过程封装在了 :class:`~fastnlp.Trainer` 类中。 | |||
根据不同的任务调整trainer中的参数即可。通常,一个trainer实例需要有:指定的训练数据集,模型,优化器,loss函数,评测指标,以及指定训练的epoch数,batch size等参数。 | |||
.. code-block:: python | |||
#实例化模型 | |||
model = CNNBiLSTMCRF(word_embed, char_embed, hidden_size=200, num_layers=1, tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type) | |||
#定义优化器 | |||
optimizer = Adam(model.parameters(), lr=0.005) | |||
model = CNBiLSTMCRFNER(char_embed, num_classes=len(data_bundle.vocabs['target']), bigram_embed=bigram_embed) | |||
#定义评估指标 | |||
Metrics=SpanFPreRecMetric(tag_vocab=data.vocabs[Const.TARGET], encoding_type=encoding_type) | |||
#实例化trainer | |||
trainer = Trainer(train_data=data.datasets['train'], model=model, optimizer=optimizer, dev_data=data.datasets['test'], batch_size=10, metrics=Metrics,callbacks=callbacks, n_epochs=100) | |||
#开始训练 | |||
trainer.train() | |||
Metrics=SpanFPreRecMetric(data_bundle.vocabs['target'], encoding_type='bioes') | |||
#实例化trainer并训练 | |||
Trainer(data_bundle.datasets['train'], model, batch_size=20, metrics=Metrics, num_workers=2, dev_data=data_bundle. datasets['dev']).train() | |||
训练中会保存最优的参数配置。 | |||
训练的结果如下: | |||
.. code-block:: python | |||
训练的结果如下 :: | |||
Evaluation on DataSet test: | |||
SpanFPreRecMetric: f=0.727661, pre=0.732293, rec=0.723088 | |||
@@ -23,7 +23,7 @@ Callback的构建和使用 | |||
class LRDecay(fastNLP.Callback): | |||
def __init__(self): | |||
super(MyCallback, self).__init__() | |||
super(LRDecay, self).__init__() | |||
self.base_lrs = [] | |||
self.delta = [] | |||
@@ -8,7 +8,7 @@ fastNLP 详细使用教程 | |||
:maxdepth: 1 | |||
使用DataSet预处理文本 </tutorials/tutorial_1_data_preprocess> | |||
使用DataSetLoader加载数据集 </tutorials/tutorial_2_load_dataset> | |||
使用Loader和Pipe加载并处理数据集 </tutorials/tutorial_2_load_dataset> | |||
使用Embedding模块将文本转成向量 </tutorials/tutorial_3_embedding> | |||
动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试 </tutorials/tutorial_4_loss_optimizer> | |||
动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程 </tutorials/tutorial_5_datasetiter> | |||
@@ -13,11 +13,12 @@ fastNLP 中最常用的组件可以直接从 fastNLP 包中 import ,他们的 | |||
__all__ = [ | |||
"Instance", | |||
"FieldArray", | |||
"DataSetIter", | |||
"BatchIter", | |||
"TorchLoaderIter", | |||
"Vocabulary", | |||
"DataSet", | |||
"Const", | |||
@@ -31,6 +32,7 @@ __all__ = [ | |||
"TensorboardCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"LRFinder", | |||
"Padder", | |||
"AutoPadder", | |||
@@ -43,7 +45,8 @@ __all__ = [ | |||
"Optimizer", | |||
"SGD", | |||
"Adam", | |||
"AdamW", | |||
"Sampler", | |||
"SequentialSampler", | |||
"BucketSampler", | |||
@@ -51,16 +54,23 @@ __all__ = [ | |||
"LossFunc", | |||
"CrossEntropyLoss", | |||
"L1Loss", "BCELoss", | |||
"L1Loss", | |||
"BCELoss", | |||
"NLLLoss", | |||
"LossInForward", | |||
"cache_results" | |||
"cache_results", | |||
'logger' | |||
] | |||
__version__ = '0.4.5' | |||
from .core import * | |||
from . import embeddings | |||
from . import models | |||
from . import modules | |||
from . import embeddings | |||
from .io import data_loader | |||
from .core import * | |||
from .io import loader, pipe | |||
import sys | |||
from .doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -10,21 +10,85 @@ core 模块里实现了 fastNLP 的核心框架,常用的功能都可以从 fa | |||
对于常用的功能,你只需要在 :doc:`fastNLP` 中查看即可。如果想了解各个子模块的具体作用,您可以在下面找到每个子模块的具体文档。 | |||
.. todo:: | |||
介绍core 的子模块的分工,好像必要性不大 | |||
""" | |||
__all__ = [ | |||
"DataSet", | |||
"Instance", | |||
"FieldArray", | |||
"Padder", | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
"Vocabulary", | |||
"DataSetIter", | |||
"BatchIter", | |||
"TorchLoaderIter", | |||
"Const", | |||
"Tester", | |||
"Trainer", | |||
"cache_results", | |||
"seq_len_to_mask", | |||
"get_seq_len", | |||
"logger", | |||
"Callback", | |||
"GradientClipCallback", | |||
"EarlyStopCallback", | |||
"FitlogCallback", | |||
"EvaluateCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"LRFinder", | |||
"TensorboardCallback", | |||
"WarmupCallback", | |||
'SaveModelCallback', | |||
"EchoCallback", | |||
"TesterCallback", | |||
"CallbackException", | |||
"EarlyStopError", | |||
"LossFunc", | |||
"CrossEntropyLoss", | |||
"L1Loss", | |||
"BCELoss", | |||
"NLLLoss", | |||
"LossInForward", | |||
"AccuracyMetric", | |||
"SpanFPreRecMetric", | |||
"ExtractiveQAMetric", | |||
"Optimizer", | |||
"SGD", | |||
"Adam", | |||
"AdamW", | |||
"SequentialSampler", | |||
"BucketSampler", | |||
"RandomSampler", | |||
"Sampler", | |||
] | |||
from ._logger import logger | |||
from .batch import DataSetIter, BatchIter, TorchLoaderIter | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, TensorboardCallback, LRScheduler, ControlC | |||
from .callback import Callback, GradientClipCallback, EarlyStopCallback, FitlogCallback, EvaluateCallback, \ | |||
LRScheduler, ControlC, LRFinder, TensorboardCallback, WarmupCallback, SaveModelCallback, EchoCallback, \ | |||
TesterCallback, CallbackException, EarlyStopError | |||
from .const import Const | |||
from .dataset import DataSet | |||
from .field import FieldArray, Padder, AutoPadder, EngChar2DPadder | |||
from .instance import Instance | |||
from .losses import LossFunc, CrossEntropyLoss, L1Loss, BCELoss, NLLLoss, LossInForward | |||
from .metrics import AccuracyMetric, SpanFPreRecMetric, ExtractiveQAMetric | |||
from .optimizer import Optimizer, SGD, Adam | |||
from .optimizer import Optimizer, SGD, Adam, AdamW | |||
from .sampler import SequentialSampler, BucketSampler, RandomSampler, Sampler | |||
from .tester import Tester | |||
from .trainer import Trainer | |||
from .utils import cache_results, seq_len_to_mask | |||
from .utils import cache_results, seq_len_to_mask, get_seq_len | |||
from .vocabulary import Vocabulary |
@@ -0,0 +1,155 @@ | |||
"""undocumented""" | |||
__all__ = [ | |||
'logger', | |||
] | |||
import logging | |||
import logging.config | |||
import os | |||
import sys | |||
import warnings | |||
ROOT_NAME = 'fastNLP' | |||
try: | |||
import fitlog | |||
except ImportError: | |||
fitlog = None | |||
try: | |||
from tqdm.auto import tqdm | |||
except ImportError: | |||
tqdm = None | |||
if tqdm is not None: | |||
class TqdmLoggingHandler(logging.Handler): | |||
def __init__(self, level=logging.INFO): | |||
super().__init__(level) | |||
def emit(self, record): | |||
try: | |||
msg = self.format(record) | |||
tqdm.write(msg) | |||
self.flush() | |||
except (KeyboardInterrupt, SystemExit): | |||
raise | |||
except: | |||
self.handleError(record) | |||
else: | |||
class TqdmLoggingHandler(logging.StreamHandler): | |||
def __init__(self, level=logging.INFO): | |||
super().__init__(sys.stdout) | |||
self.setLevel(level) | |||
def _get_level(level): | |||
if isinstance(level, int): | |||
pass | |||
else: | |||
level = level.lower() | |||
level = {'info': logging.INFO, 'debug': logging.DEBUG, | |||
'warn': logging.WARN, 'warning': logging.WARN, | |||
'error': logging.ERROR}[level] | |||
return level | |||
def _add_file_handler(logger, path, level='INFO'): | |||
for h in logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
if os.path.abspath(path) == h.baseFilename: | |||
# file path already added | |||
return | |||
# File Handler | |||
if os.path.exists(path): | |||
assert os.path.isfile(path) | |||
warnings.warn('log already exists in {}'.format(path)) | |||
dirname = os.path.abspath(os.path.dirname(path)) | |||
os.makedirs(dirname, exist_ok=True) | |||
file_handler = logging.FileHandler(path, mode='a') | |||
file_handler.setLevel(_get_level(level)) | |||
file_formatter = logging.Formatter(fmt='%(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | |||
datefmt='%Y/%m/%d %H:%M:%S') | |||
file_handler.setFormatter(file_formatter) | |||
logger.addHandler(file_handler) | |||
def _set_stdout_handler(logger, stdout='tqdm', level='INFO'): | |||
level = _get_level(level) | |||
if stdout not in ['none', 'plain', 'tqdm']: | |||
raise ValueError('stdout must in one of {}'.format(['none', 'plain', 'tqdm'])) | |||
# make sure to initialize logger only once | |||
stream_handler = None | |||
for i, h in enumerate(logger.handlers): | |||
if isinstance(h, (logging.StreamHandler, TqdmLoggingHandler)): | |||
stream_handler = h | |||
break | |||
if stream_handler is not None: | |||
logger.removeHandler(stream_handler) | |||
# Stream Handler | |||
if stdout == 'plain': | |||
stream_handler = logging.StreamHandler(sys.stdout) | |||
elif stdout == 'tqdm': | |||
stream_handler = TqdmLoggingHandler(level) | |||
else: | |||
stream_handler = None | |||
if stream_handler is not None: | |||
stream_formatter = logging.Formatter('%(message)s') | |||
stream_handler.setLevel(level) | |||
stream_handler.setFormatter(stream_formatter) | |||
logger.addHandler(stream_handler) | |||
class FastNLPLogger(logging.getLoggerClass()): | |||
def __init__(self, name): | |||
super().__init__(name) | |||
def add_file(self, path='./log.txt', level='INFO'): | |||
"""add log output file and level""" | |||
_add_file_handler(self, path, level) | |||
def set_stdout(self, stdout='tqdm', level='INFO'): | |||
"""set stdout format and level""" | |||
_set_stdout_handler(self, stdout, level) | |||
logging.setLoggerClass(FastNLPLogger) | |||
# print(logging.getLoggerClass()) | |||
# print(logging.getLogger()) | |||
def _init_logger(path=None, stdout='tqdm', level='INFO'): | |||
"""initialize logger""" | |||
level = _get_level(level) | |||
# logger = logging.getLogger() | |||
logger = logging.getLogger(ROOT_NAME) | |||
logger.propagate = False | |||
logger.setLevel(level) | |||
_set_stdout_handler(logger, stdout, level) | |||
# File Handler | |||
if path is not None: | |||
_add_file_handler(logger, path, level) | |||
return logger | |||
def _get_logger(name=None, level='INFO'): | |||
level = _get_level(level) | |||
if name is None: | |||
name = ROOT_NAME | |||
assert isinstance(name, str) | |||
if not name.startswith(ROOT_NAME): | |||
name = '{}.{}'.format(ROOT_NAME, name) | |||
logger = logging.getLogger(name) | |||
logger.setLevel(level) | |||
return logger | |||
logger = _init_logger(path=None) |
@@ -1,10 +1,14 @@ | |||
"""undocumented""" | |||
__all__ = [] | |||
import threading | |||
import torch | |||
from torch import nn | |||
from torch.nn.parallel.parallel_apply import get_a_var | |||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
from torch.nn.parallel.replicate import replicate | |||
from torch.nn.parallel.scatter_gather import scatter_kwargs, gather | |||
def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
@@ -26,11 +30,11 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
assert len(modules) == len(devices) | |||
else: | |||
devices = [None] * len(modules) | |||
lock = threading.Lock() | |||
results = {} | |||
grad_enabled = torch.is_grad_enabled() | |||
def _worker(i, module, input, kwargs, device=None): | |||
torch.set_grad_enabled(grad_enabled) | |||
if device is None: | |||
@@ -46,20 +50,20 @@ def parallel_apply(modules, func_name, inputs, kwargs_tup=None, devices=None): | |||
except Exception as e: | |||
with lock: | |||
results[i] = e | |||
if len(modules) > 1: | |||
threads = [threading.Thread(target=_worker, | |||
args=(i, module, input, kwargs, device)) | |||
for i, (module, input, kwargs, device) in | |||
enumerate(zip(modules, inputs, kwargs_tup, devices))] | |||
for thread in threads: | |||
thread.start() | |||
for thread in threads: | |||
thread.join() | |||
else: | |||
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) | |||
outputs = [] | |||
for i in range(len(inputs)): | |||
output = results[i] | |||
@@ -78,6 +82,7 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
:param output_device: nn.DataParallel中的output_device | |||
:return: | |||
""" | |||
def wrapper(network, *inputs, **kwargs): | |||
inputs, kwargs = scatter_kwargs(inputs, kwargs, device_ids, dim=0) | |||
if len(device_ids) == 1: | |||
@@ -85,4 +90,18 @@ def _data_parallel_wrapper(func_name, device_ids, output_device): | |||
replicas = replicate(network, device_ids[:len(inputs)]) | |||
outputs = parallel_apply(replicas, func_name, inputs, kwargs, device_ids[:len(replicas)]) | |||
return gather(outputs, output_device) | |||
return wrapper | |||
def _model_contains_inner_module(model): | |||
""" | |||
:param nn.Module model: 模型文件,判断是否内部包含model.module, 多用于check模型是否是nn.DataParallel, | |||
nn.parallel.DistributedDataParallel。主要是在做形参匹配的时候需要使用最内部的model的function。 | |||
:return: bool | |||
""" | |||
if isinstance(model, nn.Module): | |||
if isinstance(model, (nn.DataParallel, nn.parallel.DistributedDataParallel)): | |||
return True | |||
return False |
@@ -9,14 +9,15 @@ __all__ = [ | |||
] | |||
import atexit | |||
from numbers import Number | |||
import numpy as np | |||
import torch | |||
import torch.utils.data | |||
from numbers import Number | |||
from .sampler import SequentialSampler | |||
from ._logger import logger | |||
from .dataset import DataSet | |||
from .sampler import SequentialSampler | |||
_python_is_exit = False | |||
@@ -48,6 +49,11 @@ class DataSetGetter: | |||
return len(self.dataset) | |||
def collate_fn(self, batch: list): | |||
""" | |||
:param batch: [[idx1, x_dict1, y_dict1], [idx2, x_dict2, y_dict2], [xx, xx, xx]] | |||
:return: | |||
""" | |||
# TODO 支持在DataSet中定义collate_fn,因为有时候可能需要不同的field之间融合,比如BERT的场景 | |||
batch_x = {n:[] for n in self.inputs.keys()} | |||
batch_y = {n:[] for n in self.targets.keys()} | |||
@@ -70,7 +76,7 @@ class DataSetGetter: | |||
try: | |||
data, flag = _to_tensor(data, f.dtype) | |||
except TypeError as e: | |||
print(f"Field {n} cannot be converted to torch.tensor.") | |||
logger.error(f"Field {n} cannot be converted to torch.tensor.") | |||
raise e | |||
batch_dict[n] = data | |||
return batch_dict | |||
@@ -93,9 +99,13 @@ class DataSetGetter: | |||
class SamplerAdapter(torch.utils.data.Sampler): | |||
def __init__(self, sampler, dataset): | |||
super().__init__(dataset) | |||
self.sampler = sampler | |||
self.dataset = dataset | |||
def __len__(self): | |||
return len(self.dataset) | |||
def __iter__(self): | |||
return iter(self.sampler(self.dataset)) | |||
@@ -136,8 +146,6 @@ class BatchIter: | |||
class DataSetIter(BatchIter): | |||
""" | |||
别名::class:`fastNLP.DataSetIter` :class:`fastNLP.core.batch.DataSetIter` | |||
DataSetIter 用于从 `DataSet` 中按一定的顺序, 依次按 ``batch_size`` 的大小将数据取出, | |||
组成 `x` 和 `y`:: | |||
@@ -146,34 +154,41 @@ class DataSetIter(BatchIter): | |||
for batch_x, batch_y in batch: | |||
# do stuff ... | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||
Default: ``None`` | |||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
Default: ``False`` | |||
:param int num_workers: 使用多少个进程来预处理数据 | |||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||
:param timeout: | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
""" | |||
def __init__(self, dataset, batch_size=1, sampler=None, as_numpy=False, | |||
num_workers=0, pin_memory=False, drop_last=False, | |||
timeout=0, worker_init_fn=None): | |||
""" | |||
:param dataset: :class:`~fastNLP.DataSet` 对象, 数据集 | |||
:param int batch_size: 取出的batch大小 | |||
:param sampler: 规定使用的 :class:`~fastNLP.Sampler` 方式. 若为 ``None`` , 使用 :class:`~fastNLP.SequentialSampler`. | |||
Default: ``None`` | |||
:param bool as_numpy: 若为 ``True`` , 输出batch为 numpy.array. 否则为 :class:`torch.Tensor`. | |||
Default: ``False`` | |||
:param int num_workers: 使用多少个进程来预处理数据 | |||
:param bool pin_memory: 是否将产生的tensor使用pin memory, 可能会加快速度。 | |||
:param bool drop_last: 如果最后一个batch没有batch_size这么多sample,就扔掉最后一个 | |||
:param timeout: | |||
:param worker_init_fn: 在每个worker启动时调用该函数,会传入一个值,该值是worker的index。 | |||
""" | |||
super().__init__() | |||
assert isinstance(dataset, DataSet) | |||
sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
if not isinstance(sampler, torch.utils.data.Sampler): | |||
self.sampler = SamplerAdapter(sampler=sampler or SequentialSampler(), dataset=dataset) | |||
else: | |||
self.sampler = sampler | |||
dataset = DataSetGetter(dataset, as_numpy) | |||
collate_fn = dataset.collate_fn if hasattr(dataset, 'collate_fn') else None | |||
self.dataiter = torch.utils.data.DataLoader( | |||
dataset=dataset, batch_size=batch_size, sampler=sampler, | |||
dataset=dataset, batch_size=batch_size, sampler=self.sampler, | |||
collate_fn=collate_fn, num_workers=num_workers, | |||
pin_memory=pin_memory, drop_last=drop_last, | |||
timeout=timeout, worker_init_fn=worker_init_fn) | |||
self.num_batches = self.get_num_batches(len(dataset), batch_size, drop_last) | |||
# 以sampler的数量为准,因为DistributedSampler的时候每个进程上并不是所有的数据都用上了 | |||
self.num_batches = self.get_num_batches(len(self.dataiter.sampler), batch_size, drop_last) | |||
self.batch_size = batch_size | |||
@@ -182,7 +197,7 @@ class TorchLoaderIter(BatchIter): | |||
super().__init__() | |||
assert isinstance(dataset, torch.utils.data.DataLoader) | |||
self.dataiter = dataset | |||
self.num_batches = self.get_num_batches(len(dataset), dataset.batch_size, dataset.drop_last) | |||
self.num_batches = self.get_num_batches(len(dataset.sampler), dataset.batch_size, dataset.drop_last) | |||
self.batch_size = dataset.batch_size | |||
@@ -200,6 +215,13 @@ class OnlineDataIter(BatchIter): | |||
def _to_tensor(batch, field_dtype): | |||
""" | |||
:param batch: np.array() | |||
:param field_dtype: 数据类型 | |||
:return: batch, flag. 如果传入的数据支持转为tensor,返回的batch就是tensor,且flag为True;如果传入的数据不支持转为tensor, | |||
返回的batch就是原来的数据,且flag为False | |||
""" | |||
try: | |||
if field_dtype is not None and isinstance(field_dtype, type)\ | |||
and issubclass(field_dtype, Number) \ | |||
@@ -51,22 +51,30 @@ callback模块实现了 fastNLP 中的许多 callback 类,用于增强 :class: | |||
""" | |||
__all__ = [ | |||
"Callback", | |||
"GradientClipCallback", | |||
"EarlyStopCallback", | |||
"TensorboardCallback", | |||
"FitlogCallback", | |||
"EvaluateCallback", | |||
"LRScheduler", | |||
"ControlC", | |||
"LRFinder", | |||
"TensorboardCallback", | |||
"WarmupCallback", | |||
"SaveModelCallback", | |||
"EchoCallback", | |||
"TesterCallback", | |||
"CallbackException", | |||
"EarlyStopError" | |||
] | |||
import os | |||
import sys | |||
from copy import deepcopy | |||
import torch | |||
from copy import deepcopy | |||
import sys | |||
from .utils import _save_model | |||
try: | |||
@@ -76,9 +84,9 @@ try: | |||
except: | |||
tensorboardX_flag = False | |||
from ..io.model_io import ModelSaver, ModelLoader | |||
from .dataset import DataSet | |||
from .tester import Tester | |||
from ._logger import logger | |||
try: | |||
import fitlog | |||
@@ -88,8 +96,6 @@ except: | |||
class Callback(object): | |||
""" | |||
别名::class:`fastNLP.Callback` :class:`fastNLP.core.callback.Callback` | |||
Callback是fastNLP中被设计用于增强 :class:`~fastNLP.Trainer` 的类。 | |||
如果Callback被传递给了 Trainer , 则 Trainer 会在对应的阶段调用Callback的函数, | |||
具体调用时机可以通过 :doc:`trainer 模块<fastNLP.core.trainer>` 查看。 | |||
@@ -100,7 +106,8 @@ class Callback(object): | |||
def __init__(self): | |||
super(Callback, self).__init__() | |||
self._trainer = None # 在Trainer内部被重新赋值 | |||
self._disabled = False | |||
@property | |||
def trainer(self): | |||
""" | |||
@@ -158,7 +165,19 @@ class Callback(object): | |||
def batch_per_epoch(self): | |||
"""每个epoch一共有多少个batch,只有在on_epoch_begin之后才能调用该属性。""" | |||
return self._trainer.batch_per_epoch | |||
@property | |||
def is_master(self): | |||
return self._trainer.is_master | |||
@property | |||
def disabled(self): | |||
return self._disabled | |||
@property | |||
def logger(self): | |||
return getattr(self._trainer, 'logger', logger) | |||
def on_train_begin(self): | |||
""" | |||
在Train过程开始之前调用。 | |||
@@ -250,6 +269,14 @@ class Callback(object): | |||
:return: | |||
""" | |||
pass | |||
def on_validation(self): | |||
""" | |||
如果Trainer中设置了验证,则会在每次需要验证时调用该函数 | |||
:return: | |||
""" | |||
pass | |||
def on_epoch_end(self): | |||
""" | |||
@@ -281,6 +308,8 @@ def _transfer(func): | |||
def wrapper(manager, *arg): | |||
returns = [] | |||
for callback in manager.callbacks: | |||
if callback.disabled: | |||
continue | |||
returns.append(getattr(callback, func.__name__)(*arg)) | |||
return returns | |||
@@ -288,31 +317,39 @@ def _transfer(func): | |||
class CallbackManager(Callback): | |||
""" | |||
内部使用的Callback管理类 | |||
""" | |||
def __init__(self, env, callbacks=None): | |||
""" | |||
内部使用的Callback管理类 | |||
:param dict env: The key is the name of the Trainer attribute(str). The value is the attribute itself. | |||
:param List[Callback] callbacks: | |||
""" | |||
super(CallbackManager, self).__init__() | |||
# set attribute of trainer environment | |||
self._env = env | |||
self.callbacks = [] | |||
if callbacks is not None: | |||
if isinstance(callbacks, list): | |||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
self.callbacks.extend(callbacks) | |||
else: | |||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
if callbacks: | |||
self.callbacks = self.prepare_callbacks(callbacks) | |||
def prepare_callbacks(self, callbacks): | |||
if not callbacks: | |||
return [] | |||
if isinstance(callbacks, list): | |||
if all([isinstance(cb, Callback) for cb in callbacks]) is True: | |||
pass | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in env.items(): | |||
for callback in self.callbacks: | |||
obj = [not isinstance(cb, Callback) for cb in callbacks][0] | |||
raise TypeError(f"Expect sub-classes of Callback. Got {type(obj)}") | |||
else: | |||
raise TypeError(f"Expect callbacks in CallbackManager(callbacks) to be list. Got {type(callbacks)}.") | |||
for env_name, env_val in self._env.items(): | |||
for callback in callbacks: | |||
setattr(callback, '_' + env_name, env_val) # Callback.trainer | |||
return callbacks | |||
@_transfer | |||
def on_train_begin(self): | |||
pass | |||
@@ -352,6 +389,10 @@ class CallbackManager(Callback): | |||
@_transfer | |||
def on_valid_end(self, eval_result, metric_key, optimizer, is_better_eval): | |||
pass | |||
@_transfer | |||
def on_validation(self): | |||
pass | |||
@_transfer | |||
def on_epoch_end(self): | |||
@@ -366,28 +407,53 @@ class CallbackManager(Callback): | |||
pass | |||
class DistCallbackManager(CallbackManager): | |||
def __init__(self, env, callbacks_all=None, callbacks_master=None): | |||
super(DistCallbackManager, self).__init__(env) | |||
assert 'trainer' in env | |||
self._trainer = env['trainer'] | |||
self.callbacks_master = [] | |||
self.callbacks_all = [] | |||
self.add_callback(callbacks_all, master=False) | |||
self.add_callback(callbacks_master, master=True) | |||
def patch_callback(self, callbacks, disabled): | |||
if not callbacks: | |||
return | |||
if not isinstance(callbacks, (list, tuple)): | |||
callbacks = [callbacks] | |||
for cb in callbacks: | |||
cb._disabled = disabled | |||
def add_callback(self, cb, master=False): | |||
if master: | |||
self.patch_callback(cb, not self.is_master) | |||
self.callbacks_master += self.prepare_callbacks(cb) | |||
else: | |||
self.callbacks_all += self.prepare_callbacks(cb) | |||
self.callbacks = self.callbacks_all + self.callbacks_master | |||
class GradientClipCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.GradientClipCallback` :class:`fastNLP.core.callback.GradientClipCallback` | |||
每次backward前,将parameter的gradient clip到某个范围。 | |||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||
如果为None则默认对Trainer的model中所有参数进行clip | |||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||
:param str clip_type: 支持'norm', 'value' | |||
两种:: | |||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||
2 'value', 将gradient限制在[-clip_value, clip_value], | |||
小于-clip_value的gradient被赋值为-clip_value; | |||
大于clip_value的gradient被赋值为clip_value. | |||
""" | |||
def __init__(self, parameters=None, clip_value=1, clip_type='norm'): | |||
""" | |||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | |||
如果为None则默认对Trainer的model中所有参数进行clip | |||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | |||
:param str clip_type: 支持'norm', 'value' | |||
两种:: | |||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||
2 'value', 将gradient限制在[-clip_value, clip_value], | |||
小于-clip_value的gradient被赋值为-clip_value; | |||
大于clip_value的gradient被赋值为clip_value. | |||
""" | |||
super().__init__() | |||
from torch import nn | |||
@@ -403,6 +469,9 @@ class GradientClipCallback(Callback): | |||
def on_backward_end(self): | |||
if self.step%self.update_every==0: | |||
if self.parameters is None: | |||
if getattr(self.trainer, 'fp16', ''): | |||
from apex import amp | |||
self.clip_fun(amp.master_params(self.optimizer), self.clip_value) | |||
self.clip_fun(self.model.parameters(), self.clip_value) | |||
else: | |||
self.clip_fun(self.parameters, self.clip_value) | |||
@@ -410,14 +479,14 @@ class GradientClipCallback(Callback): | |||
class EarlyStopCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.EarlyStopCallback` :class:`fastNLP.core.callback.EarlyStopCallback` | |||
多少个epoch没有变好就停止训练,相关类 :class:`EarlyStopError` | |||
:param int patience: epoch的数量 | |||
多少个epoch没有变好就停止训练,相关类 :class:`~fastNLP.core.callback.EarlyStopError` | |||
""" | |||
def __init__(self, patience): | |||
""" | |||
:param int patience: epoch的数量 | |||
""" | |||
super(EarlyStopCallback, self).__init__() | |||
self.patience = patience | |||
self.wait = 0 | |||
@@ -434,52 +503,54 @@ class EarlyStopCallback(Callback): | |||
def on_exception(self, exception): | |||
if isinstance(exception, EarlyStopError): | |||
print("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||
logger.info("Early Stopping triggered in epoch {}!".format(self.epoch)) | |||
else: | |||
raise exception # 抛出陌生Error | |||
class FitlogCallback(Callback): | |||
""" | |||
别名: :class:`fastNLP.FitlogCallback` :class:`fastNLP.core.callback.FitlogCallback` | |||
该callback可将loss和progress写入到fitlog中; 如果Trainer有dev的数据,将自动把dev的结果写入到log中; 同时还支持传入 | |||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。若tester不为None时,data需要通过 | |||
dict的方式传入。如果仅传入DataSet, 则被命名为test | |||
:param ~fastNLP.Tester tester: Tester对象,将在on_valid_end时调用。tester中的DataSet会被称为为`test` | |||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||
:param bool log_exception: fitlog是否记录发生的exception信息 | |||
一个(或多个)test数据集进行测试(只有在trainer具有dev时才能使用),每次在dev上evaluate之后会在这些数据集上验证一下。 | |||
并将验证结果写入到fitlog中。这些数据集的结果是根据dev上最好的结果报道的,即如果dev在第3个epoch取得了最佳,则 | |||
fitlog中记录的关于这些数据集的结果就是来自第三个epoch的结果。 | |||
""" | |||
def __init__(self, data=None, tester=None, log_loss_every=0, verbose=0, log_exception=False): | |||
""" | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要 | |||
传入多个DataSet请通过dict的方式传入,dict的key将作为对应dataset的name传递给fitlog。data的结果的名称以'data'开头。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.Tester] tester: Tester对象,将在on_valid_end时调用。tester的结果的名称以'tester'开头 | |||
:param int log_loss_every: 多少个step记录一次loss(记录的是这几个batch的loss平均值),如果数据集较大建议将该值设置得 | |||
大一些,不然会导致log文件巨大。默认为0, 即不要记录loss。 | |||
:param int verbose: 是否在终端打印evaluation的结果,0不打印。 | |||
:param bool log_exception: fitlog是否记录发生的exception信息 | |||
""" | |||
super().__init__() | |||
self.datasets = {} | |||
self.testers = {} | |||
self._log_exception = log_exception | |||
assert isinstance(log_loss_every, int) and log_loss_every>=0 | |||
if tester is not None: | |||
assert isinstance(tester, Tester), "Only fastNLP.Tester allowed." | |||
assert isinstance(data, dict) or data is None, "If tester is not None, only dict[DataSet] allowed for data." | |||
if data is not None: | |||
assert 'test' not in data, "Cannot use `test` as DataSet key, when tester is passed." | |||
setattr(tester, 'verbose', 0) | |||
self.testers['test'] = tester | |||
if isinstance(tester, dict): | |||
for name, test in tester.items(): | |||
if not isinstance(test, Tester): | |||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||
self.testers['tester-' + name] = test | |||
if isinstance(tester, Tester): | |||
self.testers['tester-test'] = tester | |||
for tester in self.testers.values(): | |||
setattr(tester, 'verbose', 0) | |||
if isinstance(data, dict): | |||
for key, value in data.items(): | |||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | |||
for key, value in data.items(): | |||
self.datasets[key] = value | |||
self.datasets['data-' + key] = value | |||
elif isinstance(data, DataSet): | |||
self.datasets['test'] = data | |||
else: | |||
self.datasets['data-test'] = data | |||
elif data is not None: | |||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||
self.verbose = verbose | |||
@@ -492,8 +563,11 @@ class FitlogCallback(Callback): | |||
if len(self.datasets) > 0: | |||
for key, data in self.datasets.items(): | |||
tester = Tester(data=data, model=self.model, batch_size=self.batch_size, metrics=self.trainer.metrics, | |||
verbose=0) | |||
tester = Tester(data=data, model=self.model, | |||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||
metrics=self.trainer.metrics, | |||
verbose=0, | |||
use_tqdm=self.trainer.test_use_tqdm) | |||
self.testers[key] = tester | |||
fitlog.add_progress(total_steps=self.n_steps) | |||
@@ -533,17 +607,76 @@ class FitlogCallback(Callback): | |||
fitlog.add_other(repr(exception), name='except_info') | |||
class LRScheduler(Callback): | |||
class EvaluateCallback(Callback): | |||
""" | |||
该callback用于扩展Trainer训练过程中只能对dev数据进行验证的问题。 | |||
""" | |||
别名::class:`fastNLP.LRScheduler` :class:`fastNLP.core.callback.LRScheduler` | |||
对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||
def __init__(self, data=None, tester=None): | |||
""" | |||
:param ~fastNLP.DataSet,Dict[~fastNLP.DataSet] data: 传入DataSet对象,会使用多个Trainer中的metric对数据进行验证。如果需要传入多个 | |||
DataSet请通过dict的方式传入。 | |||
:param ~fastNLP.Tester,Dict[~fastNLP.DataSet] tester: Tester对象,将在on_valid_end时调用。 | |||
""" | |||
super().__init__() | |||
self.datasets = {} | |||
self.testers = {} | |||
if tester is not None: | |||
if isinstance(tester, dict): | |||
for name, test in tester.items(): | |||
if not isinstance(test, Tester): | |||
raise TypeError(f"{name} in tester is not a valid fastNLP.Tester.") | |||
self.testers['tester-' + name] = test | |||
if isinstance(tester, Tester): | |||
self.testers['tester-test'] = tester | |||
for tester in self.testers.values(): | |||
setattr(tester, 'verbose', 0) | |||
if isinstance(data, dict): | |||
for key, value in data.items(): | |||
assert isinstance(value, DataSet), f"Only DataSet object is allowed, not {type(value)}." | |||
for key, value in data.items(): | |||
self.datasets['data-' + key] = value | |||
elif isinstance(data, DataSet): | |||
self.datasets['data-test'] = data | |||
elif data is not None: | |||
raise TypeError("data receives dict[DataSet] or DataSet object.") | |||
def on_train_begin(self): | |||
if len(self.datasets) > 0 and self.trainer.dev_data is None: | |||
raise RuntimeError("Trainer has no dev data, you cannot pass extra DataSet to do evaluation.") | |||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||
if len(self.datasets) > 0: | |||
for key, data in self.datasets.items(): | |||
tester = Tester(data=data, model=self.model, | |||
batch_size=self.trainer.kwargs.get('dev_batch_size', self.batch_size), | |||
metrics=self.trainer.metrics, verbose=0, | |||
use_tqdm=self.trainer.test_use_tqdm) | |||
self.testers[key] = tester | |||
def on_valid_end(self, eval_result, metric_key, optimizer, better_result): | |||
if len(self.testers) > 0: | |||
for key, tester in self.testers.items(): | |||
try: | |||
eval_result = tester.test() | |||
# self.pbar.write("Evaluation on {}:".format(key)) | |||
self.logger.info("Evaluation on {}:".format(key)) | |||
# self.pbar.write(tester._format_eval_results(eval_result)) | |||
self.logger.info(tester._format_eval_results(eval_result)) | |||
except Exception: | |||
# self.pbar.write("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
self.logger.info("Exception happens when evaluate on DataSet named `{}`.".format(key)) | |||
class LRScheduler(Callback): | |||
""" | |||
对PyTorch LR Scheduler的包装以使得其可以被Trainer所使用 | |||
""" | |||
def __init__(self, lr_scheduler): | |||
""" | |||
:param torch.optim.lr_scheduler._LRScheduler lr_scheduler: PyTorch的lr_scheduler | |||
""" | |||
super(LRScheduler, self).__init__() | |||
import torch.optim | |||
if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): | |||
@@ -557,13 +690,13 @@ class LRScheduler(Callback): | |||
class ControlC(Callback): | |||
""" | |||
别名::class:`fastNLP.ControlC` :class:`fastNLP.core.callback.ControlC` | |||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
检测到 control+C 时的反馈 | |||
""" | |||
def __init__(self, quit_all): | |||
""" | |||
:param bool quit_all: 若为True,则检测到control+C 直接退出程序;否则只退出Trainer | |||
""" | |||
super(ControlC, self).__init__() | |||
if type(quit_all) != bool: | |||
raise ValueError("In KeyBoardInterrupt, quit_all arguemnt must be a bool.") | |||
@@ -586,7 +719,7 @@ class SmoothValue(object): | |||
self.smooth = None | |||
def add_value(self, val: float) -> None: | |||
"Add `val` to calculate updated smoothed value." | |||
"""Add `val` to calculate updated smoothed value.""" | |||
self.n += 1 | |||
self.mov_avg = self.beta * self.mov_avg + (1 - self.beta) * val | |||
self.smooth = self.mov_avg / (1 - self.beta ** self.n) | |||
@@ -594,16 +727,15 @@ class SmoothValue(object): | |||
class LRFinder(Callback): | |||
""" | |||
别名::class:`fastNLP.LRFinder` :class:`fastNLP.core.callback.LRFinder` | |||
用第一个 epoch 找最佳的学习率,从第二个epoch开始应用它 | |||
:param float start_lr: 学习率下界 | |||
:param float end_lr: 学习率上界 | |||
""" | |||
def __init__(self, start_lr=1e-6, end_lr=10): | |||
""" | |||
:param float start_lr: 学习率下界 | |||
:param float end_lr: 学习率上界 | |||
""" | |||
super(LRFinder, self).__init__() | |||
self.start_lr, self.end_lr = start_lr, end_lr | |||
@@ -614,8 +746,7 @@ class LRFinder(Callback): | |||
self.smooth_value = SmoothValue(0.8) | |||
self.opt = None | |||
self.find = None | |||
self.loader = ModelLoader() | |||
@property | |||
def lr_gen(self): | |||
scale = (self.end_lr - self.start_lr) / self.batch_per_epoch | |||
@@ -630,7 +761,7 @@ class LRFinder(Callback): | |||
self.opt = self.trainer.optimizer # pytorch optimizer | |||
self.opt.param_groups[0]["lr"] = self.start_lr | |||
# save model | |||
ModelSaver("tmp").save_pytorch(self.trainer.model, param_only=True) | |||
torch.save(self.model.state_dict(), 'tmp') | |||
self.find = True | |||
def on_backward_begin(self, loss): | |||
@@ -659,14 +790,14 @@ class LRFinder(Callback): | |||
self.opt.param_groups[0]["lr"] = self.best_lr | |||
self.find = False | |||
# reset model | |||
ModelLoader().load_pytorch(self.trainer.model, "tmp") | |||
states = torch.load('tmp') | |||
self.model.load_state_dict(states) | |||
os.remove('tmp') | |||
self.pbar.write("Model reset. \nFind best lr={}".format(self.best_lr)) | |||
class TensorboardCallback(Callback): | |||
""" | |||
别名::class:`fastNLP.TensorboardCallback` :class:`fastNLP.core.callback.TensorboardCallback` | |||
接受以下一个或多个字符串作为参数: | |||
- "model" | |||
- "loss" | |||
@@ -742,13 +873,15 @@ class TensorboardCallback(Callback): | |||
class WarmupCallback(Callback): | |||
""" | |||
按一定的周期调节Learning rate的大小。 | |||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||
:param str schedule: 以哪种方式调整。linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后 | |||
warmup的step下降到0; constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||
""" | |||
def __init__(self, warmup=0.1, schedule='constant'): | |||
""" | |||
:param int,float warmup: 如果warmup为int,则在该step之前,learning rate根据schedule的策略变化; 如果warmup为float, | |||
如0.1, 则前10%的step是按照schedule策略调整learning rate。 | |||
:param str schedule: 以哪种方式调整。linear: 前warmup的step上升到指定的learning rate(从Trainer中的optimizer处获取的), 后 | |||
warmup的step下降到0; constant前warmup的step上升到指定learning rate,后面的step保持learning rate. | |||
""" | |||
super().__init__() | |||
self.warmup = max(warmup, 0.) | |||
@@ -790,19 +923,23 @@ class WarmupCallback(Callback): | |||
class SaveModelCallback(Callback): | |||
""" | |||
由于Trainer在训练过程中只会保存最佳的模型, 该callback可实现多种方式的结果存储。 | |||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型 | |||
-save_dir | |||
-2019-07-03-15-06-36 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 | |||
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt | |||
-2019-07-03-15-10-00 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | |||
:param bool only_param: 是否只保存模型d饿权重。 | |||
:param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}. | |||
会根据训练开始的时间戳在save_dir下建立文件夹,再在文件夹下存放多个模型:: | |||
-save_dir | |||
-2019-07-03-15-06-36 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_performance是性能 | |||
-epoch:1_step:40_{metric_key}:{evaluate_performance}.pt | |||
-2019-07-03-15-10-00 | |||
-epoch:0_step:20_{metric_key}:{evaluate_performance}.pt # metric是给定的metric_key, evaluate_perfomance是性能 | |||
""" | |||
def __init__(self, save_dir, top=3, only_param=False, save_on_exception=False): | |||
""" | |||
:param str save_dir: 将模型存放在哪个目录下,会在该目录下创建以时间戳命名的目录,并存放模型 | |||
:param int top: 保存dev表现top多少模型。-1为保存所有模型。 | |||
:param bool only_param: 是否只保存模型d饿权重。 | |||
:param save_on_exception: 发生exception时,是否保存一份发生exception的模型。模型名称为epoch:x_step:x_Exception:{exception_name}. | |||
""" | |||
super().__init__() | |||
if not os.path.isdir(save_dir): | |||
@@ -850,14 +987,14 @@ class SaveModelCallback(Callback): | |||
try: | |||
_save_model(self.model, model_name=name, save_dir=self.save_dir, only_param=self.only_param) | |||
except Exception as e: | |||
print(f"The following exception:{e} happens when save model to {self.save_dir}.") | |||
logger.error(f"The following exception:{e} happens when save model to {self.save_dir}.") | |||
if delete_pair: | |||
try: | |||
delete_model_path = os.path.join(self.save_dir, delete_pair[1]) | |||
if os.path.exists(delete_model_path): | |||
os.remove(delete_model_path) | |||
except Exception as e: | |||
print(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||
logger.error(f"Fail to delete model {name} at {self.save_dir} caused by exception:{e}.") | |||
def on_exception(self, exception): | |||
if self.save_on_exception: | |||
@@ -868,11 +1005,13 @@ class SaveModelCallback(Callback): | |||
class CallbackException(BaseException): | |||
""" | |||
当需要通过callback跳出训练的时候可以通过抛出CallbackException并在on_exception中捕获这个值。 | |||
:param str msg: Exception的信息。 | |||
""" | |||
def __init__(self, msg): | |||
""" | |||
:param str msg: Exception的信息。 | |||
""" | |||
super(CallbackException, self).__init__(msg) | |||
@@ -884,3 +1023,69 @@ class EarlyStopError(CallbackException): | |||
def __init__(self, msg): | |||
super(EarlyStopError, self).__init__(msg) | |||
class EchoCallback(Callback): | |||
def __init__(self, name, out=sys.stdout): | |||
super(EchoCallback, self).__init__() | |||
self.name = name | |||
self.out = out # deprecated | |||
def __getattribute__(self, item): | |||
if item.startswith('on_'): | |||
logger.info('{}.{} has been called at pid: {}'.format(self.name, item, os.getpid())) | |||
return super(EchoCallback, self).__getattribute__(item) | |||
class TesterCallback(Callback): | |||
def __init__(self, data, model, metrics, metric_key=None, batch_size=16, num_workers=None): | |||
super(TesterCallback, self).__init__() | |||
self.tester = Tester(data, model, | |||
metrics=metrics, batch_size=batch_size, | |||
num_workers=num_workers, verbose=0) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
self.increase_better = True | |||
if metric_key is not None: | |||
self.increase_better = False if metric_key[0] == "-" else True | |||
self.metric_key = metric_key[1:] if metric_key[0] == "+" or metric_key[0] == "-" else metric_key | |||
else: | |||
self.metric_key = None | |||
self.score = None | |||
def on_validation(self): | |||
cur_score = self.tester.test() | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. - {}".format( | |||
self.epoch, self.n_epochs, self.step, self.n_steps, | |||
self.tester._format_eval_results(cur_score)) | |||
self.logger.info(eval_str) | |||
is_better = self.compare_better(cur_score) | |||
if is_better: | |||
self.score = cur_score | |||
return cur_score, is_better | |||
def _get_score(self, metric_dict, key): | |||
for metric in metric_dict.items(): | |||
if key in metric: | |||
return metric[key] | |||
return None | |||
def compare_better(self, a): | |||
if self.score is None: | |||
return True | |||
if self.metric_key is None: | |||
self.metric_key = list(list(self.score.values())[0].keys())[0] | |||
k = self.metric_key | |||
score = self._get_score(self.score, k) | |||
new_score = self._get_score(a, k) | |||
if score is None or new_score is None: | |||
return False | |||
if self.increase_better: | |||
return score <= new_score | |||
else: | |||
return score >= new_score | |||
def on_train_end(self): | |||
self.logger.info('Evaluate on training ends.') | |||
self.on_validation() |
@@ -1,3 +1,13 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"Const" | |||
] | |||
class Const: | |||
""" | |||
fastNLP中field命名常量。 | |||
@@ -7,12 +17,14 @@ class Const: | |||
具体列表:: | |||
INPUT 模型的序列输入 words(复数words1, words2) | |||
CHAR_INPUT 模型character输入 chars(复数chars1, chars2) | |||
INPUT_LEN 序列长度 seq_len(复数seq_len1,seq_len2) | |||
OUTPUT 模型输出 pred(复数pred1, pred2) | |||
TARGET 真实目标 target(复数target1,target2) | |||
LOSS 损失函数 loss (复数loss1,loss2) | |||
INPUT 模型的序列输入 words(具有多列words时,依次使用words1, words2, ) | |||
CHAR_INPUT 模型character输入 chars(具有多列chars时,依次使用chars1, chars2) | |||
INPUT_LEN 序列长度 seq_len(具有多列seq_len时,依次使用seq_len1,seq_len2) | |||
OUTPUT 模型输出 pred(具有多列pred时,依次使用pred1, pred2) | |||
TARGET 真实目标 target(具有多列target时,依次使用target1,target2) | |||
LOSS 损失函数 loss (具有多列loss时,依次使用loss1,loss2) | |||
RAW_WORD 原文的词 raw_words (具有多列raw_words时,依次使用raw_words1, raw_words2) | |||
RAW_CHAR 原文的字 raw_chars (具有多列raw_chars时,依次使用raw_chars1, raw_chars2) | |||
""" | |||
INPUT = 'words' | |||
@@ -21,37 +33,49 @@ class Const: | |||
OUTPUT = 'pred' | |||
TARGET = 'target' | |||
LOSS = 'loss' | |||
RAW_WORD = 'raw_words' | |||
RAW_CHAR = 'raw_chars' | |||
@staticmethod | |||
def INPUTS(i): | |||
"""得到第 i 个 ``INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT + str(i) | |||
@staticmethod | |||
def CHAR_INPUTS(i): | |||
"""得到第 i 个 ``CHAR_INPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.CHAR_INPUT + str(i) | |||
@staticmethod | |||
def RAW_WORDS(i): | |||
i = int(i) + 1 | |||
return Const.RAW_WORD + str(i) | |||
@staticmethod | |||
def RAW_CHARS(i): | |||
i = int(i) + 1 | |||
return Const.RAW_CHAR + str(i) | |||
@staticmethod | |||
def INPUT_LENS(i): | |||
"""得到第 i 个 ``INPUT_LEN`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.INPUT_LEN + str(i) | |||
@staticmethod | |||
def OUTPUTS(i): | |||
"""得到第 i 个 ``OUTPUT`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.OUTPUT + str(i) | |||
@staticmethod | |||
def TARGETS(i): | |||
"""得到第 i 个 ``TARGET`` 的命名""" | |||
i = int(i) + 1 | |||
return Const.TARGET + str(i) | |||
@staticmethod | |||
def LOSSES(i): | |||
"""得到第 i 个 ``LOSS`` 的命名""" | |||
@@ -288,29 +288,33 @@ __all__ = [ | |||
] | |||
import _pickle as pickle | |||
import warnings | |||
from copy import deepcopy | |||
import numpy as np | |||
from ._logger import logger | |||
from .const import Const | |||
from .field import AppendToTargetOrInputException | |||
from .field import AutoPadder | |||
from .field import FieldArray | |||
from .field import SetInputOrTargetException | |||
from .instance import Instance | |||
from .utils import _get_func_signature | |||
from .field import AppendToTargetOrInputException | |||
from .field import SetInputOrTargetException | |||
from .utils import pretty_table_printer | |||
from prettytable import PrettyTable | |||
class DataSet(object): | |||
""" | |||
别名::class:`fastNLP.DataSet` :class:`fastNLP.core.dataset.DataSet` | |||
fastNLP的数据容器,详细的使用方法见文档 :doc:`fastNLP.core.dataset` | |||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
""" | |||
def __init__(self, data=None): | |||
""" | |||
:param data: 如果为dict类型,则每个key的value应该为等长的list; 如果为list, | |||
每个元素应该为具有相同field的 :class:`~fastNLP.Instance` 。 | |||
""" | |||
self.field_arrays = {} | |||
if data is not None: | |||
if isinstance(data, dict): | |||
@@ -324,41 +328,45 @@ class DataSet(object): | |||
for ins in data: | |||
assert isinstance(ins, Instance), "Must be Instance type, not {}.".format(type(ins)) | |||
self.append(ins) | |||
else: | |||
raise ValueError("data only be dict or list type.") | |||
def __contains__(self, item): | |||
return item in self.field_arrays | |||
def __iter__(self): | |||
def iter_func(): | |||
for idx in range(len(self)): | |||
yield self[idx] | |||
return iter_func() | |||
def _inner_iter(self): | |||
class Iter_ptr: | |||
def __init__(self, dataset, idx): | |||
self.dataset = dataset | |||
self.idx = idx | |||
def __getitem__(self, item): | |||
assert item in self.dataset.field_arrays, "no such field:{} in Instance {}".format(item, self.dataset[ | |||
self.idx]) | |||
assert self.idx < len(self.dataset.field_arrays[item]), "index:{} out of range".format(self.idx) | |||
return self.dataset.field_arrays[item][self.idx] | |||
def items(self): | |||
ins = self.dataset[self.idx] | |||
return ins.items() | |||
def __repr__(self): | |||
return self.dataset[self.idx].__repr__() | |||
def inner_iter_func(): | |||
for idx in range(len(self)): | |||
yield Iter_ptr(self, idx) | |||
return inner_iter_func() | |||
def __getitem__(self, idx): | |||
"""给定int的index,返回一个Instance; 给定slice,返回包含这个slice内容的新的DataSet。 | |||
@@ -391,20 +399,20 @@ class DataSet(object): | |||
return dataset | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __getattr__(self, item): | |||
# Not tested. Don't use !! | |||
if item == "field_arrays": | |||
raise AttributeError | |||
if isinstance(item, str) and item in self.field_arrays: | |||
return self.field_arrays[item] | |||
def __setstate__(self, state): | |||
self.__dict__ = state | |||
def __getstate__(self): | |||
return self.__dict__ | |||
def __len__(self): | |||
"""Fetch the length of the dataset. | |||
@@ -414,16 +422,66 @@ class DataSet(object): | |||
return 0 | |||
field = iter(self.field_arrays.values()).__next__() | |||
return len(field) | |||
def __inner_repr__(self): | |||
if len(self) < 20: | |||
return ",\n".join([ins.__repr__() for ins in self]) | |||
else: | |||
return self[:5].__inner_repr__() + "\n...\n" + self[-5:].__inner_repr__() | |||
def __repr__(self): | |||
return "DataSet(" + self.__inner_repr__() + ")" | |||
return str(pretty_table_printer(self)) | |||
def print_field_meta(self): | |||
""" | |||
输出当前field的meta信息, 形似下列的输出 | |||
+-------------+-------+-------+ | |||
| field_names | x | y | | |||
+-------------+-------+-------+ | |||
| is_input | True | False | | |||
| is_target | False | False | | |||
| ignore_type | False | | | |||
| pad_value | 0 | | | |||
+-------------+-------+-------+ | |||
field_names: DataSet中field的名称 | |||
is_input: field是否为input | |||
is_target: field是否为target | |||
ignore_type: 是否忽略该field的type, 一般仅在该field至少为input或target时才有意义 | |||
pad_value: 该field的pad的值,仅在该field为input或target时有意义 | |||
:return: | |||
""" | |||
if len(self.field_arrays)>0: | |||
field_names = ['field_names'] | |||
is_inputs = ['is_input'] | |||
is_targets = ['is_target'] | |||
pad_values = ['pad_value'] | |||
ignore_types = ['ignore_type'] | |||
for name, field_array in self.field_arrays.items(): | |||
field_names.append(name) | |||
if field_array.is_input: | |||
is_inputs.append(True) | |||
else: | |||
is_inputs.append(False) | |||
if field_array.is_target: | |||
is_targets.append(True) | |||
else: | |||
is_targets.append(False) | |||
if (field_array.is_input or field_array.is_target) and field_array.padder is not None: | |||
pad_values.append(field_array.padder.get_pad_val()) | |||
else: | |||
pad_values.append(' ') | |||
if field_array._ignore_type: | |||
ignore_types.append(True) | |||
elif field_array.is_input or field_array.is_target: | |||
ignore_types.append(False) | |||
else: | |||
ignore_types.append(' ') | |||
table = PrettyTable(field_names=field_names) | |||
fields = [is_inputs, is_targets, ignore_types, pad_values] | |||
for field in fields: | |||
table.add_row(field) | |||
logger.info(table) | |||
def append(self, instance): | |||
""" | |||
将一个instance对象append到DataSet后面。 | |||
@@ -446,9 +504,9 @@ class DataSet(object): | |||
try: | |||
self.field_arrays[name].append(field) | |||
except AppendToTargetOrInputException as e: | |||
print(f"Cannot append to field:{name}.") | |||
logger.error(f"Cannot append to field:{name}.") | |||
raise e | |||
def add_fieldarray(self, field_name, fieldarray): | |||
""" | |||
将fieldarray添加到DataSet中. | |||
@@ -463,7 +521,7 @@ class DataSet(object): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fieldarray)}") | |||
self.field_arrays[field_name] = fieldarray | |||
def add_field(self, field_name, fields, padder=AutoPadder(), is_input=False, is_target=False, ignore_type=False): | |||
""" | |||
新增一个field | |||
@@ -475,19 +533,19 @@ class DataSet(object): | |||
:param bool is_target: 新加入的field是否是target | |||
:param bool ignore_type: 是否忽略对新加入的field的类型检查 | |||
""" | |||
if len(self.field_arrays) != 0: | |||
if len(self) != len(fields): | |||
raise RuntimeError(f"The field to add must have the same size as dataset. " | |||
f"Dataset size {len(self)} != field size {len(fields)}") | |||
self.field_arrays[field_name] = FieldArray(field_name, fields, is_target=is_target, is_input=is_input, | |||
padder=padder, ignore_type=ignore_type) | |||
def delete_instance(self, index): | |||
""" | |||
删除第index个instance | |||
:param int index: 需要删除的instance的index,从0开始 | |||
:param int index: 需要删除的instance的index,序号从0开始。 | |||
""" | |||
assert isinstance(index, int), "Only integer supported." | |||
if len(self) <= index: | |||
@@ -497,7 +555,8 @@ class DataSet(object): | |||
else: | |||
for field in self.field_arrays.values(): | |||
field.pop(index) | |||
return self | |||
def delete_field(self, field_name): | |||
""" | |||
删除名为field_name的field | |||
@@ -505,7 +564,22 @@ class DataSet(object): | |||
:param str field_name: 需要删除的field的名称. | |||
""" | |||
self.field_arrays.pop(field_name) | |||
return self | |||
def copy_field(self, field_name, new_field_name): | |||
""" | |||
深度copy名为field_name的field到new_field_name | |||
:param str field_name: 需要copy的field。 | |||
:param str new_field_name: copy生成的field名称 | |||
:return: self | |||
""" | |||
if not self.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} not found in DataSet.") | |||
fieldarray = deepcopy(self.get_field(field_name)) | |||
self.add_fieldarray(field_name=new_field_name, fieldarray=fieldarray) | |||
return self | |||
def has_field(self, field_name): | |||
""" | |||
判断DataSet中是否有名为field_name这个field | |||
@@ -516,7 +590,7 @@ class DataSet(object): | |||
if isinstance(field_name, str): | |||
return field_name in self.field_arrays | |||
return False | |||
def get_field(self, field_name): | |||
""" | |||
获取field_name这个field | |||
@@ -527,7 +601,7 @@ class DataSet(object): | |||
if field_name not in self.field_arrays: | |||
raise KeyError("Field name {} not found in DataSet".format(field_name)) | |||
return self.field_arrays[field_name] | |||
def get_all_fields(self): | |||
""" | |||
返回一个dict,key为field_name, value为对应的 :class:`~fastNLP.FieldArray` | |||
@@ -535,7 +609,7 @@ class DataSet(object): | |||
:return dict: 返回如上所述的字典 | |||
""" | |||
return self.field_arrays | |||
def get_field_names(self) -> list: | |||
""" | |||
返回一个list,包含所有 field 的名字 | |||
@@ -543,7 +617,7 @@ class DataSet(object): | |||
:return list: 返回如上所述的列表 | |||
""" | |||
return sorted(self.field_arrays.keys()) | |||
def get_length(self): | |||
""" | |||
获取DataSet的元素数量 | |||
@@ -551,22 +625,22 @@ class DataSet(object): | |||
:return: int: DataSet中Instance的个数。 | |||
""" | |||
return len(self) | |||
def rename_field(self, old_name, new_name): | |||
def rename_field(self, field_name, new_field_name): | |||
""" | |||
将某个field重新命名. | |||
:param str old_name: 原来的field名称。 | |||
:param str new_name: 修改为new_name。 | |||
:param str field_name: 原来的field名称。 | |||
:param str new_field_name: 修改为new_name。 | |||
""" | |||
if old_name in self.field_arrays: | |||
self.field_arrays[new_name] = self.field_arrays.pop(old_name) | |||
self.field_arrays[new_name].name = new_name | |||
if field_name in self.field_arrays: | |||
self.field_arrays[new_field_name] = self.field_arrays.pop(field_name) | |||
self.field_arrays[new_field_name].name = new_field_name | |||
else: | |||
raise KeyError("DataSet has no field named {}.".format(old_name)) | |||
raise KeyError("DataSet has no field named {}.".format(field_name)) | |||
return self | |||
def set_target(self, *field_names, flag=True): | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为target | |||
@@ -577,19 +651,23 @@ class DataSet(object): | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的target状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
""" | |||
assert isinstance(flag, bool), "Only bool type supported." | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
try: | |||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self.field_arrays[name].is_target = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as target.") | |||
logger.error(f"Cannot set field:{name} as target.") | |||
raise e | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
def set_input(self, *field_names, flag=True): | |||
return self | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True): | |||
""" | |||
将field_names的field设置为input:: | |||
@@ -598,17 +676,21 @@ class DataSet(object): | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的input状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
""" | |||
for name in field_names: | |||
if name in self.field_arrays: | |||
try: | |||
self.field_arrays[name]._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self.field_arrays[name].is_input = flag | |||
except SetInputOrTargetException as e: | |||
print(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||
logger.error(f"Cannot set field:{name} as input, exception happens at the {e.index} value.") | |||
raise e | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_ignore_type(self, *field_names, flag=True): | |||
""" | |||
将field设置为忽略类型状态。当某个field被设置了ignore_type, 则在被设置为target或者input时将不进行类型检查, | |||
@@ -624,7 +706,8 @@ class DataSet(object): | |||
self.field_arrays[name].ignore_type = flag | |||
else: | |||
raise KeyError("{} is not a valid field name.".format(name)) | |||
return self | |||
def set_padder(self, field_name, padder): | |||
""" | |||
为field_name设置padder:: | |||
@@ -639,7 +722,8 @@ class DataSet(object): | |||
if field_name not in self.field_arrays: | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_padder(padder) | |||
return self | |||
def set_pad_val(self, field_name, pad_val): | |||
""" | |||
为某个field设置对应的pad_val. | |||
@@ -650,7 +734,8 @@ class DataSet(object): | |||
if field_name not in self.field_arrays: | |||
raise KeyError("There is no field named {}.".format(field_name)) | |||
self.field_arrays[field_name].set_pad_val(pad_val) | |||
return self | |||
def get_input_name(self): | |||
""" | |||
返回所有is_input被设置为True的field名称 | |||
@@ -658,7 +743,7 @@ class DataSet(object): | |||
:return list: 里面的元素为被设置为input的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_input] | |||
def get_target_name(self): | |||
""" | |||
返回所有is_target被设置为True的field名称 | |||
@@ -666,7 +751,7 @@ class DataSet(object): | |||
:return list: 里面的元素为被设置为target的field名称 | |||
""" | |||
return [name for name, field in self.field_arrays.items() if field.is_target] | |||
def apply_field(self, func, field_name, new_field_name=None, **kwargs): | |||
""" | |||
将DataSet中的每个instance中的名为 `field_name` 的field传给func,并获取它的返回值。 | |||
@@ -695,16 +780,16 @@ class DataSet(object): | |||
results.append(func(ins[field_name])) | |||
except Exception as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
logger.error("Exception happens at the `{}`th(from 1) instance.".format(idx + 1)) | |||
raise e | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def _add_apply_field(self, results, new_field_name, kwargs): | |||
""" | |||
将results作为加入到新的field中,field名称为new_field_name | |||
@@ -736,7 +821,7 @@ class DataSet(object): | |||
self.add_field(field_name=new_field_name, fields=results, is_input=extra_param.get("is_input", None), | |||
is_target=extra_param.get("is_target", None), | |||
ignore_type=extra_param.get("ignore_type", False)) | |||
def apply(self, func, new_field_name=None, **kwargs): | |||
""" | |||
将DataSet中每个instance传入到func中,并获取它的返回值. | |||
@@ -760,20 +845,21 @@ class DataSet(object): | |||
results = [] | |||
for idx, ins in enumerate(self._inner_iter()): | |||
results.append(func(ins)) | |||
except Exception as e: | |||
except BaseException as e: | |||
if idx != -1: | |||
print("Exception happens at the `{}`th instance.".format(idx)) | |||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
# results = [func(ins) for ins in self._inner_iter()] | |||
if not (new_field_name is None) and len(list(filter(lambda x: x is not None, results))) == 0: # all None | |||
raise ValueError("{} always return None.".format(_get_func_signature(func=func))) | |||
if new_field_name is not None: | |||
self._add_apply_field(results, new_field_name, kwargs) | |||
return results | |||
def add_seq_len(self, field_name:str, new_field_name='seq_len'): | |||
def add_seq_len(self, field_name: str, new_field_name=Const.INPUT_LEN): | |||
""" | |||
将使用len()直接对field_name中每个元素作用,将其结果作为seqence length, 并放入seq_len这个field。 | |||
@@ -810,7 +896,7 @@ class DataSet(object): | |||
return dataset | |||
else: | |||
return DataSet() | |||
def split(self, ratio, shuffle=True): | |||
""" | |||
将DataSet按照ratio的比例拆分,返回两个DataSet | |||
@@ -836,51 +922,9 @@ class DataSet(object): | |||
for field_name in self.field_arrays: | |||
train_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
dev_set.field_arrays[field_name].to(self.field_arrays[field_name]) | |||
return train_set, dev_set | |||
@classmethod | |||
def read_csv(cls, csv_path, headers=None, sep=",", dropna=True): | |||
r""" | |||
.. warning:: | |||
此方法会在下个版本移除,请使用 :class:`fastNLP.io.CSVLoader` | |||
从csv_path路径下以csv的格式读取数据。 | |||
:param str csv_path: 从哪里读取csv文件 | |||
:param list[str] headers: 如果为None,则使用csv文件的第一行作为header; 如果传入list(str), 则元素的个数必须 | |||
与csv文件中每行的元素个数相同。 | |||
:param str sep: 分割符 | |||
:param bool dropna: 是否忽略与header数量不一致行。 | |||
:return: 读取后的 :class:`~fastNLP.读取后的DataSet`。 | |||
""" | |||
warnings.warn('DataSet.read_csv is deprecated, use CSVLoader instead', | |||
category=DeprecationWarning) | |||
with open(csv_path, "r", encoding='utf-8') as f: | |||
start_idx = 0 | |||
if headers is None: | |||
headers = f.readline().rstrip('\r\n') | |||
headers = headers.split(sep) | |||
start_idx += 1 | |||
else: | |||
assert isinstance(headers, (list, tuple)), "headers should be list or tuple, not {}.".format( | |||
type(headers)) | |||
_dict = {} | |||
for col in headers: | |||
_dict[col] = [] | |||
for line_idx, line in enumerate(f, start_idx): | |||
contents = line.rstrip('\r\n').split(sep) | |||
if len(contents) != len(headers): | |||
if dropna: | |||
continue | |||
else: | |||
# TODO change error type | |||
raise ValueError("Line {} has {} parts, while header has {} parts." \ | |||
.format(line_idx, len(contents), len(headers))) | |||
for header, content in zip(headers, contents): | |||
_dict[header].append(content) | |||
return cls(_dict) | |||
def save(self, path): | |||
""" | |||
保存DataSet. | |||
@@ -889,7 +933,7 @@ class DataSet(object): | |||
""" | |||
with open(path, 'wb') as f: | |||
pickle.dump(self, f) | |||
@staticmethod | |||
def load(path): | |||
r""" | |||
@@ -0,0 +1,356 @@ | |||
"""undocumented | |||
正在开发中的分布式训练代码 | |||
""" | |||
import logging | |||
import os | |||
import time | |||
from datetime import datetime | |||
import torch | |||
import torch.cuda | |||
import torch.distributed as dist | |||
import torch.optim | |||
from pkg_resources import parse_version | |||
from torch.nn.parallel import DistributedDataParallel as DDP | |||
from torch.utils.data.distributed import DistributedSampler | |||
from tqdm import tqdm | |||
from ._logger import logger | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import DistCallbackManager, CallbackException, TesterCallback | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
from .optimizer import Optimizer | |||
from .utils import _build_args | |||
from .utils import _get_func_signature | |||
from .utils import _move_dict_value_to_device | |||
__all__ = [ | |||
'get_local_rank', | |||
'DistTrainer', | |||
] | |||
def get_local_rank(): | |||
if 'LOCAL_RANK' in os.environ: | |||
return int(os.environ['LOCAL_RANK']) | |||
from argparse import ArgumentParser | |||
parser = ArgumentParser() | |||
parser.add_argument('--local_rank', type=int) | |||
args, _ = parser.parse_known_args() | |||
if 'local_rank' in args and args.local_rank: | |||
os.environ['LOCAL_RANK'] = str(args.local_rank) # for multiple calls for this function | |||
return args.local_rank | |||
raise RuntimeError('Please use "python -m torch.distributed.launch --nproc_per_node=N train_script.py') | |||
class DistTrainer(): | |||
""" | |||
Distributed Trainer that support distributed and mixed precision training | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
callbacks_all=None, callbacks_master=None, | |||
batch_size_per_gpu=8, n_epochs=1, | |||
num_workers=1, drop_last=False, | |||
dev_data=None, metrics=None, metric_key=None, | |||
update_every=1, print_every=10, validate_every=-1, | |||
save_every=-1, save_path=None, device='auto', | |||
fp16='', backend=None, init_method=None): | |||
assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" | |||
if device == 'auto': | |||
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |||
if backend is None: | |||
backend = 'nccl' if device == 'cuda' else 'gloo' | |||
# init distributed | |||
if device == 'cuda': | |||
torch.cuda.set_device(get_local_rank()) | |||
self.device = torch.device("cuda", get_local_rank()) | |||
else: | |||
self.device = torch.device(device) | |||
dist.init_process_group(backend=backend, init_method=init_method) | |||
self.world_size = dist.get_world_size() | |||
self.rank = dist.get_rank() # unique id for each process | |||
self.model = model | |||
self.train_data = train_data | |||
self.batch_size_per_gpu = int(batch_size_per_gpu) | |||
self.n_epochs = int(n_epochs) | |||
self.num_data_workers = int(num_workers) | |||
self.drop_last = drop_last | |||
self.update_every = int(update_every) | |||
self.print_every = int(print_every) | |||
self.validate_every = int(validate_every) | |||
self.save_every = int(save_every) | |||
self.save_path = save_path | |||
self.losser = _prepare_losser(loss) | |||
self.fp16 = fp16 | |||
self.init_method = init_method | |||
self.backend = backend | |||
self.local_rank = get_local_rank() | |||
self._forward_func = model.forward | |||
self.callback_manager = DistCallbackManager( | |||
env={"trainer": self}, callbacks_all=callbacks_all, | |||
callbacks_master=callbacks_master) | |||
self.metric_key = metric_key | |||
model.to(self.device) | |||
optimizer = self._get_optimizer(optimizer) | |||
# init fp16, must before DataParallel init | |||
if len(self.fp16): | |||
assert isinstance(self.fp16, str), "Please set Apex AMP optimization level selected in ['O0', 'O1', 'O2', 'O3']" | |||
try: | |||
from apex import amp | |||
except ImportError: | |||
raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") | |||
assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled." | |||
assert device == 'cuda', "Amp requires cuda device" | |||
model, optimizer = amp.initialize(model, optimizer, opt_level=self.fp16) | |||
# init DataParallel | |||
if parse_version(torch.__version__)>=parse_version('1.1'): | |||
self.model = DDP(model, device_ids=[self.local_rank], | |||
output_device=self.local_rank, find_unused_parameters=True) | |||
else: | |||
self.model = DDP(model, device_ids=[self.local_rank], | |||
output_device=self.local_rank) | |||
self.optimizer = optimizer | |||
self.sampler = DistributedSampler(self.train_data) | |||
self.data_iterator = self._get_data_iter(self.train_data) | |||
self.n_steps = self._get_n_steps() | |||
# for evaluation, only run eval on master proc | |||
if dev_data and metrics: | |||
cb = TesterCallback( | |||
dev_data, model, metrics, | |||
batch_size=batch_size_per_gpu, num_workers=num_workers) | |||
self.callback_manager.add_callback([cb], master=True) | |||
# Setup logging | |||
dist.barrier() | |||
self.start_time = datetime.now().strftime('%m_%d_%Y-%H_%M') | |||
if self.save_path: | |||
self.cp_save_path = os.path.join(self.save_path, 'checkpoints', self.start_time) | |||
else: | |||
self.cp_save_path = None | |||
# use INFO in the master, WARN for others | |||
logger.setLevel(logging.INFO if self.is_master else logging.WARNING) | |||
self.logger = logger | |||
self.logger.info("Setup Distributed Trainer") | |||
self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( | |||
os.getpid(), self.rank, self.local_rank, self.device, self.fp16 if self.fp16 else False)) | |||
self.logger.info("Num of processes: {}".format(self.world_size)) | |||
self.logger.info("Use device: {}".format(device)) | |||
self.logger.info("Training with fp16: {}, optimization level: {}".format( | |||
len(self.fp16) > 0, self.fp16 if self.fp16 else None)) | |||
def _get_n_steps(self): | |||
batch_size = self.world_size * self.batch_size_per_gpu | |||
return (len(self.train_data) // batch_size + int( | |||
len(self.train_data) % batch_size != 0)) * int(self.drop_last == 0) * self.n_epochs | |||
def _get_data_iter(self, dataset): | |||
if isinstance(dataset, DataSet): | |||
return DataSetIter( | |||
dataset=dataset, batch_size=self.batch_size_per_gpu, | |||
num_workers=self.num_data_workers, sampler=self.sampler, | |||
drop_last=self.drop_last | |||
) | |||
elif isinstance(dataset, BatchIter): | |||
return dataset | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(dataset))) | |||
def _get_optimizer(self, optimizer): | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
return optimizer | |||
elif isinstance(optimizer, Optimizer): | |||
return optimizer.construct_from_pytorch(self.model.parameters()) | |||
elif optimizer is None: | |||
return torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
@property | |||
def is_master(self): | |||
return self.rank == 0 | |||
def train(self, on_exception='auto'): | |||
try: | |||
self.logger.info("###### Training epochs started ######") | |||
self.logger.info('Total epochs: %d'% self.n_epochs) | |||
self.logger.info('Total steps: %d'% self.n_steps) | |||
self.logger.info('Num instances per GPU %d'% self.batch_size_per_gpu) | |||
self.logger.info('Total batch_size: %d'% self.batch_size_per_gpu * dist.get_world_size()) | |||
self.logger.info('Total num of samples: %d'% len(self.train_data)) | |||
self.logger.info("Num of callbacks for all workers: {}".format( | |||
len(self.callback_manager.callbacks_all))) | |||
self.logger.info("Num of callbacks for master workers: {}".format( | |||
len(self.callback_manager.callbacks_master))) | |||
self.logger.info("Callbacks for all workers: {}".format( | |||
[repr(cb) for cb in self.callback_manager.callbacks_all])) | |||
self.logger.info("Callbacks for master workers: {}".format( | |||
[repr(cb) for cb in self.callback_manager.callbacks_master])) | |||
start_time = time.time() | |||
results = {} | |||
if self.n_epochs <= 0: | |||
self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) | |||
results['seconds'] = 0. | |||
return results | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
self.callback_manager.on_train_end() | |||
except BaseException as e: | |||
self.callback_manager.on_exception(e) | |||
if on_exception == 'auto': | |||
if not isinstance(e, (CallbackException, KeyboardInterrupt)): | |||
raise e | |||
else: | |||
self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) | |||
elif on_exception == 'raise': | |||
raise e | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
self.logger.info("###### Train finished ######") | |||
self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) | |||
return results | |||
finally: | |||
self.close() | |||
def _train(self): | |||
if self.fp16: | |||
# skip check, done in __init__() | |||
from apex import amp | |||
self.step = 0 | |||
self.epoch = 0 | |||
self.pbar = tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', | |||
leave=False, dynamic_ncols=True, disable=not self.is_master) | |||
pbar = self.pbar | |||
avg_loss = 0 | |||
data_iterator = self.data_iterator | |||
self.model.zero_grad() | |||
for epoch in range(1, self.n_epochs + 1): | |||
self.epoch = epoch | |||
pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) | |||
# early stopping | |||
self.callback_manager.on_epoch_begin() | |||
for batch_x, batch_y in data_iterator: | |||
self.model.train() | |||
self.step += 1 | |||
_move_dict_value_to_device(batch_x, batch_y, device=self.device) | |||
indices = data_iterator.get_batch_indices() | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y) | |||
avg_loss += loss.item() | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
if self.fp16: | |||
with amp.scale_loss(loss, self.optimizer) as scale_loss: | |||
scale_loss.backward() | |||
else: | |||
loss.backward() | |||
self.callback_manager.on_backward_end() | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
print_output = "loss:{:<6.5f}".format(avg_loss) | |||
pbar.update(self.print_every) | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.callback_manager.on_batch_end() | |||
if (self.validate_every > 0 and self.step % self.validate_every == 0): | |||
self._do_validation() | |||
if self.cp_save_path and \ | |||
self.save_every > 0 and \ | |||
self.step % self.save_every == 0: | |||
self.save_check_point() | |||
# ================= mini-batch end ==================== # | |||
if self.validate_every < 0: | |||
self._do_validation() | |||
if self.save_every < 0 and self.cp_save_path: | |||
self.save_check_point() | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
self.pbar = None | |||
# ============ tqdm end ============== # | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
""" | |||
if self.step % self.update_every == 0: | |||
self.optimizer.step() | |||
self.model.zero_grad() | |||
def _data_forward(self, network, x): | |||
x = _build_args(self._forward_func, **x) | |||
y = network(**x) | |||
if not isinstance(y, dict): | |||
raise TypeError( | |||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | |||
return y | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
:param predict: prediction dict, produced by model.forward | |||
:param truth: ground truth dict, produced by batch_y | |||
:return: a scalar | |||
""" | |||
loss = self.losser(predict, truth) | |||
if self.update_every > 1: | |||
loss = loss / self.update_every | |||
return loss.mean() | |||
def save_check_point(self, only_params=False): | |||
# only master save models | |||
if self.is_master: | |||
os.makedirs(self.cp_save_path, exist_ok=True) | |||
path = os.path.join(self.cp_save_path, 'checkpoint-{}.bin'.format(self.step)) | |||
self.logger.info("Save checkpoint to {}".format(path)) | |||
model_to_save = self.model.module | |||
if only_params: | |||
model_to_save = model_to_save.state_dict() | |||
torch.save(model_to_save, path) | |||
def _do_validation(self): | |||
self.callback_manager.on_valid_begin() | |||
eval_res = self.callback_manager.on_validation() | |||
eval_res = list(filter(lambda x: x is not None, eval_res)) | |||
if len(eval_res): | |||
eval_res, is_better = list(zip(*eval_res)) | |||
else: | |||
eval_res, is_better = None, None | |||
self.callback_manager.on_valid_end( | |||
eval_res, self.metric_key, self.optimizer, is_better) | |||
dist.barrier() | |||
def close(self): | |||
dist.destroy_process_group() |
@@ -1,73 +1,91 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"Padder", | |||
"AutoPadder", | |||
"EngChar2DPadder", | |||
] | |||
from numbers import Number | |||
import torch | |||
import numpy as np | |||
from typing import Any | |||
from abc import abstractmethod | |||
from copy import deepcopy | |||
from collections import Counter | |||
from copy import deepcopy | |||
from numbers import Number | |||
from typing import Any | |||
import numpy as np | |||
import torch | |||
from ._logger import logger | |||
from .utils import _is_iterable | |||
class SetInputOrTargetException(Exception): | |||
def __init__(self, msg, index=None, field_name=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
self.field_name = field_name # 标示当前field的名称 | |||
self.field_name = field_name # 标示当前field的名称 | |||
class AppendToTargetOrInputException(Exception): | |||
def __init__(self, msg, index=None, field_name=None): | |||
super().__init__(msg) | |||
self.msg = msg | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
self.field_name = field_name # 标示当前field的名称 | |||
self.field_name = field_name # 标示当前field的名称 | |||
class FieldArray: | |||
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False): | |||
if len(content)==0: | |||
def __init__(self, name, content, is_target=False, is_input=False, padder=None, ignore_type=False, | |||
use_1st_ins_infer_dim_type=True): | |||
if len(content) == 0: | |||
raise RuntimeError("Empty fieldarray is not allowed.") | |||
_content = content | |||
try: | |||
_content = list(_content) | |||
except BaseException as e: | |||
print(f"Cannot convert content(of type:{type(content)}) into list.") | |||
logger.error(f"Cannot convert content(of type:{type(content)}) into list.") | |||
raise e | |||
self.name = name | |||
self.content = _content | |||
self._ignore_type = ignore_type | |||
# 根据input的情况设置input,target等 | |||
self._cell_ndim = None # 多少维度 | |||
self._cell_ndim = None # 多少维度, 如果value是1, dim为0; 如果value是[1, 2], dim=2 | |||
self.dtype = None # 最内层的element都是什么类型的 | |||
self._use_1st_ins_infer_dim_type = bool(use_1st_ins_infer_dim_type) | |||
self._is_input = False | |||
self._is_target = False | |||
if is_input: | |||
self.is_input = is_input | |||
if is_target: | |||
self.is_target = is_target | |||
if padder is None: | |||
padder = AutoPadder(pad_val=0) | |||
else: | |||
assert isinstance(padder, Padder), "padder must be of type fastNLP.Padder." | |||
padder = deepcopy(padder) | |||
self.set_padder(padder) | |||
@property | |||
def ignore_type(self): | |||
return self._ignore_type | |||
@ignore_type.setter | |||
def ignore_type(self, value): | |||
if value: | |||
self._cell_ndim = None | |||
self.dtype = None | |||
self._ignore_type = value | |||
@property | |||
def is_input(self): | |||
return self._is_input | |||
@is_input.setter | |||
def is_input(self, value): | |||
""" | |||
@@ -77,16 +95,16 @@ class FieldArray: | |||
if value is True and \ | |||
self._is_target is False and \ | |||
self._ignore_type is False: | |||
self._check_dtype_and_ndim() | |||
self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) | |||
if value is False and self._is_target is False: | |||
self.dtype = None | |||
self._cell_ndim = None | |||
self._is_input = value | |||
@property | |||
def is_target(self): | |||
return self._is_target | |||
@is_target.setter | |||
def is_target(self, value): | |||
""" | |||
@@ -95,70 +113,82 @@ class FieldArray: | |||
if value is True and \ | |||
self._is_input is False and \ | |||
self._ignore_type is False: | |||
self._check_dtype_and_ndim() | |||
self._check_dtype_and_ndim(only_check_1st_ins_dim_type=self._use_1st_ins_infer_dim_type) | |||
if value is False and self._is_input is False: | |||
self.dtype = None | |||
self._cell_ndim = None | |||
self._is_target = value | |||
def _check_dtype_and_ndim(self): | |||
def _check_dtype_and_ndim(self, only_check_1st_ins_dim_type=True): | |||
""" | |||
检查当前content所有的element是否是同一个类型,且是否每个元素具有相同的维度。通过的话,设置_cell_ndim与_ele_type属性;没有 | |||
通过将直接报错. | |||
:param bool only_check_1st_ins_dim_type: 是否只检查第一个元素的type和dim | |||
:return: | |||
""" | |||
cell_0 = self.content[0] | |||
index = 0 | |||
try: | |||
type_0, dim_0 = _get_ele_type_and_dim(cell_0) | |||
for cell in self.content[1:]: | |||
index += 1 | |||
type_i, dim_i = _get_ele_type_and_dim(cell) | |||
if type_i!=type_0: | |||
raise SetInputOrTargetException("Type:{} in index {} is different from the first element with type:{}." | |||
".".format(type_i, index, type_0)) | |||
if dim_0!=dim_i: | |||
raise SetInputOrTargetException("Dimension:{} in index {} is different from the first element with " | |||
"dimension:{}.".format(dim_i, index, dim_0)) | |||
if not only_check_1st_ins_dim_type: | |||
for cell in self.content[1:]: | |||
index += 1 | |||
type_i, dim_i = _get_ele_type_and_dim(cell) | |||
if type_i != type_0: | |||
raise SetInputOrTargetException( | |||
"Type:{} in index {} is different from the first element with type:{}." | |||
".".format(type_i, index, type_0)) | |||
if dim_0 != dim_i: | |||
raise SetInputOrTargetException( | |||
"Dimension:{} in index {} is different from the first element with " | |||
"dimension:{}.".format(dim_i, index, dim_0)) | |||
self._cell_ndim = dim_0 | |||
self.dtype = type_0 | |||
except SetInputOrTargetException as e: | |||
e.index = index | |||
raise e | |||
def append(self, val:Any): | |||
def append(self, val: Any): | |||
""" | |||
:param val: 把该val append到fieldarray。 | |||
:return: | |||
""" | |||
if (self._is_target or self._is_input) and self._ignore_type is False: | |||
if (self._is_target or self._is_input) and self._ignore_type is False and not self._use_1st_ins_infer_dim_type: | |||
type_, dim_ = _get_ele_type_and_dim(val) | |||
if self.dtype!=type_: | |||
if self.dtype != type_: | |||
raise AppendToTargetOrInputException(f"Value(type:{type_}) are of different types with " | |||
f"previous values(type:{self.dtype}).") | |||
if self._cell_ndim!=dim_: | |||
if self._cell_ndim != dim_: | |||
raise AppendToTargetOrInputException(f"Value(dim:{dim_}) are of different dimensions with " | |||
f"previous values(dim:{self._cell_ndim}).") | |||
self.content.append(val) | |||
else: | |||
self.content.append(val) | |||
def pop(self, index): | |||
""" | |||
删除该field中index处的元素 | |||
:param int index: 从0开始的数据下标。 | |||
:return: | |||
""" | |||
self.content.pop(index) | |||
def __getitem__(self, indices): | |||
return self.get(indices, pad=False) | |||
def __setitem__(self, idx, val): | |||
assert isinstance(idx, int) | |||
if (self._is_target or self._is_input) and self.ignore_type is False: # 需要检测类型 | |||
type_, dim_ = _get_ele_type_and_dim(val) | |||
if self.dtype!=type_: | |||
if self.dtype != type_: | |||
raise RuntimeError(f"Value(type:{type_}) are of different types with " | |||
f"other values(type:{self.dtype}).") | |||
if self._cell_ndim!=dim_: | |||
f"other values(type:{self.dtype}).") | |||
if self._cell_ndim != dim_: | |||
raise RuntimeError(f"Value(dim:{dim_}) are of different dimensions with " | |||
f"previous values(dim:{self._cell_ndim}).") | |||
f"previous values(dim:{self._cell_ndim}).") | |||
self.content[idx] = val | |||
def get(self, indices, pad=True): | |||
""" | |||
根据给定的indices返回内容 | |||
@@ -171,16 +201,16 @@ class FieldArray: | |||
return self.content[indices] | |||
if self.is_input is False and self.is_target is False: | |||
raise RuntimeError("Please specify either is_input or is_target to True for {}".format(self.name)) | |||
contents = [self.content[i] for i in indices] | |||
if self.padder is None or pad is False: | |||
return np.array(contents) | |||
else: | |||
return self.pad(contents) | |||
def pad(self, contents): | |||
return self.padder(contents, field_name=self.name, field_ele_dtype=self.dtype, dim=self._cell_ndim) | |||
def set_padder(self, padder): | |||
""" | |||
设置padder,在这个field进行pad的时候用这个padder进行pad,如果为None则不进行pad。 | |||
@@ -192,7 +222,7 @@ class FieldArray: | |||
self.padder = deepcopy(padder) | |||
else: | |||
self.padder = None | |||
def set_pad_val(self, pad_val): | |||
""" | |||
修改padder的pad_val. | |||
@@ -202,7 +232,7 @@ class FieldArray: | |||
if self.padder is not None: | |||
self.padder.set_pad_val(pad_val) | |||
return self | |||
def __len__(self): | |||
""" | |||
Returns the size of FieldArray. | |||
@@ -210,7 +240,7 @@ class FieldArray: | |||
:return int length: | |||
""" | |||
return len(self.content) | |||
def to(self, other): | |||
""" | |||
将other的属性复制给本FieldArray(other必须为FieldArray类型). | |||
@@ -220,15 +250,15 @@ class FieldArray: | |||
:return: :class:`~fastNLP.FieldArray` | |||
""" | |||
assert isinstance(other, FieldArray), "Only supports fastNLP.FieldArray type, not {}.".format(type(other)) | |||
self.ignore_type = other.ignore_type | |||
self.is_input = other.is_input | |||
self.is_target = other.is_target | |||
self.padder = other.padder | |||
return self | |||
def split(self, sep:str=None, inplace:bool=True): | |||
def split(self, sep: str = None, inplace: bool = True): | |||
""" | |||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||
@@ -241,11 +271,11 @@ class FieldArray: | |||
try: | |||
new_contents.append(cell.split(sep)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def int(self, inplace:bool=True): | |||
def int(self, inplace: bool = True): | |||
""" | |||
将本field中的值调用int(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
(2) [['1', '2', ..], ['3', ..], ...](即field中每个值为一个list,list中的值会被依次转换。) | |||
@@ -261,10 +291,10 @@ class FieldArray: | |||
else: | |||
new_contents.append(int(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
print(e) | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def float(self, inplace=True): | |||
""" | |||
将本field中的值调用float(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
@@ -281,10 +311,10 @@ class FieldArray: | |||
else: | |||
new_contents.append(float(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def bool(self, inplace=True): | |||
""" | |||
将本field中的值调用bool(cell). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
@@ -301,11 +331,11 @@ class FieldArray: | |||
else: | |||
new_contents.append(bool(cell)) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def lower(self, inplace=True): | |||
""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
@@ -322,10 +352,10 @@ class FieldArray: | |||
else: | |||
new_contents.append(cell.lower()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def upper(self, inplace=True): | |||
""" | |||
将本field中的值调用cell.lower(). 支持field中内容为以下两种情况(1)['1', '2', ...](即field中每个值为str的), | |||
@@ -342,10 +372,10 @@ class FieldArray: | |||
else: | |||
new_contents.append(cell.upper()) | |||
except Exception as e: | |||
print(f"Exception happens when process value in index {index}.") | |||
logger.error(f"Exception happens when process value in index {index}.") | |||
raise e | |||
return self._after_process(new_contents, inplace=inplace) | |||
def value_count(self): | |||
""" | |||
返回该field下不同value的数量。多用于统计label数量 | |||
@@ -353,17 +383,18 @@ class FieldArray: | |||
:return: Counter, key是label,value是出现次数 | |||
""" | |||
count = Counter() | |||
def cum(cell): | |||
if _is_iterable(cell) and not isinstance(cell, str): | |||
for cell_ in cell: | |||
cum(cell_) | |||
else: | |||
count[cell] += 1 | |||
for cell in self.content: | |||
cum(cell) | |||
return count | |||
def _after_process(self, new_contents, inplace): | |||
""" | |||
当调用处理函数之后,决定是否要替换field。 | |||
@@ -378,14 +409,14 @@ class FieldArray: | |||
self.is_input = self.is_input | |||
self.is_target = self.is_input | |||
except SetInputOrTargetException as e: | |||
print("The newly generated field cannot be set as input or target.") | |||
logger.error("The newly generated field cannot be set as input or target.") | |||
raise e | |||
return self | |||
else: | |||
return new_contents | |||
def _get_ele_type_and_dim(cell:Any, dim=0): | |||
def _get_ele_type_and_dim(cell: Any, dim=0): | |||
""" | |||
识别cell的类别与dimension的数量 | |||
@@ -401,13 +432,13 @@ def _get_ele_type_and_dim(cell:Any, dim=0): | |||
elif isinstance(cell, list): | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i,j in res]) | |||
dims = set([j for i,j in res]) | |||
if len(types)>1: | |||
types = set([i for i, j in res]) | |||
dims = set([j for i, j in res]) | |||
if len(types) > 1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types)==0: | |||
elif len(types) == 0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims)>1: | |||
if len(dims) > 1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
elif isinstance(cell, torch.Tensor): | |||
@@ -418,55 +449,47 @@ def _get_ele_type_and_dim(cell:Any, dim=0): | |||
# 否则需要继续往下iterate | |||
dim += 1 | |||
res = [_get_ele_type_and_dim(cell_i, dim) for cell_i in cell] | |||
types = set([i for i,j in res]) | |||
dims = set([j for i,j in res]) | |||
if len(types)>1: | |||
types = set([i for i, j in res]) | |||
dims = set([j for i, j in res]) | |||
if len(types) > 1: | |||
raise SetInputOrTargetException("Mixed types detected: {}.".format(list(types))) | |||
elif len(types)==0: | |||
elif len(types) == 0: | |||
raise SetInputOrTargetException("Empty value encountered.") | |||
if len(dims)>1: | |||
if len(dims) > 1: | |||
raise SetInputOrTargetException("Mixed dimension detected: {}.".format(list(dims))) | |||
return types.pop(), dims.pop() | |||
else: # 包含tuple, set, dict以及其它的类型 | |||
else: # 包含tuple, set, dict以及其它的类型 | |||
raise SetInputOrTargetException(f"Cannot process type:{type(cell)}.") | |||
def _is_iterable(value): | |||
# 检查是否是iterable的, duck typing | |||
try: | |||
iter(value) | |||
return True | |||
except BaseException as e: | |||
return False | |||
class Padder: | |||
""" | |||
别名::class:`fastNLP.Padder` :class:`fastNLP.core.field.Padder` | |||
所有padder都需要继承这个类,并覆盖__call__方法。 | |||
用于对batch进行padding操作。传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前deepcopy一份。 | |||
.. py:function:: __call__(self, contents, field_name, field_ele_dtype): | |||
""" | |||
def __init__(self, pad_val=0, **kwargs): | |||
""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
:param List[Any] contents: 传入的element是inplace的,即直接修改element可能导致数据变化,建议inplace修改之前 | |||
deepcopy一份。 | |||
:param str, field_name: field的名称。 | |||
:param np.int64,np.float64,np.str,None, field_ele_dtype: 该field的内层元素的类型。如果该field的ignore_type为True,该这个值为None。 | |||
:return: np.array([padded_element]) | |||
""" | |||
def __init__(self, pad_val=0, **kwargs): | |||
""" | |||
self.pad_val = pad_val | |||
def set_pad_val(self, pad_val): | |||
self.pad_val = pad_val | |||
def get_pad_val(self): | |||
return self.pad_val | |||
@abstractmethod | |||
def __call__(self, contents, field_name, field_ele_dtype, dim:int): | |||
def __call__(self, contents, field_name, field_ele_dtype, dim: int): | |||
""" | |||
传入的是List内容。假设有以下的DataSet。 | |||
@@ -512,8 +535,6 @@ class Padder: | |||
class AutoPadder(Padder): | |||
""" | |||
别名::class:`fastNLP.AutoPadder` :class:`fastNLP.core.field.AutoPadder` | |||
根据contents的数据自动判定是否需要做padding。 | |||
1 如果元素类型(元素类型是指field中最里层元素的数据类型, 可以通过FieldArray.dtype查看,比如['This', 'is', ...]的元素类 | |||
@@ -533,23 +554,24 @@ class AutoPadder(Padder): | |||
3 其它情况不进行处理,返回一个np.array类型。 | |||
""" | |||
def __init__(self, pad_val=0): | |||
super().__init__(pad_val=pad_val) | |||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||
if field_ele_dtype: | |||
if dim>3: | |||
if dim > 3: | |||
return np.array(contents) | |||
if isinstance(field_ele_dtype, type) and \ | |||
(issubclass(field_ele_dtype, np.number) or issubclass(field_ele_dtype, Number)): | |||
if dim==0: | |||
if dim == 0: | |||
array = np.array(contents, dtype=field_ele_dtype) | |||
elif dim==1: | |||
elif dim == 1: | |||
max_len = max(map(len, contents)) | |||
array = np.full((len(contents), max_len), self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
array[i, :len(content_i)] = content_i | |||
elif dim==2: | |||
elif dim == 2: | |||
max_len = max(map(len, contents)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in contents]) | |||
@@ -559,20 +581,21 @@ class AutoPadder(Padder): | |||
array[i, j, :len(content_ii)] = content_ii | |||
else: | |||
shape = np.shape(contents) | |||
if len(shape)==4: # 说明各dimension是相同的大小 | |||
if len(shape) == 4: # 说明各dimension是相同的大小 | |||
array = np.array(contents, dtype=field_ele_dtype) | |||
else: | |||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return array | |||
elif str(field_ele_dtype).startswith('torch'): | |||
if dim==0: | |||
if dim == 0: | |||
tensor = torch.tensor(contents).to(field_ele_dtype) | |||
elif dim==1: | |||
elif dim == 1: | |||
max_len = max(map(len, contents)) | |||
tensor = torch.full((len(contents), max_len), fill_value=self.pad_val, dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
tensor[i, :len(content_i)] = torch.tensor(content_i) | |||
elif dim==2: | |||
tensor[i, :len(content_i)] = content_i.clone().detach() | |||
elif dim == 2: | |||
max_len = max(map(len, contents)) | |||
max_word_len = max([max([len(content_ii) for content_ii in content_i]) for | |||
content_i in contents]) | |||
@@ -580,18 +603,21 @@ class AutoPadder(Padder): | |||
dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
for j, content_ii in enumerate(content_i): | |||
tensor[i, j, :len(content_ii)] = torch.tensor(content_ii) | |||
tensor[i, j, :len(content_ii)] = content_ii.clone().detach() | |||
else: | |||
shapes = set([np.shape(content_i) for content_i in contents]) | |||
if len(shapes)>1: | |||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
if len(shapes) > 1: | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
shape = shapes.pop() | |||
if len(shape)==3: | |||
tensor = torch.full([len(contents)]+list(shape), fill_value=self.pad_val, dtype=field_ele_dtype) | |||
if len(shape) == 3: | |||
tensor = torch.full([len(contents)] + list(shape), fill_value=self.pad_val, | |||
dtype=field_ele_dtype) | |||
for i, content_i in enumerate(contents): | |||
tensor[i] = torch.tensor(content_i, dtype=field_ele_dtype) | |||
tensor[i] = content_i.clone().detach().to(field_ele_dtype) | |||
else: | |||
raise RuntimeError(f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
raise RuntimeError( | |||
f"Field:{field_name} has 3 dimensions, every sample should have the same shape.") | |||
return tensor | |||
else: | |||
return np.array(contents) # 不进行任何操作 | |||
@@ -601,8 +627,6 @@ class AutoPadder(Padder): | |||
class EngChar2DPadder(Padder): | |||
""" | |||
别名::class:`fastNLP.EngChar2DPadder` :class:`fastNLP.core.field.EngChar2DPadder` | |||
用于为英语执行character级别的2D padding操作。对应的field内容应该类似[['T', 'h', 'i', 's'], ['a'], ['d', 'e', 'm', 'o']], | |||
但这个Padder只能处理index为int的情况。 | |||
@@ -622,7 +646,7 @@ class EngChar2DPadder(Padder): | |||
dataset.set_padder('chars', padder) # chars这个field的设置为了EnChar2DPadder | |||
""" | |||
def __init__(self, pad_val=0, pad_length=0): | |||
""" | |||
:param pad_val: int, pad的位置使用该index | |||
@@ -630,9 +654,9 @@ class EngChar2DPadder(Padder): | |||
都pad或截取到该长度. | |||
""" | |||
super().__init__(pad_val=pad_val) | |||
self.pad_length = pad_length | |||
def __call__(self, contents, field_name, field_ele_dtype, dim): | |||
""" | |||
期望输入类似于 | |||
@@ -651,7 +675,7 @@ class EngChar2DPadder(Padder): | |||
raise TypeError('dtype of Field:{} should be np.int64 or np.float64 to do 2D padding, get {}.'.format( | |||
field_name, field_ele_dtype | |||
)) | |||
assert dim==2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." | |||
assert dim == 2, f"Field:{field_name} has {dim}, EngChar2DPadder only supports input with 2 dimensions." | |||
if self.pad_length < 1: | |||
max_char_length = max([max(len(char_lst) for char_lst in word_lst) for word_lst in contents]) | |||
else: | |||
@@ -659,12 +683,12 @@ class EngChar2DPadder(Padder): | |||
max_sent_length = max(len(word_lst) for word_lst in contents) | |||
batch_size = len(contents) | |||
dtype = type(contents[0][0][0]) | |||
padded_array = np.full((batch_size, max_sent_length, max_char_length), fill_value=self.pad_val, | |||
dtype=dtype) | |||
for b_idx, word_lst in enumerate(contents): | |||
for c_idx, char_lst in enumerate(word_lst): | |||
chars = char_lst[:max_char_length] | |||
padded_array[b_idx, c_idx, :len(chars)] = chars | |||
return padded_array |
@@ -3,15 +3,16 @@ instance 模块实现了Instance 类在fastNLP中对应sample。一个sample可 | |||
便于理解的例子可以参考文档 :doc:`fastNLP.core.dataset` 中的表格 | |||
""" | |||
__all__ = [ | |||
"Instance" | |||
] | |||
from .utils import pretty_table_printer | |||
class Instance(object): | |||
""" | |||
别名::class:`fastNLP.Instance` :class:`fastNLP.core.instance.Instance` | |||
Instance是fastNLP中对应一个sample的类。每个sample在fastNLP中是一个Instance对象。 | |||
Instance一般与 :class:`~fastNLP.DataSet` 一起使用, Instance的初始化如下面的Example所示:: | |||
@@ -22,11 +23,11 @@ class Instance(object): | |||
>>>ins.add_field("field_3", [3, 3, 3]) | |||
>>>ins = Instance(**{'x1': 1, 'x2':np.zeros((3, 4))}) | |||
""" | |||
def __init__(self, **fields): | |||
self.fields = fields | |||
def add_field(self, field_name, field): | |||
""" | |||
向Instance中增加一个field | |||
@@ -35,18 +36,23 @@ class Instance(object): | |||
:param Any field: 新增field的内容 | |||
""" | |||
self.fields[field_name] = field | |||
def items(self): | |||
""" | |||
返回一个迭代器,迭代器返回两个内容,第一个内容是field_name, 第二个内容是field_value | |||
:return: 一个迭代器 | |||
""" | |||
return self.fields.items() | |||
def __getitem__(self, name): | |||
if name in self.fields: | |||
return self.fields[name] | |||
else: | |||
raise KeyError("{} not found".format(name)) | |||
def __setitem__(self, name, field): | |||
return self.add_field(name, field) | |||
def __repr__(self): | |||
s = '\'' | |||
return "{" + ",\n".join( | |||
"\'" + field_name + "\': " + str(self.fields[field_name]) + \ | |||
f" type={(str(type(self.fields[field_name]))).split(s)[1]}" for field_name in self.fields) + "}" | |||
return str(pretty_table_printer(self)) |
@@ -20,7 +20,6 @@ from collections import defaultdict | |||
import torch | |||
import torch.nn.functional as F | |||
from ..core.const import Const | |||
from .utils import _CheckError | |||
from .utils import _CheckRes | |||
from .utils import _build_args | |||
@@ -28,6 +27,7 @@ from .utils import _check_arg_dict_list | |||
from .utils import _check_function_or_method | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
from ..core.const import Const | |||
class LossBase(object): | |||
@@ -166,8 +166,6 @@ class LossBase(object): | |||
class LossFunc(LossBase): | |||
""" | |||
别名::class:`fastNLP.LossFunc` :class:`fastNLP.core.losses.LossFunc` | |||
提供给用户使用自定义损失函数的类 | |||
:param func: 用户自行定义的损失函数,应当为一个函数或者callable(func)为True的ojbect | |||
@@ -199,13 +197,15 @@ class LossFunc(LossBase): | |||
class CrossEntropyLoss(LossBase): | |||
""" | |||
别名::class:`fastNLP.CrossEntropyLoss` :class:`fastNLP.core.losses.CrossEntropyLoss` | |||
交叉熵损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。。 | |||
:param seq_len: 句子的长度, 长度之外的token不会计算loss。 | |||
:param int class_in_dim: 在序列标注的场景中,pred可能的shape为(batch_size, max_len, num_classes) | |||
或(batch_size, num_classes, max_len), CrossEntropyLoss需要知道哪一维是class的维度以计算loss。如果为-1,就根据pred的第 | |||
二维是否等于target的第二维来判断是否需要交换pred的第二维和第三维,因为target的第二维是length的维度,如果这一维度上和pred相等, | |||
那么pred可能第二维也是长度维(存在误判的可能,如果有误判的情况,请显示设置该值)。其它大于0的值则认为该维度是class的维度。 | |||
:param padding_idx: padding的index,在计算loss时将忽略target中标号为padding_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
@@ -216,21 +216,25 @@ class CrossEntropyLoss(LossBase): | |||
""" | |||
def __init__(self, pred=None, target=None, seq_len=None, padding_idx=-100, reduction='mean'): | |||
def __init__(self, pred=None, target=None, seq_len=None, class_in_dim=-1, padding_idx=-100, reduction='mean'): | |||
super(CrossEntropyLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target, seq_len=seq_len) | |||
self.padding_idx = padding_idx | |||
assert reduction in ('mean', 'sum', 'none') | |||
self.reduction = reduction | |||
self.class_in_dim = class_in_dim | |||
def get_loss(self, pred, target, seq_len=None): | |||
if pred.dim() > 2: | |||
if pred.size(1) != target.size(1): | |||
pred = pred.transpose(1, 2) | |||
if self.class_in_dim == -1: | |||
if pred.size(1) != target.size(1): # 有可能顺序替换了 | |||
pred = pred.transpose(1, 2) | |||
else: | |||
pred = pred.tranpose(-1, pred) | |||
pred = pred.reshape(-1, pred.size(-1)) | |||
target = target.reshape(-1) | |||
if seq_len is not None: | |||
mask = seq_len_to_mask(seq_len).reshape(-1).eq(0) | |||
if seq_len is not None and target.dim()>1: | |||
mask = seq_len_to_mask(seq_len, max_len=target.size(1)).reshape(-1).eq(0) | |||
target = target.masked_fill(mask, self.padding_idx) | |||
return F.cross_entropy(input=pred, target=target, | |||
@@ -239,8 +243,6 @@ class CrossEntropyLoss(LossBase): | |||
class L1Loss(LossBase): | |||
""" | |||
别名::class:`fastNLP.L1Loss` :class:`fastNLP.core.losses.L1Loss` | |||
L1损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -261,8 +263,6 @@ class L1Loss(LossBase): | |||
class BCELoss(LossBase): | |||
""" | |||
别名::class:`fastNLP.BCELoss` :class:`fastNLP.core.losses.BCELoss` | |||
二分类交叉熵损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
@@ -282,18 +282,18 @@ class BCELoss(LossBase): | |||
class NLLLoss(LossBase): | |||
""" | |||
别名::class:`fastNLP.NLLLoss` :class:`fastNLP.core.losses.NLLLoss` | |||
负对数似然损失函数 | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
""" | |||
def __init__(self, pred=None, target=None, ignore_idx=-100, reduction='mean'): | |||
""" | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param ignore_idx: ignore的index,在计算loss时将忽略target中标号为ignore_idx的内容, 可以通过该值代替 | |||
传入seq_len. | |||
:param str reduction: 支持 `mean` ,`sum` 和 `none` . | |||
""" | |||
super(NLLLoss, self).__init__() | |||
self._init_param_map(pred=pred, target=target) | |||
assert reduction in ('mean', 'sum', 'none') | |||
@@ -306,14 +306,14 @@ class NLLLoss(LossBase): | |||
class LossInForward(LossBase): | |||
""" | |||
别名::class:`fastNLP.LossInForward` :class:`fastNLP.core.losses.LossInForward` | |||
从forward()函数返回结果中获取loss | |||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||
""" | |||
def __init__(self, loss_key=Const.LOSS): | |||
""" | |||
:param str loss_key: 在forward函数中loss的键名,默认为loss | |||
""" | |||
super().__init__() | |||
if not isinstance(loss_key, str): | |||
raise TypeError(f"Only str allowed for loss_key, got {type(loss_key)}.") | |||
@@ -10,7 +10,10 @@ __all__ = [ | |||
] | |||
import inspect | |||
import warnings | |||
from abc import abstractmethod | |||
from collections import defaultdict | |||
from typing import Union | |||
import numpy as np | |||
import torch | |||
@@ -22,7 +25,6 @@ from .utils import _check_arg_dict_list | |||
from .utils import _get_func_signature | |||
from .utils import seq_len_to_mask | |||
from .vocabulary import Vocabulary | |||
from abc import abstractmethod | |||
class MetricBase(object): | |||
@@ -118,6 +120,7 @@ class MetricBase(object): | |||
def __init__(self): | |||
self._param_map = {} # key is param in function, value is input param. | |||
self._checked = False | |||
self._metric_name = self.__class__.__name__ | |||
@property | |||
def param_map(self): | |||
@@ -135,6 +138,24 @@ class MetricBase(object): | |||
@abstractmethod | |||
def get_metric(self, reset=True): | |||
raise NotImplemented | |||
def set_metric_name(self, name:str): | |||
""" | |||
设置metric的名称,默认是Metric的class name. | |||
:param str name: | |||
:return: self | |||
""" | |||
self._metric_name = name | |||
return self | |||
def get_metric_name(self): | |||
""" | |||
返回metric的名称 | |||
:return: | |||
""" | |||
return self._metric_name | |||
def _init_param_map(self, key_map=None, **kwargs): | |||
"""检查key_map和其他参数map,并将这些映射关系添加到self._param_map | |||
@@ -275,17 +296,16 @@ class MetricBase(object): | |||
class AccuracyMetric(MetricBase): | |||
""" | |||
别名::class:`fastNLP.AccuracyMetric` :class:`fastNLP.core.metrics.AccuracyMetric` | |||
准确率Metric(其它的Metric参见 :doc:`fastNLP.core.metrics` ) | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||
""" | |||
def __init__(self, pred=None, target=None, seq_len=None): | |||
""" | |||
:param pred: 参数映射表中 `pred` 的映射关系,None表示映射关系为 `pred` -> `pred` | |||
:param target: 参数映射表中 `target` 的映射关系,None表示映射关系为 `target` -> `target` | |||
:param seq_len: 参数映射表中 `seq_len` 的映射关系,None表示映射关系为 `seq_len` -> `seq_len` | |||
""" | |||
super().__init__() | |||
@@ -318,15 +338,18 @@ class AccuracyMetric(MetricBase): | |||
raise TypeError(f"`seq_lens` in {_get_func_signature(self.evaluate)} must be torch.Tensor," | |||
f"got {type(seq_len)}.") | |||
if seq_len is not None: | |||
masks = seq_len_to_mask(seq_len=seq_len) | |||
if seq_len is not None and target.dim()>1: | |||
max_len = target.size(1) | |||
masks = seq_len_to_mask(seq_len=seq_len, max_len=max_len) | |||
else: | |||
masks = None | |||
if pred.size() == target.size(): | |||
if pred.dim() == target.dim(): | |||
pass | |||
elif len(pred.size()) == len(target.size()) + 1: | |||
elif pred.dim() == target.dim() + 1: | |||
pred = pred.argmax(dim=-1) | |||
if seq_len is None and target.dim()>1: | |||
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") | |||
else: | |||
raise RuntimeError(f"In {_get_func_signature(self.evaluate)}, when pred have " | |||
f"size:{pred.size()}, target should have size: {pred.size()} or " | |||
@@ -358,6 +381,7 @@ def _bmes_tag_to_spans(tags, ignore_labels=None): | |||
""" | |||
给定一个tags的lis,比如['S-song', 'B-singer', 'M-singer', 'E-singer', 'S-moive', 'S-actor']。 | |||
返回[('song', (0, 1)), ('singer', (1, 4)), ('moive', (4, 5)), ('actor', (5, 6))] (左闭右开区间) | |||
也可以是单纯的['S', 'B', 'M', 'E', 'B', 'M', 'M',...]序列 | |||
:param tags: List[str], | |||
:param ignore_labels: List[str], 在该list中的label将被忽略 | |||
@@ -473,10 +497,75 @@ def _bio_tag_to_spans(tags, ignore_labels=None): | |||
return [(span[0], (span[1][0], span[1][1] + 1)) for span in spans if span[0] not in ignore_labels] | |||
def _get_encoding_type_from_tag_vocab(tag_vocab:Union[Vocabulary, dict])->str: | |||
""" | |||
给定Vocabulary自动判断是哪种类型的encoding, 支持判断bmes, bioes, bmeso, bio | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
bmes_tag_set = set('bmes') | |||
if tag_set == bmes_tag_set: | |||
return 'bmes' | |||
bio_tag_set = set('bio') | |||
if tag_set == bio_tag_set: | |||
return 'bio' | |||
bmeso_tag_set = set('bmeso') | |||
if tag_set == bmeso_tag_set: | |||
return 'bmeso' | |||
bioes_tag_set = set('bioes') | |||
if tag_set == bioes_tag_set: | |||
return 'bioes' | |||
raise RuntimeError("encoding_type cannot be inferred automatically. Only support " | |||
"'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
def _check_tag_vocab_and_encoding_type(tag_vocab:Union[Vocabulary, dict], encoding_type:str): | |||
""" | |||
检查vocab中的tag是否与encoding_type是匹配的 | |||
:param tag_vocab: 支持传入tag Vocabulary; 或者传入形如{0:"O", 1:"B-tag1"},即index在前,tag在后的dict。 | |||
:param encoding_type: bio, bmes, bioes, bmeso | |||
:return: | |||
""" | |||
tag_set = set() | |||
unk_token = '<unk>' | |||
pad_token = '<pad>' | |||
if isinstance(tag_vocab, Vocabulary): | |||
unk_token = tag_vocab.unknown | |||
pad_token = tag_vocab.padding | |||
tag_vocab = tag_vocab.idx2word | |||
for idx, tag in tag_vocab.items(): | |||
if tag in (unk_token, pad_token): | |||
continue | |||
tag = tag[:1].lower() | |||
tag_set.add(tag) | |||
tags = encoding_type | |||
for tag in tag_set: | |||
assert tag in tags, f"{tag} is not a valid tag in encoding type:{encoding_type}. Please check your " \ | |||
f"encoding_type." | |||
tags = tags.replace(tag, '') # 删除该值 | |||
if tags: # 如果不为空,说明出现了未使用的tag | |||
warnings.warn(f"Tag:{tags} in encoding type:{encoding_type} is not presented in your Vocabulary. Check your " | |||
"encoding_type.") | |||
class SpanFPreRecMetric(MetricBase): | |||
r""" | |||
别名::class:`fastNLP.SpanFPreRecMetric` :class:`fastNLP.core.metrics.SpanFPreRecMetric` | |||
在序列标注问题中,以span的方式计算F, pre, rec. | |||
比如中文Part of speech中,会以character的方式进行标注,句子 `中国在亚洲` 对应的POS可能为(以BMES为例) | |||
['B-NN', 'E-NN', 'S-DET', 'B-NN', 'E-NN']。该metric就是为类似情况下的F1计算。 | |||
@@ -499,34 +588,36 @@ class SpanFPreRecMetric(MetricBase): | |||
'rec-label':xxx, | |||
... | |||
} | |||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'这 | |||
个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个 | |||
label的f1, pre, rec | |||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : | |||
分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type='bio', ignore_labels=None, | |||
def __init__(self, tag_vocab, pred=None, target=None, seq_len=None, encoding_type=None, ignore_labels=None, | |||
only_gross=True, f_type='micro', beta=1): | |||
encoding_type = encoding_type.lower() | |||
r""" | |||
:param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), | |||
在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. | |||
:param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 | |||
:param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 | |||
:param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 | |||
:param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. | |||
:param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label | |||
:param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec | |||
:param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
""" | |||
if not isinstance(tag_vocab, Vocabulary): | |||
raise TypeError("tag_vocab can only be fastNLP.Vocabulary, not {}.".format(type(tag_vocab))) | |||
if f_type not in ('micro', 'macro'): | |||
raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) | |||
self.encoding_type = encoding_type | |||
if encoding_type: | |||
encoding_type = encoding_type.lower() | |||
_check_tag_vocab_and_encoding_type(tag_vocab, encoding_type) | |||
self.encoding_type = encoding_type | |||
else: | |||
self.encoding_type = _get_encoding_type_from_tag_vocab(tag_vocab) | |||
if self.encoding_type == 'bmes': | |||
self.tag_to_span_func = _bmes_tag_to_spans | |||
elif self.encoding_type == 'bio': | |||
@@ -536,7 +627,7 @@ class SpanFPreRecMetric(MetricBase): | |||
elif self.encoding_type == 'bioes': | |||
self.tag_to_span_func = _bioes_tag_to_spans | |||
else: | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso' type.") | |||
raise ValueError("Only support 'bio', 'bmes', 'bmeso', 'bioes' type.") | |||
self.ignore_labels = ignore_labels | |||
self.f_type = f_type | |||
@@ -624,7 +715,7 @@ class SpanFPreRecMetric(MetricBase): | |||
f, pre, rec = self._compute_f_pre_rec(tp, fn, fp) | |||
f_sum += f | |||
pre_sum += pre | |||
rec_sum + rec | |||
rec_sum += rec | |||
if not self.only_gross and tag != '': # tag!=''防止无tag的情况 | |||
f_key = 'f-{}'.format(tag) | |||
pre_key = 'pre-{}'.format(tag) | |||
@@ -738,24 +829,23 @@ def _pred_topk(y_prob, k=1): | |||
class ExtractiveQAMetric(MetricBase): | |||
r""" | |||
别名::class:`fastNLP.ExtractiveQAMetric` :class:`fastNLP.core.metrics.ExtractiveQAMetric` | |||
抽取式QA(如SQuAD)的metric. | |||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||
""" | |||
def __init__(self, pred1=None, pred2=None, target1=None, target2=None, | |||
beta=1, right_open=True, print_predict_stat=False): | |||
r""" | |||
:param pred1: 参数映射表中 `pred1` 的映射关系,None表示映射关系为 `pred1` -> `pred1` | |||
:param pred2: 参数映射表中 `pred2` 的映射关系,None表示映射关系为 `pred2` -> `pred2` | |||
:param target1: 参数映射表中 `target1` 的映射关系,None表示映射关系为 `target1` -> `target1` | |||
:param target2: 参数映射表中 `target2` 的映射关系,None表示映射关系为 `target2` -> `target2` | |||
:param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . | |||
常用为beta=0.5, 1, 2. 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 | |||
:param bool right_open: right_open为true表示start跟end指针指向一个左闭右开区间,为false表示指向一个左闭右闭区间。 | |||
:param bool print_predict_stat: True则输出预测答案是否为空与正确答案是否为空的统计信息, False则不输出 | |||
""" | |||
super(ExtractiveQAMetric, self).__init__() | |||
self._init_param_map(pred1=pred1, pred2=pred2, target1=target1, target2=target2) | |||
@@ -814,8 +904,8 @@ class ExtractiveQAMetric(MetricBase): | |||
if not self.right_open: | |||
e += 1 | |||
te += 1 | |||
if ts == 0 and te == int(not self.right_open): | |||
if s == 0 and e == int(not self.right_open): | |||
if ts == 0 and te == 1: | |||
if s == 0 and e == 1: | |||
self.no_ans_correct += 1 | |||
self.no2no += 1 | |||
else: | |||
@@ -9,21 +9,23 @@ __all__ = [ | |||
"AdamW" | |||
] | |||
import torch | |||
import math | |||
import torch | |||
from torch.optim.optimizer import Optimizer as TorchOptimizer | |||
class Optimizer(object): | |||
""" | |||
别名::class:`fastNLP.Optimizer` :class:`fastNLP.core.optimizer.Optimizer` | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
:param kwargs: additional parameters. | |||
Optimizer | |||
""" | |||
def __init__(self, model_params, **kwargs): | |||
""" | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
:param kwargs: additional parameters. | |||
""" | |||
if model_params is not None and not hasattr(model_params, "__next__"): | |||
raise RuntimeError("model parameters should be a generator, rather than {}.".format(type(model_params))) | |||
self.model_params = model_params | |||
@@ -49,7 +51,7 @@ class NullOptimizer(Optimizer): | |||
super().__init__(None) | |||
def construct_from_pytorch(self, model_params): | |||
pass | |||
return self | |||
def __getattr__(self, item): | |||
def pass_func(*args, **kwargs): | |||
@@ -60,14 +62,15 @@ class NullOptimizer(Optimizer): | |||
class SGD(Optimizer): | |||
""" | |||
别名::class:`fastNLP.SGD` :class:`fastNLP.core.optimizer.SGD` | |||
:param float lr: learning rate. Default: 0.01 | |||
:param float momentum: momentum. Default: 0 | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
SGD | |||
""" | |||
def __init__(self, lr=0.001, momentum=0, model_params=None): | |||
""" | |||
:param float lr: learning rate. Default: 0.01 | |||
:param float momentum: momentum. Default: 0 | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
""" | |||
if not isinstance(lr, float): | |||
raise TypeError("learning rate has to be float.") | |||
super(SGD, self).__init__(model_params, lr=lr, momentum=momentum) | |||
@@ -82,14 +85,18 @@ class SGD(Optimizer): | |||
class Adam(Optimizer): | |||
""" | |||
别名::class:`fastNLP.Adam` :class:`fastNLP.core.optimizer.Adam` | |||
:param float lr: learning rate | |||
:param float weight_decay: | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
""" | |||
def __init__(self, lr=0.001, weight_decay=0, betas=(0.9, 0.999), eps=1e-8, amsgrad=False, model_params=None): | |||
""" | |||
:param float lr: learning rate | |||
:param float weight_decay: | |||
:param eps: | |||
:param amsgrad: | |||
:param model_params: a generator. E.g. ``model.parameters()`` for PyTorch models. | |||
""" | |||
if not isinstance(lr, float): | |||
raise TypeError("learning rate has to be float.") | |||
super(Adam, self).__init__(model_params, lr=lr, betas=betas, eps=eps, amsgrad=amsgrad, | |||
@@ -105,8 +112,6 @@ class Adam(Optimizer): | |||
class AdamW(TorchOptimizer): | |||
r""" | |||
别名::class:`fastNLP.AdamW` :class:`fastNLP.core.optimizer.AdamW` | |||
对AdamW的实现,该实现应该会在pytorch更高版本中出现,https://github.com/pytorch/pytorch/pull/21250。这里提前加入 | |||
.. todo:: | |||
@@ -115,27 +120,28 @@ class AdamW(TorchOptimizer): | |||
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_. | |||
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_. | |||
:param params (iterable): iterable of parameters to optimize or dicts defining | |||
parameter groups | |||
:param lr (float, optional): learning rate (default: 1e-3) | |||
:param betas (Tuple[float, float], optional): coefficients used for computing | |||
running averages of gradient and its square (default: (0.9, 0.99)) | |||
:param eps (float, optional): term added to the denominator to improve | |||
numerical stability (default: 1e-8) | |||
:param weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | |||
(default: False) | |||
.. _Adam\: A Method for Stochastic Optimization: | |||
https://arxiv.org/abs/1412.6980 | |||
.. _Decoupled Weight Decay Regularization: | |||
https://arxiv.org/abs/1711.05101 | |||
.. _On the Convergence of Adam and Beyond: | |||
https://openreview.net/forum?id=ryQu7f-RZ | |||
.. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 | |||
.. _Decoupled Weight Decay Regularization: https://arxiv.org/abs/1711.05101 | |||
.. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ | |||
""" | |||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, | |||
weight_decay=1e-2, amsgrad=False): | |||
""" | |||
:param params (iterable): iterable of parameters to optimize or dicts defining | |||
parameter groups | |||
:param lr (float, optional): learning rate (default: 1e-3) | |||
:param betas (Tuple[float, float], optional): coefficients used for computing | |||
running averages of gradient and its square (default: (0.9, 0.99)) | |||
:param eps (float, optional): term added to the denominator to improve | |||
numerical stability (default: 1e-8) | |||
:param weight_decay (float, optional): weight decay coefficient (default: 1e-2) | |||
algorithm from the paper `On the Convergence of Adam and Beyond`_ | |||
(default: False) | |||
""" | |||
if not 0.0 <= lr: | |||
raise ValueError("Invalid learning rate: {}".format(lr)) | |||
if not 0.0 <= eps: | |||
@@ -1,13 +1,15 @@ | |||
""" | |||
..todo:: | |||
检查这个类是否需要 | |||
""" | |||
"""undocumented""" | |||
__all__ = [ | |||
"Predictor" | |||
] | |||
from collections import defaultdict | |||
import torch | |||
from . import DataSetIter | |||
from . import DataSet | |||
from . import DataSetIter | |||
from . import SequentialSampler | |||
from .utils import _build_args, _move_dict_value_to_device, _get_model_device | |||
@@ -18,18 +20,20 @@ class Predictor(object): | |||
与测试器(Tester)不同的是,predictor不关心模型性能的评价指标,只做inference。 | |||
这是一个fastNLP调用的高级模型包装器。它与Trainer、Tester不共享任何操作。 | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
def __init__(self, network): | |||
""" | |||
:param torch.nn.Module network: 用来完成预测任务的模型 | |||
""" | |||
if not isinstance(network, torch.nn.Module): | |||
raise ValueError( | |||
"Only fastNLP.models.BaseModel or torch.nn,Module is allowed, not {}".format(type(network))) | |||
self.network = network | |||
self.batch_size = 1 | |||
self.batch_output = [] | |||
def predict(self, data: DataSet, seq_len_field_name=None): | |||
"""用已经训练好的模型进行inference. | |||
@@ -41,27 +45,27 @@ class Predictor(object): | |||
raise ValueError("Only Dataset class is allowed, not {}.".format(type(data))) | |||
if seq_len_field_name is not None and seq_len_field_name not in data.field_arrays: | |||
raise ValueError("Field name {} not found in DataSet {}.".format(seq_len_field_name, data)) | |||
prev_training = self.network.training | |||
self.network.eval() | |||
network_device = _get_model_device(self.network) | |||
batch_output = defaultdict(list) | |||
data_iterator = DataSetIter(data, batch_size=self.batch_size, sampler=SequentialSampler(), as_numpy=False) | |||
if hasattr(self.network, "predict"): | |||
predict_func = self.network.predict | |||
else: | |||
predict_func = self.network.forward | |||
with torch.no_grad(): | |||
for batch_x, _ in data_iterator: | |||
_move_dict_value_to_device(batch_x, _, device=network_device) | |||
refined_batch_x = _build_args(predict_func, **batch_x) | |||
prediction = predict_func(**refined_batch_x) | |||
if seq_len_field_name is not None: | |||
seq_lens = batch_x[seq_len_field_name].tolist() | |||
for key, value in prediction.items(): | |||
value = value.cpu().numpy() | |||
if len(value.shape) == 1 or (len(value.shape) == 2 and value.shape[1] == 1): | |||
@@ -74,6 +78,6 @@ class Predictor(object): | |||
batch_output[key].extend(tmp_batch) | |||
else: | |||
batch_output[key].append(value) | |||
self.network.train(prev_training) | |||
return batch_output |
@@ -15,9 +15,6 @@ import numpy as np | |||
class Sampler(object): | |||
""" | |||
别名::class:`fastNLP.Sampler` :class:`fastNLP.core.sampler.Sampler` | |||
`Sampler` 类的基类. 规定以何种顺序取出data中的元素 | |||
子类必须实现 ``__call__`` 方法. 输入 `DataSet` 对象, 返回其中元素的下标序列 | |||
@@ -25,16 +22,14 @@ class Sampler(object): | |||
def __call__(self, data_set): | |||
""" | |||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | |||
:return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | |||
""" | |||
:param DataSet data_set: `DataSet` 对象, 需要Sample的数据 | |||
:return result: list(int) 其中元素的下标序列, ``data_set`` 中元素会按 ``result`` 中顺序取出 | |||
""" | |||
raise NotImplementedError | |||
class SequentialSampler(Sampler): | |||
""" | |||
别名::class:`fastNLP.SequentialSampler` :class:`fastNLP.core.sampler.SequentialSampler` | |||
顺序取出元素的 `Sampler` | |||
""" | |||
@@ -45,8 +40,6 @@ class SequentialSampler(Sampler): | |||
class RandomSampler(Sampler): | |||
""" | |||
别名::class:`fastNLP.RandomSampler` :class:`fastNLP.core.sampler.RandomSampler` | |||
随机化取元素的 `Sampler` | |||
""" | |||
@@ -57,17 +50,17 @@ class RandomSampler(Sampler): | |||
class BucketSampler(Sampler): | |||
""" | |||
别名::class:`fastNLP.BucketSampler` :class:`fastNLP.core.sampler.BucketSampler` | |||
带Bucket的 `Random Sampler`. 可以随机地取出长度相似的元素 | |||
:param int num_buckets: bucket的数量 | |||
:param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 | |||
要显示传递该值 | |||
:param str seq_len_field_name: 对应序列长度的 `field` 的名字 | |||
""" | |||
def __init__(self, num_buckets=10, batch_size=None, seq_len_field_name='seq_len'): | |||
""" | |||
:param int num_buckets: bucket的数量 | |||
:param int batch_size: batch的大小. 默认为None,Trainer在调用BucketSampler时,会将该值正确设置,如果是非Trainer场景使用,需 | |||
要显示传递该值 | |||
:param str seq_len_field_name: 对应序列长度的 `field` 的名字 | |||
""" | |||
self.num_buckets = num_buckets | |||
self.batch_size = batch_size | |||
self.seq_len_field_name = seq_len_field_name | |||
@@ -32,9 +32,16 @@ Tester在验证进行之前会调用model.eval()提示当前进入了evaluation | |||
""" | |||
import time | |||
import torch | |||
import torch.nn as nn | |||
try: | |||
from tqdm.auto import tqdm | |||
except: | |||
from .utils import _pseudo_tqdm as tqdm | |||
from .batch import BatchIter, DataSetIter | |||
from .dataset import DataSet | |||
from .metrics import _prepare_metrics | |||
@@ -47,7 +54,9 @@ from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from ._parallel_utils import _data_parallel_wrapper | |||
from ._parallel_utils import _model_contains_inner_module | |||
from functools import partial | |||
from ._logger import logger | |||
__all__ = [ | |||
"Tester" | |||
@@ -56,36 +65,35 @@ __all__ = [ | |||
class Tester(object): | |||
""" | |||
别名::class:`fastNLP.Tester` :class:`fastNLP.core.tester.Tester` | |||
Tester是在提供数据,模型以及metric的情况下进行性能测试的类。需要传入模型,数据以及metric进行验证。 | |||
:param ~fastNLP.DataSet data: 需要测试的数据集 | |||
:param torch.nn.module model: 使用的模型 | |||
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | |||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
""" | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1): | |||
super(Tester, self).__init__() | |||
def __init__(self, data, model, metrics, batch_size=16, num_workers=0, device=None, verbose=1, use_tqdm=True): | |||
""" | |||
if not isinstance(data, DataSet): | |||
raise TypeError(f"The type of data must be `fastNLP.DataSet`, got `{type(data)}`.") | |||
:param ~fastNLP.DataSet data: 需要测试的数据集 | |||
:param torch.nn.module model: 使用的模型 | |||
:param ~fastNLP.core.metrics.MetricBase,List[~fastNLP.core.metrics.MetricBase] metrics: 测试时使用的metrics | |||
:param int batch_size: evaluation时使用的batch_size有多大。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中,可见的第一个GPU中,可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
如果模型是通过predict()进行预测的话,那么将不能使用多卡(DataParallel)进行验证,只会使用第一张卡上的模型。 | |||
:param int verbose: 如果为0不输出任何信息; 如果为1,打印出验证结果。 | |||
:param bool use_tqdm: 是否使用tqdm来显示测试进度; 如果为False,则不会显示任何内容。 | |||
""" | |||
super(Tester, self).__init__() | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be `torch.nn.Module`, got `{type(model)}`.") | |||
@@ -95,6 +103,8 @@ class Tester(object): | |||
self._model = _move_model_to_device(model, device=device) | |||
self.batch_size = batch_size | |||
self.verbose = verbose | |||
self.use_tqdm = use_tqdm | |||
self.logger = logger | |||
if isinstance(data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
@@ -106,19 +116,22 @@ class Tester(object): | |||
# check predict | |||
if (hasattr(self._model, 'predict') and callable(self._model.predict)) or \ | |||
(isinstance(self._model, nn.DataParallel) and hasattr(self._model.module, 'predict') and | |||
callable(self._model.module.predict)): | |||
(_model_contains_inner_module(self._model) and hasattr(self._model.module, 'predict') and | |||
callable(self._model.module.predict)): | |||
if isinstance(self._model, nn.DataParallel): | |||
self._predict_func_wrapper = partial(_data_parallel_wrapper('predict', | |||
self._model.device_ids, | |||
self._model.output_device), | |||
network=self._model.module) | |||
self._predict_func = self._model.module.predict # 用于匹配参数 | |||
elif isinstance(self._model, nn.parallel.DistributedDataParallel): | |||
self._predict_func = self._model.module.predict | |||
self._predict_func_wrapper = self._model.module.predict # 用于调用 | |||
else: | |||
self._predict_func = self._model.predict | |||
self._predict_func_wrapper = self._model.predict | |||
else: | |||
if isinstance(self._model, nn.DataParallel): | |||
if _model_contains_inner_module(model): | |||
self._predict_func_wrapper = self._model.forward | |||
self._predict_func = self._model.module.forward | |||
else: | |||
@@ -126,10 +139,9 @@ class Tester(object): | |||
self._predict_func_wrapper = self._model.forward | |||
def test(self): | |||
"""开始进行验证,并返回验证结果。 | |||
r"""开始进行验证,并返回验证结果。 | |||
:return Dict[Dict] : dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。 | |||
一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | |||
:return Dict[Dict]: dict的二层嵌套结构,dict的第一层是metric的名称; 第二层是这个metric的指标。一个AccuracyMetric的例子为{'AccuracyMetric': {'acc': 1.0}}。 | |||
""" | |||
# turn on the testing mode; clean up the history | |||
self._model_device = _get_model_device(self._model) | |||
@@ -139,21 +151,39 @@ class Tester(object): | |||
eval_results = {} | |||
try: | |||
with torch.no_grad(): | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||
f"must be `dict`, got {type(pred_dict)}.") | |||
if not self.use_tqdm: | |||
from .utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
with inner_tqdm(total=len(data_iterator), leave=False, dynamic_ncols=True) as pbar: | |||
pbar.set_description_str(desc="Test") | |||
start_time = time.time() | |||
for batch_x, batch_y in data_iterator: | |||
_move_dict_value_to_device(batch_x, batch_y, device=self._model_device) | |||
pred_dict = self._data_forward(self._predict_func, batch_x) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} " | |||
f"must be `dict`, got {type(pred_dict)}.") | |||
for metric in self.metrics: | |||
metric(pred_dict, batch_y) | |||
if self.use_tqdm: | |||
pbar.update() | |||
for metric in self.metrics: | |||
metric(pred_dict, batch_y) | |||
for metric in self.metrics: | |||
eval_result = metric.get_metric() | |||
if not isinstance(eval_result, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " | |||
f"`dict`, got {type(eval_result)}") | |||
metric_name = metric.__class__.__name__ | |||
eval_results[metric_name] = eval_result | |||
eval_result = metric.get_metric() | |||
if not isinstance(eval_result, dict): | |||
raise TypeError(f"The return value of {_get_func_signature(metric.get_metric)} must be " | |||
f"`dict`, got {type(eval_result)}") | |||
metric_name = metric.get_metric_name() | |||
eval_results[metric_name] = eval_result | |||
pbar.close() | |||
end_time = time.time() | |||
test_str = f'Evaluate data in {round(end_time - start_time, 2)} seconds!' | |||
# pbar.write(test_str) | |||
self.logger.info(test_str) | |||
except _CheckError as e: | |||
prev_func_signature = _get_func_signature(self._predict_func) | |||
_check_loss_evaluate(prev_func_signature=prev_func_signature, func_signature=e.func_signature, | |||
@@ -161,7 +191,7 @@ class Tester(object): | |||
dataset=self.data, check_level=0) | |||
if self.verbose >= 1: | |||
print("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||
logger.info("[tester] \n{}".format(self._format_eval_results(eval_results))) | |||
self._mode(network, is_test=False) | |||
return eval_results | |||
@@ -336,7 +336,7 @@ except: | |||
import warnings | |||
from .batch import DataSetIter, BatchIter | |||
from .callback import CallbackManager, CallbackException | |||
from .callback import CallbackManager, CallbackException, Callback | |||
from .dataset import DataSet | |||
from .losses import _prepare_losser | |||
from .metrics import _prepare_metrics | |||
@@ -352,12 +352,11 @@ from .utils import _move_dict_value_to_device | |||
from .utils import _get_func_signature | |||
from .utils import _get_model_device | |||
from .utils import _move_model_to_device | |||
from ._parallel_utils import _model_contains_inner_module | |||
from ._logger import logger | |||
class Trainer(object): | |||
""" | |||
别名::class:`fastNLP.Trainer` :class:`fastNLP.core.trainer.Trainer` | |||
Trainer在fastNLP中用于组织单任务的训练过程,可以避免用户在不同训练任务中重复撰写 | |||
(1) epoch循环; | |||
(2) 将数据分成不同的Batch; | |||
@@ -366,87 +365,84 @@ class Trainer(object): | |||
(5) 保存获得更好验证性能的模型等。 | |||
详细的介绍参见 :doc:`fastNLP.core.trainer` | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param nn.modules model: 待训练的模型 | |||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||
:param int n_epochs: 需要优化迭代多少次。 | |||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 | |||
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , | |||
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 | |||
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 | |||
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, | |||
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 | |||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
:param str,None save_path: 将模型保存路径。如果为None,则不保存模型。如果dev_data为None,则保存最后一次迭代的模型。 | |||
保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||
可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
已知可能会出现的问题:Adagrad优化器可能无法正常使用这个参数,请手动管理模型位置。 | |||
:param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | |||
通过callback机制实现。 可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||
:param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, | |||
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | |||
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | |||
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | |||
""" | |||
def __init__(self, train_data, model, optimizer=None, loss=None, | |||
batch_size=32, sampler=None, drop_last=False, update_every=1, | |||
num_workers=0, n_epochs=10, print_every=5, | |||
dev_data=None, metrics=None, metric_key=None, | |||
validate_every=-1, save_path=None, use_tqdm=True, device=None, prefetch=False, | |||
callbacks=None, check_code_level=0): | |||
if prefetch and num_workers==0: | |||
num_workers = 1 | |||
if prefetch: | |||
warnings.warn("prefetch is deprecated, will be removed in version 0.5.0, please use num_workers instead.") | |||
validate_every=-1, save_path=None, use_tqdm=True, device=None, | |||
callbacks=None, check_code_level=0, **kwargs): | |||
""" | |||
:param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 | |||
:param nn.modules model: 待训练的模型 | |||
:param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 | |||
:param int batch_size: 训练和验证的时候的batch大小。 | |||
:param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` | |||
:param sampler: Batch数据生成的顺序, :class:`~fastNLP.Sampler` 类型。如果为None,默认使用 :class:`~fastNLP.RandomSampler` | |||
:param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch | |||
:param num_workers: int, 有多少个线程来进行数据pad处理。 | |||
:param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 | |||
会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 | |||
:param int n_epochs: 需要优化迭代多少次。 | |||
:param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 | |||
:param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 | |||
:param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , | |||
也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 | |||
如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, | |||
则保存当前模型。Metric种类详见 :doc:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 | |||
:param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, | |||
比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 | |||
要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 | |||
明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 | |||
:param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 | |||
:param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 | |||
最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 | |||
:param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 | |||
:param str,int,torch.device,list(int) device: 将模型load到哪个设备。默认为None,即Trainer不对模型 | |||
的计算位置进行管理。支持以下的输入: | |||
1. str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, | |||
可见的第二个GPU中; | |||
2. torch.device:将模型装载到torch.device上。 | |||
3. int: 将使用device_id为该值的gpu进行训练 | |||
4. list(int):如果多于1个device,将使用torch.nn.DataParallel包裹model, 并使用传入的device。 | |||
5. None. 为None则不对模型进行任何处理,如果传入的model为torch.nn.DataParallel该值必须为None。 | |||
已知可能会出现的问题:Adagrad优化器可能无法正常使用这个参数,请手动管理模型位置。 | |||
:param list(callbacks) callbacks: 用于在train过程中起调节作用的回调函数。比如early stop,negative sampling等可以 | |||
通过callback机制实现。 可使用的callback参见 :doc:`callback模块 <fastNLP.core.callback>` | |||
:param int check_code_level: 模型检查等级. -1: 不进行检查; 0: 仅出现错误时停止; 1: 如果有field没有被使用, | |||
报告警告信息; 2: 有任何field没有被使用都报错. 检查的原理是通过使用很小的batch(默认2个sample)来运行代码,但是 | |||
这个过程理论上不会修改任何参数,只是会检查能否运行。但如果(1)模型中存在将batch_size写为某个固定值的情况; | |||
(2)模型中存在累加前向计算次数的,可能会多计算1次。以上情况建议将check_code_level设置为-1。 | |||
""" | |||
super(Trainer, self).__init__() | |||
if not isinstance(model, nn.Module): | |||
raise TypeError(f"The type of model must be torch.nn.Module, got {type(model)}.") | |||
# check metrics and dev_data | |||
if (not metrics) and dev_data is not None: | |||
raise ValueError("No metric for dev_data evaluation.") | |||
if metrics and (dev_data is None): | |||
raise ValueError("No dev_data for evaluations, pass dev_data or set metrics to None. ") | |||
# check update every | |||
assert update_every >= 1, "update_every must be no less than 1." | |||
self.update_every = int(update_every) | |||
# check save_path | |||
if not (save_path is None or isinstance(save_path, str)): | |||
raise ValueError("save_path can only be None or `str`.") | |||
# prepare evaluate | |||
metrics = _prepare_metrics(metrics) | |||
# parse metric_key | |||
# increase_better is True. It means the exp result gets better if the indicator increases. | |||
# It is true by default. | |||
@@ -458,30 +454,69 @@ class Trainer(object): | |||
self.metric_key = None | |||
# prepare loss | |||
losser = _prepare_losser(loss) | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, Sampler): | |||
raise ValueError("The type of sampler should be fastNLP.BaseSampler, got {}.".format(type(sampler))) | |||
if sampler is None: | |||
sampler = RandomSampler() | |||
elif hasattr(sampler, 'set_batch_size'): | |||
sampler.set_batch_size(batch_size) | |||
if isinstance(train_data, BatchIter): | |||
if sampler is not None: | |||
warnings.warn("sampler is ignored when train_data is a BatchIter.") | |||
if num_workers>0: | |||
warnings.warn("num_workers is ignored when train_data is BatchIter.") | |||
if drop_last: | |||
warnings.warn("drop_last is ignored when train_data is BatchIter.") | |||
if isinstance(model, nn.parallel.DistributedDataParallel): # 如果是分布式的 | |||
# device为None | |||
if device is not None: | |||
warnings.warn("device is ignored when model is nn.parallel.DistributedDataParallel.") | |||
device = None | |||
# Sampler要是分布式的 | |||
if sampler is None: | |||
sampler = torch.utils.data.DistributedSampler(train_data) | |||
elif not isinstance(sampler, torch.utils.data.DistributedSampler): | |||
raise TypeError("When using nn.parallel.DistributedDataParallel, " | |||
"sampler must be None or torch.utils.data.DistributedSampler.") | |||
# 不能保存模型 | |||
if save_path: | |||
raise RuntimeError("Saving model in Distributed situation is not allowed right now.") | |||
else: | |||
# sampler check | |||
if sampler is not None and not isinstance(sampler, (Sampler, torch.utils.data.Sampler)): | |||
raise ValueError(f"The type of sampler should be fastNLP.BaseSampler or pytorch's Sampler, got {type(sampler)}") | |||
if sampler is None: | |||
sampler = RandomSampler() | |||
elif hasattr(sampler, 'set_batch_size'): | |||
sampler.set_batch_size(batch_size) | |||
if isinstance(train_data, DataSet): | |||
self.data_iterator = DataSetIter( | |||
dataset=train_data, batch_size=batch_size, num_workers=num_workers, sampler=sampler, drop_last=drop_last) | |||
elif isinstance(train_data, BatchIter): | |||
self.data_iterator = train_data | |||
train_data = train_data.dataset | |||
else: | |||
raise TypeError("train_data type {} not support".format(type(train_data))) | |||
if check_code_level > -1 and isinstance(self.data_iterator, DataSetIter): | |||
_check_code(dataset=train_data, model=model, losser=losser, metrics=metrics, dev_data=dev_data, | |||
metric_key=self.metric_key, check_level=check_code_level, | |||
batch_size=min(batch_size, DEFAULT_CHECK_BATCH_SIZE)) | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的代码 | |||
self.model = _move_model_to_device(model, device=device) | |||
if _model_contains_inner_module(self.model): | |||
self._forward_func = self.model.module.forward | |||
else: | |||
self._forward_func = self.model.forward | |||
if check_code_level > -1: | |||
# _check_code 是 fastNLP 帮助你检查代码是否正确的方法 。如果你在错误栈中看到这行注释,请认真检查你的field名与模型的输入 | |||
# 名是否匹配 | |||
dev_dataset = dev_data | |||
if isinstance(dev_data, BatchIter): | |||
dev_dataset = None | |||
warnings.warn("dev_data is of BatchIter type, ignore validation checking.") | |||
check_batch_size = min(batch_size, DEFAULT_CHECK_BATCH_SIZE) | |||
if isinstance(self.model, nn.DataParallel): | |||
_num_devices = len(self.model.device_ids) | |||
if batch_size//_num_devices>1: # 如果多卡是每个卡可以分多个数据的,则用每个卡给两个sample | |||
check_batch_size = max(len(self.model.device_ids)*2, check_batch_size) | |||
else: | |||
check_batch_size = max(len(self.model.device_ids), check_batch_size) | |||
_check_code(dataset=train_data, model=self.model, losser=losser, forward_func=self._forward_func, metrics=metrics, | |||
dev_data=dev_dataset, metric_key=self.metric_key, check_level=check_code_level, | |||
batch_size=check_batch_size) | |||
self.train_data = train_data | |||
self.dev_data = dev_data # If None, No validation. | |||
@@ -496,8 +531,7 @@ class Trainer(object): | |||
self.best_dev_epoch = None | |||
self.best_dev_step = None | |||
self.best_dev_perf = None | |||
self.n_steps = (len(self.train_data) // self.batch_size + int( | |||
len(self.train_data) % self.batch_size != 0)) * int(drop_last==0) * self.n_epochs | |||
self.n_steps = len(self.data_iterator) * self.n_epochs | |||
if isinstance(optimizer, torch.optim.Optimizer): | |||
self.optimizer = optimizer | |||
@@ -507,22 +541,32 @@ class Trainer(object): | |||
self.optimizer = torch.optim.Adam(self.model.parameters(), lr=4e-3) | |||
else: | |||
raise TypeError("optimizer can only be torch.optim.Optimizer type, not {}.".format(type(optimizer))) | |||
self.logger = logger | |||
self.use_tqdm = use_tqdm | |||
if 'test_use_tqdm' in kwargs: | |||
self.test_use_tqdm = kwargs.get('test_use_tqdm') | |||
else: | |||
self.test_use_tqdm = self.use_tqdm | |||
self.pbar = None | |||
self.print_every = abs(self.print_every) | |||
self.kwargs = kwargs | |||
if self.dev_data is not None: | |||
self.tester = Tester(model=self.model, | |||
data=self.dev_data, | |||
metrics=self.metrics, | |||
batch_size=self.batch_size, | |||
batch_size=kwargs.get("dev_batch_size", self.batch_size), | |||
device=None, # 由上面的部分处理device | |||
verbose=0) | |||
verbose=0, | |||
use_tqdm=self.test_use_tqdm) | |||
self.step = 0 | |||
self.start_time = None # start timestamp | |||
if isinstance(callbacks, Callback): | |||
callbacks = [callbacks] | |||
self.callback_manager = CallbackManager(env={"trainer": self}, | |||
callbacks=callbacks) | |||
@@ -548,7 +592,7 @@ class Trainer(object): | |||
""" | |||
results = {} | |||
if self.n_epochs <= 0: | |||
print(f"training epoch is {self.n_epochs}, nothing was done.") | |||
self.logger.info(f"training epoch is {self.n_epochs}, nothing was done.") | |||
results['seconds'] = 0. | |||
return results | |||
try: | |||
@@ -557,8 +601,8 @@ class Trainer(object): | |||
self._load_best_model = load_best_model | |||
self.start_time = str(datetime.now().strftime('%Y-%m-%d-%H-%M-%S')) | |||
start_time = time.time() | |||
print("training epochs started " + self.start_time, flush=True) | |||
self.logger.info("training epochs started " + self.start_time) | |||
try: | |||
self.callback_manager.on_train_begin() | |||
self._train() | |||
@@ -571,11 +615,11 @@ class Trainer(object): | |||
raise e | |||
elif on_exception == 'raise': | |||
raise e | |||
if self.dev_data is not None and self.best_dev_perf is not None: | |||
print( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step) + | |||
self.tester._format_eval_results(self.best_dev_perf), ) | |||
self.logger.info( | |||
"\nIn Epoch:{}/Step:{}, got best dev performance:".format(self.best_dev_epoch, self.best_dev_step)) | |||
self.logger.info(self.tester._format_eval_results(self.best_dev_perf)) | |||
results['best_eval'] = self.best_dev_perf | |||
results['best_epoch'] = self.best_dev_epoch | |||
results['best_step'] = self.best_dev_step | |||
@@ -583,27 +627,23 @@ class Trainer(object): | |||
model_name = "best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time]) | |||
load_succeed = self._load_model(self.model, model_name) | |||
if load_succeed: | |||
print("Reloaded the best model.") | |||
self.logger.info("Reloaded the best model.") | |||
else: | |||
print("Fail to reload best model.") | |||
self.logger.info("Fail to reload best model.") | |||
finally: | |||
pass | |||
results['seconds'] = round(time.time() - start_time, 2) | |||
return results | |||
def _train(self): | |||
if not self.use_tqdm: | |||
from fastNLP.core.utils import _pseudo_tqdm as inner_tqdm | |||
from .utils import _pseudo_tqdm as inner_tqdm | |||
else: | |||
inner_tqdm = tqdm | |||
self.step = 0 | |||
self.epoch = 0 | |||
start = time.time() | |||
if isinstance(self.model, nn.DataParallel): | |||
self._forward_func = self.model.module.forward | |||
else: | |||
self._forward_func = self.model.forward | |||
with inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True) as pbar: | |||
self.pbar = pbar | |||
avg_loss = 0 | |||
@@ -621,21 +661,21 @@ class Trainer(object): | |||
# negative sampling; replace unknown; re-weight batch_y | |||
self.callback_manager.on_batch_begin(batch_x, batch_y, indices) | |||
prediction = self._data_forward(self.model, batch_x) | |||
# edit prediction | |||
self.callback_manager.on_loss_begin(batch_y, prediction) | |||
loss = self._compute_loss(prediction, batch_y).mean() | |||
avg_loss += loss.item() | |||
loss = loss / self.update_every | |||
# Is loss NaN or inf? requires_grad = False | |||
self.callback_manager.on_backward_begin(loss) | |||
self._grad_backward(loss) | |||
self.callback_manager.on_backward_end() | |||
self._update() | |||
self.callback_manager.on_step_end() | |||
if self.step % self.print_every == 0: | |||
avg_loss = float(avg_loss) / self.print_every | |||
if self.use_tqdm: | |||
@@ -649,36 +689,36 @@ class Trainer(object): | |||
pbar.set_postfix_str(print_output) | |||
avg_loss = 0 | |||
self.callback_manager.on_batch_end() | |||
if ((self.validate_every > 0 and self.step % self.validate_every == 0) or | |||
(self.validate_every < 0 and self.step % len(data_iterator) == 0)) \ | |||
and self.dev_data is not None: | |||
eval_res = self._do_validation(epoch=epoch, step=self.step) | |||
eval_str = "Evaluation at Epoch {}/{}. Step:{}/{}. ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) + \ | |||
self.tester._format_eval_results(eval_res) | |||
pbar.write(eval_str + '\n') | |||
eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step, | |||
self.n_steps) | |||
# pbar.write(eval_str + '\n') | |||
self.logger.info(eval_str) | |||
self.logger.info(self.tester._format_eval_results(eval_res)+'\n') | |||
# ================= mini-batch end ==================== # | |||
# lr decay; early stopping | |||
self.callback_manager.on_epoch_end() | |||
# =============== epochs end =================== # | |||
pbar.close() | |||
self.pbar = None | |||
# ============ tqdm end ============== # | |||
def _do_validation(self, epoch, step): | |||
self.callback_manager.on_valid_begin() | |||
res = self.tester.test() | |||
is_better_eval = False | |||
if self._better_eval_result(res): | |||
if self.save_path is not None: | |||
self._save_model(self.model, | |||
"best_" + "_".join([self.model.__class__.__name__, self.metric_key, self.start_time])) | |||
elif self._load_best_model: | |||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.named_parameters()} | |||
self._best_model_states = {name: param.cpu().clone() for name, param in self.model.state_dict().items()} | |||
self.best_dev_perf = res | |||
self.best_dev_epoch = epoch | |||
self.best_dev_step = step | |||
@@ -686,7 +726,7 @@ class Trainer(object): | |||
# get validation results; adjust optimizer | |||
self.callback_manager.on_valid_end(res, self.metric_key, self.optimizer, is_better_eval) | |||
return res | |||
def _mode(self, model, is_test=False): | |||
"""Train mode or Test mode. This is for PyTorch currently. | |||
@@ -698,14 +738,14 @@ class Trainer(object): | |||
model.eval() | |||
else: | |||
model.train() | |||
def _update(self): | |||
"""Perform weight update on a model. | |||
""" | |||
if self.step % self.update_every == 0: | |||
self.optimizer.step() | |||
def _data_forward(self, network, x): | |||
x = _build_args(self._forward_func, **x) | |||
y = network(**x) | |||
@@ -713,7 +753,7 @@ class Trainer(object): | |||
raise TypeError( | |||
f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") | |||
return y | |||
def _grad_backward(self, loss): | |||
"""Compute gradient with link rules. | |||
@@ -724,7 +764,7 @@ class Trainer(object): | |||
if (self.step-1) % self.update_every == 0: | |||
self.model.zero_grad() | |||
loss.backward() | |||
def _compute_loss(self, predict, truth): | |||
"""Compute loss given prediction and ground truth. | |||
@@ -733,7 +773,7 @@ class Trainer(object): | |||
:return: a scalar | |||
""" | |||
return self.losser(predict, truth) | |||
def _save_model(self, model, model_name, only_param=False): | |||
""" 存储不含有显卡信息的state_dict或model | |||
:param model: | |||
@@ -745,7 +785,7 @@ class Trainer(object): | |||
model_path = os.path.join(self.save_path, model_name) | |||
if not os.path.exists(self.save_path): | |||
os.makedirs(self.save_path, exist_ok=True) | |||
if isinstance(model, nn.DataParallel): | |||
if _model_contains_inner_module(model): | |||
model = model.module | |||
if only_param: | |||
state_dict = model.state_dict() | |||
@@ -756,7 +796,7 @@ class Trainer(object): | |||
model.cpu() | |||
torch.save(model, model_path) | |||
model.to(self._model_device) | |||
def _load_model(self, model, model_name, only_param=False): | |||
# 返回bool值指示是否成功reload模型 | |||
if self.save_path is not None: | |||
@@ -765,7 +805,7 @@ class Trainer(object): | |||
states = torch.load(model_path) | |||
else: | |||
states = torch.load(model_path).state_dict() | |||
if isinstance(model, nn.DataParallel): | |||
if _model_contains_inner_module(model): | |||
model.module.load_state_dict(states) | |||
else: | |||
model.load_state_dict(states) | |||
@@ -774,7 +814,7 @@ class Trainer(object): | |||
else: | |||
return False | |||
return True | |||
def _better_eval_result(self, metrics): | |||
"""Check if the current epoch yields better validation results. | |||
@@ -789,17 +829,20 @@ class Trainer(object): | |||
self.best_metric_indicator = indicator_val | |||
else: | |||
if self.increase_better is True: | |||
if indicator_val > self.best_metric_indicator: | |||
if indicator_val >= self.best_metric_indicator: | |||
self.best_metric_indicator = indicator_val | |||
else: | |||
is_better = False | |||
else: | |||
if indicator_val < self.best_metric_indicator: | |||
if indicator_val <= self.best_metric_indicator: | |||
self.best_metric_indicator = indicator_val | |||
else: | |||
is_better = False | |||
return is_better | |||
@property | |||
def is_master(self): | |||
return True | |||
DEFAULT_CHECK_BATCH_SIZE = 2 | |||
DEFAULT_CHECK_NUM_BATCH = 2 | |||
@@ -821,14 +864,15 @@ def _get_value_info(_dict): | |||
strs.append(_str) | |||
return strs | |||
from numbers import Number | |||
from .batch import _to_tensor | |||
def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, | |||
check_level=0): | |||
def _check_code(dataset, model, losser, metrics, forward_func, batch_size=DEFAULT_CHECK_BATCH_SIZE, | |||
dev_data=None, metric_key=None, check_level=0): | |||
# check get_loss 方法 | |||
model_devcie = _get_model_device(model=model) | |||
model_device = _get_model_device(model=model) | |||
def _iter(): | |||
start_idx = 0 | |||
while start_idx<len(dataset): | |||
@@ -849,7 +893,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
start_idx += batch_size | |||
for batch_count, (batch_x, batch_y) in enumerate(_iter()): | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_devcie) | |||
_move_dict_value_to_device(batch_x, batch_y, device=model_device) | |||
# forward check | |||
if batch_count == 0: | |||
info_str = "" | |||
@@ -867,16 +911,12 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
info_str += '\n' | |||
else: | |||
info_str += 'There is no target field.' | |||
print(info_str) | |||
_check_forward_error(forward_func=model.forward, dataset=dataset, | |||
logger.info(info_str) | |||
_check_forward_error(forward_func=forward_func, dataset=dataset, | |||
batch_x=batch_x, check_level=check_level) | |||
if isinstance(model, nn.DataParallel): | |||
forward_func = model.module.forward | |||
else: | |||
forward_func = model.forward | |||
refined_batch_x = _build_args(forward_func, **batch_x) | |||
pred_dict = model(**refined_batch_x) | |||
func_signature = _get_func_signature(model.forward) | |||
func_signature = _get_func_signature(forward_func) | |||
if not isinstance(pred_dict, dict): | |||
raise TypeError(f"The return value of {func_signature} should be `dict`, not `{type(pred_dict)}`.") | |||
@@ -896,7 +936,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
loss.backward() | |||
except _CheckError as e: | |||
# TODO: another error raised if _CheckError caught | |||
pre_func_signature = _get_func_signature(model.forward) | |||
pre_func_signature = _get_func_signature(forward_func) | |||
_check_loss_evaluate(prev_func_signature=pre_func_signature, func_signature=e.func_signature, | |||
check_res=e.check_res, pred_dict=pred_dict, target_dict=batch_y, | |||
dataset=dataset, check_level=check_level) | |||
@@ -906,7 +946,7 @@ def _check_code(dataset, model, losser, metrics, batch_size=DEFAULT_CHECK_BATCH_ | |||
if dev_data is not None: | |||
tester = Tester(data=dev_data[:batch_size * DEFAULT_CHECK_NUM_BATCH], model=model, metrics=metrics, | |||
batch_size=batch_size, verbose=-1) | |||
batch_size=batch_size, verbose=-1, use_tqdm=False) | |||
evaluate_results = tester.test() | |||
_check_eval_results(metrics=evaluate_results, metric_key=metric_key, metric_list=metrics) | |||
@@ -1,9 +1,11 @@ | |||
""" | |||
utils模块实现了 fastNLP 内部和外部所需的很多工具。其中用户可以使用的是 :func:`cache_results` 修饰器。 | |||
""" | |||
__all__ = [ | |||
"cache_results", | |||
"seq_len_to_mask", | |||
"get_seq_len" | |||
] | |||
import _pickle | |||
@@ -11,11 +13,12 @@ import inspect | |||
import os | |||
import warnings | |||
from collections import Counter, namedtuple | |||
import numpy as np | |||
import torch | |||
import torch.nn as nn | |||
from typing import List | |||
from ._logger import logger | |||
from prettytable import PrettyTable | |||
_CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'required', 'all_needed', | |||
'varargs']) | |||
@@ -23,27 +26,27 @@ _CheckRes = namedtuple('_CheckRes', ['missing', 'unused', 'duplicated', 'require | |||
class Option(dict): | |||
"""a dict can treat keys as attributes""" | |||
def __getattr__(self, item): | |||
try: | |||
return self.__getitem__(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __setattr__(self, key, value): | |||
if key.startswith('__') and key.endswith('__'): | |||
raise AttributeError(key) | |||
self.__setitem__(key, value) | |||
def __delattr__(self, item): | |||
try: | |||
self.pop(item) | |||
except KeyError: | |||
raise AttributeError(item) | |||
def __getstate__(self): | |||
return self | |||
def __setstate__(self, state): | |||
self.update(state) | |||
@@ -62,11 +65,8 @@ def _prepare_cache_filepath(filepath): | |||
os.makedirs(cache_dir) | |||
# TODO 可以保存下缓存时的参数,如果load的时候发现参数不一致,发出警告。 | |||
def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
""" | |||
别名::class:`fastNLP.cache_results` :class:`fastNLP.core.uitls.cache_results` | |||
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | |||
import time | |||
@@ -113,13 +113,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
:param int _verbose: 是否打印cache的信息。 | |||
:return: | |||
""" | |||
def wrapper_(func): | |||
signature = inspect.signature(func) | |||
for key, _ in signature.parameters.items(): | |||
if key in ('_cache_fp', '_refresh', '_verbose'): | |||
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | |||
def wrapper(*args, **kwargs): | |||
if '_cache_fp' in kwargs: | |||
cache_filepath = kwargs.pop('_cache_fp') | |||
@@ -137,16 +137,16 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
else: | |||
verbose = _verbose | |||
refresh_flag = True | |||
if cache_filepath is not None and refresh is False: | |||
# load data | |||
if os.path.exists(cache_filepath): | |||
with open(cache_filepath, 'rb') as f: | |||
results = _pickle.load(f) | |||
if verbose == 1: | |||
print("Read cache from {}.".format(cache_filepath)) | |||
logger.info("Read cache from {}.".format(cache_filepath)) | |||
refresh_flag = False | |||
if refresh_flag: | |||
results = func(*args, **kwargs) | |||
if cache_filepath is not None: | |||
@@ -155,12 +155,12 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1): | |||
_prepare_cache_filepath(cache_filepath) | |||
with open(cache_filepath, 'wb') as f: | |||
_pickle.dump(results, f) | |||
print("Save cache to {}.".format(cache_filepath)) | |||
logger.info("Save cache to {}.".format(cache_filepath)) | |||
return results | |||
return wrapper | |||
return wrapper_ | |||
@@ -189,49 +189,6 @@ def _save_model(model, model_name, save_dir, only_param=False): | |||
model.to(_model_device) | |||
# def save_pickle(obj, pickle_path, file_name): | |||
# """Save an object into a pickle file. | |||
# | |||
# :param obj: an object | |||
# :param pickle_path: str, the directory where the pickle file is to be saved | |||
# :param file_name: str, the name of the pickle file. In general, it should be ended by "pkl". | |||
# """ | |||
# if not os.path.exists(pickle_path): | |||
# os.mkdir(pickle_path) | |||
# print("make dir {} before saving pickle file".format(pickle_path)) | |||
# with open(os.path.join(pickle_path, file_name), "wb") as f: | |||
# _pickle.dump(obj, f) | |||
# print("{} saved in {}".format(file_name, pickle_path)) | |||
# | |||
# | |||
# def load_pickle(pickle_path, file_name): | |||
# """Load an object from a given pickle file. | |||
# | |||
# :param pickle_path: str, the directory where the pickle file is. | |||
# :param file_name: str, the name of the pickle file. | |||
# :return obj: an object stored in the pickle | |||
# """ | |||
# with open(os.path.join(pickle_path, file_name), "rb") as f: | |||
# obj = _pickle.load(f) | |||
# print("{} loaded from {}".format(file_name, pickle_path)) | |||
# return obj | |||
# | |||
# | |||
# def pickle_exist(pickle_path, pickle_name): | |||
# """Check if a given pickle file exists in the directory. | |||
# | |||
# :param pickle_path: the directory of target pickle file | |||
# :param pickle_name: the filename of target pickle file | |||
# :return: True if file exists else False | |||
# """ | |||
# if not os.path.exists(pickle_path): | |||
# os.makedirs(pickle_path) | |||
# file_name = os.path.join(pickle_path, pickle_name) | |||
# if os.path.exists(file_name): | |||
# return True | |||
# else: | |||
# return False | |||
def _move_model_to_device(model, device): | |||
""" | |||
将model移动到device | |||
@@ -254,9 +211,9 @@ def _move_model_to_device(model, device): | |||
:return: torch.nn.DataParallel or torch.nn.Module | |||
""" | |||
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | |||
# if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |||
# raise RuntimeError("model of `torch.nn.parallel.DistributedDataParallel` is not supported right now.") | |||
if device is None: | |||
if isinstance(model, torch.nn.DataParallel): | |||
model.cuda() | |||
@@ -265,10 +222,10 @@ def _move_model_to_device(model, device): | |||
if not torch.cuda.is_available() and ( | |||
device != 'cpu' or (isinstance(device, torch.device) and device.type != 'cpu')): | |||
raise ValueError("There is no usable gpu. set `device` as `cpu` or `None`.") | |||
if isinstance(model, torch.nn.DataParallel): | |||
raise RuntimeError("When model is `torch.nn.DataParallel`, the device has to be `None`.") | |||
if isinstance(device, int): | |||
assert device > -1, "device can only be non-negative integer" | |||
assert torch.cuda.device_count() > device, "Only has {} gpus, cannot use device {}.".format( | |||
@@ -312,7 +269,7 @@ def _get_model_device(model): | |||
""" | |||
# TODO 这个函数存在一定的风险,因为同一个模型可能存在某些parameter不在显卡中,比如BertEmbedding. 或者跨显卡 | |||
assert isinstance(model, nn.Module) | |||
parameters = list(model.parameters()) | |||
if len(parameters) == 0: | |||
return None | |||
@@ -352,7 +309,6 @@ def _map_args(maps: dict, **kwargs): | |||
output.update({name: val}) | |||
for keys in maps.keys(): | |||
if keys not in output.keys(): | |||
# TODO: add UNUSED warning. | |||
pass | |||
return output | |||
@@ -473,10 +429,10 @@ def _move_dict_value_to_device(*args, device: torch.device, non_blocking=False): | |||
""" | |||
if not torch.cuda.is_available(): | |||
return | |||
if not isinstance(device, torch.device): | |||
raise TypeError(f"device must be `torch.device`, got `{type(device)}`") | |||
for arg in args: | |||
if isinstance(arg, dict): | |||
for key, value in arg.items(): | |||
@@ -491,10 +447,10 @@ class _CheckError(Exception): | |||
_CheckError. Used in losses.LossBase, metrics.MetricBase. | |||
""" | |||
def __init__(self, check_res: _CheckRes, func_signature: str): | |||
errs = [f'Problems occurred when calling `{func_signature}`'] | |||
if check_res.varargs: | |||
errs.append(f"\tvarargs: {check_res.varargs}(Does not support pass positional arguments, please delete it)") | |||
if check_res.missing: | |||
@@ -503,9 +459,9 @@ class _CheckError(Exception): | |||
errs.append(f"\tduplicated param: {check_res.duplicated}") | |||
if check_res.unused: | |||
errs.append(f"\tunused param: {check_res.unused}") | |||
Exception.__init__(self, '\n'.join(errs)) | |||
self.check_res = check_res | |||
self.func_signature = func_signature | |||
@@ -525,7 +481,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
# if check_res.varargs: | |||
# errs.append(f"\tvarargs: *{check_res.varargs}") | |||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||
if check_res.unused: | |||
for _unused in check_res.unused: | |||
if _unused in target_dict: | |||
@@ -536,7 +492,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
unuseds.append(f"\tunused field: {_unused_field}") | |||
if _unused_param: | |||
unuseds.append(f"\tunused param: {_unused_param}") # output from predict or forward | |||
module_name = func_signature.split('.')[0] | |||
if check_res.missing: | |||
errs.append(f"\tmissing param: {check_res.missing}") | |||
@@ -557,7 +513,7 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
mapped_missing.append(_miss) | |||
else: | |||
unmapped_missing.append(_miss) | |||
for _miss in mapped_missing + unmapped_missing: | |||
if _miss in dataset: | |||
suggestions.append(f"Set `{_miss}` as target.") | |||
@@ -570,29 +526,17 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
else: | |||
_tmp = f'Provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
suggestions.append(_tmp) | |||
# for _miss in unmapped_missing: | |||
# if _miss in dataset: | |||
# suggestions.append(f"Set `{_miss}` as target.") | |||
# else: | |||
# _tmp = '' | |||
# if check_res.unused: | |||
# _tmp = f"Specify your assignment for `{input_func_map.get(_miss, _miss)}` when initialize {module_name}." | |||
# if _tmp: | |||
# _tmp += f' Or provide `{_miss}` in DataSet or output of {prev_func_signature}.' | |||
# else: | |||
# _tmp = f'Provide `{_miss}` in output of {prev_func_signature} or DataSet.' | |||
# suggestions.append(_tmp) | |||
if check_res.duplicated: | |||
errs.append(f"\tduplicated param: {check_res.duplicated}.") | |||
suggestions.append(f"Delete {check_res.duplicated} in the output of " | |||
f"{prev_func_signature} or do not set {check_res.duplicated} as targets. ") | |||
if len(errs) > 0: | |||
errs.extend(unuseds) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(unuseds) | |||
if len(errs) > 0: | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
@@ -619,11 +563,11 @@ def _check_loss_evaluate(prev_func_signature: str, func_signature: str, check_re | |||
def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
check_res = _check_arg_dict_list(forward_func, batch_x) | |||
func_signature = _get_func_signature(forward_func) | |||
errs = [] | |||
suggestions = [] | |||
_unused = [] | |||
# if check_res.varargs: | |||
# errs.append(f"\tvarargs: {check_res.varargs}") | |||
# suggestions.append(f"Does not support pass positional arguments, please delete *{check_res.varargs}.") | |||
@@ -644,14 +588,14 @@ def _check_forward_error(forward_func, batch_x, dataset, check_level): | |||
# _tmp += f"Or you might find it in `unused field:`, you can use DataSet.rename_field() to " \ | |||
# f"rename the field in `unused field:`." | |||
suggestions.append(_tmp) | |||
if check_res.unused: | |||
_unused = [f"\tunused field: {check_res.unused}"] | |||
if len(errs) > 0: | |||
errs.extend(_unused) | |||
elif check_level == STRICT_CHECK_LEVEL: | |||
errs.extend(_unused) | |||
if len(errs) > 0: | |||
errs.insert(0, f'Problems occurred when calling {func_signature}') | |||
sugg_str = "" | |||
@@ -699,7 +643,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||
max_len = int(max_len) if max_len else int(seq_len.max()) | |||
broad_cast_seq_len = np.tile(np.arange(max_len), (len(seq_len), 1)) | |||
mask = broad_cast_seq_len < seq_len.reshape(-1, 1) | |||
elif isinstance(seq_len, torch.Tensor): | |||
assert seq_len.dim() == 1, f"seq_len can only have one dimension, got {seq_len.dim() == 1}." | |||
batch_size = seq_len.size(0) | |||
@@ -708,7 +652,7 @@ def seq_len_to_mask(seq_len, max_len=None): | |||
mask = broad_cast_seq_len.lt(seq_len.unsqueeze(1)) | |||
else: | |||
raise TypeError("Only support 1-d numpy.ndarray or 1-d torch.Tensor.") | |||
return mask | |||
@@ -716,25 +660,25 @@ class _pseudo_tqdm: | |||
""" | |||
当无法引入tqdm,或者Trainer中设置use_tqdm为false的时候,用该方法打印数据 | |||
""" | |||
def __init__(self, **kwargs): | |||
pass | |||
self.logger = logger | |||
def write(self, info): | |||
print(info) | |||
self.logger.info(info) | |||
def set_postfix_str(self, info): | |||
print(info) | |||
self.logger.info(info) | |||
def __getattr__(self, item): | |||
def pass_func(*args, **kwargs): | |||
pass | |||
return pass_func | |||
def __enter__(self): | |||
return self | |||
def __exit__(self, exc_type, exc_val, exc_tb): | |||
del self | |||
@@ -788,3 +732,76 @@ def iob2bioes(tags: List[str]) -> List[str]: | |||
else: | |||
raise TypeError("Invalid IOB format.") | |||
return new_tags | |||
def _is_iterable(value): | |||
# 检查是否是iterable的, duck typing | |||
try: | |||
iter(value) | |||
return True | |||
except BaseException as e: | |||
return False | |||
def get_seq_len(words, pad_value=0): | |||
""" | |||
给定batch_size x max_len的words矩阵,返回句子长度 | |||
:param words: batch_size x max_len | |||
:return: (batch_size,) | |||
""" | |||
mask = words.ne(pad_value) | |||
return mask.sum(dim=-1) | |||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | |||
""" | |||
:param dataset_or_ins: 传入一个dataSet或者instance | |||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||
+-----------+-----------+-----------------+ | |||
| field_1 | field_2 | field_3 | | |||
+-----------+-----------+-----------------+ | |||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||
+-----------+-----------+-----------------+ | |||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | |||
""" | |||
x = PrettyTable() | |||
try: | |||
sz = os.get_terminal_size() | |||
column = sz.columns | |||
row = sz.lines | |||
except OSError: | |||
column = 144 | |||
row = 11 | |||
if type(dataset_or_ins).__name__ == "DataSet": | |||
x.field_names = list(dataset_or_ins.field_arrays.keys()) | |||
c_size = len(x.field_names) | |||
for ins in dataset_or_ins: | |||
x.add_row([sub_column(ins[k], column, c_size, k) for k in x.field_names]) | |||
row -= 1 | |||
if row < 0: | |||
x.add_row(["..." for _ in range(c_size)]) | |||
break | |||
elif type(dataset_or_ins).__name__ == "Instance": | |||
x.field_names = list(dataset_or_ins.fields.keys()) | |||
c_size = len(x.field_names) | |||
x.add_row([sub_column(dataset_or_ins[k], column, c_size, k) for k in x.field_names]) | |||
else: | |||
raise Exception("only accept DataSet and Instance") | |||
return x | |||
def sub_column(string: str, c: int, c_size: int, title: str) -> str: | |||
""" | |||
:param string: 要被截断的字符串 | |||
:param c: 命令行列数 | |||
:param c_size: instance或dataset field数 | |||
:param title: 列名 | |||
:return: 对一个过长的列进行截断的结果 | |||
""" | |||
avg = max(int(c / c_size), len(title)) | |||
string = str(string) | |||
if len(string) > avg: | |||
string = string[:(avg - 3)] + "..." | |||
return string |
@@ -1,14 +1,21 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"Vocabulary", | |||
"VocabularyOption", | |||
] | |||
from collections import Counter | |||
from functools import partial | |||
from functools import wraps | |||
from collections import Counter, defaultdict | |||
from ._logger import logger | |||
from .dataset import DataSet | |||
from .utils import Option | |||
from functools import partial | |||
import numpy as np | |||
from .utils import _is_iterable | |||
class VocabularyOption(Option): | |||
@@ -32,7 +39,7 @@ def _check_build_vocab(func): | |||
@wraps(func) # to solve missing docstring | |||
def _wrapper(self, *args, **kwargs): | |||
if self.word2idx is None or self.rebuild is True: | |||
if self._word2idx is None or self.rebuild is True: | |||
self.build_vocab() | |||
return func(self, *args, **kwargs) | |||
@@ -49,8 +56,8 @@ def _check_build_status(func): | |||
if self.rebuild is False: | |||
self.rebuild = True | |||
if self.max_size is not None and len(self.word_count) >= self.max_size: | |||
print("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
logger.info("[Warning] Vocabulary has reached the max size {} when calling {} method. " | |||
"Adding more words may cause unexpected behaviour of Vocabulary. ".format( | |||
self.max_size, func.__name__)) | |||
return func(self, *args, **kwargs) | |||
@@ -59,8 +66,6 @@ def _check_build_status(func): | |||
class Vocabulary(object): | |||
""" | |||
别名::class:`fastNLP.Vocabulary` :class:`fastNLP.core.vocabulary.Vocabulary` | |||
用于构建, 存储和使用 `str` 到 `int` 的一一映射:: | |||
vocab = Vocabulary() | |||
@@ -68,32 +73,52 @@ class Vocabulary(object): | |||
vocab.update(word_list) | |||
vocab["word"] # str to int | |||
vocab.to_word(5) # int to str | |||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<pad>' | |||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<unk>' | |||
""" | |||
def __init__(self, max_size=None, min_freq=None, padding='<pad>', unknown='<unk>'): | |||
""" | |||
:param int max_size: `Vocabulary` 的最大大小, 即能存储词的最大数量 | |||
若为 ``None`` , 则不限制大小. Default: ``None`` | |||
:param int min_freq: 能被记录下的词在文本中的最小出现频率, 应大于或等于 1. | |||
若小于该频率, 词语将被视为 `unknown`. 若为 ``None`` , 所有文本中的词都被记录. Default: ``None`` | |||
:param str optional padding: padding的字符. 如果设置为 ``None`` , | |||
则vocabulary中不考虑padding, 也不计入词表大小,为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<pad>' | |||
:param str optional unknown: unknown的字符,所有未被记录的词在转为 `int` 时将被视为unknown. | |||
如果设置为 ``None`` ,则vocabulary中不考虑unknow, 也不计入词表大小. | |||
为 ``None`` 的情况多在为label建立Vocabulary的情况. | |||
Default: '<unk>' | |||
""" | |||
self.max_size = max_size | |||
self.min_freq = min_freq | |||
self.word_count = Counter() | |||
self.unknown = unknown | |||
self.padding = padding | |||
self.word2idx = None | |||
self.idx2word = None | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
# 用于承载不需要单独创建entry的词语,具体见from_dataset()方法 | |||
self._no_create_word = Counter() | |||
@property | |||
@_check_build_vocab | |||
def word2idx(self): | |||
return self._word2idx | |||
@word2idx.setter | |||
def word2idx(self, value): | |||
self._word2idx = value | |||
@property | |||
@_check_build_vocab | |||
def idx2word(self): | |||
return self._idx2word | |||
@idx2word.setter | |||
def idx2word(self, value): | |||
self._word2idx = value | |||
@_check_build_status | |||
def update(self, word_lst, no_create_entry=False): | |||
"""依次增加序列中词在词典中的出现频率 | |||
@@ -131,11 +156,11 @@ class Vocabulary(object): | |||
""" | |||
在新加入word时,检查_no_create_word的设置。 | |||
:param str, List[str] word: | |||
:param str List[str] word: | |||
:param bool no_create_entry: | |||
:return: | |||
""" | |||
if isinstance(word, str): | |||
if isinstance(word, str) or not _is_iterable(word): | |||
word = [word] | |||
for w in word: | |||
if no_create_entry and self.word_count.get(w, 0) == self._no_create_word.get(w, 0): | |||
@@ -180,36 +205,36 @@ class Vocabulary(object): | |||
但已经记录在词典中的词, 不会改变对应的 `int` | |||
""" | |||
if self.word2idx is None: | |||
self.word2idx = {} | |||
if self._word2idx is None: | |||
self._word2idx = {} | |||
if self.padding is not None: | |||
self.word2idx[self.padding] = len(self.word2idx) | |||
self._word2idx[self.padding] = len(self._word2idx) | |||
if self.unknown is not None: | |||
self.word2idx[self.unknown] = len(self.word2idx) | |||
self._word2idx[self.unknown] = len(self._word2idx) | |||
max_size = min(self.max_size, len(self.word_count)) if self.max_size else None | |||
words = self.word_count.most_common(max_size) | |||
if self.min_freq is not None: | |||
words = filter(lambda kv: kv[1] >= self.min_freq, words) | |||
if self.word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self.word2idx, words) | |||
start_idx = len(self.word2idx) | |||
self.word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
if self._word2idx is not None: | |||
words = filter(lambda kv: kv[0] not in self._word2idx, words) | |||
start_idx = len(self._word2idx) | |||
self._word2idx.update({w: i + start_idx for i, (w, _) in enumerate(words)}) | |||
self.build_reverse_vocab() | |||
self.rebuild = False | |||
return self | |||
def build_reverse_vocab(self): | |||
""" | |||
基于 `word to index` dict, 构建 `index to word` dict. | |||
""" | |||
self.idx2word = {i: w for w, i in self.word2idx.items()} | |||
self._idx2word = {i: w for w, i in self._word2idx.items()} | |||
return self | |||
@_check_build_vocab | |||
def __len__(self): | |||
return len(self.word2idx) | |||
return len(self._word2idx) | |||
@_check_build_vocab | |||
def __contains__(self, item): | |||
@@ -219,7 +244,7 @@ class Vocabulary(object): | |||
:param item: the word | |||
:return: True or False | |||
""" | |||
return item in self.word2idx | |||
return item in self._word2idx | |||
def has_word(self, w): | |||
""" | |||
@@ -241,12 +266,12 @@ class Vocabulary(object): | |||
vocab[w] | |||
""" | |||
if w in self.word2idx: | |||
return self.word2idx[w] | |||
if w in self._word2idx: | |||
return self._word2idx[w] | |||
if self.unknown is not None: | |||
return self.word2idx[self.unknown] | |||
return self._word2idx[self.unknown] | |||
else: | |||
raise ValueError("word {} not in vocabulary".format(w)) | |||
raise ValueError("word `{}` not in vocabulary".format(w)) | |||
@_check_build_vocab | |||
def index_dataset(self, *datasets, field_name, new_field_name=None): | |||
@@ -257,37 +282,47 @@ class Vocabulary(object): | |||
vocab.index_dataset(train_data, dev_data, test_data, field_name='words') | |||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||
:param str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||
目前仅支持 ``str`` , ``List[str]`` , ``List[List[str]]`` | |||
:param str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||
Default: ``None`` | |||
:param list,str field_name: 需要转index的field, 若有多个 DataSet, 每个DataSet都必须有此 field. | |||
目前支持 ``str`` , ``List[str]`` | |||
:param list,str new_field_name: 保存结果的field_name. 若为 ``None`` , 将覆盖原field. | |||
Default: ``None``. | |||
""" | |||
def index_instance(ins): | |||
def index_instance(field): | |||
""" | |||
有几种情况, str, 1d-list, 2d-list | |||
:param ins: | |||
:return: | |||
""" | |||
field = ins[field_name] | |||
if isinstance(field, str): | |||
if isinstance(field, str) or not _is_iterable(field): | |||
return self.to_index(field) | |||
elif isinstance(field, list): | |||
if not isinstance(field[0], list): | |||
else: | |||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||
return [self.to_index(w) for w in field] | |||
else: | |||
if isinstance(field[0][0], list): | |||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
return [[self.to_index(c) for c in w] for w in field] | |||
if new_field_name is None: | |||
new_field_name = field_name | |||
new_field_name = new_field_name or field_name | |||
if type(new_field_name) == type(field_name): | |||
if isinstance(new_field_name, list): | |||
assert len(new_field_name) == len(field_name), "new_field_name should have same number elements with " \ | |||
"field_name." | |||
elif isinstance(new_field_name, str): | |||
field_name = [field_name] | |||
new_field_name = [new_field_name] | |||
else: | |||
raise TypeError("field_name and new_field_name can only be str or List[str].") | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
dataset.apply(index_instance, new_field_name=new_field_name) | |||
for f_n, n_f_n in zip(field_name, new_field_name): | |||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n) | |||
except Exception as e: | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
logger.info("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
raise e | |||
else: | |||
raise RuntimeError("Only DataSet type is allowed.") | |||
@@ -306,9 +341,8 @@ class Vocabulary(object): | |||
:param ~fastNLP.DataSet,List[~fastNLP.DataSet] datasets: 需要转index的一个或多个数据集 | |||
:param str,List[str] field_name: 可为 ``str`` 或 ``List[str]`` . | |||
构建词典所使用的 field(s), 支持一个或多个field | |||
若有多个 DataSet, 每个DataSet都必须有这些field. | |||
目前仅支持的field结构: ``str`` , ``List[str]`` , ``list[List[str]]`` | |||
构建词典所使用的 field(s), 支持一个或多个field,若有多个 DataSet, 每个DataSet都必须有这些field. 目前支持的field结构 | |||
: ``str`` , ``List[str]`` | |||
:param no_create_entry_dataset: 可以传入DataSet, List[DataSet]或者None(默认),该选项用在接下来的模型会使用pretrain | |||
的embedding(包括glove, word2vec, elmo与bert)且会finetune的情况。如果仅使用来自于train的数据建立vocabulary,会导致test与dev | |||
中的数据无法充分利用到来自于预训练embedding的信息,所以在建立词表的时候将test与dev考虑进来会使得最终的结果更好。 | |||
@@ -326,14 +360,14 @@ class Vocabulary(object): | |||
def construct_vocab(ins, no_create_entry=False): | |||
for fn in field_name: | |||
field = ins[fn] | |||
if isinstance(field, str): | |||
if isinstance(field, str) or not _is_iterable(field): | |||
self.add_word(field, no_create_entry=no_create_entry) | |||
elif isinstance(field, (list, np.ndarray)): | |||
if not isinstance(field[0], (list, np.ndarray)): | |||
else: | |||
if isinstance(field[0], str) or not _is_iterable(field[0]): | |||
for word in field: | |||
self.add_word(word, no_create_entry=no_create_entry) | |||
else: | |||
if isinstance(field[0][0], (list, np.ndarray)): | |||
if not isinstance(field[0][0], str) and _is_iterable(field[0][0]): | |||
raise RuntimeError("Only support field with 2 dimensions.") | |||
for words in field: | |||
for word in words: | |||
@@ -343,8 +377,8 @@ class Vocabulary(object): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
dataset.apply(construct_vocab) | |||
except Exception as e: | |||
print("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
except BaseException as e: | |||
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
raise e | |||
else: | |||
raise TypeError("Only DataSet type is allowed.") | |||
@@ -370,7 +404,7 @@ class Vocabulary(object): | |||
def to_index(self, w): | |||
""" | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出``ValueError``:: | |||
将词转为数字. 若词不再词典中被记录, 将视为 unknown, 若 ``unknown=None`` , 将抛出 ``ValueError`` :: | |||
index = vocab.to_index('abc') | |||
# equals to | |||
@@ -389,7 +423,7 @@ class Vocabulary(object): | |||
""" | |||
if self.unknown is None: | |||
return None | |||
return self.word2idx[self.unknown] | |||
return self._word2idx[self.unknown] | |||
@property | |||
@_check_build_vocab | |||
@@ -399,7 +433,7 @@ class Vocabulary(object): | |||
""" | |||
if self.padding is None: | |||
return None | |||
return self.word2idx[self.padding] | |||
return self._word2idx[self.padding] | |||
@_check_build_vocab | |||
def to_word(self, idx): | |||
@@ -409,7 +443,7 @@ class Vocabulary(object): | |||
:param int idx: the index | |||
:return str word: the word | |||
""" | |||
return self.idx2word[idx] | |||
return self._idx2word[idx] | |||
def clear(self): | |||
""" | |||
@@ -418,8 +452,8 @@ class Vocabulary(object): | |||
:return: | |||
""" | |||
self.word_count.clear() | |||
self.word2idx = None | |||
self.idx2word = None | |||
self._word2idx = None | |||
self._idx2word = None | |||
self.rebuild = True | |||
self._no_create_word.clear() | |||
return self | |||
@@ -430,8 +464,8 @@ class Vocabulary(object): | |||
""" | |||
len(self) # make sure vocab has been built | |||
state = self.__dict__.copy() | |||
# no need to pickle idx2word as it can be constructed from word2idx | |||
del state['idx2word'] | |||
# no need to pickle _idx2word as it can be constructed from _word2idx | |||
del state['_idx2word'] | |||
return state | |||
def __setstate__(self, state): | |||
@@ -446,5 +480,5 @@ class Vocabulary(object): | |||
@_check_build_vocab | |||
def __iter__(self): | |||
for word, index in self.word2idx.items(): | |||
for word, index in self._word2idx.items(): | |||
yield word, index |
@@ -0,0 +1,27 @@ | |||
"""undocumented""" | |||
__all__ = [] | |||
import inspect | |||
import sys | |||
def doc_process(m): | |||
for name, obj in inspect.getmembers(m): | |||
if inspect.isclass(obj) or inspect.isfunction(obj): | |||
if obj.__module__ != m.__name__: | |||
if obj.__doc__ is None: | |||
# print(name, obj.__doc__) | |||
pass | |||
else: | |||
module_name = obj.__module__ | |||
while 1: | |||
defined_m = sys.modules[module_name] | |||
if "undocumented" not in defined_m.__doc__ and name in defined_m.__all__: | |||
obj.__doc__ = r"别名 :class:`" + m.__name__ + "." + name + "`" \ | |||
+ " :class:`" + module_name + "." + name + "`\n" + obj.__doc__ | |||
break | |||
module_name = ".".join(module_name.split('.')[:-1]) | |||
if module_name == m.__name__: | |||
# print(name, ": not found defined doc.") | |||
break |
@@ -7,20 +7,25 @@ torch.FloatTensor。所有的embedding都可以使用 `self.num_embedding` 获 | |||
__all__ = [ | |||
"Embedding", | |||
"TokenEmbedding", | |||
"StaticEmbedding", | |||
"ElmoEmbedding", | |||
"BertEmbedding", | |||
"BertWordPieceEncoder", | |||
"StackEmbedding", | |||
"LSTMCharEmbedding", | |||
"CNNCharEmbedding", | |||
"get_embeddings" | |||
"get_embeddings", | |||
] | |||
from .embedding import Embedding | |||
from .embedding import Embedding, TokenEmbedding | |||
from .static_embedding import StaticEmbedding | |||
from .elmo_embedding import ElmoEmbedding | |||
from .bert_embedding import BertEmbedding | |||
from .bert_embedding import BertEmbedding, BertWordPieceEncoder | |||
from .char_embedding import CNNCharEmbedding, LSTMCharEmbedding | |||
from .stack_embedding import StackEmbedding | |||
from .utils import get_embeddings | |||
from .utils import get_embeddings | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -1,3 +1,12 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"BertEmbedding", | |||
"BertWordPieceEncoder" | |||
] | |||
import os | |||
import collections | |||
@@ -8,15 +17,15 @@ import numpy as np | |||
from itertools import chain | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ..io.file_utils import PRETRAINED_BERT_MODEL_DIR | |||
from ..modules.encoder.bert import _WordPieceBertModel, BertModel, BertTokenizer | |||
from .contextual_embedding import ContextualEmbedding | |||
import warnings | |||
from ..core import logger | |||
class BertEmbedding(ContextualEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.BertEmbedding` :class:`fastNLP.embeddings.bert_embedding.BertEmbedding` | |||
使用BERT对words进行编码的Embedding。建议将输入的words长度限制在430以内,而不要使用512(根据预训练模型参数,可能有变化)。这是由于 | |||
预训练的bert模型长度限制为512个token,而因为输入的word是未进行word piece分割的(word piece的分割有BertEmbedding在输入word | |||
时切分),在分割之后长度可能会超过最大长度限制。 | |||
@@ -27,6 +36,7 @@ class BertEmbedding(ContextualEmbedding): | |||
>>> import torch | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings import BertEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed = BertEmbedding(vocab, model_dir_or_name='en-base-uncased', requires_grad=False, layers='4,-2,-1') | |||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||
@@ -37,8 +47,8 @@ class BertEmbedding(ContextualEmbedding): | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。当传入模型所在目录时,目录中应该包含一个词表文件(以.txt作为后缀名), | |||
权重文件(以.bin作为文件后缀名), 配置文件(以.json作为后缀名)。 | |||
:param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,可以以负数 | |||
去索引倒数几层。 | |||
:param str layers: 输出embedding表示来自于哪些层,不同层的结果按照layers中的顺序在最后一维concat起来。以','隔开层数,层的序号是 | |||
从0开始,可以以负数去索引倒数几层。 | |||
:param str pool_method: 因为在bert中,每个word会被表示为多个word pieces, 当获取一个word的表示的时候,怎样从它的word pieces | |||
中计算得到它对应的表示。支持 ``last`` , ``first`` , ``avg`` , ``max``。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
@@ -46,34 +56,40 @@ class BertEmbedding(ContextualEmbedding): | |||
:param bool include_cls_sep: bool,在bert计算句子的表示的时候,需要在前面加上[CLS]和[SEP], 是否在结果中保留这两个内容。 这样 | |||
会使得word embedding的结果比输入的结果长两个token。如果该值为True,则在使用 :class::StackEmbedding 可能会与其它类型的 | |||
embedding长度不匹配。 | |||
:param bool pooled_cls: 返回的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取[CLS]做预测, | |||
一般该值为True。 | |||
:param bool requires_grad: 是否需要gradient以更新Bert的权重。 | |||
:param bool auto_truncate: 当句子words拆分为word pieces长度超过bert最大允许长度(一般为512), 自动截掉拆分后的超过510个 | |||
word pieces后的内容,并将第512个word piece置为[SEP]。超过长度的部分的encode结果直接全部置零。一般仅有只使用[CLS] | |||
来进行分类的任务将auto_truncate置为True。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
pool_method: str='first', word_dropout=0, dropout=0, requires_grad: bool=False, | |||
include_cls_sep: bool=False): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', | |||
pool_method: str = 'first', word_dropout=0, dropout=0, include_cls_sep: bool = False, | |||
pooled_cls=True, requires_grad: bool = True, auto_truncate: bool = False): | |||
super(BertEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
if model_dir_or_name.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _WordBertModel(model_dir=model_dir, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep) | |||
if 'cn' in model_dir_or_name.lower() and pool_method not in ('first', 'last'): | |||
logger.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | |||
" faster speed.") | |||
warnings.warn("For Chinese bert, pooled_method should choose from 'first', 'last' in order to achieve" | |||
" faster speed.") | |||
self._word_sep_index = None | |||
if '[SEP]' in vocab: | |||
self._word_sep_index = vocab['[SEP]'] | |||
self.model = _WordBertModel(model_dir_or_name=model_dir_or_name, vocab=vocab, layers=layers, | |||
pool_method=pool_method, include_cls_sep=include_cls_sep, | |||
pooled_cls=pooled_cls, auto_truncate=auto_truncate, min_freq=2) | |||
self.requires_grad = requires_grad | |||
self._embed_size = len(self.model.layers)*self.model.encoder.hidden_size | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
def _delete_model_weights(self): | |||
del self.model | |||
def forward(self, words): | |||
""" | |||
计算words的bert embedding表示。计算之前会在每句话的开始增加[CLS]在结束增加[SEP], 并根据include_cls_sep判断要不要 | |||
@@ -85,12 +101,32 @@ class BertEmbedding(ContextualEmbedding): | |||
words = self.drop_word(words) | |||
outputs = self._get_sent_reprs(words) | |||
if outputs is not None: | |||
return self.dropout(words) | |||
return self.dropout(outputs) | |||
outputs = self.model(words) | |||
outputs = torch.cat([*outputs], dim=-1) | |||
return self.dropout(outputs) | |||
def drop_word(self, words): | |||
""" | |||
按照设定随机将words设置为unknown_index。 | |||
:param torch.LongTensor words: batch_size x max_len | |||
:return: | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
with torch.no_grad(): | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._word_sep_index) | |||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(0) | |||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._word_sep_index) | |||
return words | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -99,12 +135,12 @@ class BertEmbedding(ContextualEmbedding): | |||
:return: | |||
""" | |||
requires_grads = set([param.requires_grad for name, param in self.named_parameters() | |||
if 'word_pieces_lengths' not in name]) | |||
if 'word_pieces_lengths' not in name]) | |||
if len(requires_grads) == 1: | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
@@ -119,27 +155,26 @@ class BertWordPieceEncoder(nn.Module): | |||
:param str model_dir_or_name: 模型所在目录或者模型的名称。默认值为 ``en-base-uncased`` | |||
:param str layers: 最终结果中的表示。以','隔开层数,可以以负数去索引倒数几层 | |||
:param bool pooled_cls: 返回的句子开头的[CLS]是否使用预训练中的BertPool映射一下,仅在include_cls_sep时有效。如果下游任务只取 | |||
[CLS]做预测,一般该值为True。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
:param bool requires_grad: 是否需要gradient。 | |||
""" | |||
def __init__(self, model_dir_or_name: str='en-base-uncased', layers: str='-1', | |||
requires_grad: bool=False): | |||
def __init__(self, model_dir_or_name: str = 'en-base-uncased', layers: str = '-1', pooled_cls: bool = False, | |||
word_dropout=0, dropout=0, requires_grad: bool = True): | |||
super().__init__() | |||
PRETRAIN_URL = _get_base_url('bert') | |||
if model_dir_or_name in PRETRAINED_BERT_MODEL_DIR: | |||
model_name = PRETRAINED_BERT_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(model_dir_or_name): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _WordPieceBertModel(model_dir=model_dir, layers=layers) | |||
self.model = _WordPieceBertModel(model_dir_or_name=model_dir_or_name, layers=layers, pooled_cls=pooled_cls) | |||
self._sep_index = self.model._sep_index | |||
self._wordpiece_pad_index = self.model._wordpiece_pad_index | |||
self._wordpiece_unk_index = self.model._wordpiece_unknown_index | |||
self._embed_size = len(self.model.layers) * self.model.encoder.hidden_size | |||
self.requires_grad = requires_grad | |||
self.word_dropout = word_dropout | |||
self.dropout_layer = nn.Dropout(dropout) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -151,77 +186,129 @@ class BertWordPieceEncoder(nn.Module): | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
param.requires_grad = value | |||
@property | |||
def embed_size(self): | |||
return self._embed_size | |||
def index_datasets(self, *datasets, field_name): | |||
@property | |||
def embedding_dim(self): | |||
return self._embed_size | |||
@property | |||
def num_embedding(self): | |||
return self.model.encoder.config.vocab_size | |||
def index_datasets(self, *datasets, field_name, add_cls_sep=True): | |||
""" | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input。如果首尾不是 | |||
[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP], 且将word_pieces这一列的pad value设置为了bert的pad value。 | |||
使用bert的tokenizer新生成word_pieces列加入到datasets中,并将他们设置为input,且将word_pieces这一列的pad value设置为了 | |||
bert的pad value。 | |||
:param datasets: DataSet对象 | |||
:param field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||
:param ~fastNLP.DataSet datasets: DataSet对象 | |||
:param str field_name: 基于哪一列的内容生成word_pieces列。这一列中每个数据应该是List[str]的形式。 | |||
:param bool add_cls_sep: 如果首尾不是[CLS]与[SEP]会在首尾额外加入[CLS]与[SEP]。 | |||
:return: | |||
""" | |||
self.model.index_dataset(*datasets, field_name=field_name) | |||
self.model.index_dataset(*datasets, field_name=field_name, add_cls_sep=add_cls_sep) | |||
def forward(self, word_pieces, token_type_ids=None): | |||
""" | |||
计算words的bert embedding表示。传入的words中应该自行包含[CLS]与[SEP]的tag。 | |||
:param words: batch_size x max_len | |||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话 | |||
:param token_type_ids: batch_size x max_len, 用于区分前一句和后一句话. 如果不传入,则自动生成(大部分情况,都不需要输入), | |||
第一个[SEP]及之前为0, 第二个[SEP]及到第一个[SEP]之间为1; 第三个[SEP]及到第二个[SEP]之间为0,依次往后推。 | |||
:return: torch.FloatTensor. batch_size x max_len x (768*len(self.layers)) | |||
""" | |||
with torch.no_grad(): | |||
sep_mask = word_pieces.eq(self._sep_index) # batch_size x max_len | |||
if token_type_ids is None: | |||
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) | |||
token_type_ids = sep_mask_cumsum.fmod(2) | |||
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 | |||
token_type_ids = token_type_ids.eq(0).long() | |||
word_pieces = self.drop_word(word_pieces) | |||
outputs = self.model(word_pieces, token_type_ids) | |||
outputs = torch.cat([*outputs], dim=-1) | |||
return self.dropout_layer(outputs) | |||
def drop_word(self, words): | |||
""" | |||
按照设定随机将words设置为unknown_index。 | |||
return outputs | |||
:param torch.LongTensor words: batch_size x max_len | |||
:return: | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
with torch.no_grad(): | |||
if self._word_sep_index: # 不能drop sep | |||
sep_mask = words.eq(self._wordpiece_unk_index) | |||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(self._wordpiece_pad_index) | |||
mask = pad_mask.__and__(mask) # pad的位置不为unk | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
if self._word_sep_index: | |||
words.masked_fill_(sep_mask, self._wordpiece_unk_index) | |||
return words | |||
class _WordBertModel(nn.Module): | |||
def __init__(self, model_dir:str, vocab:Vocabulary, layers:str='-1', pool_method:str='first', include_cls_sep:bool=False): | |||
def __init__(self, model_dir_or_name: str, vocab: Vocabulary, layers: str = '-1', pool_method: str = 'first', | |||
include_cls_sep: bool = False, pooled_cls: bool = False, auto_truncate: bool = False, min_freq=2): | |||
super().__init__() | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir) | |||
self.encoder = BertModel.from_pretrained(model_dir) | |||
self.tokenzier = BertTokenizer.from_pretrained(model_dir_or_name) | |||
self.encoder = BertModel.from_pretrained(model_dir_or_name) | |||
self._max_position_embeddings = self.encoder.config.max_position_embeddings | |||
# 检查encoder_layer_number是否合理 | |||
encoder_layer_number = len(self.encoder.encoder.layer) | |||
self.layers = list(map(int, layers.split(','))) | |||
for layer in self.layers: | |||
if layer<0: | |||
assert -layer<=encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
if layer < 0: | |||
assert -layer <= encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
else: | |||
assert layer<encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
assert layer < encoder_layer_number, f"The layer index:{layer} is out of scope for " \ | |||
f"a bert model with {encoder_layer_number} layers." | |||
assert pool_method in ('avg', 'max', 'first', 'last') | |||
self.pool_method = pool_method | |||
self.include_cls_sep = include_cls_sep | |||
self.pooled_cls = pooled_cls | |||
self.auto_truncate = auto_truncate | |||
# 将所有vocab中word的wordpiece计算出来, 需要额外考虑[CLS]和[SEP] | |||
print("Start to generating word pieces for word.") | |||
logger.info("Start to generate word pieces for word.") | |||
# 第一步统计出需要的word_piece, 然后创建新的embed和word_piece_vocab, 然后填入值 | |||
word_piece_dict = {'[CLS]':1, '[SEP]':1} # 用到的word_piece以及新增的 | |||
word_piece_dict = {'[CLS]': 1, '[SEP]': 1} # 用到的word_piece以及新增的 | |||
found_count = 0 | |||
self._has_sep_in_vocab = '[SEP]' in vocab # 用来判断传入的数据是否需要生成token_ids | |||
if '[sep]' in vocab: | |||
warnings.warn("Lower cased [sep] detected, it cannot be correctly recognized as [SEP] by BertEmbedding.") | |||
if "[CLS]" in vocab: | |||
warnings.warn("[CLS] detected in your vocabulary. BertEmbedding will add [CSL] and [SEP] to the begin " | |||
"and end of the input automatically, make sure you don't add [CLS] and [SEP] at the begin" | |||
" and end.") | |||
for word, index in vocab: | |||
if index == vocab.padding_idx: # pad是个特殊的符号 | |||
word = '[PAD]' | |||
elif index == vocab.unknown_idx: | |||
word = '[UNK]' | |||
word_pieces = self.tokenzier.wordpiece_tokenizer.tokenize(word) | |||
if len(word_pieces)==1: | |||
if len(word_pieces) == 1: | |||
if not vocab._is_word_no_create_entry(word): # 如果是train中的值, 但是却没有找到 | |||
if index!=vocab.unknown_idx and word_pieces[0]=='[UNK]': # 说明这个词不在原始的word里面 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
if index != vocab.unknown_idx and word_pieces[0] == '[UNK]': # 说明这个词不在原始的word里面 | |||
if vocab.word_count[word] >= min_freq and not vocab._is_word_no_create_entry( | |||
word): # 出现次数大于这个次数才新增 | |||
word_piece_dict[word] = 1 # 新增一个值 | |||
continue | |||
for word_piece in word_pieces: | |||
word_piece_dict[word_piece] = 1 | |||
@@ -242,7 +329,7 @@ class _WordBertModel(nn.Module): | |||
new_word_piece_vocab[token] = len(new_word_piece_vocab) | |||
self.tokenzier._reinit_on_new_vocab(new_word_piece_vocab) | |||
self.encoder.embeddings.word_embeddings = embed | |||
word_to_wordpieces = [] | |||
word_pieces_lengths = [] | |||
for word, index in vocab: | |||
@@ -254,81 +341,126 @@ class _WordBertModel(nn.Module): | |||
word_pieces = self.tokenzier.convert_tokens_to_ids(word_pieces) | |||
word_to_wordpieces.append(word_pieces) | |||
word_pieces_lengths.append(len(word_pieces)) | |||
print("Found(Or seg into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
self._cls_index = self.tokenzier.vocab['[CLS]'] | |||
self._sep_index = self.tokenzier.vocab['[SEP]'] | |||
self._pad_index = vocab.padding_idx | |||
self._word_pad_index = vocab.padding_idx | |||
self._wordpiece_pad_index = self.tokenzier.vocab['[PAD]'] # 需要用于生成word_piece | |||
logger.info("Found(Or segment into word pieces) {} words out of {}.".format(found_count, len(vocab))) | |||
self.word_to_wordpieces = np.array(word_to_wordpieces) | |||
self.word_pieces_lengths = nn.Parameter(torch.LongTensor(word_pieces_lengths), requires_grad=False) | |||
print("Successfully generate word pieces.") | |||
self.register_buffer('word_pieces_lengths', torch.LongTensor(word_pieces_lengths)) | |||
logger.debug("Successfully generate word pieces.") | |||
def forward(self, words): | |||
""" | |||
:param words: torch.LongTensor, batch_size x max_len | |||
:return: num_layers x batch_size x max_len x hidden_size或者num_layers x batch_size x (max_len+2) x hidden_size | |||
""" | |||
batch_size, max_word_len = words.size() | |||
seq_len = words.ne(self._pad_index).sum(dim=-1) | |||
batch_word_pieces_length = self.word_pieces_lengths[words] # batch_size x max_len | |||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) | |||
max_word_piece_length = word_pieces_lengths.max().item() | |||
# +2是由于需要加入[CLS]与[SEP] | |||
word_pieces = words.new_full((batch_size, max_word_piece_length+2), fill_value=self._wordpiece_pad_index) | |||
word_pieces[:, 0].fill_(self._cls_index) | |||
batch_indexes = torch.arange(batch_size).to(words) | |||
word_pieces[batch_indexes, word_pieces_lengths+1] = self._sep_index | |||
attn_masks = torch.zeros_like(word_pieces) | |||
# 1. 获取words的word_pieces的id,以及对应的span范围 | |||
word_indexes = words.tolist() | |||
for i in range(batch_size): | |||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i]])) | |||
word_pieces[i, 1:len(word_pieces_i)+1] = torch.LongTensor(word_pieces_i) | |||
attn_masks[i, :len(word_pieces_i)+2].fill_(1) | |||
# TODO 截掉长度超过的部分。 | |||
with torch.no_grad(): | |||
batch_size, max_word_len = words.size() | |||
word_mask = words.ne(self._word_pad_index) # 为1的地方有word | |||
seq_len = word_mask.sum(dim=-1) | |||
batch_word_pieces_length = self.word_pieces_lengths[words].masked_fill(word_mask.eq(0), | |||
0) # batch_size x max_len | |||
word_pieces_lengths = batch_word_pieces_length.sum(dim=-1) # batch_size | |||
word_piece_length = batch_word_pieces_length.sum(dim=-1).max().item() # 表示word piece的长度(包括padding) | |||
if word_piece_length + 2 > self._max_position_embeddings: | |||
if self.auto_truncate: | |||
word_pieces_lengths = word_pieces_lengths.masked_fill( | |||
word_pieces_lengths + 2 > self._max_position_embeddings, | |||
self._max_position_embeddings - 2) | |||
else: | |||
raise RuntimeError( | |||
"After split words into word pieces, the lengths of word pieces are longer than the " | |||
f"maximum allowed sequence length:{self._max_position_embeddings} of bert. You can set " | |||
f"`auto_truncate=True` for BertEmbedding to automatically truncate overlong input.") | |||
# +2是由于需要加入[CLS]与[SEP] | |||
word_pieces = words.new_full((batch_size, min(word_piece_length + 2, self._max_position_embeddings)), | |||
fill_value=self._wordpiece_pad_index) | |||
attn_masks = torch.zeros_like(word_pieces) | |||
# 1. 获取words的word_pieces的id,以及对应的span范围 | |||
word_indexes = words.cpu().numpy() | |||
for i in range(batch_size): | |||
word_pieces_i = list(chain(*self.word_to_wordpieces[word_indexes[i, :seq_len[i]]])) | |||
if self.auto_truncate and len(word_pieces_i) > self._max_position_embeddings - 2: | |||
word_pieces_i = word_pieces_i[:self._max_position_embeddings - 2] | |||
word_pieces[i, 1:word_pieces_lengths[i] + 1] = torch.LongTensor(word_pieces_i) | |||
attn_masks[i, :word_pieces_lengths[i] + 2].fill_(1) | |||
# 添加[cls]和[sep] | |||
word_pieces[:, 0].fill_(self._cls_index) | |||
batch_indexes = torch.arange(batch_size).to(words) | |||
word_pieces[batch_indexes, word_pieces_lengths + 1] = self._sep_index | |||
if self._has_sep_in_vocab: # 但[SEP]在vocab中出现应该才会需要token_ids | |||
sep_mask = word_pieces.eq(self._sep_index).long() # batch_size x max_len | |||
sep_mask_cumsum = sep_mask.flip(dims=[-1]).cumsum(dim=-1).flip(dims=[-1]) | |||
token_type_ids = sep_mask_cumsum.fmod(2) | |||
if token_type_ids[0, 0].item(): # 如果开头是奇数,则需要flip一下结果,因为需要保证开头为0 | |||
token_type_ids = token_type_ids.eq(0).long() | |||
else: | |||
token_type_ids = torch.zeros_like(word_pieces) | |||
# 2. 获取hidden的结果,根据word_pieces进行对应的pool计算 | |||
# all_outputs: [batch_size x max_len x hidden_size, batch_size x max_len x hidden_size, ...] | |||
bert_outputs, _ = self.encoder(word_pieces, token_type_ids=None, attention_mask=attn_masks, | |||
output_all_encoded_layers=True) | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x max_word_piece_length x hidden_size | |||
bert_outputs, pooled_cls = self.encoder(word_pieces, token_type_ids=token_type_ids, attention_mask=attn_masks, | |||
output_all_encoded_layers=True) | |||
# output_layers = [self.layers] # len(self.layers) x batch_size x real_word_piece_length x hidden_size | |||
if self.include_cls_sep: | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 1 | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len + 2, | |||
bert_outputs[-1].size(-1)) | |||
else: | |||
s_shift = 0 | |||
outputs = bert_outputs[-1].new_zeros(len(self.layers), batch_size, max_word_len, | |||
bert_outputs[-1].size(-1)) | |||
s_shift = 0 | |||
batch_word_pieces_cum_length = batch_word_pieces_length.new_zeros(batch_size, max_word_len + 1) | |||
batch_word_pieces_cum_length[:, 1:] = batch_word_pieces_length.cumsum(dim=-1) # batch_size x max_len | |||
if self.pool_method == 'first': | |||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, :seq_len.max()] | |||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
elif self.pool_method == 'last': | |||
batch_word_pieces_cum_length = batch_word_pieces_cum_length[:, 1:seq_len.max()+1] - 1 | |||
batch_word_pieces_cum_length.masked_fill_(batch_word_pieces_cum_length.ge(word_piece_length), 0) | |||
_batch_indexes = batch_indexes[:, None].expand((batch_size, batch_word_pieces_cum_length.size(1))) | |||
for l_index, l in enumerate(self.layers): | |||
output_layer = bert_outputs[l] | |||
real_word_piece_length = output_layer.size(1) - 2 | |||
if word_piece_length > real_word_piece_length: # 如果实际上是截取出来的 | |||
paddings = output_layer.new_zeros(batch_size, | |||
word_piece_length - real_word_piece_length, | |||
output_layer.size(2)) | |||
output_layer = torch.cat((output_layer, paddings), dim=1).contiguous() | |||
# 从word_piece collapse到word的表示 | |||
truncate_output_layer = output_layer[:, 1:-1] # 删除[CLS]与[SEP] batch_size x len x hidden_size | |||
outputs_seq_len = seq_len + s_shift | |||
if self.pool_method == 'first': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, :seq_len[i]] # 每个word的start位置 | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] # num_layer x batch_size x len x hidden_size | |||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||
elif self.pool_method == 'last': | |||
for i in range(batch_size): | |||
i_word_pieces_cum_length = batch_word_pieces_cum_length[i, 1:seq_len[i]+1] - 1 # 每个word的end | |||
outputs[l_index, i, s_shift:outputs_seq_len[i]] = truncate_output_layer[i, i_word_pieces_cum_length] | |||
tmp = truncate_output_layer[_batch_indexes, batch_word_pieces_cum_length] | |||
tmp = tmp.masked_fill(word_mask[:, :batch_word_pieces_cum_length.size(1), None].eq(0), 0) | |||
outputs[l_index, :, s_shift:batch_word_pieces_cum_length.size(1)+s_shift] = tmp | |||
elif self.pool_method == 'max': | |||
for i in range(batch_size): | |||
for j in range(seq_len[i]): | |||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] | |||
outputs[l_index, i, j+s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2) | |||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||
outputs[l_index, i, j + s_shift], _ = torch.max(truncate_output_layer[i, start:end], dim=-2) | |||
else: | |||
for i in range(batch_size): | |||
for j in range(seq_len[i]): | |||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j+1] | |||
outputs[l_index, i, j+s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | |||
start, end = batch_word_pieces_cum_length[i, j], batch_word_pieces_cum_length[i, j + 1] | |||
outputs[l_index, i, j + s_shift] = torch.mean(truncate_output_layer[i, start:end], dim=-2) | |||
if self.include_cls_sep: | |||
outputs[l_index, :, 0] = output_layer[:, 0] | |||
outputs[l_index, batch_indexes, seq_len+s_shift] = output_layer[batch_indexes, seq_len+s_shift] | |||
if l in (len(bert_outputs) - 1, -1) and self.pooled_cls: | |||
outputs[l_index, :, 0] = pooled_cls | |||
else: | |||
outputs[l_index, :, 0] = output_layer[:, 0] | |||
outputs[l_index, batch_indexes, seq_len + s_shift] = output_layer[batch_indexes, seq_len + s_shift] | |||
# 3. 最终的embedding结果 | |||
return outputs | |||
@@ -3,27 +3,35 @@ | |||
词的index而不需要使用词语中的char的index来获取表达。 | |||
""" | |||
__all__ = [ | |||
"CNNCharEmbedding", | |||
"LSTMCharEmbedding" | |||
] | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
from typing import List | |||
from .static_embedding import StaticEmbedding | |||
from ..modules.encoder.lstm import LSTM | |||
from ..core.vocabulary import Vocabulary | |||
from .embedding import TokenEmbedding | |||
from .utils import _construct_char_vocab_from_vocab | |||
from .utils import get_embeddings | |||
from ..core import logger | |||
class CNNCharEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.CNNCharEmbedding` :class:`fastNLP.embeddings.char_embedding.CNNCharEmbedding` | |||
使用CNN生成character embedding。CNN的结构为, embed(x) -> Dropout(x) -> CNN(x) -> activation(x) -> pool -> fc -> Dropout. | |||
不同的kernel大小的fitler结果是concat起来然后通过一层fully connected layer, 然后输出word的表示。 | |||
Example:: | |||
>>> import torch | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings import CNNCharEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed = CNNCharEmbedding(vocab, embed_size=50) | |||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||
@@ -32,8 +40,8 @@ class CNNCharEmbedding(TokenEmbedding): | |||
>>> # torch.Size([1, 5,50]) | |||
:param vocab: 词表 | |||
:param embed_size: 该word embedding的大小,默认值为50. | |||
:param char_emb_size: character的embed的大小。character是从vocab中生成的。默认值为50. | |||
:param embed_size: 该CNNCharEmbedding的输出维度大小,默认值为50. | |||
:param char_emb_size: character的embed的维度。character是从vocab中生成的。默认值为50. | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param float dropout: 以多大的概率drop分布式表示与char embedding的输出。 | |||
:param filter_nums: filter的数量. 长度需要和kernels一致。默认值为[40, 30, 20]. | |||
@@ -41,17 +49,20 @@ class CNNCharEmbedding(TokenEmbedding): | |||
:param pool_method: character的表示在合成一个表示时所使用的pool方法,支持'avg', 'max'. | |||
:param activation: CNN之后使用的激活方法,支持'relu', 'sigmoid', 'tanh' 或者自定义函数. | |||
:param min_char_freq: character的最少出现次数。默认值为2. | |||
:param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹 | |||
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, | |||
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. | |||
""" | |||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, | |||
dropout:float=0.5, filter_nums: List[int]=(40, 30, 20), kernel_sizes: List[int]=(5, 3, 1), | |||
pool_method: str='max', activation='relu', min_char_freq: int=2): | |||
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0, | |||
dropout: float = 0, filter_nums: List[int] = (40, 30, 20), kernel_sizes: List[int] = (5, 3, 1), | |||
pool_method: str = 'max', activation='relu', min_char_freq: int = 2, pre_train_char_embed: str = None): | |||
super(CNNCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
for kernel in kernel_sizes: | |||
assert kernel % 2 == 1, "Only odd kernel is allowed." | |||
assert pool_method in ('max', 'avg') | |||
self.dropout = nn.Dropout(dropout) | |||
self.pool_method = pool_method | |||
# activation function | |||
if isinstance(activation, str): | |||
@@ -68,32 +79,35 @@ class CNNCharEmbedding(TokenEmbedding): | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
print("Start constructing character vocabulary.") | |||
logger.info("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long), | |||
requires_grad=False) | |||
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long)) | |||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了。修改为不区分pad, 这样所有的<pad>也是同一个embed | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||
self.word_lengths[index] = len(word) | |||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
if pre_train_char_embed: | |||
self.char_embedding = StaticEmbedding(self.char_vocab, model_dir_or_name=pre_train_char_embed) | |||
else: | |||
self.char_embedding = get_embeddings((len(self.char_vocab), char_emb_size)) | |||
self.convs = nn.ModuleList([nn.Conv1d( | |||
char_emb_size, filter_nums[i], kernel_size=kernel_sizes[i], bias=True, padding=kernel_sizes[i] // 2) | |||
for i in range(len(kernel_sizes))]) | |||
self._embed_size = embed_size | |||
self.fc = nn.Linear(sum(filter_nums), embed_size) | |||
self.init_param() | |||
self.reset_parameters() | |||
def forward(self, words): | |||
""" | |||
输入words的index后,生成对应的words的表示。 | |||
@@ -104,14 +118,14 @@ class CNNCharEmbedding(TokenEmbedding): | |||
words = self.drop_word(words) | |||
batch_size, max_len = words.size() | |||
chars = self.words_to_chars_embedding[words] # batch_size x max_len x max_word_len | |||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||
word_lengths = self.word_lengths[words] # batch_size x max_len | |||
max_word_len = word_lengths.max() | |||
chars = chars[:, :, :max_word_len] | |||
# 为1的地方为mask | |||
chars_masks = chars.eq(self.char_pad_index) # batch_size x max_len x max_word_len 如果为0, 说明是padding的位置了 | |||
chars = self.char_embedding(chars) # batch_size x max_len x max_word_len x embed_size | |||
chars = self.dropout(chars) | |||
reshaped_chars = chars.reshape(batch_size*max_len, max_word_len, -1) | |||
reshaped_chars = chars.reshape(batch_size * max_len, max_word_len, -1) | |||
reshaped_chars = reshaped_chars.transpose(1, 2) # B' x E x M | |||
conv_chars = [conv(reshaped_chars).transpose(1, 2).reshape(batch_size, max_len, max_word_len, -1) | |||
for conv in self.convs] | |||
@@ -119,13 +133,13 @@ class CNNCharEmbedding(TokenEmbedding): | |||
conv_chars = self.activation(conv_chars) | |||
if self.pool_method == 'max': | |||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) | |||
chars, _ = torch.max(conv_chars, dim=-2) # batch_size x max_len x sum(filters) | |||
else: | |||
conv_chars = conv_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||
chars = torch.sum(conv_chars, dim=-2)/chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||
chars = torch.sum(conv_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||
chars = self.fc(chars) | |||
return self.dropout(chars) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -141,19 +155,21 @@ class CNNCharEmbedding(TokenEmbedding): | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能加入到requires_grad中 | |||
continue | |||
param.requires_grad = value | |||
def init_param(self): | |||
def reset_parameters(self): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_chars_embedding' in name or 'word_lengths' in name: # 这个不能reset | |||
continue | |||
if param.data.dim()>1: | |||
if 'char_embedding' in name: | |||
continue | |||
if param.data.dim() > 1: | |||
nn.init.xavier_uniform_(param, 1) | |||
else: | |||
nn.init.uniform_(param, -1, 1) | |||
@@ -161,12 +177,13 @@ class CNNCharEmbedding(TokenEmbedding): | |||
class LSTMCharEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.LSTMCharEmbedding` :class:`fastNLP.embeddings.char_embedding.LSTMCharEmbedding` | |||
使用LSTM的方式对character进行encode. embed(x) -> Dropout(x) -> LSTM(x) -> activation(x) -> pool -> Dropout | |||
Example:: | |||
>>> import torch | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings import LSTMCharEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed = LSTMCharEmbedding(vocab, embed_size=50) | |||
>>> words = torch.LongTensor([[vocab.to_index(word) for word in "The whether is good .".split()]]) | |||
@@ -175,8 +192,8 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
>>> # torch.Size([1, 5,50]) | |||
:param vocab: 词表 | |||
:param embed_size: embedding的大小。默认值为50. | |||
:param char_emb_size: character的embedding的大小。默认值为50. | |||
:param embed_size: LSTMCharEmbedding的输出维度。默认值为50. | |||
:param char_emb_size: character的embedding的维度。默认值为50. | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param dropout: 以多大概率drop character embedding的输出以及最终的word的输出。 | |||
:param hidden_size: LSTM的中间hidden的大小,如果为bidirectional的,hidden会除二,默认为50. | |||
@@ -184,17 +201,21 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
:param activation: 激活函数,支持'relu', 'sigmoid', 'tanh', 或者自定义函数. | |||
:param min_char_freq: character的最小出现次数。默认值为2. | |||
:param bidirectional: 是否使用双向的LSTM进行encode。默认值为True。 | |||
:param pre_train_char_embed: 可以有两种方式调用预训练好的character embedding:第一种是传入embedding文件夹 | |||
(文件夹下应该只有一个以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型, | |||
没有的话将自动下载。如果输入为None则使用embedding_dim的维度随机初始化一个embedding. | |||
""" | |||
def __init__(self, vocab: Vocabulary, embed_size: int=50, char_emb_size: int=50, word_dropout:float=0, | |||
dropout:float=0.5, hidden_size=50,pool_method: str='max', activation='relu', min_char_freq: int=2, | |||
bidirectional=True): | |||
super(LSTMCharEmbedding, self).__init__(vocab) | |||
def __init__(self, vocab: Vocabulary, embed_size: int = 50, char_emb_size: int = 50, word_dropout: float = 0, | |||
dropout: float = 0, hidden_size=50, pool_method: str = 'max', activation='relu', | |||
min_char_freq: int = 2, | |||
bidirectional=True, pre_train_char_embed: str = None): | |||
super(LSTMCharEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
assert hidden_size % 2 == 0, "Only even kernel is allowed." | |||
assert pool_method in ('max', 'avg') | |||
self.pool_method = pool_method | |||
self.dropout = nn.Dropout(dropout) | |||
# activation function | |||
if isinstance(activation, str): | |||
if activation.lower() == 'relu': | |||
@@ -210,32 +231,35 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
else: | |||
raise Exception( | |||
"Undefined activation function: choose from: [relu, tanh, sigmoid, or a callable function]") | |||
print("Start constructing character vocabulary.") | |||
logger.info("Start constructing character vocabulary.") | |||
# 建立char的词表 | |||
self.char_vocab = _construct_char_vocab_from_vocab(vocab, min_freq=min_char_freq) | |||
self.char_pad_index = self.char_vocab.padding_idx | |||
print(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
logger.info(f"In total, there are {len(self.char_vocab)} distinct characters.") | |||
# 对vocab进行index | |||
self.max_word_len = max(map(lambda x: len(x[0]), vocab)) | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab), self.max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long), | |||
requires_grad=False) | |||
self.word_lengths = nn.Parameter(torch.zeros(len(vocab)).long(), requires_grad=False) | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab), self.max_word_len), | |||
fill_value=self.char_pad_index, dtype=torch.long)) | |||
self.register_buffer('word_lengths', torch.zeros(len(vocab)).long()) | |||
for word, index in vocab: | |||
# if index!=vocab.padding_idx: # 如果是pad的话,直接就为pad_value了. 修改为不区分pad与否 | |||
self.words_to_chars_embedding[index, :len(word)] = \ | |||
torch.LongTensor([self.char_vocab.to_index(c) for c in word]) | |||
self.word_lengths[index] = len(word) | |||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
# self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
if pre_train_char_embed: | |||
self.char_embedding = StaticEmbedding(self.char_vocab, pre_train_char_embed) | |||
else: | |||
self.char_embedding = nn.Embedding(len(self.char_vocab), char_emb_size) | |||
self.fc = nn.Linear(hidden_size, embed_size) | |||
hidden_size = hidden_size // 2 if bidirectional else hidden_size | |||
self.lstm = LSTM(char_emb_size, hidden_size, bidirectional=bidirectional, batch_first=True) | |||
self._embed_size = embed_size | |||
self.bidirectional = bidirectional | |||
def forward(self, words): | |||
""" | |||
输入words的index后,生成对应的words的表示。 | |||
@@ -257,7 +281,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
char_seq_len = chars_masks.eq(0).sum(dim=-1).reshape(batch_size * max_len) | |||
lstm_chars = self.lstm(reshaped_chars, char_seq_len)[0].reshape(batch_size, max_len, max_word_len, -1) | |||
# B x M x M x H | |||
lstm_chars = self.activation(lstm_chars) | |||
if self.pool_method == 'max': | |||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), float('-inf')) | |||
@@ -265,11 +289,11 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
else: | |||
lstm_chars = lstm_chars.masked_fill(chars_masks.unsqueeze(-1), 0) | |||
chars = torch.sum(lstm_chars, dim=-2) / chars_masks.eq(0).sum(dim=-1, keepdim=True).float() | |||
chars = self.fc(chars) | |||
return self.dropout(chars) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -286,7 +310,7 @@ class LSTMCharEmbedding(TokenEmbedding): | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
@@ -1,20 +1,30 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"ContextualEmbedding" | |||
] | |||
from abc import abstractmethod | |||
import torch | |||
from ..core.vocabulary import Vocabulary | |||
from ..core.dataset import DataSet | |||
from .embedding import TokenEmbedding | |||
from ..core import logger | |||
from ..core.batch import DataSetIter | |||
from ..core.dataset import DataSet | |||
from ..core.sampler import SequentialSampler | |||
from ..core.utils import _move_model_to_device, _get_model_device | |||
from .embedding import TokenEmbedding | |||
from ..core.vocabulary import Vocabulary | |||
class ContextualEmbedding(TokenEmbedding): | |||
def __init__(self, vocab: Vocabulary, word_dropout:float=0.0, dropout:float=0.0): | |||
def __init__(self, vocab: Vocabulary, word_dropout: float = 0.0, dropout: float = 0.0): | |||
super(ContextualEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool=True): | |||
def add_sentence_cache(self, *datasets, batch_size=32, device='cpu', delete_weights: bool = True): | |||
""" | |||
由于动态embedding生成比较耗时,所以可以把每句话embedding缓存下来,这样就不需要每次都运行生成过程。 | |||
@@ -29,14 +39,14 @@ class ContextualEmbedding(TokenEmbedding): | |||
assert isinstance(dataset, DataSet), "Only fastNLP.DataSet object is allowed." | |||
assert 'words' in dataset.get_input_name(), "`words` field has to be set as input." | |||
except Exception as e: | |||
print(f"Exception happens at {index} dataset.") | |||
logger.error(f"Exception happens at {index} dataset.") | |||
raise e | |||
sent_embeds = {} | |||
_move_model_to_device(self, device=device) | |||
device = _get_model_device(self) | |||
pad_index = self._word_vocab.padding_idx | |||
print("Start to calculate sentence representations.") | |||
logger.info("Start to calculate sentence representations.") | |||
with torch.no_grad(): | |||
for index, dataset in enumerate(datasets): | |||
try: | |||
@@ -51,18 +61,18 @@ class ContextualEmbedding(TokenEmbedding): | |||
word_embeds = self(words).detach().cpu().numpy() | |||
for b in range(words.size(0)): | |||
length = seq_len_from_behind[b] | |||
if length==0: | |||
if length == 0: | |||
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b] | |||
else: | |||
sent_embeds[tuple(words_list[b][:seq_len[b]])] = word_embeds[b, :-length] | |||
except Exception as e: | |||
print(f"Exception happens at {index} dataset.") | |||
logger.error(f"Exception happens at {index} dataset.") | |||
raise e | |||
print("Finish calculating sentence representations.") | |||
logger.info("Finish calculating sentence representations.") | |||
self.sent_embeds = sent_embeds | |||
if delete_weights: | |||
self._delete_model_weights() | |||
def _get_sent_reprs(self, words): | |||
""" | |||
获取sentence的表示,如果有缓存,则返回缓存的值; 没有缓存则返回None | |||
@@ -85,12 +95,12 @@ class ContextualEmbedding(TokenEmbedding): | |||
embeds[i, :len(embed)] = torch.FloatTensor(embed).to(words.device) | |||
return embeds | |||
return None | |||
@abstractmethod | |||
def _delete_model_weights(self): | |||
"""删除计算表示的模型以节省资源""" | |||
raise NotImplementedError | |||
def remove_sentence_cache(self): | |||
""" | |||
删除缓存的句子表示. 删除之后如果模型权重没有被删除,将开始使用动态计算权重。 | |||
@@ -1,6 +1,13 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
import os | |||
__all__ = [ | |||
"ElmoEmbedding" | |||
] | |||
import os | |||
import torch | |||
import torch.nn as nn | |||
import torch.nn.functional as F | |||
@@ -8,19 +15,20 @@ import json | |||
import codecs | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import cached_path, _get_base_url, PRETRAINED_ELMO_MODEL_DIR | |||
from ..io.file_utils import cached_path, _get_embedding_url, PRETRAINED_ELMO_MODEL_DIR | |||
from ..modules.encoder._elmo import ElmobiLm, ConvTokenEmbedder | |||
from .contextual_embedding import ContextualEmbedding | |||
from ..core import logger | |||
class ElmoEmbedding(ContextualEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.ElmoEmbedding` :class:`fastNLP.embeddings.elmo_embedding.ElmoEmbedding` | |||
使用ELMo的embedding。初始化之后,只需要传入words就可以得到对应的embedding。当前支持的使用名称初始化的模型有以下的这些(待补充) | |||
Example:: | |||
>>> import torch | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings import ElmoEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> # 使用不同层的concat的结果 | |||
>>> embed = ElmoEmbedding(vocab, model_dir_or_name='en', layers='1,2', requires_grad=False) | |||
@@ -37,7 +45,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的ELMo embedding:第一种是传入ELMo所在文件夹,该文件夹下面应该有两个文件, | |||
其中一个是以json为后缀的配置文件,另一个是以pkl为后缀的权重文件;第二种是传入ELMo版本的名称,将自动查看缓存中是否存在该模型, | |||
没有的话将自动下载并缓存。 | |||
:param layers: str, 指定返回的层数, 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 | |||
:param layers: str, 指定返回的层数(从0开始), 以,隔开不同的层。如果要返回第二层的结果'2', 返回后两层的结果'1,2'。不同的层的结果 | |||
按照这个顺序concat起来,默认为'2'。'mix'会使用可学习的权重结合不同层的表示(权重是否可训练与requires_grad保持一致, | |||
初始化权重对三层结果进行mean-pooling, 可以通过ElmoEmbedding.set_mix_weights_requires_grad()方法只将mix weights设置为可学习。) | |||
:param requires_grad: bool, 该层是否需要gradient, 默认为False. | |||
@@ -46,24 +54,23 @@ class ElmoEmbedding(ContextualEmbedding): | |||
:param cache_word_reprs: 可以选择对word的表示进行cache; 设置为True的话,将在初始化的时候为每个word生成对应的embedding, | |||
并删除character encoder,之后将直接使用cache的embedding。默认为False。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', layers: str = '2', requires_grad: bool = False, | |||
word_dropout=0.0, dropout=0.0, cache_word_reprs: bool = False): | |||
super(ElmoEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
# 根据model_dir_or_name检查是否存在并下载 | |||
if model_dir_or_name.lower() in PRETRAINED_ELMO_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('elmo') | |||
model_name = PRETRAINED_ELMO_MODEL_DIR[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
model_url = _get_embedding_url('elmo', model_dir_or_name.lower()) | |||
model_dir = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_dir = model_dir_or_name | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
self.model = _ElmoModel(model_dir, vocab, cache_word_reprs=cache_word_reprs) | |||
num_layers = self.model.encoder.num_layers | |||
if layers == 'mix': | |||
self.layer_weights = nn.Parameter(torch.zeros(self.model.config['lstm']['n_layers'] + 1), | |||
requires_grad=requires_grad) | |||
@@ -72,22 +79,22 @@ class ElmoEmbedding(ContextualEmbedding): | |||
self._embed_size = self.model.config['lstm']['projection_dim'] * 2 | |||
else: | |||
layers = list(map(int, layers.split(','))) | |||
assert len(layers) > 0, "Must choose one output" | |||
assert len(layers) > 0, "Must choose at least one output, but got None." | |||
for layer in layers: | |||
assert 0 <= layer <= 2, "Layer index should be in range [0, 2]." | |||
assert 0 <= layer <= num_layers, f"Layer index should be in range [0, {num_layers}], but got {layer}." | |||
self.layers = layers | |||
self._get_outputs = self._get_layer_outputs | |||
self._embed_size = len(self.layers) * self.model.config['lstm']['projection_dim'] * 2 | |||
self.requires_grad = requires_grad | |||
def _get_mixed_outputs(self, outputs): | |||
# outputs: num_layers x batch_size x max_len x hidden_size | |||
# return: batch_size x max_len x hidden_size | |||
weights = F.softmax(self.layer_weights + 1 / len(outputs), dim=0).to(outputs) | |||
outputs = torch.einsum('l,lbij->bij', weights, outputs) | |||
return self.gamma.to(outputs) * outputs | |||
def set_mix_weights_requires_grad(self, flag=True): | |||
""" | |||
当初始化ElmoEmbedding时layers被设置为mix时,可以通过调用该方法设置mix weights是否可训练。如果layers不是mix,调用 | |||
@@ -99,15 +106,15 @@ class ElmoEmbedding(ContextualEmbedding): | |||
if hasattr(self, 'layer_weights'): | |||
self.layer_weights.requires_grad = flag | |||
self.gamma.requires_grad = flag | |||
def _get_layer_outputs(self, outputs): | |||
if len(self.layers) == 1: | |||
outputs = outputs[self.layers[0]] | |||
else: | |||
outputs = torch.cat(tuple([*outputs[self.layers]]), dim=-1) | |||
return outputs | |||
def forward(self, words: torch.LongTensor): | |||
""" | |||
计算words的elmo embedding表示。根据elmo文章中介绍的ELMO实际上是有2L+1层结果,但是为了让结果比较容易拆分,token的 | |||
@@ -124,12 +131,12 @@ class ElmoEmbedding(ContextualEmbedding): | |||
outputs = self.model(words) | |||
outputs = self._get_outputs(outputs) | |||
return self.dropout(outputs) | |||
def _delete_model_weights(self): | |||
for name in ['layers', 'model', 'layer_weights', 'gamma']: | |||
if hasattr(self, name): | |||
delattr(self, name) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -143,7 +150,7 @@ class ElmoEmbedding(ContextualEmbedding): | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
@@ -161,7 +168,7 @@ class _ElmoModel(nn.Module): | |||
(4) 设计一个保存token的embedding,允许缓存word的表示。 | |||
""" | |||
def __init__(self, model_dir: str, vocab: Vocabulary = None, cache_word_reprs: bool = False): | |||
super(_ElmoModel, self).__init__() | |||
self.model_dir = model_dir | |||
@@ -182,18 +189,18 @@ class _ElmoModel(nn.Module): | |||
raise Exception(f"Multiple config files(*.json) or weight files(*.hdf5) detected in {model_dir}.") | |||
elif config_count == 0 or weight_count == 0: | |||
raise Exception(f"No config file or weight file found in {model_dir}") | |||
config = json.load(open(os.path.join(model_dir, config_file), 'r')) | |||
with open(os.path.join(model_dir, config_file), 'r') as config_f: | |||
config = json.load(config_f) | |||
self.weight_file = os.path.join(model_dir, weight_file) | |||
self.config = config | |||
OOV_TAG = '<oov>' | |||
PAD_TAG = '<pad>' | |||
BOS_TAG = '<bos>' | |||
EOS_TAG = '<eos>' | |||
BOW_TAG = '<bow>' | |||
EOW_TAG = '<eow>' | |||
# For the model trained with character-based word encoder. | |||
char_lexicon = {} | |||
with codecs.open(os.path.join(model_dir, 'char.dic'), 'r', encoding='utf-8') as fpi: | |||
@@ -203,29 +210,29 @@ class _ElmoModel(nn.Module): | |||
tokens.insert(0, '\u3000') | |||
token, i = tokens | |||
char_lexicon[token] = int(i) | |||
# 做一些sanity check | |||
for special_word in [PAD_TAG, OOV_TAG, BOW_TAG, EOW_TAG]: | |||
assert special_word in char_lexicon, f"{special_word} not found in char.dic." | |||
# 从vocab中构建char_vocab | |||
char_vocab = Vocabulary(unknown=OOV_TAG, padding=PAD_TAG) | |||
# 需要保证<bow>与<eow>在里面 | |||
char_vocab.add_word_lst([BOW_TAG, EOW_TAG, BOS_TAG, EOS_TAG]) | |||
for word, index in vocab: | |||
char_vocab.add_word_lst(list(word)) | |||
self.bos_index, self.eos_index, self._pad_index = len(vocab), len(vocab) + 1, vocab.padding_idx | |||
# 根据char_lexicon调整, 多设置一位,是预留给word padding的(该位置的char表示为全0表示) | |||
char_emb_layer = nn.Embedding(len(char_vocab) + 1, int(config['char_cnn']['embedding']['dim']), | |||
padding_idx=len(char_vocab)) | |||
# 读入预训练权重 这里的elmo_model 包含char_cnn和 lstm 的 state_dict | |||
elmo_model = torch.load(os.path.join(self.model_dir, weight_file), map_location='cpu') | |||
char_embed_weights = elmo_model["char_cnn"]['char_emb_layer.weight'] | |||
found_char_count = 0 | |||
for char, index in char_vocab: # 调整character embedding | |||
if char in char_lexicon: | |||
@@ -234,15 +241,13 @@ class _ElmoModel(nn.Module): | |||
else: | |||
index_in_pre = char_lexicon[OOV_TAG] | |||
char_emb_layer.weight.data[index] = char_embed_weights[index_in_pre] | |||
print(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||
logger.info(f"{found_char_count} out of {len(char_vocab)} characters were found in pretrained elmo embedding.") | |||
# 生成words到chars的映射 | |||
max_chars = config['char_cnn']['max_characters_per_token'] | |||
self.words_to_chars_embedding = nn.Parameter(torch.full((len(vocab) + 2, max_chars), | |||
self.register_buffer('words_to_chars_embedding', torch.full((len(vocab) + 2, max_chars), | |||
fill_value=len(char_vocab), | |||
dtype=torch.long), | |||
requires_grad=False) | |||
dtype=torch.long)) | |||
for word, index in list(iter(vocab)) + [(BOS_TAG, len(vocab)), (EOS_TAG, len(vocab) + 1)]: | |||
if len(word) + 2 > max_chars: | |||
word = word[:max_chars - 2] | |||
@@ -257,29 +262,29 @@ class _ElmoModel(nn.Module): | |||
char_vocab.to_index(EOW_TAG)] | |||
char_ids += [char_vocab.to_index(PAD_TAG)] * (max_chars - len(char_ids)) | |||
self.words_to_chars_embedding[index] = torch.LongTensor(char_ids) | |||
self.char_vocab = char_vocab | |||
self.token_embedder = ConvTokenEmbedder( | |||
config, self.weight_file, None, char_emb_layer) | |||
elmo_model["char_cnn"]['char_emb_layer.weight'] = char_emb_layer.weight | |||
self.token_embedder.load_state_dict(elmo_model["char_cnn"]) | |||
self.output_dim = config['lstm']['projection_dim'] | |||
# lstm encoder | |||
self.encoder = ElmobiLm(config) | |||
self.encoder.load_state_dict(elmo_model["lstm"]) | |||
if cache_word_reprs: | |||
if config['char_cnn']['embedding']['dim'] > 0: # 只有在使用了chars的情况下有用 | |||
print("Start to generate cache word representations.") | |||
logger.info("Start to generate cache word representations.") | |||
batch_size = 320 | |||
# bos eos | |||
word_size = self.words_to_chars_embedding.size(0) | |||
num_batches = word_size // batch_size + \ | |||
int(word_size % batch_size != 0) | |||
self.cached_word_embedding = nn.Embedding(word_size, | |||
config['lstm']['projection_dim']) | |||
with torch.no_grad(): | |||
@@ -290,12 +295,12 @@ class _ElmoModel(nn.Module): | |||
word_reprs = self.token_embedder(words.unsqueeze(1), | |||
chars).detach() # batch_size x 1 x config['encoder']['projection_dim'] | |||
self.cached_word_embedding.weight.data[words] = word_reprs.squeeze(1) | |||
print("Finish generating cached word representations. Going to delete the character encoder.") | |||
logger.info("Finish generating cached word representations. Going to delete the character encoder.") | |||
del self.token_embedder, self.words_to_chars_embedding | |||
else: | |||
print("There is no need to cache word representations, since no character information is used.") | |||
logger.info("There is no need to cache word representations, since no character information is used.") | |||
def forward(self, words): | |||
""" | |||
@@ -320,7 +325,7 @@ class _ElmoModel(nn.Module): | |||
else: | |||
chars = None | |||
token_embedding = self.token_embedder(expanded_words, chars) # batch_size x max_len x embed_dim | |||
encoder_output = self.encoder(token_embedding, seq_len) | |||
if encoder_output.size(2) < max_len + 2: | |||
num_layers, _, output_len, hidden_size = encoder_output.size() | |||
@@ -331,7 +336,7 @@ class _ElmoModel(nn.Module): | |||
token_embedding = token_embedding.masked_fill(mask, 0) | |||
token_embedding = torch.cat((token_embedding, token_embedding), dim=2).view(1, sz[1], sz[2], sz[3]) | |||
encoder_output = torch.cat((token_embedding, encoder_output), dim=0) | |||
# 删除<eos>, <bos>. 这里没有精确地删除,但应该也不会影响最后的结果了。 | |||
encoder_output = encoder_output[:, :, 1:-1] | |||
return encoder_output |
@@ -3,6 +3,10 @@ | |||
""" | |||
__all__ = [ | |||
"Embedding", | |||
"TokenEmbedding" | |||
] | |||
import torch.nn as nn | |||
from abc import abstractmethod | |||
@@ -13,13 +17,12 @@ from .utils import get_embeddings | |||
class Embedding(nn.Module): | |||
""" | |||
别名::class:`fastNLP.embeddings.Embedding` :class:`fastNLP.embeddings.embedding.Embedding` | |||
词向量嵌入,支持输入多种方式初始化. 可以通过self.num_embeddings获取词表大小; self.embedding_dim获取embedding的维度. | |||
Example:: | |||
>>> import numpy as np | |||
>>> from fastNLP.embeddings import Embedding | |||
>>> init_embed = (2000, 100) | |||
>>> embed = Embedding(init_embed) # 随机初始化一个具有2000个词,每个词表示为100维的词向量 | |||
>>> init_embed = np.zeros((2000, 100)) | |||
@@ -32,54 +35,59 @@ class Embedding(nn.Module): | |||
:param float dropout: 对Embedding的输出的dropout。 | |||
:param int unk_index: drop word时替换为的index。fastNLP的Vocabulary的unk_index默认为1。 | |||
""" | |||
def __init__(self, init_embed, word_dropout=0, dropout=0.0, unk_index=None): | |||
super(Embedding, self).__init__() | |||
self.embed = get_embeddings(init_embed) | |||
self.dropout = nn.Dropout(dropout) | |||
if not isinstance(self.embed, TokenEmbedding): | |||
self._embed_size = self.embed.weight.size(1) | |||
if word_dropout>0 and not isinstance(unk_index, int): | |||
if hasattr(self.embed, 'embed_size'): | |||
self._embed_size = self.embed.embed_size | |||
elif hasattr(self.embed, 'embedding_dim'): | |||
self._embed_size = self.embed.embedding_dim | |||
else: | |||
self._embed_size = self.embed.weight.size(1) | |||
if word_dropout > 0 and not isinstance(unk_index, int): | |||
raise ValueError("When drop word is set, you need to pass in the unk_index.") | |||
else: | |||
self._embed_size = self.embed.embed_size | |||
unk_index = self.embed.get_word_vocab().unknown_idx | |||
self.unk_index = unk_index | |||
self.word_dropout = word_dropout | |||
def forward(self, words): | |||
""" | |||
:param torch.LongTensor words: [batch, seq_len] | |||
:return: torch.Tensor : [batch, seq_len, embed_dim] | |||
""" | |||
if self.word_dropout>0 and self.training: | |||
if self.word_dropout > 0 and self.training: | |||
mask = torch.ones_like(words).float() * self.word_dropout | |||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
words = words.masked_fill(mask, self.unk_index) | |||
words = self.embed(words) | |||
return self.dropout(words) | |||
@property | |||
def num_embedding(self)->int: | |||
def num_embedding(self) -> int: | |||
if isinstance(self.embed, nn.Embedding): | |||
return self.embed.weight.size(0) | |||
else: | |||
return self.embed.num_embedding | |||
def __len__(self): | |||
return len(self.embed) | |||
@property | |||
def embed_size(self) -> int: | |||
return self._embed_size | |||
@property | |||
def embedding_dim(self) -> int: | |||
return self._embed_size | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -90,14 +98,14 @@ class Embedding(nn.Module): | |||
return self.embed.weight.requires_grad | |||
else: | |||
return self.embed.requires_grad | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
if not isinstance(self.embed, TokenEmbedding): | |||
self.embed.weight.requires_grad = value | |||
else: | |||
self.embed.requires_grad = value | |||
@property | |||
def size(self): | |||
if isinstance(self.embed, TokenEmbedding): | |||
@@ -114,12 +122,12 @@ class TokenEmbedding(nn.Module): | |||
assert vocab.padding is not None, "Vocabulary must have a padding entry." | |||
self._word_vocab = vocab | |||
self._word_pad_index = vocab.padding_idx | |||
if word_dropout>0: | |||
if word_dropout > 0: | |||
assert vocab.unknown is not None, "Vocabulary must have unknown entry when you want to drop a word." | |||
self.word_dropout = word_dropout | |||
self._word_unk_index = vocab.unknown_idx | |||
self.dropout_layer = nn.Dropout(dropout) | |||
def drop_word(self, words): | |||
""" | |||
按照设定随机将words设置为unknown_index。 | |||
@@ -128,11 +136,13 @@ class TokenEmbedding(nn.Module): | |||
:return: | |||
""" | |||
if self.word_dropout > 0 and self.training: | |||
mask = torch.ones_like(words).float() * self.word_dropout | |||
mask = torch.bernoulli(mask).byte() # dropout_word越大,越多位置为1 | |||
mask = torch.full_like(words, fill_value=self.word_dropout, dtype=torch.float, device=words.device) | |||
mask = torch.bernoulli(mask).eq(1) # dropout_word越大,越多位置为1 | |||
pad_mask = words.ne(self._word_pad_index) | |||
mask = mask.__and__(pad_mask) | |||
words = words.masked_fill(mask, self._word_unk_index) | |||
return words | |||
def dropout(self, words): | |||
""" | |||
对embedding后的word表示进行drop。 | |||
@@ -141,7 +151,7 @@ class TokenEmbedding(nn.Module): | |||
:return: | |||
""" | |||
return self.dropout_layer(words) | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -153,23 +163,23 @@ class TokenEmbedding(nn.Module): | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for param in self.parameters(): | |||
param.requires_grad = value | |||
def __len__(self): | |||
return len(self._word_vocab) | |||
@property | |||
def embed_size(self) -> int: | |||
return self._embed_size | |||
@property | |||
def embedding_dim(self) -> int: | |||
return self._embed_size | |||
@property | |||
def num_embedding(self) -> int: | |||
""" | |||
@@ -177,7 +187,7 @@ class TokenEmbedding(nn.Module): | |||
:return: | |||
""" | |||
return len(self._word_vocab) | |||
def get_word_vocab(self): | |||
""" | |||
返回embedding的词典。 | |||
@@ -185,11 +195,11 @@ class TokenEmbedding(nn.Module): | |||
:return: Vocabulary | |||
""" | |||
return self._word_vocab | |||
@property | |||
def size(self): | |||
return torch.Size(self.num_embedding, self._embed_size) | |||
@abstractmethod | |||
def forward(self, words): | |||
raise NotImplementedError |
@@ -1,3 +1,12 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"StackEmbedding", | |||
] | |||
from typing import List | |||
import torch | |||
@@ -8,8 +17,6 @@ from .embedding import TokenEmbedding | |||
class StackEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.StackEmbedding` :class:`fastNLP.embeddings.stack_embedding.StackEmbedding` | |||
支持将多个embedding集合成一个embedding。 | |||
Example:: | |||
@@ -17,7 +24,7 @@ class StackEmbedding(TokenEmbedding): | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings import StaticEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50', requires_grad=True) | |||
>>> embed_1 = StaticEmbedding(vocab, model_dir_or_name='en-glove-6b-50d', requires_grad=True) | |||
>>> embed_2 = StaticEmbedding(vocab, model_dir_or_name='en-word2vec-300', requires_grad=True) | |||
:param embeds: 一个由若干个TokenEmbedding组成的list,要求每一个TokenEmbedding的词表都保持一致 | |||
@@ -26,6 +33,7 @@ class StackEmbedding(TokenEmbedding): | |||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
""" | |||
def __init__(self, embeds: List[TokenEmbedding], word_dropout=0, dropout=0): | |||
vocabs = [] | |||
for embed in embeds: | |||
@@ -34,14 +42,14 @@ class StackEmbedding(TokenEmbedding): | |||
_vocab = vocabs[0] | |||
for vocab in vocabs[1:]: | |||
assert vocab == _vocab, "All embeddings in StackEmbedding should use the same word vocabulary." | |||
super(StackEmbedding, self).__init__(_vocab, word_dropout=word_dropout, dropout=dropout) | |||
assert isinstance(embeds, list) | |||
for embed in embeds: | |||
assert isinstance(embed, TokenEmbedding), "Only TokenEmbedding type is supported." | |||
self.embeds = nn.ModuleList(embeds) | |||
self._embed_size = sum([embed.embed_size for embed in self.embeds]) | |||
def append(self, embed: TokenEmbedding): | |||
""" | |||
添加一个embedding到结尾。 | |||
@@ -50,18 +58,18 @@ class StackEmbedding(TokenEmbedding): | |||
""" | |||
assert isinstance(embed, TokenEmbedding) | |||
self.embeds.append(embed) | |||
def pop(self): | |||
""" | |||
弹出最后一个embed | |||
:return: | |||
""" | |||
return self.embeds.pop() | |||
@property | |||
def embed_size(self): | |||
return self._embed_size | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -73,12 +81,12 @@ class StackEmbedding(TokenEmbedding): | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for embed in self.embeds(): | |||
embed.requires_grad = value | |||
def forward(self, words): | |||
""" | |||
得到多个embedding的结果,并把结果按照顺序concat起来。 | |||
@@ -91,4 +99,4 @@ class StackEmbedding(TokenEmbedding): | |||
for embed in self.embeds: | |||
outputs.append(embed(words)) | |||
outputs = self.dropout(torch.cat(outputs, dim=-1)) | |||
return outputs | |||
return outputs |
@@ -1,4 +1,11 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
"StaticEmbedding" | |||
] | |||
import os | |||
import torch | |||
@@ -7,25 +14,29 @@ import numpy as np | |||
import warnings | |||
from ..core.vocabulary import Vocabulary | |||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_base_url, cached_path | |||
from ..io.file_utils import PRETRAIN_STATIC_FILES, _get_embedding_url, cached_path | |||
from .embedding import TokenEmbedding | |||
from ..modules.utils import _get_file_name_base_on_postfix | |||
from copy import deepcopy | |||
from collections import defaultdict | |||
from ..core import logger | |||
class StaticEmbedding(TokenEmbedding): | |||
""" | |||
别名::class:`fastNLP.embeddings.StaticEmbedding` :class:`fastNLP.embeddings.static_embedding.StaticEmbedding` | |||
StaticEmbedding组件. 给定预训练embedding的名称或路径,根据vocab从embedding中抽取相应的数据(只会将出现在vocab中的词抽取出来, | |||
如果没有找到,则会随机初始化一个值(但如果该word是被标记为no_create_entry的话,则不会单独创建一个值,而是会被指向unk的index))。 | |||
当前支持自动下载的预训练vector有以下的几种(待补充); | |||
Example:: | |||
>>> from fastNLP import Vocabulary | |||
>>> from fastNLP.embeddings import StaticEmbedding | |||
>>> vocab = Vocabulary().add_word_lst("The whether is good .".split()) | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50') | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name='en-glove-50d') | |||
>>> vocab = Vocabulary().add_word_lst(["The", 'the', "THE"]) | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50", lower=True) | |||
>>> embed = StaticEmbedding(vocab, model_dir_or_name="en-glove-50d", lower=True) | |||
>>> # "the", "The", "THE"它们共用一个vector,且将使用"the"在预训练词表中寻找它们的初始化表示。 | |||
>>> vocab = Vocabulary().add_word_lst(["The", "the", "THE"]) | |||
@@ -41,85 +52,120 @@ class StaticEmbedding(TokenEmbedding): | |||
:param model_dir_or_name: 可以有两种方式调用预训练好的static embedding:第一种是传入embedding文件夹(文件夹下应该只有一个 | |||
以.txt作为后缀的文件)或文件路径;第二种是传入embedding的名称,第二种情况将自动查看缓存中是否存在该模型,没有的话将自动下载。 | |||
如果输入为None则使用embedding_dim的维度随机初始化一个embedding。 | |||
:param int embedding_dim: 随机初始化的embedding的维度,仅在model_dir_or_name为None时有效。 | |||
:param int embedding_dim: 随机初始化的embedding的维度,当该值为大于0的值时,将忽略model_dir_or_name。 | |||
:param bool requires_grad: 是否需要gradient. 默认为True | |||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对象。 | |||
:param callable init_method: 如何初始化没有找到的值。可以使用torch.nn.init.*中各种方法。调用该方法时传入一个tensor对 | |||
:param bool lower: 是否将vocab中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 | |||
为大写的词语开辟一个vector表示,则将lower设置为False。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param float dropout: 以多大的概率对embedding的表示进行Dropout。0.1即随机将10%的值置为0。 | |||
:param float word_dropout: 以多大的概率将一个词替换为unk。这样既可以训练unk也是一定的regularize。 | |||
:param bool normalize: 是否对vector进行normalize,使得每个vector的norm为1。 | |||
:param int min_freq: Vocabulary词频数小于这个数量的word将被指向unk。 | |||
""" | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str='en', embedding_dim=100, requires_grad: bool=True, | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False): | |||
def __init__(self, vocab: Vocabulary, model_dir_or_name: str = 'en', embedding_dim=-1, requires_grad: bool = True, | |||
init_method=None, lower=False, dropout=0, word_dropout=0, normalize=False, min_freq=1, **kwargs): | |||
super(StaticEmbedding, self).__init__(vocab, word_dropout=word_dropout, dropout=dropout) | |||
if embedding_dim > 0: | |||
model_dir_or_name = None | |||
# 得到cache_path | |||
if model_dir_or_name is None: | |||
assert embedding_dim>=1, "The dimension of embedding should be larger than 1." | |||
assert embedding_dim >= 1, "The dimension of embedding should be larger than 1." | |||
embedding_dim = int(embedding_dim) | |||
model_path = None | |||
elif model_dir_or_name.lower() in PRETRAIN_STATIC_FILES: | |||
PRETRAIN_URL = _get_base_url('static') | |||
model_name = PRETRAIN_STATIC_FILES[model_dir_or_name] | |||
model_url = PRETRAIN_URL + model_name | |||
model_path = cached_path(model_url) | |||
model_url = _get_embedding_url('static', model_dir_or_name.lower()) | |||
model_path = cached_path(model_url, name='embedding') | |||
# 检查是否存在 | |||
elif os.path.isfile(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_path = model_dir_or_name | |||
elif os.path.isdir(os.path.expanduser(os.path.abspath(model_dir_or_name))): | |||
model_path = _get_file_name_base_on_postfix(model_dir_or_name, '.txt') | |||
elif os.path.isfile(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_path = os.path.abspath(os.path.expanduser(model_dir_or_name)) | |||
elif os.path.isdir(os.path.abspath(os.path.expanduser(model_dir_or_name))): | |||
model_path = _get_file_name_base_on_postfix(os.path.abspath(os.path.expanduser(model_dir_or_name)), '.txt') | |||
else: | |||
raise ValueError(f"Cannot recognize {model_dir_or_name}.") | |||
# 根据min_freq缩小vocab | |||
truncate_vocab = (vocab.min_freq is None and min_freq > 1) or (vocab.min_freq and vocab.min_freq < min_freq) | |||
if truncate_vocab: | |||
truncated_vocab = deepcopy(vocab) | |||
truncated_vocab.min_freq = min_freq | |||
truncated_vocab.word2idx = None | |||
if lower: # 如果有lower,将大小写的的freq需要同时考虑到 | |||
lowered_word_count = defaultdict(int) | |||
for word, count in truncated_vocab.word_count.items(): | |||
lowered_word_count[word.lower()] += count | |||
for word in truncated_vocab.word_count.keys(): | |||
word_count = truncated_vocab.word_count[word] | |||
if lowered_word_count[word.lower()] >= min_freq and word_count < min_freq: | |||
truncated_vocab.add_word_lst([word] * (min_freq - word_count), | |||
no_create_entry=truncated_vocab._is_word_no_create_entry(word)) | |||
# 只限制在train里面的词语使用min_freq筛选 | |||
if kwargs.get('only_train_min_freq', False) and model_dir_or_name is not None: | |||
for word in truncated_vocab.word_count.keys(): | |||
if truncated_vocab._is_word_no_create_entry(word) and truncated_vocab.word_count[word] < min_freq: | |||
truncated_vocab.add_word_lst([word] * (min_freq - truncated_vocab.word_count[word]), | |||
no_create_entry=True) | |||
truncated_vocab.build_vocab() | |||
truncated_words_to_words = torch.arange(len(vocab)).long() | |||
for word, index in vocab: | |||
truncated_words_to_words[index] = truncated_vocab.to_index(word) | |||
logger.info(f"{len(vocab) - len(truncated_vocab)} out of {len(vocab)} words have frequency less than {min_freq}.") | |||
vocab = truncated_vocab | |||
self.only_norm_found_vector = kwargs.get('only_norm_found_vector', False) | |||
# 读取embedding | |||
if lower: | |||
lowered_vocab = Vocabulary(padding=vocab.padding, unknown=vocab.unknown) | |||
for word, index in vocab: | |||
if not vocab._is_word_no_create_entry(word): | |||
if vocab._is_word_no_create_entry(word): | |||
lowered_vocab.add_word(word.lower(), no_create_entry=True) | |||
else: | |||
lowered_vocab.add_word(word.lower()) # 先加入需要创建entry的 | |||
for word in vocab._no_create_word.keys(): # 不需要创建entry的 | |||
if word in vocab: | |||
lowered_word = word.lower() | |||
if lowered_word not in lowered_vocab.word_count: | |||
lowered_vocab.add_word(lowered_word) | |||
lowered_vocab._no_create_word[lowered_word] += 1 | |||
print(f"All word in vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} unique lowered " | |||
f"words.") | |||
logger.info(f"All word in the vocab have been lowered. There are {len(vocab)} words, {len(lowered_vocab)} " | |||
f"unique lowered words.") | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=lowered_vocab, init_method=init_method) | |||
else: | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
# 需要适配一下 | |||
if not hasattr(self, 'words_to_words'): | |||
self.words_to_words = torch.arange(len(lowered_vocab, )).long() | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
if lowered_vocab.unknown: | |||
unknown_idx = lowered_vocab.unknown_idx | |||
else: | |||
unknown_idx = embedding.size(0) - 1 # 否则是最后一个为unknow | |||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
words_to_words = torch.full((len(vocab),), fill_value=unknown_idx).long() | |||
for word, index in vocab: | |||
if word not in lowered_vocab: | |||
word = word.lower() | |||
if lowered_vocab._is_word_no_create_entry(word): # 如果不需要创建entry,已经默认unknown了 | |||
continue | |||
if word not in lowered_vocab and lowered_vocab._is_word_no_create_entry(word): | |||
continue # 如果不需要创建entry,已经默认unknown了 | |||
words_to_words[index] = self.words_to_words[lowered_vocab.to_index(word)] | |||
self.words_to_words = words_to_words | |||
self.register_buffer('words_to_words', words_to_words) | |||
self._word_unk_index = lowered_vocab.unknown_idx # 替换一下unknown的index | |||
else: | |||
if model_path: | |||
embedding = self._load_with_vocab(model_path, vocab=vocab, init_method=init_method) | |||
else: | |||
embedding = self._randomly_init_embed(len(vocab), embedding_dim, init_method) | |||
if normalize: | |||
self.register_buffer('words_to_words', torch.arange(len(vocab)).long()) | |||
if not self.only_norm_found_vector and normalize: | |||
embedding /= (torch.norm(embedding, dim=1, keepdim=True) + 1e-12) | |||
if truncate_vocab: | |||
for i in range(len(truncated_words_to_words)): | |||
index_in_truncated_vocab = truncated_words_to_words[i] | |||
truncated_words_to_words[i] = self.words_to_words[index_in_truncated_vocab] | |||
del self.words_to_words | |||
self.register_buffer('words_to_words', truncated_words_to_words) | |||
self.embedding = nn.Embedding(num_embeddings=embedding.shape[0], embedding_dim=embedding.shape[1], | |||
padding_idx=vocab.padding_idx, | |||
max_norm=None, norm_type=2, scale_grad_by_freq=False, | |||
sparse=False, _weight=embedding) | |||
self._embed_size = self.embedding.weight.size(1) | |||
self.requires_grad = requires_grad | |||
def _randomly_init_embed(self, num_embedding, embedding_dim, init_embed=None): | |||
""" | |||
@@ -129,14 +175,14 @@ class StaticEmbedding(TokenEmbedding): | |||
:return: torch.FloatTensor | |||
""" | |||
embed = torch.zeros(num_embedding, embedding_dim) | |||
if init_embed is None: | |||
nn.init.uniform_(embed, -np.sqrt(3/embedding_dim), np.sqrt(3/embedding_dim)) | |||
nn.init.uniform_(embed, -np.sqrt(3 / embedding_dim), np.sqrt(3 / embedding_dim)) | |||
else: | |||
init_embed(embed) | |||
return embed | |||
@property | |||
def requires_grad(self): | |||
""" | |||
@@ -150,14 +196,14 @@ class StaticEmbedding(TokenEmbedding): | |||
return requires_grads.pop() | |||
else: | |||
return None | |||
@requires_grad.setter | |||
def requires_grad(self, value): | |||
for name, param in self.named_parameters(): | |||
if 'words_to_words' in name: | |||
continue | |||
param.requires_grad = value | |||
def _load_with_vocab(self, embed_filepath, vocab, dtype=np.float32, padding='<pad>', unknown='<unk>', | |||
error='ignore', init_method=None): | |||
""" | |||
@@ -189,7 +235,12 @@ class StaticEmbedding(TokenEmbedding): | |||
dim = len(parts) - 1 | |||
f.seek(0) | |||
matrix = {} | |||
if vocab.padding: | |||
matrix[vocab.padding_idx] = torch.zeros(dim) | |||
if vocab.unknown: | |||
matrix[vocab.unknown_idx] = torch.zeros(dim) | |||
found_count = 0 | |||
found_unknown = False | |||
for idx, line in enumerate(f, start_idx): | |||
try: | |||
parts = line.strip().split() | |||
@@ -200,46 +251,42 @@ class StaticEmbedding(TokenEmbedding): | |||
word = vocab.padding | |||
elif word == unknown and vocab.unknown is not None: | |||
word = vocab.unknown | |||
found_unknown = True | |||
if word in vocab: | |||
index = vocab.to_index(word) | |||
matrix[index] = torch.from_numpy(np.fromstring(' '.join(nums), sep=' ', dtype=dtype, count=dim)) | |||
if self.only_norm_found_vector: | |||
matrix[index] = matrix[index] / np.linalg.norm(matrix[index]) | |||
found_count += 1 | |||
except Exception as e: | |||
if error == 'ignore': | |||
warnings.warn("Error occurred at the {} line.".format(idx)) | |||
else: | |||
print("Error occurred at the {} line.".format(idx)) | |||
logger.error("Error occurred at the {} line.".format(idx)) | |||
raise e | |||
print("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
logger.info("Found {} out of {} words in the pre-training embedding.".format(found_count, len(vocab))) | |||
for word, index in vocab: | |||
if index not in matrix and not vocab._is_word_no_create_entry(word): | |||
if vocab.unknown_idx in matrix: # 如果有unkonwn,用unknown初始化 | |||
if found_unknown: # 如果有unkonwn,用unknown初始化 | |||
matrix[index] = matrix[vocab.unknown_idx] | |||
else: | |||
matrix[index] = None | |||
# matrix中代表是需要建立entry的词 | |||
vectors = self._randomly_init_embed(len(matrix), dim, init_method) | |||
if vocab._no_create_word_length>0: | |||
if vocab.unknown is None: # 创建一个专门的unknown | |||
unknown_idx = len(matrix) | |||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||
else: | |||
unknown_idx = vocab.unknown_idx | |||
words_to_words = nn.Parameter(torch.full((len(vocab),), fill_value=unknown_idx).long(), | |||
requires_grad=False) | |||
for order, (index, vec) in enumerate(matrix.items()): | |||
if vec is not None: | |||
vectors[order] = vec | |||
words_to_words[index] = order | |||
self.words_to_words = words_to_words | |||
if vocab.unknown is None: # 创建一个专门的unknown | |||
unknown_idx = len(matrix) | |||
vectors = torch.cat((vectors, torch.zeros(1, dim)), dim=0).contiguous() | |||
else: | |||
for index, vec in matrix.items(): | |||
if vec is not None: | |||
vectors[index] = vec | |||
unknown_idx = vocab.unknown_idx | |||
self.register_buffer('words_to_words', torch.full((len(vocab), ), fill_value=unknown_idx).long()) | |||
for index, (index_in_vocab, vec) in enumerate(matrix.items()): | |||
if vec is not None: | |||
vectors[index] = vec | |||
self.words_to_words[index_in_vocab] = index | |||
return vectors | |||
def forward(self, words): | |||
""" | |||
传入words的index | |||
@@ -1,13 +1,19 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
import numpy as np | |||
import torch | |||
from torch import nn as nn | |||
from ..core.vocabulary import Vocabulary | |||
__all__ = ['get_embeddings'] | |||
__all__ = [ | |||
'get_embeddings' | |||
] | |||
def _construct_char_vocab_from_vocab(vocab:Vocabulary, min_freq:int=1): | |||
def _construct_char_vocab_from_vocab(vocab: Vocabulary, min_freq: int = 1): | |||
""" | |||
给定一个word的vocabulary生成character的vocabulary. | |||
@@ -31,13 +37,13 @@ def get_embeddings(init_embed): | |||
:param init_embed: 可以是 tuple:(num_embedings, embedding_dim), 即embedding的大小和每个词的维度;也可以传入 | |||
nn.Embedding 对象, 此时就以传入的对象作为embedding; 传入np.ndarray也行,将使用传入的ndarray作为作为Embedding初始化; | |||
传入torch.Tensor, 将使用传入的值作为Embedding初始化。 | |||
:return nn.Embedding embeddings: | |||
:return nn.Embedding: embeddings | |||
""" | |||
if isinstance(init_embed, tuple): | |||
res = nn.Embedding( | |||
num_embeddings=init_embed[0], embedding_dim=init_embed[1]) | |||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3/res.weight.data.size(1)), | |||
b=np.sqrt(3/res.weight.data.size(1))) | |||
nn.init.uniform_(res.weight.data, a=-np.sqrt(3 / res.weight.data.size(1)), | |||
b=np.sqrt(3 / res.weight.data.size(1))) | |||
elif isinstance(init_embed, nn.Module): | |||
res = init_embed | |||
elif isinstance(init_embed, torch.Tensor): | |||
@@ -48,4 +54,4 @@ def get_embeddings(init_embed): | |||
else: | |||
raise TypeError( | |||
'invalid init_embed type: {}'.format((type(init_embed)))) | |||
return res | |||
return res |
@@ -3,45 +3,92 @@ | |||
1. 用于读入 embedding 的 :doc:`EmbedLoader <fastNLP.io.embed_loader>` 类, | |||
2. 用于读入不同格式数据的 :doc:`DataSetLoader <fastNLP.io.dataset_loader>` 类 | |||
2. 用于读入不同格式数据的 :doc:`Loader <fastNLP.io.loader>` 类 | |||
3. 用于读入不同数据集并进行预处理的 :doc:`DataLoader <fastNLP.io.data_loader>` 类 | |||
3. 用于处理读入数据的 :doc:`Pipe <fastNLP.io.pipe>` 类 | |||
4. 用于保存和载入模型的类, 参考 :doc:`model_io文档</fastNLP.io.model_io>` | |||
这些类的使用方法如下: | |||
""" | |||
__all__ = [ | |||
'DataBundle', | |||
'EmbedLoader', | |||
'Loader', | |||
'YelpLoader', | |||
'YelpFullLoader', | |||
'YelpPolarityLoader', | |||
'IMDBLoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
"ChnSentiCorpLoader", | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'Conll2003NERLoader', | |||
'OntoNotesNERLoader', | |||
'CTBLoader', | |||
"MsraNERLoader", | |||
"WeiboNERLoader", | |||
"PeopleDailyNERLoader", | |||
'CSVLoader', | |||
'JsonLoader', | |||
'DataBundle', | |||
'DataSetLoader', | |||
'CWSLoader', | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'IMDBLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'MNLILoader', | |||
'MTL16Loader', | |||
'PeopleDailyCorpusLoader', | |||
'QNLILoader', | |||
'QuoraLoader', | |||
'RTELoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
'YelpLoader', | |||
"QuoraLoader", | |||
"SNLILoader", | |||
"QNLILoader", | |||
"RTELoader", | |||
"Pipe", | |||
"YelpFullPipe", | |||
"YelpPolarityPipe", | |||
"SSTPipe", | |||
"SST2Pipe", | |||
"IMDBPipe", | |||
"ChnSentiCorpPipe", | |||
"Conll2003Pipe", | |||
"Conll2003NERPipe", | |||
"OntoNotesNERPipe", | |||
"MsraNERPipe", | |||
"PeopleDailyPipe", | |||
"WeiboNERPipe", | |||
"CWSPipe", | |||
"MatchingBertPipe", | |||
"RTEBertPipe", | |||
"SNLIBertPipe", | |||
"QuoraBertPipe", | |||
"QNLIBertPipe", | |||
"MNLIBertPipe", | |||
"MatchingPipe", | |||
"RTEPipe", | |||
"SNLIPipe", | |||
"QuoraPipe", | |||
"QNLIPipe", | |||
"MNLIPipe", | |||
'ModelLoader', | |||
'ModelSaver', | |||
] | |||
from .embed_loader import EmbedLoader | |||
from .base_loader import DataBundle, DataSetLoader | |||
from .dataset_loader import CSVLoader, JsonLoader | |||
from .data_bundle import DataBundle | |||
from .model_io import ModelLoader, ModelSaver | |||
from .data_loader import * | |||
from .loader import * | |||
from .pipe import * | |||
import sys | |||
from ..doc_utils import doc_process | |||
doc_process(sys.modules[__name__]) |
@@ -1,220 +0,0 @@ | |||
__all__ = [ | |||
"BaseLoader", | |||
'DataBundle', | |||
'DataSetLoader', | |||
] | |||
import _pickle as pickle | |||
import os | |||
from typing import Union, Dict | |||
import os | |||
from ..core.dataset import DataSet | |||
class BaseLoader(object): | |||
""" | |||
各个 Loader 的基类,提供了 API 的参考。 | |||
""" | |||
def __init__(self): | |||
super(BaseLoader, self).__init__() | |||
@staticmethod | |||
def load_lines(data_path): | |||
""" | |||
按行读取,舍弃每行两侧空白字符,返回list of str | |||
:param data_path: 读取数据的路径 | |||
""" | |||
with open(data_path, "r", encoding="utf=8") as f: | |||
text = f.readlines() | |||
return [line.strip() for line in text] | |||
@classmethod | |||
def load(cls, data_path): | |||
""" | |||
先按行读取,去除一行两侧空白,再提取每行的字符。返回list of list of str | |||
:param data_path: | |||
""" | |||
with open(data_path, "r", encoding="utf-8") as f: | |||
text = f.readlines() | |||
return [[word for word in sent.strip()] for sent in text] | |||
@classmethod | |||
def load_with_cache(cls, data_path, cache_path): | |||
"""缓存版的load | |||
""" | |||
if os.path.isfile(cache_path) and os.path.getmtime(data_path) < os.path.getmtime(cache_path): | |||
with open(cache_path, 'rb') as f: | |||
return pickle.load(f) | |||
else: | |||
obj = cls.load(data_path) | |||
with open(cache_path, 'wb') as f: | |||
pickle.dump(obj, f) | |||
return obj | |||
def _download_from_url(url, path): | |||
try: | |||
from tqdm.auto import tqdm | |||
except: | |||
from ..core.utils import _pseudo_tqdm as tqdm | |||
import requests | |||
"""Download file""" | |||
r = requests.get(url, headers={'User-Agent': 'Mozilla/5.0'}, stream=True) | |||
chunk_size = 16 * 1024 | |||
total_size = int(r.headers.get('Content-length', 0)) | |||
with open(path, "wb") as file, \ | |||
tqdm(total=total_size, unit='B', unit_scale=1, desc=path.split('/')[-1]) as t: | |||
for chunk in r.iter_content(chunk_size): | |||
if chunk: | |||
file.write(chunk) | |||
t.update(len(chunk)) | |||
def _uncompress(src, dst): | |||
import zipfile | |||
import gzip | |||
import tarfile | |||
import os | |||
def unzip(src, dst): | |||
with zipfile.ZipFile(src, 'r') as f: | |||
f.extractall(dst) | |||
def ungz(src, dst): | |||
with gzip.open(src, 'rb') as f, open(dst, 'wb') as uf: | |||
length = 16 * 1024 # 16KB | |||
buf = f.read(length) | |||
while buf: | |||
uf.write(buf) | |||
buf = f.read(length) | |||
def untar(src, dst): | |||
with tarfile.open(src, 'r:gz') as f: | |||
f.extractall(dst) | |||
fn, ext = os.path.splitext(src) | |||
_, ext_2 = os.path.splitext(fn) | |||
if ext == '.zip': | |||
unzip(src, dst) | |||
elif ext == '.gz' and ext_2 != '.tar': | |||
ungz(src, dst) | |||
elif (ext == '.gz' and ext_2 == '.tar') or ext_2 == '.tgz': | |||
untar(src, dst) | |||
else: | |||
raise ValueError('unsupported file {}'.format(src)) | |||
class DataBundle: | |||
""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。 | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||
""" | |||
def __init__(self, vocabs: dict = None, datasets: dict = None): | |||
self.vocabs = vocabs or {} | |||
self.datasets = datasets or {} | |||
def __repr__(self): | |||
_str = 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
class DataSetLoader: | |||
""" | |||
别名::class:`fastNLP.io.DataSetLoader` :class:`fastNLP.io.dataset_loader.DataSetLoader` | |||
定义了各种 DataSetLoader 所需的API 接口,开发者应该继承它实现各种的 DataSetLoader。 | |||
开发者至少应该编写如下内容: | |||
- _load 函数:从一个数据文件中读取数据到一个 :class:`~fastNLP.DataSet` | |||
- load 函数(可以使用基类的方法):从一个或多个数据文件中读取数据到一个或多个 :class:`~fastNLP.DataSet` | |||
- process 函数:一个或多个从数据文件中读取数据,并处理成可以训练的一个或多个 :class:`~fastNLP.DataSet` | |||
**process 函数中可以 调用load 函数或 _load 函数** | |||
""" | |||
URL = '' | |||
DATA_DIR = '' | |||
ROOT_DIR = '.fastnlp/datasets/' | |||
UNCOMPRESS = True | |||
def _download(self, url: str, pdir: str, uncompress=True) -> str: | |||
""" | |||
从 ``url`` 下载数据到 ``path``, 如果 ``uncompress`` 为 ``True`` ,自动解压。 | |||
:param url: 下载的网站 | |||
:param pdir: 下载到的目录 | |||
:param uncompress: 是否自动解压缩 | |||
:return: 数据的存放路径 | |||
""" | |||
fn = os.path.basename(url) | |||
path = os.path.join(pdir, fn) | |||
"""check data exists""" | |||
if not os.path.exists(path): | |||
os.makedirs(pdir, exist_ok=True) | |||
_download_from_url(url, path) | |||
if uncompress: | |||
dst = os.path.join(pdir, 'data') | |||
if not os.path.exists(dst): | |||
_uncompress(path, dst) | |||
return dst | |||
return path | |||
def download(self): | |||
return self._download( | |||
self.URL, | |||
os.path.join(self.ROOT_DIR, self.DATA_DIR), | |||
uncompress=self.UNCOMPRESS) | |||
def load(self, paths: Union[str, Dict[str, str]]) -> Union[DataSet, Dict[str, DataSet]]: | |||
""" | |||
从指定一个或多个路径中的文件中读取数据,返回一个或多个数据集 :class:`~fastNLP.DataSet` 。 | |||
如果处理多个路径,传入的 dict 中的 key 与返回的 dict 中的 key 保存一致。 | |||
:param Union[str, Dict[str, str]] paths: 文件路径 | |||
:return: :class:`~fastNLP.DataSet` 类的对象或存储多个 :class:`~fastNLP.DataSet` 的字典 | |||
""" | |||
if isinstance(paths, str): | |||
return self._load(paths) | |||
return {name: self._load(path) for name, path in paths.items()} | |||
def _load(self, path: str) -> DataSet: | |||
"""从指定路径的文件中读取数据,返回 :class:`~fastNLP.DataSet` 类型的对象 | |||
:param str path: 文件路径 | |||
:return: 一个 :class:`~fastNLP.DataSet` 类型的对象 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], **options) -> DataBundle: | |||
""" | |||
对于特定的任务和数据集,读取并处理数据,返回处理DataInfo类对象或字典。 | |||
从指定一个或多个路径中的文件中读取数据,DataInfo对象中可以包含一个或多个数据集 。 | |||
如果处理多个路径,传入的 dict 的 key 与返回DataInfo中的 dict 中的 key 保存一致。 | |||
返回的 :class:`DataBundle` 对象有如下属性: | |||
- vocabs: 由从数据集中获取的词表组成的字典,每个词表 | |||
- datasets: 一个dict,包含一系列 :class:`~fastNLP.DataSet` 类型的对象。其中 field 的命名参考 :mod:`~fastNLP.core.const` | |||
:param paths: 原始数据读取的路径 | |||
:param options: 根据不同的任务和数据集,设计自己的参数 | |||
:return: 返回一个 DataBundle | |||
""" | |||
raise NotImplementedError |
@@ -1,311 +0,0 @@ | |||
""" | |||
用于读入和处理和保存 config 文件 | |||
.. todo:: | |||
这个模块中的类可能被抛弃? | |||
""" | |||
__all__ = [ | |||
"ConfigLoader", | |||
"ConfigSection", | |||
"ConfigSaver" | |||
] | |||
import configparser | |||
import json | |||
import os | |||
from .base_loader import BaseLoader | |||
class ConfigLoader(BaseLoader): | |||
""" | |||
别名::class:`fastNLP.io.ConfigLoader` :class:`fastNLP.io.config_io.ConfigLoader` | |||
读取配置文件的Loader | |||
:param str data_path: 配置文件的路径 | |||
""" | |||
def __init__(self, data_path=None): | |||
super(ConfigLoader, self).__init__() | |||
if data_path is not None: | |||
self.config = self.parse(super(ConfigLoader, self).load(data_path)) | |||
@staticmethod | |||
def parse(string): | |||
raise NotImplementedError | |||
@staticmethod | |||
def load_config(file_path, sections): | |||
""" | |||
把配置文件的section 存入提供的 ``sections`` 中 | |||
:param str file_path: 配置文件的路径 | |||
:param dict sections: 符合如下键值对组成的字典 `section_name(string)` : :class:`~fastNLP.io.ConfigSection` | |||
Example:: | |||
test_args = ConfigSection() | |||
ConfigLoader("config.cfg").load_config("./data_for_tests/config", {"POS_test": test_args}) | |||
""" | |||
assert isinstance(sections, dict) | |||
cfg = configparser.ConfigParser() | |||
if not os.path.exists(file_path): | |||
raise FileNotFoundError("config file {} not found. ".format(file_path)) | |||
cfg.read(file_path) | |||
for s in sections: | |||
attr_list = [i for i in sections[s].__dict__.keys() if | |||
not callable(getattr(sections[s], i)) and not i.startswith("__")] | |||
if s not in cfg: | |||
print('section %s not found in config file' % (s)) | |||
continue | |||
gen_sec = cfg[s] | |||
for attr in gen_sec.keys(): | |||
try: | |||
val = json.loads(gen_sec[attr]) | |||
# print(s, attr, val, type(val)) | |||
if attr in attr_list: | |||
assert type(val) == type(getattr(sections[s], attr)), \ | |||
'type not match, except %s but got %s' % \ | |||
(type(getattr(sections[s], attr)), type(val)) | |||
""" | |||
if attr in attr_list then check its type and | |||
update its value. | |||
else add a new attr in sections[s] | |||
""" | |||
setattr(sections[s], attr, val) | |||
except Exception as e: | |||
print("cannot load attribute %s in section %s" | |||
% (attr, s)) | |||
pass | |||
class ConfigSection(object): | |||
""" | |||
别名::class:`fastNLP.io.ConfigSection` :class:`fastNLP.io.config_io.ConfigSection` | |||
ConfigSection是一个存储了一个section中所有键值对的数据结构,推荐使用此类的实例来配合 :meth:`ConfigLoader.load_config` 使用 | |||
""" | |||
def __init__(self): | |||
super(ConfigSection, self).__init__() | |||
def __getitem__(self, key): | |||
""" | |||
:param key: str, the name of the attribute | |||
:return attr: the value of this attribute | |||
if key not in self.__dict__.keys(): | |||
return self[key] | |||
else: | |||
raise AttributeError | |||
""" | |||
if key in self.__dict__.keys(): | |||
return getattr(self, key) | |||
raise AttributeError("do NOT have attribute %s" % key) | |||
def __setitem__(self, key, value): | |||
""" | |||
:param key: str, the name of the attribute | |||
:param value: the value of this attribute | |||
if key not in self.__dict__.keys(): | |||
self[key] will be added | |||
else: | |||
self[key] will be updated | |||
""" | |||
if key in self.__dict__.keys(): | |||
if not isinstance(value, type(getattr(self, key))): | |||
raise AttributeError("attr %s except %s but got %s" % | |||
(key, str(type(getattr(self, key))), str(type(value)))) | |||
setattr(self, key, value) | |||
def __contains__(self, item): | |||
""" | |||
:param item: The key of item. | |||
:return: True if the key in self.__dict__.keys() else False. | |||
""" | |||
return item in self.__dict__.keys() | |||
def __eq__(self, other): | |||
"""Overwrite the == operator | |||
:param other: Another ConfigSection() object which to be compared. | |||
:return: True if value of each key in each ConfigSection() object are equal to the other, else False. | |||
""" | |||
for k in self.__dict__.keys(): | |||
if k not in other.__dict__.keys(): | |||
return False | |||
if getattr(self, k) != getattr(self, k): | |||
return False | |||
for k in other.__dict__.keys(): | |||
if k not in self.__dict__.keys(): | |||
return False | |||
if getattr(self, k) != getattr(self, k): | |||
return False | |||
return True | |||
def __ne__(self, other): | |||
"""Overwrite the != operator | |||
:param other: | |||
:return: | |||
""" | |||
return not self.__eq__(other) | |||
@property | |||
def data(self): | |||
return self.__dict__ | |||
class ConfigSaver(object): | |||
""" | |||
别名::class:`fastNLP.io.ConfigSaver` :class:`fastNLP.io.config_io.ConfigSaver` | |||
ConfigSaver 是用来存储配置文件并解决相关冲突的类 | |||
:param str file_path: 配置文件的路径 | |||
""" | |||
def __init__(self, file_path): | |||
self.file_path = file_path | |||
if not os.path.exists(self.file_path): | |||
raise FileNotFoundError("file {} NOT found!".__format__(self.file_path)) | |||
def _get_section(self, sect_name): | |||
""" | |||
This is the function to get the section with the section name. | |||
:param sect_name: The name of section what wants to load. | |||
:return: The section. | |||
""" | |||
sect = ConfigSection() | |||
ConfigLoader().load_config(self.file_path, {sect_name: sect}) | |||
return sect | |||
def _read_section(self): | |||
""" | |||
This is the function to read sections from the config file. | |||
:return: sect_list, sect_key_list | |||
sect_list: A list of ConfigSection(). | |||
sect_key_list: A list of names in sect_list. | |||
""" | |||
sect_name = None | |||
sect_list = {} | |||
sect_key_list = [] | |||
single_section = {} | |||
single_section_key = [] | |||
with open(self.file_path, 'r') as f: | |||
lines = f.readlines() | |||
for line in lines: | |||
if line.startswith('[') and line.endswith(']\n'): | |||
if sect_name is None: | |||
pass | |||
else: | |||
sect_list[sect_name] = single_section, single_section_key | |||
single_section = {} | |||
single_section_key = [] | |||
sect_key_list.append(sect_name) | |||
sect_name = line[1: -2] | |||
continue | |||
if line.startswith('#'): | |||
single_section[line] = '#' | |||
single_section_key.append(line) | |||
continue | |||
if line.startswith('\n'): | |||
single_section_key.append('\n') | |||
continue | |||
if '=' not in line: | |||
raise RuntimeError("can NOT load config file {}".__format__(self.file_path)) | |||
key = line.split('=', maxsplit=1)[0].strip() | |||
value = line.split('=', maxsplit=1)[1].strip() + '\n' | |||
single_section[key] = value | |||
single_section_key.append(key) | |||
if sect_name is not None: | |||
sect_list[sect_name] = single_section, single_section_key | |||
sect_key_list.append(sect_name) | |||
return sect_list, sect_key_list | |||
def _write_section(self, sect_list, sect_key_list): | |||
""" | |||
This is the function to write config file with section list and name list. | |||
:param sect_list: A list of ConfigSection() need to be writen into file. | |||
:param sect_key_list: A list of name of sect_list. | |||
:return: | |||
""" | |||
with open(self.file_path, 'w') as f: | |||
for sect_key in sect_key_list: | |||
single_section, single_section_key = sect_list[sect_key] | |||
f.write('[' + sect_key + ']\n') | |||
for key in single_section_key: | |||
if key == '\n': | |||
f.write('\n') | |||
continue | |||
if single_section[key] == '#': | |||
f.write(key) | |||
continue | |||
f.write(key + ' = ' + single_section[key]) | |||
f.write('\n') | |||
def save_config_file(self, section_name, section): | |||
""" | |||
这个方法可以用来修改并保存配置文件中单独的一个 section | |||
:param str section_name: 需要保存的 section 的名字. | |||
:param section: 你需要修改并保存的 section, :class:`~fastNLP.io.ConfigSaver` 类型 | |||
""" | |||
section_file = self._get_section(section_name) | |||
if len(section_file.__dict__.keys()) == 0: # the section not in the file before | |||
# append this section to config file | |||
with open(self.file_path, 'a') as f: | |||
f.write('[' + section_name + ']\n') | |||
for k in section.__dict__.keys(): | |||
f.write(k + ' = ') | |||
if isinstance(section[k], str): | |||
f.write('\"' + str(section[k]) + '\"\n\n') | |||
else: | |||
f.write(str(section[k]) + '\n\n') | |||
else: | |||
# the section exists | |||
change_file = False | |||
for k in section.__dict__.keys(): | |||
if k not in section_file: | |||
# find a new key in this section | |||
change_file = True | |||
break | |||
if section_file[k] != section[k]: | |||
change_file = True | |||
break | |||
if not change_file: | |||
return | |||
sect_list, sect_key_list = self._read_section() | |||
if section_name not in sect_key_list: | |||
raise AttributeError() | |||
sect, sect_key = sect_list[section_name] | |||
for k in section.__dict__.keys(): | |||
if k not in sect_key: | |||
if sect_key[-1] != '\n': | |||
sect_key.append('\n') | |||
sect_key.append(k) | |||
sect[k] = str(section[k]) | |||
if isinstance(section[k], str): | |||
sect[k] = "\"" + sect[k] + "\"" | |||
sect[k] = sect[k] + "\n" | |||
sect_list[section_name] = sect, sect_key | |||
self._write_section(sect_list, sect_key_list) |
@@ -0,0 +1,320 @@ | |||
""" | |||
.. todo:: | |||
doc | |||
""" | |||
__all__ = [ | |||
'DataBundle', | |||
] | |||
from ..core.dataset import DataSet | |||
from ..core.vocabulary import Vocabulary | |||
from typing import Union | |||
class DataBundle: | |||
""" | |||
经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 | |||
Loader的load函数生成,可以通过以下的方法获取里面的内容 | |||
Example:: | |||
data_bundle = YelpLoader().load({'train':'/path/to/train', 'dev': '/path/to/dev'}) | |||
train_vocabs = data_bundle.vocabs['train'] | |||
train_data = data_bundle.datasets['train'] | |||
dev_data = data_bundle.datasets['train'] | |||
:param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict | |||
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict | |||
""" | |||
def __init__(self, vocabs: dict = None, datasets: dict = None): | |||
self.vocabs = vocabs or {} | |||
self.datasets = datasets or {} | |||
def set_vocab(self, vocab, field_name): | |||
""" | |||
向DataBunlde中增加vocab | |||
:param ~fastNLP.Vocabulary vocab: 词表 | |||
:param str field_name: 这个vocab对应的field名称 | |||
:return: self | |||
""" | |||
assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." | |||
self.vocabs[field_name] = vocab | |||
return self | |||
def set_dataset(self, dataset, name): | |||
""" | |||
:param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet | |||
:param str name: dataset的名称 | |||
:return: self | |||
""" | |||
self.datasets[name] = dataset | |||
return self | |||
def get_dataset(self, name: str) -> DataSet: | |||
""" | |||
获取名为name的dataset | |||
:param str name: dataset的名称,一般为'train', 'dev', 'test' | |||
:return: DataSet | |||
""" | |||
return self.datasets[name] | |||
def delete_dataset(self, name: str): | |||
""" | |||
删除名为name的DataSet | |||
:param str name: | |||
:return: self | |||
""" | |||
self.datasets.pop(name, None) | |||
return self | |||
def get_vocab(self, field_name: str) -> Vocabulary: | |||
""" | |||
获取field名为field_name对应的vocab | |||
:param str field_name: 名称 | |||
:return: Vocabulary | |||
""" | |||
return self.vocabs[field_name] | |||
def delete_vocab(self, field_name: str): | |||
""" | |||
删除vocab | |||
:param str field_name: | |||
:return: self | |||
""" | |||
self.vocabs.pop(field_name, None) | |||
return self | |||
def set_input(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | |||
""" | |||
将field_names中的field设置为input, 对data_bundle中所有的dataset执行该操作:: | |||
data_bundle.set_input('words', 'seq_len') # 将words和seq_len这两个field的input属性设置为True | |||
data_bundle.set_input('words', flag=False) # 将words这个field的input属性设置为False | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的input状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for field_name in field_names: | |||
for name, dataset in self.datasets.items(): | |||
if not ignore_miss_dataset and not dataset.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") | |||
if not dataset.has_field(field_name): | |||
continue | |||
else: | |||
dataset.set_input(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||
return self | |||
def set_target(self, *field_names, flag=True, use_1st_ins_infer_dim_type=True, ignore_miss_dataset=True): | |||
""" | |||
将field_names中的field设置为target, 对data_bundle中所有的dataset执行该操作:: | |||
data_bundle.set_target('target', 'seq_len') # 将words和target这两个field的input属性设置为True | |||
data_bundle.set_target('target', flag=False) # 将target这个field的input属性设置为False | |||
:param str field_names: field的名称 | |||
:param bool flag: 将field_name的target状态设置为flag | |||
:param bool use_1st_ins_infer_dim_type: 如果为True,将不会check该列是否所有数据都是同样的维度,同样的类型。将直接使用第一 | |||
行的数据进行类型和维度推断本列的数据的类型和维度。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for field_name in field_names: | |||
for name, dataset in self.datasets.items(): | |||
if not ignore_miss_dataset and not dataset.has_field(field_name): | |||
raise KeyError(f"Field:{field_name} was not found in DataSet:{name}") | |||
if not dataset.has_field(field_name): | |||
continue | |||
else: | |||
dataset.set_target(field_name, flag=flag, use_1st_ins_infer_dim_type=use_1st_ins_infer_dim_type) | |||
return self | |||
def set_pad_val(self, field_name, pad_val, ignore_miss_dataset=True): | |||
""" | |||
将DataBundle中所有的DataSet中名为field_name的Field的padding值设置为pad_val. | |||
:param str field_name: | |||
:param int pad_val: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.set_pad_val(field_name=field_name, pad_val=pad_val) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def set_ignore_type(self, *field_names, flag=True, ignore_miss_dataset=True): | |||
""" | |||
将DataBundle中所有的DataSet中名为*field_names的Field的ignore_type设置为flag状态 | |||
:param str field_names: | |||
:param bool flag: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
for field_name in field_names: | |||
if dataset.has_field(field_name=field_name): | |||
dataset.set_ignore_type(field_name, flag=flag) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def copy_field(self, field_name, new_field_name, ignore_miss_dataset=True): | |||
""" | |||
将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. | |||
:param str field_name: | |||
:param str new_field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.copy_field(field_name=field_name, new_field_name=new_field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def rename_field(self, field_name, new_field_name, ignore_miss_dataset=True, rename_vocab=True): | |||
""" | |||
将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. | |||
:param str field_name: | |||
:param str new_field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.rename_field(field_name=field_name, new_field_name=new_field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if rename_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs[new_field_name] = self.vocabs.pop(field_name) | |||
return self | |||
def delete_field(self, field_name, ignore_miss_dataset=True, delete_vocab=True): | |||
""" | |||
将DataBundle中所有DataSet中名为field_name的field删除掉. | |||
:param str field_name: | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 | |||
:return: self | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.delete_field(field_name=field_name) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
if delete_vocab: | |||
if field_name in self.vocabs: | |||
self.vocabs.pop(field_name) | |||
return self | |||
def iter_datasets(self)->Union[str, DataSet]: | |||
""" | |||
迭代data_bundle中的DataSet | |||
Example:: | |||
for name, dataset in data_bundle.iter_datasets(): | |||
pass | |||
:return: | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
yield name, dataset | |||
def iter_vocabs(self)->Union[str, Vocabulary]: | |||
""" | |||
迭代data_bundle中的DataSet | |||
Example: | |||
for field_name, vocab in data_bundle.iter_vocabs(): | |||
pass | |||
:return: | |||
""" | |||
for field_name, vocab in self.vocabs.items(): | |||
yield field_name, vocab | |||
def apply_field(self, func, field_name:str, new_field_name:str, ignore_miss_dataset=True, **kwargs): | |||
""" | |||
对DataBundle中所有的dataset使用apply_field方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str field_name: 传入func的是哪个field。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input | |||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | |||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
if dataset.has_field(field_name=field_name): | |||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, **kwargs) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def apply(self, func, new_field_name:str, **kwargs): | |||
""" | |||
对DataBundle中所有的dataset使用apply方法 | |||
:param callable func: input是instance中名为 `field_name` 的field的内容。 | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param optional kwargs: 支持输入is_input,is_target,ignore_type | |||
1. is_input: bool, 如果为True则将名为 `new_field_name` 的field设置为input | |||
2. is_target: bool, 如果为True则将名为 `new_field_name` 的field设置为target | |||
3. ignore_type: bool, 如果为True则将名为 `new_field_name` 的field的ignore_type设置为true, 忽略其类型 | |||
""" | |||
for name, dataset in self.datasets.items(): | |||
dataset.apply(func, new_field_name=new_field_name, **kwargs) | |||
return self | |||
def __repr__(self): | |||
_str = '' | |||
if len(self.datasets): | |||
_str += 'In total {} datasets:\n'.format(len(self.datasets)) | |||
for name, dataset in self.datasets.items(): | |||
_str += '\t{} has {} instances.\n'.format(name, len(dataset)) | |||
if len(self.vocabs): | |||
_str += 'In total {} vocabs:\n'.format(len(self.vocabs)) | |||
for name, vocab in self.vocabs.items(): | |||
_str += '\t{} has {} entries.\n'.format(name, len(vocab)) | |||
return _str | |||
@@ -1,35 +0,0 @@ | |||
""" | |||
用于读数据集的模块, 可以读取文本分类、序列标注、Matching任务的数据集 | |||
这些模块的具体介绍如下,您可以通过阅读 :doc:`教程</tutorials/tutorial_2_load_dataset>` 来进行了解。 | |||
""" | |||
__all__ = [ | |||
'ConllLoader', | |||
'Conll2003Loader', | |||
'IMDBLoader', | |||
'MatchingLoader', | |||
'SNLILoader', | |||
'MNLILoader', | |||
'MTL16Loader', | |||
'PeopleDailyCorpusLoader', | |||
'QNLILoader', | |||
'QuoraLoader', | |||
'RTELoader', | |||
'SSTLoader', | |||
'SST2Loader', | |||
'YelpLoader', | |||
] | |||
from .conll import ConllLoader, Conll2003Loader | |||
from .imdb import IMDBLoader | |||
from .matching import MatchingLoader | |||
from .mnli import MNLILoader | |||
from .mtl import MTL16Loader | |||
from .people_daily import PeopleDailyCorpusLoader | |||
from .qnli import QNLILoader | |||
from .quora import QuoraLoader | |||
from .rte import RTELoader | |||
from .snli import SNLILoader | |||
from .sst import SSTLoader, SST2Loader | |||
from .yelp import YelpLoader |
@@ -1,73 +0,0 @@ | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ..base_loader import DataSetLoader | |||
from ..file_reader import _read_conll | |||
class ConllLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.ConllLoader` :class:`fastNLP.io.data_loader.ConllLoader` | |||
读取Conll格式的数据. 数据格式详见 http://conll.cemantix.org/2012/data.html. 数据中以"-DOCSTART-"开头的行将被忽略,因为 | |||
该符号在conll 2003中被用为文档分割符。 | |||
列号从0开始, 每列对应内容为:: | |||
Column Type | |||
0 Document ID | |||
1 Part number | |||
2 Word number | |||
3 Word itself | |||
4 Part-of-Speech | |||
5 Parse bit | |||
6 Predicate lemma | |||
7 Predicate Frameset ID | |||
8 Word sense | |||
9 Speaker/Author | |||
10 Named Entities | |||
11:N Predicate Arguments | |||
N Coreference | |||
:param headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``False`` | |||
""" | |||
def __init__(self, headers, indexes=None, dropna=False): | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
raise TypeError( | |||
'invalid headers: {}, should be list of strings'.format(headers)) | |||
self.headers = headers | |||
self.dropna = dropna | |||
if indexes is None: | |||
self.indexes = list(range(len(self.headers))) | |||
else: | |||
if len(indexes) != len(headers): | |||
raise ValueError | |||
self.indexes = indexes | |||
def _load(self, path): | |||
ds = DataSet() | |||
for idx, data in _read_conll(path, indexes=self.indexes, dropna=self.dropna): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
class Conll2003Loader(ConllLoader): | |||
""" | |||
别名::class:`fastNLP.io.Conll2003Loader` :class:`fastNLP.io.data_loader.Conll2003Loader` | |||
读取Conll2003数据 | |||
关于数据集的更多信息,参考: | |||
https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data | |||
""" | |||
def __init__(self): | |||
headers = [ | |||
'tokens', 'pos', 'chunks', 'ner', | |||
] | |||
super(Conll2003Loader, self).__init__(headers=headers) |
@@ -1,99 +0,0 @@ | |||
from typing import Union, Dict | |||
from ..embed_loader import EmbeddingOption, EmbedLoader | |||
from ..base_loader import DataSetLoader, DataBundle | |||
from ...core.vocabulary import VocabularyOption, Vocabulary | |||
from ...core.dataset import DataSet | |||
from ...core.instance import Instance | |||
from ...core.const import Const | |||
from ..utils import get_tokenizer | |||
class IMDBLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.IMDBLoader` :class:`fastNLP.io.data_loader.IMDBLoader` | |||
读取IMDB数据集,DataSet包含以下fields: | |||
words: list(str), 需要分类的文本 | |||
target: str, 文本的标签 | |||
""" | |||
def __init__(self): | |||
super(IMDBLoader, self).__init__() | |||
self.tokenizer = get_tokenizer() | |||
def _load(self, path): | |||
dataset = DataSet() | |||
with open(path, 'r', encoding="utf-8") as f: | |||
for line in f: | |||
line = line.strip() | |||
if not line: | |||
continue | |||
parts = line.split('\t') | |||
target = parts[0] | |||
words = self.tokenizer(parts[1].lower()) | |||
dataset.append(Instance(words=words, target=target)) | |||
if len(dataset) == 0: | |||
raise RuntimeError(f"{path} has no valid data.") | |||
return dataset | |||
def process(self, | |||
paths: Union[str, Dict[str, str]], | |||
src_vocab_opt: VocabularyOption = None, | |||
tgt_vocab_opt: VocabularyOption = None, | |||
char_level_op=False): | |||
datasets = {} | |||
info = DataBundle() | |||
for name, path in paths.items(): | |||
dataset = self.load(path) | |||
datasets[name] = dataset | |||
def wordtochar(words): | |||
chars = [] | |||
for word in words: | |||
word = word.lower() | |||
for char in word: | |||
chars.append(char) | |||
chars.append('') | |||
chars.pop() | |||
return chars | |||
if char_level_op: | |||
for dataset in datasets.values(): | |||
dataset.apply_field(wordtochar, field_name="words", new_field_name='chars') | |||
datasets["train"], datasets["dev"] = datasets["train"].split(0.1, shuffle=False) | |||
src_vocab = Vocabulary() if src_vocab_opt is None else Vocabulary(**src_vocab_opt) | |||
src_vocab.from_dataset(datasets['train'], field_name='words') | |||
src_vocab.index_dataset(*datasets.values(), field_name='words') | |||
tgt_vocab = Vocabulary(unknown=None, padding=None) \ | |||
if tgt_vocab_opt is None else Vocabulary(**tgt_vocab_opt) | |||
tgt_vocab.from_dataset(datasets['train'], field_name='target') | |||
tgt_vocab.index_dataset(*datasets.values(), field_name='target') | |||
info.vocabs = { | |||
Const.INPUT: src_vocab, | |||
Const.TARGET: tgt_vocab | |||
} | |||
info.datasets = datasets | |||
for name, dataset in info.datasets.items(): | |||
dataset.set_input(Const.INPUT) | |||
dataset.set_target(Const.TARGET) | |||
return info | |||
@@ -1,248 +0,0 @@ | |||
import os | |||
from typing import Union, Dict, List | |||
from ...core.const import Const | |||
from ...core.vocabulary import Vocabulary | |||
from ..base_loader import DataBundle, DataSetLoader | |||
from ..file_utils import _get_base_url, cached_path, PRETRAINED_BERT_MODEL_DIR | |||
from ...modules.encoder.bert import BertTokenizer | |||
class MatchingLoader(DataSetLoader): | |||
""" | |||
别名::class:`fastNLP.io.MatchingLoader` :class:`fastNLP.io.data_loader.MatchingLoader` | |||
读取Matching任务的数据集 | |||
:param dict paths: key是数据集名称(如train、dev、test),value是对应的文件名 | |||
""" | |||
def __init__(self, paths: dict=None): | |||
self.paths = paths | |||
def _load(self, path): | |||
""" | |||
:param str path: 待读取数据集的路径名 | |||
:return: fastNLP.DataSet ds: 返回一个DataSet对象,里面必须包含3个field:其中两个分别为两个句子 | |||
的原始字符串文本,第三个为标签 | |||
""" | |||
raise NotImplementedError | |||
def process(self, paths: Union[str, Dict[str, str]], dataset_name: str=None, | |||
to_lower=False, seq_len_type: str=None, bert_tokenizer: str=None, | |||
cut_text: int = None, get_index=True, auto_pad_length: int=None, | |||
auto_pad_token: str='<pad>', set_input: Union[list, str, bool]=True, | |||
set_target: Union[list, str, bool]=True, concat: Union[str, list, bool]=None, | |||
extra_split: List[str]=None, ) -> DataBundle: | |||
""" | |||
:param paths: str或者Dict[str, str]。如果是str,则为数据集所在的文件夹或者是全路径文件名:如果是文件夹, | |||
则会从self.paths里面找对应的数据集名称与文件名。如果是Dict,则为数据集名称(如train、dev、test)和 | |||
对应的全路径文件名。 | |||
:param str dataset_name: 如果在paths里传入的是一个数据集的全路径文件名,那么可以用dataset_name来定义 | |||
这个数据集的名字,如果不定义则默认为train。 | |||
:param bool to_lower: 是否将文本自动转为小写。默认值为False。 | |||
:param str seq_len_type: 提供的seq_len类型,支持 ``seq_len`` :提供一个数字作为句子长度; ``mask`` : | |||
提供一个0/1的mask矩阵作为句子长度; ``bert`` :提供segment_type_id(第一个句子为0,第二个句子为1)和 | |||
attention mask矩阵(0/1的mask矩阵)。默认值为None,即不提供seq_len | |||
:param str bert_tokenizer: bert tokenizer所使用的词表所在的文件夹路径 | |||
:param int cut_text: 将长于cut_text的内容截掉。默认为None,即不截。 | |||
:param bool get_index: 是否需要根据词表将文本转为index | |||
:param int auto_pad_length: 是否需要将文本自动pad到一定长度(超过这个长度的文本将会被截掉),默认为不会自动pad | |||
:param str auto_pad_token: 自动pad的内容 | |||
:param set_input: 如果为True,则会自动将相关的field(名字里含有Const.INPUT的)设置为input,如果为False | |||
则不会将任何field设置为input。如果传入str或者List[str],则会根据传入的内容将相对应的field设置为input, | |||
于此同时其他field不会被设置为input。默认值为True。 | |||
:param set_target: set_target将控制哪些field可以被设置为target,用法与set_input一致。默认值为True。 | |||
:param concat: 是否需要将两个句子拼接起来。如果为False则不会拼接。如果为True则会在两个句子之间插入一个<sep>。 | |||
如果传入一个长度为4的list,则分别表示插在第一句开始前、第一句结束后、第二句开始前、第二句结束后的标识符。如果 | |||
传入字符串 ``bert`` ,则会采用bert的拼接方式,等价于['[CLS]', '[SEP]', '', '[SEP]']. | |||
:param extra_split: 额外的分隔符,即除了空格之外的用于分词的字符。 | |||
:return: | |||
""" | |||
if isinstance(set_input, str): | |||
set_input = [set_input] | |||
if isinstance(set_target, str): | |||
set_target = [set_target] | |||
if isinstance(set_input, bool): | |||
auto_set_input = set_input | |||
else: | |||
auto_set_input = False | |||
if isinstance(set_target, bool): | |||
auto_set_target = set_target | |||
else: | |||
auto_set_target = False | |||
if isinstance(paths, str): | |||
if os.path.isdir(paths): | |||
path = {n: os.path.join(paths, self.paths[n]) for n in self.paths.keys()} | |||
else: | |||
path = {dataset_name if dataset_name is not None else 'train': paths} | |||
else: | |||
path = paths | |||
data_info = DataBundle() | |||
for data_name in path.keys(): | |||
data_info.datasets[data_name] = self._load(path[data_name]) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if auto_set_input: | |||
data_set.set_input(Const.INPUTS(0), Const.INPUTS(1)) | |||
if auto_set_target: | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.set_target(Const.TARGET) | |||
if extra_split is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(0)]), new_field_name=Const.INPUTS(0)) | |||
data_set.apply(lambda x: ' '.join(x[Const.INPUTS(1)]), new_field_name=Const.INPUTS(1)) | |||
for s in extra_split: | |||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||
new_field_name=Const.INPUTS(0)) | |||
data_set.apply(lambda x: x[Const.INPUTS(0)].replace(s, ' ' + s + ' '), | |||
new_field_name=Const.INPUTS(0)) | |||
_filt = lambda x: x | |||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(0)].split(' '))), | |||
new_field_name=Const.INPUTS(0), is_input=auto_set_input) | |||
data_set.apply(lambda x: list(filter(_filt, x[Const.INPUTS(1)].split(' '))), | |||
new_field_name=Const.INPUTS(1), is_input=auto_set_input) | |||
_filt = None | |||
if to_lower: | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(0)]], new_field_name=Const.INPUTS(0), | |||
is_input=auto_set_input) | |||
data_set.apply(lambda x: [w.lower() for w in x[Const.INPUTS(1)]], new_field_name=Const.INPUTS(1), | |||
is_input=auto_set_input) | |||
if bert_tokenizer is not None: | |||
if bert_tokenizer.lower() in PRETRAINED_BERT_MODEL_DIR: | |||
PRETRAIN_URL = _get_base_url('bert') | |||
model_name = PRETRAINED_BERT_MODEL_DIR[bert_tokenizer] | |||
model_url = PRETRAIN_URL + model_name | |||
model_dir = cached_path(model_url) | |||
# 检查是否存在 | |||
elif os.path.isdir(bert_tokenizer): | |||
model_dir = bert_tokenizer | |||
else: | |||
raise ValueError(f"Cannot recognize BERT tokenizer from {bert_tokenizer}.") | |||
words_vocab = Vocabulary(padding='[PAD]', unknown='[UNK]') | |||
with open(os.path.join(model_dir, 'vocab.txt'), 'r') as f: | |||
lines = f.readlines() | |||
lines = [line.strip() for line in lines] | |||
words_vocab.add_word_lst(lines) | |||
words_vocab.build_vocab() | |||
tokenizer = BertTokenizer.from_pretrained(model_dir) | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: tokenizer.tokenize(' '.join(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
if isinstance(concat, bool): | |||
concat = 'default' if concat else None | |||
if concat is not None: | |||
if isinstance(concat, str): | |||
CONCAT_MAP = {'bert': ['[CLS]', '[SEP]', '', '[SEP]'], | |||
'default': ['', '<sep>', '', '']} | |||
if concat.lower() in CONCAT_MAP: | |||
concat = CONCAT_MAP[concat] | |||
else: | |||
concat = 4 * [concat] | |||
assert len(concat) == 4, \ | |||
f'Please choose a list with 4 symbols which at the beginning of first sentence ' \ | |||
f'the end of first sentence, the begin of second sentence, and the end of second' \ | |||
f'sentence. Your input is {concat}' | |||
for data_name, data_set in data_info.datasets.items(): | |||
data_set.apply(lambda x: [concat[0]] + x[Const.INPUTS(0)] + [concat[1]] + [concat[2]] + | |||
x[Const.INPUTS(1)] + [concat[3]], new_field_name=Const.INPUT) | |||
data_set.apply(lambda x: [w for w in x[Const.INPUT] if len(w) > 0], new_field_name=Const.INPUT, | |||
is_input=auto_set_input) | |||
if seq_len_type is not None: | |||
if seq_len_type == 'seq_len': # | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'mask': | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [1] * len(x[fields]), | |||
new_field_name=fields.replace(Const.INPUT, Const.INPUT_LEN), | |||
is_input=auto_set_input) | |||
elif seq_len_type == 'bert': | |||
for data_name, data_set in data_info.datasets.items(): | |||
if Const.INPUT not in data_set.get_field_names(): | |||
raise KeyError(f'Field ``{Const.INPUT}`` not in {data_name} data set: ' | |||
f'got {data_set.get_field_names()}') | |||
data_set.apply(lambda x: [0] * (len(x[Const.INPUTS(0)]) + 2) + [1] * (len(x[Const.INPUTS(1)]) + 1), | |||
new_field_name=Const.INPUT_LENS(0), is_input=auto_set_input) | |||
data_set.apply(lambda x: [1] * len(x[Const.INPUT_LENS(0)]), | |||
new_field_name=Const.INPUT_LENS(1), is_input=auto_set_input) | |||
if auto_pad_length is not None: | |||
cut_text = min(auto_pad_length, cut_text if cut_text is not None else auto_pad_length) | |||
if cut_text is not None: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if (Const.INPUT in fields) or ((Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len')): | |||
data_set.apply(lambda x: x[fields][: cut_text], new_field_name=fields, | |||
is_input=auto_set_input) | |||
data_set_list = [d for n, d in data_info.datasets.items()] | |||
assert len(data_set_list) > 0, f'There are NO data sets in data info!' | |||
if bert_tokenizer is None: | |||
words_vocab = Vocabulary(padding=auto_pad_token) | |||
words_vocab = words_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=[n for n in data_set_list[0].get_field_names() | |||
if (Const.INPUT in n)], | |||
no_create_entry_dataset=[d for n, d in data_info.datasets.items() | |||
if 'train' not in n]) | |||
target_vocab = Vocabulary(padding=None, unknown=None) | |||
target_vocab = target_vocab.from_dataset(*[d for n, d in data_info.datasets.items() if 'train' in n], | |||
field_name=Const.TARGET) | |||
data_info.vocabs = {Const.INPUT: words_vocab, Const.TARGET: target_vocab} | |||
if get_index: | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: [words_vocab.to_index(w) for w in x[fields]], new_field_name=fields, | |||
is_input=auto_set_input) | |||
if Const.TARGET in data_set.get_field_names(): | |||
data_set.apply(lambda x: target_vocab.to_index(x[Const.TARGET]), new_field_name=Const.TARGET, | |||
is_input=auto_set_input, is_target=auto_set_target) | |||
if auto_pad_length is not None: | |||
if seq_len_type == 'seq_len': | |||
raise RuntimeError(f'the sequence will be padded with the length {auto_pad_length}, ' | |||
f'so the seq_len_type cannot be `{seq_len_type}`!') | |||
for data_name, data_set in data_info.datasets.items(): | |||
for fields in data_set.get_field_names(): | |||
if Const.INPUT in fields: | |||
data_set.apply(lambda x: x[fields] + [words_vocab.to_index(words_vocab.padding)] * | |||
(auto_pad_length - len(x[fields])), new_field_name=fields, | |||
is_input=auto_set_input) | |||
elif (Const.INPUT_LEN in fields) and (seq_len_type != 'seq_len'): | |||
data_set.apply(lambda x: x[fields] + [0] * (auto_pad_length - len(x[fields])), | |||
new_field_name=fields, is_input=auto_set_input) | |||
for data_name, data_set in data_info.datasets.items(): | |||
if isinstance(set_input, list): | |||
data_set.set_input(*[inputs for inputs in set_input if inputs in data_set.get_field_names()]) | |||
if isinstance(set_target, list): | |||
data_set.set_target(*[target for target in set_target if target in data_set.get_field_names()]) | |||
return data_info |
@@ -1,62 +0,0 @@ | |||
from ...core.const import Const | |||
from .matching import MatchingLoader | |||
from ..dataset_loader import CSVLoader | |||
class MNLILoader(MatchingLoader, CSVLoader): | |||
""" | |||
别名::class:`fastNLP.io.MNLILoader` :class:`fastNLP.io.data_loader.MNLILoader` | |||
读取MNLI数据集,读取的DataSet包含fields:: | |||
words1: list(str),第一句文本, premise | |||
words2: list(str), 第二句文本, hypothesis | |||
target: str, 真实标签 | |||
数据来源: | |||
""" | |||
def __init__(self, paths: dict=None): | |||
paths = paths if paths is not None else { | |||
'train': 'train.tsv', | |||
'dev_matched': 'dev_matched.tsv', | |||
'dev_mismatched': 'dev_mismatched.tsv', | |||
'test_matched': 'test_matched.tsv', | |||
'test_mismatched': 'test_mismatched.tsv', | |||
# 'test_0.9_matched': 'multinli_0.9_test_matched_unlabeled.txt', | |||
# 'test_0.9_mismatched': 'multinli_0.9_test_mismatched_unlabeled.txt', | |||
# test_0.9_mathed与mismatched是MNLI0.9版本的(数据来源:kaggle) | |||
} | |||
MatchingLoader.__init__(self, paths=paths) | |||
CSVLoader.__init__(self, sep='\t') | |||
self.fields = { | |||
'sentence1_binary_parse': Const.INPUTS(0), | |||
'sentence2_binary_parse': Const.INPUTS(1), | |||
'gold_label': Const.TARGET, | |||
} | |||
def _load(self, path): | |||
ds = CSVLoader._load(self, path) | |||
for k, v in self.fields.items(): | |||
if k in ds.get_field_names(): | |||
ds.rename_field(k, v) | |||
if Const.TARGET in ds.get_field_names(): | |||
if ds[0][Const.TARGET] == 'hidden': | |||
ds.delete_field(Const.TARGET) | |||
parentheses_table = str.maketrans({'(': None, ')': None}) | |||
ds.apply(lambda ins: ins[Const.INPUTS(0)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(0)) | |||
ds.apply(lambda ins: ins[Const.INPUTS(1)].translate(parentheses_table).strip().split(), | |||
new_field_name=Const.INPUTS(1)) | |||
if Const.TARGET in ds.get_field_names(): | |||
ds.drop(lambda x: x[Const.TARGET] == '-') | |||
return ds |