diff --git a/README.md b/README.md
index 74090646..9c03fe30 100644
--- a/README.md
+++ b/README.md
@@ -1,110 +1,208 @@
# fastNLP
-[![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP)
-[![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP)
-[![Pypi](https://img.shields.io/pypi/v/fastNLP.svg)](https://pypi.org/project/fastNLP)
-![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
-[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
-fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是快速实现NLP任务以及构建复杂模型。
+[//]: # ([![Build Status](https://travis-ci.org/fastnlp/fastNLP.svg?branch=master)](https://travis-ci.org/fastnlp/fastNLP))
-fastNLP具有如下的特性:
-
-- 统一的Tabular式数据容器,简化数据预处理过程;
-- 内置多种数据集的Loader和Pipe,省去预处理代码;
-- 各种方便的NLP工具,例如Embedding加载(包括ELMo和BERT)、中间数据cache等;
-- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载;
-- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务);
-- Trainer提供多种内置Callback函数,方便实验记录、异常捕获等。
-
-## 安装指南
-
-fastNLP 依赖以下包:
+[//]: # ([![codecov](https://codecov.io/gh/fastnlp/fastNLP/branch/master/graph/badge.svg)](https://codecov.io/gh/fastnlp/fastNLP))
-+ numpy>=1.14.2
-+ torch>=1.0.0
-+ tqdm>=4.28.1
-+ nltk>=3.4.1
-+ requests
-+ spacy
-+ prettytable>=0.7.2
+[//]: # ([![Pypi](https://img.shields.io/pypi/v/fastNLP.svg)](https://pypi.org/project/fastNLP))
-其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
-在依赖包安装完成后,您可以在命令行执行如下指令完成安装
-
-```shell
-pip install fastNLP
-python -m spacy download en
-```
+[//]: # (![Hex.pm](https://img.shields.io/hexpm/l/plug.svg))
+[//]: # ([![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest))
-## fastNLP教程
-中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)
-### 快速入门
+fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是减少用户项目中的工程型代码,例如数据处理循环、训练循环、多卡运行等。
-- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
-
-### 详细使用教程
-
-- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
-- [2. 使用Vocabulary转换文本与index](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_vocabulary.html)
-- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
-- [4. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)
-- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_loss_optimizer.html)
-- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_datasetiter.html)
-- [7. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_metrics.html)
-- [8. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_modules_models.html)
-- [9. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_seq_labeling.html)
-- [10. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_callback.html)
-
-### 扩展教程
-
-- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
-- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
-- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)
+fastNLP具有如下的特性:
+- 便捷。在数据处理中可以通过apply函数避免循环、使用多进程提速等;在训练循环阶段可以很方便定制操作。
+- 高效。无需改动代码,实现fp16切换、多卡、ZeRO优化等。
+- 兼容。fastNLP支持多种深度学习框架作为后端。
-## 内置组件
+> :warning: **为了实现对不同深度学习架构的兼容,fastNLP 1.0.0之后的版本重新设计了架构,因此与过去的fastNLP版本不完全兼容,
+> 基于更早的fastNLP代码需要做一定的调整**:
-大部分用于的 NLP 任务神经网络都可以看做由词嵌入(embeddings)和两种模块:编码器(encoder)、解码器(decoder)组成。
+## 安装指南
+fastNLP可以通过以下的命令进行安装
+```shell
+pip install fastNLP
+```
+如果需要安装更早版本的fastNLP请指定版本号,例如
+```shell
+pip install fastNLP==0.7.1
+```
+另外,请根据使用的深度学习框架,安装相应的深度学习框架。
+
+
+Pytorch
+下面是使用pytorch来进行文本分类的例子。需要安装torch>=1.6.0。
+
+```python
+from fastNLP.io import ChnSentiCorpLoader
+from functools import partial
+from fastNLP import cache_results
+from fastNLP.transformers.torch import BertTokenizer
+
+# 使用cache_results装饰器装饰函数,将prepare_data的返回结果缓存到caches/cache.pkl,再次运行时,如果
+# 该文件还存在,将自动读取缓存文件,而不再次运行预处理代码。
+@cache_results('caches/cache.pkl')
+def prepare_data():
+ # 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
+ data_bundle = ChnSentiCorpLoader().load()
+ # 使用tokenizer对数据进行tokenize
+ tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm')
+ tokenize = partial(tokenizer, max_length=256) # 限制数据的最大长度
+ data_bundle.apply_field_more(tokenize, field_name='raw_chars', num_proc=4) # 会新增"input_ids", "attention_mask"等field进入dataset中
+ data_bundle.apply_field(int, field_name='target', new_field_name='labels') # 将int函数应用到每个target上,并且放入新的labels field中
+ return data_bundle
+data_bundle = prepare_data()
+print(data_bundle.get_dataset('train')[:4])
+
+# 初始化model, optimizer
+from fastNLP.transformers.torch import BertForSequenceClassification
+from torch import optim
+model = BertForSequenceClassification.from_pretrained('hfl/chinese-bert-wwm')
+optimizer = optim.AdamW(model.parameters(), lr=2e-5)
+
+# 准备dataloader
+from fastNLP import prepare_dataloader
+dls = prepare_dataloader(data_bundle, batch_size=32)
+
+# 准备训练
+from fastNLP import Trainer, Accuracy, LoadBestModelCallback, TorchWarmupCallback, Event
+callbacks = [
+ TorchWarmupCallback(warmup=0.1, schedule='linear'), # 训练过程中调整学习率。
+ LoadBestModelCallback() # 将在训练结束之后,加载性能最优的model
+]
+# 在训练特定时机加入一些操作, 不同时机能够获取到的参数不一样,可以通过Trainer.on函数的文档查看每个时机的参数
+@Trainer.on(Event.on_before_backward())
+def print_loss(trainer, outputs):
+ if trainer.global_forward_batches % 10 == 0: # 每10个batch打印一次loss。
+ print(outputs.loss.item())
+
+trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer,
+ device=0, evaluate_dataloaders=dls['dev'], metrics={'acc': Accuracy()},
+ callbacks=callbacks, monitor='acc#acc',n_epochs=5,
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ evaluate_input_mapping={'labels': 'target'}, # 在评测时,将dataloader中会输入到模型的labels重新命名为target
+ evaluate_output_mapping={'logits': 'pred'} # 在评测时,将model输出中的logits重新命名为pred
+ )
+trainer.run()
+
+# 在测试集合上进行评测
+from fastNLP import Evaluator
+evaluator = Evaluator(model=model, dataloaders=dls['test'], metrics={'acc': Accuracy()},
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ output_mapping={'logits': 'pred'},
+ input_mapping={'labels': 'target'})
+evaluator.run()
+```
-以文本分类任务为例,下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图:
+
+
+
+Paddle
+下面是使用paddle来进行文本分类的例子。需要安装paddle>=2.2.0以及paddlenlp>=2.3.3。
+
+```python
+from fastNLP.io import ChnSentiCorpLoader
+from functools import partial
+
+# 会自动下载 SST2 数据,并且可以通过文档看到返回的 dataset 应该是包含"raw_words"和"target"两个field的
+data_bundle = ChnSentiCorpLoader().load()
+
+# 使用tokenizer对数据进行tokenize
+from paddlenlp.transformers import BertTokenizer
+tokenizer = BertTokenizer.from_pretrained('hfl/chinese-bert-wwm')
+tokenize = partial(tokenizer, max_length=256) # 限制一下最大长度
+data_bundle.apply_field_more(tokenize, field_name='raw_chars', num_proc=4) # 会新增"input_ids", "attention_mask"等field进入dataset中
+data_bundle.apply_field(int, field_name='target', new_field_name='labels') # 将int函数应用到每个target上,并且放入新的labels field中
+print(data_bundle.get_dataset('train')[:4])
+
+# 初始化 model
+from paddlenlp.transformers import BertForSequenceClassification, LinearDecayWithWarmup
+from paddle import optimizer, nn
+class SeqClsModel(nn.Layer):
+ def __init__(self, model_checkpoint, num_labels):
+ super(SeqClsModel, self).__init__()
+ self.num_labels = num_labels
+ self.bert = BertForSequenceClassification.from_pretrained(model_checkpoint)
+
+ def forward(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
+ logits = self.bert(input_ids, token_type_ids, position_ids, attention_mask)
+ return logits
+
+ def train_step(self, input_ids, labels, token_type_ids=None, position_ids=None, attention_mask=None):
+ logits = self(input_ids, token_type_ids, position_ids, attention_mask)
+ loss_fct = nn.CrossEntropyLoss()
+ loss = loss_fct(logits.reshape((-1, self.num_labels)), labels.reshape((-1, )))
+ return {
+ "logits": logits,
+ "loss": loss,
+ }
+
+ def evaluate_step(self, input_ids, token_type_ids=None, position_ids=None, attention_mask=None):
+ logits = self(input_ids, token_type_ids, position_ids, attention_mask)
+ return {
+ "logits": logits,
+ }
+
+model = SeqClsModel('hfl/chinese-bert-wwm', num_labels=2)
+
+# 准备dataloader
+from fastNLP import prepare_dataloader
+dls = prepare_dataloader(data_bundle, batch_size=16)
+
+# 训练过程中调整学习率。
+scheduler = LinearDecayWithWarmup(2e-5, total_steps=20 * len(dls['train']), warmup=0.1)
+optimizer = optimizer.AdamW(parameters=model.parameters(), learning_rate=scheduler)
+
+# 准备训练
+from fastNLP import Trainer, Accuracy, LoadBestModelCallback, Event
+callbacks = [
+ LoadBestModelCallback() # 将在训练结束之后,加载性能最优的model
+]
+# 在训练特定时机加入一些操作, 不同时机能够获取到的参数不一样,可以通过Trainer.on函数的文档查看每个时机的参数
+@Trainer.on(Event.on_before_backward())
+def print_loss(trainer, outputs):
+ if trainer.global_forward_batches % 10 == 0: # 每10个batch打印一次loss。
+ print(outputs["loss"].item())
+
+trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer,
+ device=0, evaluate_dataloaders=dls['dev'], metrics={'acc': Accuracy()},
+ callbacks=callbacks, monitor='acc#acc',
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ evaluate_output_mapping={'logits': 'pred'},
+ evaluate_input_mapping={'labels': 'target'}
+ )
+trainer.run()
+
+# 在测试集合上进行评测
+from fastNLP import Evaluator
+evaluator = Evaluator(model=model, dataloaders=dls['test'], metrics={'acc': Accuracy()},
+ # Accuracy的update()函数需要pred,target两个参数,它们实际对应的就是以下的field。
+ output_mapping={'logits': 'pred'},
+ input_mapping={'labels': 'target'})
+evaluator.run()
+```
+
-![](./docs/source/figures/text_classification.png)
+
+oneflow
+
-fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedding(GloVe、word2vec)、上下文相关embedding
-(ELMo、BERT)、字符embedding(基于CNN或者LSTM的CharEmbedding)
-与此同时,fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
-
-
- 类型 |
- 功能 |
- 例子 |
-
-
- encoder |
- 将输入编码为具有具有表示能力的向量 |
- Embedding, RNN, CNN, Transformer, ...
- |
-
- decoder |
- 将具有某种表示意义的向量解码为需要的输出形式 |
- MLP, CRF, ... |
-
-
+
+jittor
+
## 项目结构
-
-
-
-
-fastNLP的大致工作流程如上图所示,而项目结构如下:
+fastNLP的项目结构如下:
@@ -135,4 +233,3 @@ fastNLP的大致工作流程如上图所示,而项目结构如下:
-*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*
diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py
index 04b35984..26de0a04 100644
--- a/fastNLP/core/callbacks/more_evaluate_callback.py
+++ b/fastNLP/core/callbacks/more_evaluate_callback.py
@@ -127,10 +127,6 @@ class MoreEvaluateCallback(HasMonitorCallback):
assert trainer.evaluator is not None, f"You set `watch_monitor={self.monitor}`, but no " \
f"evaluate_dataloaders is provided in Trainer."
- if trainer.evaluate_fn is self.evaluate_fn:
- logger.warning_once("The `evaluate_fn` is the same as in Trainer, there seems no need to use "
- "`MoreEvaluateCallback`.")
-
# 初始化 evaluator , 同时避免调用 super 对 monitor 赋值
kwargs = {
'model': self.kwargs.get('model', trainer.model),
diff --git a/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
index 24474b64..07c3c612 100644
--- a/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
+++ b/fastNLP/core/callbacks/torch_callbacks/torch_lr_sched_callback.py
@@ -18,7 +18,7 @@ class TorchWarmupCallback(Callback):
1. *linear* -- 前 ``warmup`` 的 step 上升到指定的学习率(从 Trainer 中 optimizer 处获取), 在剩下的 step 中下降到 0;
2. *constant* -- 前 ``warmup`` 的 step 上升到指定的学习率,余下的 step 保持不变。
"""
- def __init__(self, warmup:Union[int, float]=0.1, schedule:str='constant'):
+ def __init__(self, warmup:Union[int, float]=0.1, schedule:str='linear'):
super().__init__()
self.warmup = max(warmup, 0.)
diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py
index 96ba5833..597d13de 100644
--- a/fastNLP/core/controllers/evaluator.py
+++ b/fastNLP/core/controllers/evaluator.py
@@ -42,7 +42,8 @@ class Evaluator:
``Trainer`` 中的参数相同,对于这些参数,您可以参考 ``Trainer`` 的文档来获取更详细的信息;详见 :class:`~fastNLP.core.controllers.trainer.Trainer`;
:param model: 训练所需要的模型,例如 ``torch.nn.Module``,等价于 ``Trainer`` 中的 ``model`` 参数;
- :param dataloaders: 用于评测的数据集。如果为多个,您需要使用 ``dict`` 传入,即对每一个数据集标上用于标识它们的标签;
+ :param dataloaders: 用于评测的数据集。如果为多个,您需要使用 ``dict`` 传入,即对每一个数据集标上用于标识它们的标签;也可以使用 evaluate_dataloaders
+ 作为参数的名称。
:param metrics: 评测时使用的指标。注意该参数必须为 ``dict`` 类型,其中 ``key`` 为一个 ``metric`` 的名称,``value`` 为具体的 ``Metric`` 对象。目前支持以下 metrics:
1. fastNLP 自己的 ``metric``:详见 :class:`~fastNLP.core.metrics.Metric`;
@@ -82,13 +83,14 @@ class Evaluator:
2. 如果为 ``str`` 类型,例如为 ``'my_evaluate_step_fn'``,则尝试寻找 :meth:`model.my_evaluate_step_fn`,如果找不到则直接报错;
:param input_mapping: 等价于 ``Trainer`` 中的 ``input_mapping`` 参数;对具体的用于评测一个 batch 的数据使用 ``input_mapping`` 处理之后再输入到 ``model`` 以及 ``metric`` 中。如果针对
- ``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制;
+ ``model`` 和 ``metric`` 需要不同的 ``mapping``,请考虑使用 ``evaluate_batch_step_fn`` 参数定制;也可以使用 evaluate_input_mapping 参数名传入。
.. todo::
之后链接上 参数匹配 的文档;
:param output_mapping: 等价于 ``Trainer`` 中的 ``output_mapping`` 参数;对 ``model`` 输出的内容,将通过 ``output_mapping`` 处理之后再输入到 ``metric`` 中;
+ 也可以使用 evaluate_output_mapping 参数名传入。
:param model_wo_auto_param_call: 等价于 ``Trainer`` 中的 ``model_wo_auto_param_call`` 参数;
.. note::
@@ -128,7 +130,7 @@ class Evaluator:
driver: Driver
_evaluate_batch_loop: Loop
- def __init__(self, model, dataloaders, metrics: Optional[Dict] = None,
+ def __init__(self, model, dataloaders=None, metrics: Optional[Dict] = None,
driver: Union[str, Driver] = 'auto', device: Optional[Union[int, List[int], str]] = None,
evaluate_batch_step_fn: Optional[Callable] = None, evaluate_fn: Optional[str] = None,
input_mapping: Optional[Union[Callable, Dict]] = None,
@@ -139,6 +141,7 @@ class Evaluator:
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call,
**kwargs)
+ dataloaders = dataloaders if dataloaders is not None else kwargs.get('evaluate_dataloaders')
if dataloaders is None:
raise ValueError("Parameter `dataloaders` can not be None.")
self.dataloaders = dataloaders
@@ -151,8 +154,8 @@ class Evaluator:
_check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn')
self.evaluate_batch_step_fn = evaluate_batch_step_fn
- self.input_mapping = input_mapping
- self.output_mapping = output_mapping
+ self.input_mapping = input_mapping if input_mapping is not None else kwargs.get('evaluate_input_mapping')
+ self.output_mapping = output_mapping if output_mapping is not None else kwargs.get('evaluate_output_mapping')
# check dataloader
if not isinstance(dataloaders, dict):
diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py
index 41bac374..14dad89b 100644
--- a/fastNLP/core/controllers/trainer.py
+++ b/fastNLP/core/controllers/trainer.py
@@ -595,7 +595,7 @@ class Trainer(TrainerEventTrigger):
:param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``,
在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的
- 其余状态都是保持初始化时的状态不会改变;
+ 其余状态都是保持初始化时的状态不会改变。仅当传入了 resume_from 参数时有意义。
:param catch_KeyboardInterrupt: 是否捕获 :class:`KeyboardInterrupt`;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,
``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``
时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True;
diff --git a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py
index 1dd40500..15fea5d4 100644
--- a/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py
+++ b/fastNLP/core/dataloaders/oneflow_dataloader/fdl.py
@@ -287,6 +287,8 @@ def prepare_oneflow_dataloader(ds_or_db,
from fastNLP.io import DataBundle
if isinstance(ds_or_db, DataBundle):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
@@ -313,6 +315,8 @@ def prepare_oneflow_dataloader(ds_or_db,
return dl_bundle
elif isinstance(ds_or_db, Mapping):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
dl_bundle = {}
for name, ds in ds_or_db.items():
if 'train' in name:
diff --git a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py
index 549e6c36..529f23aa 100644
--- a/fastNLP/core/dataloaders/paddle_dataloader/fdl.py
+++ b/fastNLP/core/dataloaders/paddle_dataloader/fdl.py
@@ -320,6 +320,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
from fastNLP.io.data_bundle import DataBundle
if isinstance(ds_or_db, DataBundle):
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
@@ -346,6 +347,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None,
return dl_bundle
elif isinstance(ds_or_db, Dict):
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
ds_dict = {}
for name, ds in ds_or_db.items():
if 'train' in name:
diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py
index 3bba0476..5ae72367 100644
--- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py
+++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py
@@ -287,6 +287,8 @@ def prepare_torch_dataloader(ds_or_db,
from fastNLP.io import DataBundle
if isinstance(ds_or_db, DataBundle):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
dl_bundle = {}
for name, ds in ds_or_db.iter_datasets():
if 'train' in name:
@@ -313,6 +315,8 @@ def prepare_torch_dataloader(ds_or_db,
return dl_bundle
elif isinstance(ds_or_db, Mapping):
+ assert sampler is None, "sampler can only be None when multiple datasets are presented."
+ assert batch_sampler is None, "batch_sampler can only be None when multiple datasets are presented."
dl_bundle = {}
for name, ds in ds_or_db.items():
if 'train' in name:
diff --git a/fastNLP/core/metrics/accuracy.py b/fastNLP/core/metrics/accuracy.py
index 7af788ce..5475046e 100644
--- a/fastNLP/core/metrics/accuracy.py
+++ b/fastNLP/core/metrics/accuracy.py
@@ -34,8 +34,7 @@ class Accuracy(Metric):
:return: 包含以下内容的字典:``{"acc": float, 'total': float, 'correct': float}``;
"""
- evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6),
- 'total': self.total.item(), 'correct': self.correct.item()}
+ evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)}
return evaluate_result
def update(self, pred, target, seq_len=None):
diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py
index 5ec8b9ad..a3c5a722 100644
--- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py
+++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py
@@ -328,7 +328,7 @@ class SpanFPreRecMetric(Metric):
return evaluate_result
- def update(self, pred, target, seq_len: Optional[List] = None) -> None:
+ def update(self, pred, target, seq_len) -> None:
r"""
:meth:`update` 函数将针对一个批次的预测结果做评价指标的累计。
diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py
index 4bcc6e3b..3a6ab650 100644
--- a/fastNLP/core/vocabulary.py
+++ b/fastNLP/core/vocabulary.py
@@ -349,15 +349,18 @@ class Vocabulary(object):
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
- try:
- for f_n, n_f_n in zip(field_name, new_field_name):
- dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n,
- progress_bar=None)
- except Exception as e:
- logger.error("When processing the `{}` dataset, the following error occurred.".format(idx))
- raise e
+ ds_lst = [dataset]
+ elif _is_iterable(dataset):
+ ds_lst = list(dataset)
else:
- raise RuntimeError("Only DataSet type is allowed.")
+ raise TypeError(f"Only DataSet type is allowed, instead of {type(dataset)}.")
+ try:
+ for ds in ds_lst:
+ for f_n, n_f_n in zip(field_name, new_field_name):
+ ds.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n, progress_bar=None)
+ except Exception as e:
+ logger.error("When processing the `{}` dataset, the following error occurred.".format(idx))
+ raise e
return self
@property
@@ -408,13 +411,18 @@ class Vocabulary(object):
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
- try:
- dataset.apply(construct_vocab, progress_bar=None)
- except BaseException as e:
- logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
- raise e
+ ds_lst = [dataset]
+ elif _is_iterable(dataset):
+ ds_lst = list(dataset)
else:
- raise TypeError("Only DataSet type is allowed.")
+ raise TypeError(f"Only DataSet type is allowed, instead of {type(dataset)}.")
+
+ try:
+ for ds in ds_lst:
+ ds.apply(construct_vocab, progress_bar=None)
+ except BaseException as e:
+ logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
+ raise e
if no_create_entry_dataset is not None:
partial_construct_vocab = partial(construct_vocab, no_create_entry=True)
diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py
index 4029e092..a53f00a5 100644
--- a/fastNLP/io/data_bundle.py
+++ b/fastNLP/io/data_bundle.py
@@ -6,7 +6,7 @@ __all__ = [
'DataBundle',
]
-from typing import Union, List, Callable
+from typing import Union, List, Callable, Dict
from ..core.dataset import DataSet
from fastNLP.core.vocabulary import Vocabulary
@@ -34,8 +34,16 @@ class DataBundle:
:param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在
使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入。
"""
- self.vocabs = vocabs or {}
- self.datasets = datasets or {}
+ self._vocabs = vocabs or {}
+ self._datasets = datasets or {}
+
+ @property
+ def datasets(self)->Dict:
+ return self._datasets
+
+ @property
+ def vocabs(self) -> Dict:
+ return self._vocabs
def set_vocab(self, vocab: Vocabulary, field_name: str):
r"""
diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py
index 79806794..5c0b63ce 100644
--- a/fastNLP/io/utils.py
+++ b/fastNLP/io/utils.py
@@ -57,16 +57,16 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]:
raise FileExistsError(f"Two files contain `{path_pair[0]}` were found, please specify the "
f"filepath for `{path_pair[0]}`.")
files[path_pair[0]] = os.path.join(paths, path_pair[1])
- if 'train' not in files:
- raise KeyError(f"There is no train file in {paths}.")
+ # if 'train' not in files:
+ # raise KeyError(f"There is no train file in {paths}.")
return files
else:
raise FileNotFoundError(f"{paths} is not a valid file path.")
elif isinstance(paths, dict):
if paths:
- if 'train' not in paths:
- raise KeyError("You have to include `train` in your dict.")
+ # if 'train' not in paths:
+ # raise KeyError("You have to include `train` in your dict.")
for key, value in paths.items():
if isinstance(key, str) and isinstance(value, str):
value = os.path.abspath(os.path.expanduser(value))
diff --git a/setup.py b/setup.py
index cde5680c..029c8696 100644
--- a/setup.py
+++ b/setup.py
@@ -16,14 +16,14 @@ print(pkgs)
setup(
name='FastNLP',
- version='0.8.0alpha',
+ version='1.0.0alpha',
url='https://gitee.com/fastnlp/fastNLP',
description='fastNLP: Deep Learning Toolkit for NLP, developed by Fudan FastNLP Team',
long_description=readme,
long_description_content_type='text/markdown',
license='Apache License',
author='Fudan FastNLP Team',
- python_requires='>=3.7',
+ python_requires='>=3.6',
packages=pkgs,
install_requires=reqs.strip().split('\n'),
)