diff --git a/README.md b/README.md
index 2fd27048..74090646 100644
--- a/README.md
+++ b/README.md
@@ -6,4 +6,133 @@
![Hex.pm](https://img.shields.io/hexpm/l/plug.svg)
[![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest)
-dev0.8.0正在开发中
\ No newline at end of file
+fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是快速实现NLP任务以及构建复杂模型。
+
+fastNLP具有如下的特性:
+
+- 统一的Tabular式数据容器,简化数据预处理过程;
+- 内置多种数据集的Loader和Pipe,省去预处理代码;
+- 各种方便的NLP工具,例如Embedding加载(包括ELMo和BERT)、中间数据cache等;
+- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载;
+- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务);
+- Trainer提供多种内置Callback函数,方便实验记录、异常捕获等。
+
+## 安装指南
+
+fastNLP 依赖以下包:
+
++ numpy>=1.14.2
++ torch>=1.0.0
++ tqdm>=4.28.1
++ nltk>=3.4.1
++ requests
++ spacy
++ prettytable>=0.7.2
+
+其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。
+在依赖包安装完成后,您可以在命令行执行如下指令完成安装
+
+```shell
+pip install fastNLP
+python -m spacy download en
+```
+
+
+## fastNLP教程
+中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html)
+
+### 快速入门
+
+- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html)
+
+### 详细使用教程
+
+- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html)
+- [2. 使用Vocabulary转换文本与index](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_vocabulary.html)
+- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html)
+- [4. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html)
+- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_loss_optimizer.html)
+- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_datasetiter.html)
+- [7. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_metrics.html)
+- [8. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_modules_models.html)
+- [9. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_seq_labeling.html)
+- [10. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_callback.html)
+
+### 扩展教程
+
+- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html)
+- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html)
+- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html)
+
+
+## 内置组件
+
+大部分用于的 NLP 任务神经网络都可以看做由词嵌入(embeddings)和两种模块:编码器(encoder)、解码器(decoder)组成。
+
+以文本分类任务为例,下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图:
+
+
+![](./docs/source/figures/text_classification.png)
+
+fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedding(GloVe、word2vec)、上下文相关embedding
+(ELMo、BERT)、字符embedding(基于CNN或者LSTM的CharEmbedding)
+
+与此同时,fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下:
+
+
+
+ 类型 |
+ 功能 |
+ 例子 |
+
+
+ encoder |
+ 将输入编码为具有具有表示能力的向量 |
+ Embedding, RNN, CNN, Transformer, ...
+ |
+
+ decoder |
+ 将具有某种表示意义的向量解码为需要的输出形式 |
+ MLP, CRF, ... |
+
+
+
+
+## 项目结构
+
+
+
+
+
+fastNLP的大致工作流程如上图所示,而项目结构如下:
+
+
+
+ fastNLP |
+ 开源的自然语言处理库 |
+
+
+ fastNLP.core |
+ 实现了核心功能,包括数据处理组件、训练器、测试器等 |
+
+
+ fastNLP.models |
+ 实现了一些完整的神经网络模型 |
+
+
+ fastNLP.modules |
+ 实现了用于搭建神经网络模型的诸多组件 |
+
+
+ fastNLP.embeddings |
+ 实现了将序列index转为向量序列的功能,包括读取预训练embedding等 |
+
+
+ fastNLP.io |
+ 实现了读写功能,包括数据读入与预处理,模型读写,数据与模型自动下载等 |
+
+
+
+
+
+*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!*
diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py
index 0fd74795..d2d548f5 100644
--- a/fastNLP/core/drivers/paddle_driver/fleet.py
+++ b/fastNLP/core/drivers/paddle_driver/fleet.py
@@ -19,7 +19,7 @@ from fastNLP.core.utils import (
paddle_move_data_to_device,
is_in_paddle_dist,
)
-from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler
+from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler
from fastNLP.envs.env import FASTNLP_DISTRIBUTED_CHECK, USER_CUDA_VISIBLE_DEVICES
from fastNLP.core.log import logger
@@ -362,7 +362,7 @@ class PaddleFleetDriver(PaddleDriver):
return dataloader
# evaluator
elif dist == "unrepeatdist":
- sampler = UnrepeatedDistributedSampler(
+ sampler = UnrepeatedSampler(
dataset=dataloader.dataset,
shuffle=shuffle,
seed=int(os.environ.get("FASTNLP_SEED", 0))
diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py
index 2d393dab..9e5e16fd 100644
--- a/fastNLP/core/drivers/torch_driver/ddp.py
+++ b/fastNLP/core/drivers/torch_driver/ddp.py
@@ -28,7 +28,7 @@ from fastNLP.core.drivers.torch_driver.utils import (
)
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.utils import auto_param_call, check_user_specific_params
-from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedDistributedSampler, ReproducibleBatchSampler
+from fastNLP.core.samplers import ReproducibleIterator, RandomSampler, UnrepeatedSampler, ReproducibleBatchSampler
from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED
from fastNLP.core.log import logger
from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object
@@ -507,7 +507,7 @@ class TorchDDPDriver(TorchDriver):
args = self.get_dataloader_args(dataloader)
# todo 判断 batch_sampler;
- sampler = UnrepeatedDistributedSampler(
+ sampler = UnrepeatedSampler(
dataset=args.dataset,
shuffle=args.shuffle,
)
diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py
index 68928b66..bb2ee661 100644
--- a/fastNLP/core/samplers/__init__.py
+++ b/fastNLP/core/samplers/__init__.py
@@ -3,19 +3,24 @@ __all__ = [
'SortedSampler',
'ConstTokenNumSampler',
'ConstantTokenNumSampler',
- 'UnrepeatedDistributedSampler',
+
'MixSampler',
- 'InnerSampler',
'DopedSampler',
'MixSequentialSampler',
'PollingSampler',
+
'ReproducibleIterator',
'RandomSampler',
- 're_instantiate_sampler'
+
+ 're_instantiate_sampler',
+
+ 'UnrepeatedSampler',
+ "UnrepeatedSortedSampler"
]
-from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler
-from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler
+from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
+from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedSortedSampler
+from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler
from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler
diff --git a/fastNLP/core/samplers/mix_sampler.py b/fastNLP/core/samplers/mix_sampler.py
index e219b6e2..f53c06a5 100644
--- a/fastNLP/core/samplers/mix_sampler.py
+++ b/fastNLP/core/samplers/mix_sampler.py
@@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict
__all__ = [
'MixSampler',
- 'InnerSampler',
'DopedSampler',
'MixSequentialSampler',
'PollingSampler'
diff --git a/fastNLP/core/samplers/sampler.py b/fastNLP/core/samplers/sampler.py
index e41472bf..89751884 100644
--- a/fastNLP/core/samplers/sampler.py
+++ b/fastNLP/core/samplers/sampler.py
@@ -7,7 +7,6 @@ __all__ = [
"SortedSampler",
'ConstTokenNumSampler',
"ConstantTokenNumSampler",
- "UnrepeatedDistributedSampler",
]
from itertools import chain
@@ -18,7 +17,7 @@ import numpy as np
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
- from torch.utils.data import SequentialSampler, Sampler, RandomSampler
+ from torch.utils.data import Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Sampler
@@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets):
if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]:
bucket_data[bucket_id].append(idx)
return bucket_data
-
-
-class UnrepeatedDistributedSampler:
- def __init__(self, dataset, shuffle: bool = False, seed: int = 0):
- """
- 考虑在多卡evaluate的场景下,不能重复sample。
-
- :param dataset:
- :param shuffle:
- :param seed:
- """
- self.dataset = dataset
- self.shuffle = shuffle
- self.seed = seed
-
- # 多卡的相关的参数
- self.num_replicas = 1
- self.rank = 0
- self.epoch = -1
-
- def __len__(self):
- """
- 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank;
- :return:
- """
- num_common = len(self.dataset)//self.num_replicas
- self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas))
- return self.num_samples
-
- def __iter__(self):
- r"""
- 当前使用num_consumed_samples做法会在交替使用的时候遇到问题;
- Example:
- >>> sampler = RandomSampler()
- >>> iter1 = iter(sampler)
- >>> iter2 = iter(sampler)
- >>> next(iter1)
- >>> next(iter2) # 当前num_consumed_samples的数量会发生变化
- """
-
- indices = self.generate_indices()
-
- # subsample
- indices = indices[self.rank:len(indices):self.num_replicas]
- assert len(indices) == len(self)
-
- for index in indices:
- yield index
-
- def generate_indices(self) -> List[int]:
- """
- 生成随机序列
-
- :return:
- """
- if self.shuffle:
- indices = list(range(len(self.dataset)))
- seed = self.seed + self.epoch
- rng = np.random.default_rng(abs(seed))
- rng.shuffle(indices)
- if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
- self.epoch -= 1
- else:
- indices = list(range(len(self.dataset)))
- return indices
-
- def set_epoch(self, epoch: int) -> None:
- self.epoch = epoch
-
- def set_distributed(self, num_replicas, rank):
- """
- 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;
-
- :param num_replicas:
- :param rank:
- :return:
- """
- assert num_replicas>0 and isinstance(num_replicas, int)
- assert isinstance(rank, int) and 0<=rank List[int]:
+ """
+ 生成随机序列
+
+ :return:
+ """
+ if self.shuffle:
+ indices = list(range(len(self.dataset)))
+ seed = self.seed + self.epoch
+ rng = np.random.default_rng(abs(seed))
+ rng.shuffle(indices)
+ if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。
+ self.epoch -= 1
+ else:
+ indices = list(range(len(self.dataset)))
+ return indices
+
+ def set_epoch(self, epoch: int) -> None:
+ self.epoch = epoch
+
+ def set_distributed(self, num_replicas, rank):
+ """
+ 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用;
+
+ :param num_replicas:
+ :param rank:
+ :return:
+ """
+ assert num_replicas>0 and isinstance(num_replicas, int)
+ assert isinstance(rank, int) and 0<=rank List[int]:
+ return self.sorted_indices
diff --git a/tests/core/metrics/test_span_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py
index 5908663a..bc711a54 100644
--- a/tests/core/metrics/test_span_f1_rec_acc_torch.py
+++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py
@@ -1,5 +1,5 @@
import pytest
-import unittest
+
from collections import Counter
import os, sys
import copy
diff --git a/tests/core/samplers/test_unrepeated_sampler.py b/tests/core/samplers/test_unrepeated_sampler.py
new file mode 100644
index 00000000..3e2f79ed
--- /dev/null
+++ b/tests/core/samplers/test_unrepeated_sampler.py
@@ -0,0 +1,64 @@
+from itertools import chain
+
+import pytest
+
+from fastNLP.core.samplers import UnrepeatedSampler, UnrepeatedSortedSampler
+
+
+class DatasetWithVaryLength:
+ def __init__(self, num_of_data=100):
+ self.data = list(range(num_of_data))
+
+ def __getitem__(self, item):
+ return self.data[item]
+
+ def __len__(self):
+ return len(self.data)
+
+
+class TestUnrepeatedSampler:
+ @pytest.mark.parametrize('shuffle', [True, False])
+ def test_single(self, shuffle):
+ num_of_data = 100
+ data = DatasetWithVaryLength(num_of_data)
+ sampler = UnrepeatedSampler(data, shuffle)
+ indexes = set(sampler)
+ assert indexes==set(range(num_of_data))
+
+ @pytest.mark.parametrize('num_replica', [2, 3])
+ @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
+ @pytest.mark.parametrize('shuffle', [False, True])
+ def test_multi(self, num_replica, num_of_data, shuffle):
+ data = DatasetWithVaryLength(num_of_data=num_of_data)
+ samplers = []
+ for i in range(num_replica):
+ sampler = UnrepeatedSampler(dataset=data, shuffle=shuffle)
+ sampler.set_distributed(num_replica, rank=i)
+ samplers.append(sampler)
+
+ indexes = set(chain(*samplers))
+ assert indexes==set(range(num_of_data))
+
+
+class TestUnrepeatedSortedSampler:
+ @pytest.mark.parametrize('shuffle', [True, False])
+ def test_single(self, shuffle):
+ num_of_data = 100
+ data = DatasetWithVaryLength(num_of_data)
+ sampler = UnrepeatedSortedSampler(data, length=data.data)
+ indexes = list(sampler)
+ assert indexes==list(range(num_of_data-1, -1, -1))
+
+ @pytest.mark.parametrize('num_replica', [2, 3])
+ @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
+ @pytest.mark.parametrize('shuffle', [False, True])
+ def test_multi(self, num_replica, num_of_data, shuffle):
+ data = DatasetWithVaryLength(num_of_data=num_of_data)
+ samplers = []
+ for i in range(num_replica):
+ sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
+ sampler.set_distributed(num_replica, rank=i)
+ samplers.append(sampler)
+
+ indexes = set(chain(*samplers))
+ assert indexes==set(range(num_of_data))