diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md
new file mode 100644
index 00000000..9f550edf
--- /dev/null
+++ b/.github/PULL_REQUEST_TEMPLATE.md
@@ -0,0 +1,24 @@
+Description:简要描述这次PR的内容
+
+Main reason: 做出这次修改的原因
+
+
+Checklist 检查下面各项是否完成
+
+Please feel free to remove inapplicable items for your PR.
+
+- [ ] The PR title starts with [$CATEGORY] (例如[bugfix]修复bug,[new]添加新功能,[test]修改测试,[rm]删除旧代码)
+- [ ] Changes are complete (i.e. I finished coding on this PR) 修改完成才提PR
+- [ ] All changes have test coverage 修改的部分顺利通过测试。对于fastnlp/fastnlp/*的修改,测试代码**必须**提供在fastnlp/test/*。
+- [ ] Code is well-documented 注释写好,API文档会从注释中抽取
+- [ ] To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change 修改导致例子或tutorial有变化,请找核心开发人员
+
+Changes: 逐项描述修改的内容
+- 添加了新模型;用于句子分类的CNN,来自Yoon Kim的Convolutional Neural Networks for Sentence Classification
+- 修改dataset.py中过时的和不合规则的注释 #286
+- 添加对var-LSTM的测试代码
+
+Mention: 找人review你的PR
+
+@修改过这个文件的人
+@核心开发人员
diff --git a/README.md b/README.md
index c9c934eb..65d713e6 100644
--- a/README.md
+++ b/README.md
@@ -30,6 +30,7 @@ A deep learning NLP model is the composition of three types of modules:
decode the representation into the output |
MLP, CRF |
+
For example:
@@ -37,9 +38,11 @@ For example:
## Requirements
+- Python>=3.6
- numpy>=1.14.2
- torch>=0.4.0
- tensorboardX
+- tqdm>=4.28.1
## Resources
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 294a44d0..c7d94486 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,5 +1,8 @@
numpy>=1.14.2
-http://download.pytorch.org/whl/cpu/torch-0.4.1-cp35-cp35m-linux_x86_64.whl
+http://download.pytorch.org/whl/cpu/torch-0.4.1-cp36-cp36m-linux_x86_64.whl
torchvision>=0.1.8
sphinx-rtd-theme==0.4.1
-tensorboardX>=1.4
\ No newline at end of file
+tensorboardX>=1.4
+tqdm>=4.28.1
+ipython>=6.4.0
+ipython-genutils>=0.2.0
\ No newline at end of file
diff --git a/docs/source/fastNLP.api.rst b/docs/source/fastNLP.api.rst
new file mode 100644
index 00000000..eb9192da
--- /dev/null
+++ b/docs/source/fastNLP.api.rst
@@ -0,0 +1,36 @@
+fastNLP.api
+============
+
+fastNLP.api.api
+----------------
+
+.. automodule:: fastNLP.api.api
+ :members:
+
+fastNLP.api.converter
+----------------------
+
+.. automodule:: fastNLP.api.converter
+ :members:
+
+fastNLP.api.model\_zoo
+-----------------------
+
+.. automodule:: fastNLP.api.model_zoo
+ :members:
+
+fastNLP.api.pipeline
+---------------------
+
+.. automodule:: fastNLP.api.pipeline
+ :members:
+
+fastNLP.api.processor
+----------------------
+
+.. automodule:: fastNLP.api.processor
+ :members:
+
+
+.. automodule:: fastNLP.api
+ :members:
diff --git a/docs/source/fastNLP.core.rst b/docs/source/fastNLP.core.rst
index b70b6798..b9f6c89f 100644
--- a/docs/source/fastNLP.core.rst
+++ b/docs/source/fastNLP.core.rst
@@ -13,8 +13,8 @@ fastNLP.core.dataset
.. automodule:: fastNLP.core.dataset
:members:
-fastNLP.core.fieldarray
--------------------
+fastNLP.core.fieldarray
+------------------------
.. automodule:: fastNLP.core.fieldarray
:members:
@@ -25,8 +25,8 @@ fastNLP.core.instance
.. automodule:: fastNLP.core.instance
:members:
-fastNLP.core.losses
-------------------
+fastNLP.core.losses
+--------------------
.. automodule:: fastNLP.core.losses
:members:
@@ -67,6 +67,12 @@ fastNLP.core.trainer
.. automodule:: fastNLP.core.trainer
:members:
+fastNLP.core.utils
+-------------------
+
+.. automodule:: fastNLP.core.utils
+ :members:
+
fastNLP.core.vocabulary
------------------------
diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst
new file mode 100644
index 00000000..d91e0d1c
--- /dev/null
+++ b/docs/source/fastNLP.io.rst
@@ -0,0 +1,42 @@
+fastNLP.io
+===========
+
+fastNLP.io.base\_loader
+------------------------
+
+.. automodule:: fastNLP.io.base_loader
+ :members:
+
+fastNLP.io.config\_io
+----------------------
+
+.. automodule:: fastNLP.io.config_io
+ :members:
+
+fastNLP.io.dataset\_loader
+---------------------------
+
+.. automodule:: fastNLP.io.dataset_loader
+ :members:
+
+fastNLP.io.embed\_loader
+-------------------------
+
+.. automodule:: fastNLP.io.embed_loader
+ :members:
+
+fastNLP.io.logger
+------------------
+
+.. automodule:: fastNLP.io.logger
+ :members:
+
+fastNLP.io.model\_io
+---------------------
+
+.. automodule:: fastNLP.io.model_io
+ :members:
+
+
+.. automodule:: fastNLP.io
+ :members:
diff --git a/docs/source/fastNLP.loader.rst b/docs/source/fastNLP.loader.rst
deleted file mode 100644
index 658e07ff..00000000
--- a/docs/source/fastNLP.loader.rst
+++ /dev/null
@@ -1,36 +0,0 @@
-fastNLP.loader
-===============
-
-fastNLP.loader.base\_loader
-----------------------------
-
-.. automodule:: fastNLP.loader.base_loader
- :members:
-
-fastNLP.loader.config\_loader
-------------------------------
-
-.. automodule:: fastNLP.loader.config_loader
- :members:
-
-fastNLP.loader.dataset\_loader
--------------------------------
-
-.. automodule:: fastNLP.loader.dataset_loader
- :members:
-
-fastNLP.loader.embed\_loader
------------------------------
-
-.. automodule:: fastNLP.loader.embed_loader
- :members:
-
-fastNLP.loader.model\_loader
------------------------------
-
-.. automodule:: fastNLP.loader.model_loader
- :members:
-
-
-.. automodule:: fastNLP.loader
- :members:
diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst
index f17b1d49..7452fdf6 100644
--- a/docs/source/fastNLP.models.rst
+++ b/docs/source/fastNLP.models.rst
@@ -7,6 +7,12 @@ fastNLP.models.base\_model
.. automodule:: fastNLP.models.base_model
:members:
+fastNLP.models.biaffine\_parser
+--------------------------------
+
+.. automodule:: fastNLP.models.biaffine_parser
+ :members:
+
fastNLP.models.char\_language\_model
-------------------------------------
@@ -25,6 +31,12 @@ fastNLP.models.sequence\_modeling
.. automodule:: fastNLP.models.sequence_modeling
:members:
+fastNLP.models.snli
+--------------------
+
+.. automodule:: fastNLP.models.snli
+ :members:
+
.. automodule:: fastNLP.models
:members:
diff --git a/docs/source/fastNLP.modules.encoder.rst b/docs/source/fastNLP.modules.encoder.rst
index 41b4ce13..ea8fc699 100644
--- a/docs/source/fastNLP.modules.encoder.rst
+++ b/docs/source/fastNLP.modules.encoder.rst
@@ -43,6 +43,12 @@ fastNLP.modules.encoder.masked\_rnn
.. automodule:: fastNLP.modules.encoder.masked_rnn
:members:
+fastNLP.modules.encoder.transformer
+------------------------------------
+
+.. automodule:: fastNLP.modules.encoder.transformer
+ :members:
+
fastNLP.modules.encoder.variational\_rnn
-----------------------------------------
diff --git a/docs/source/fastNLP.modules.interactor.rst b/docs/source/fastNLP.modules.interactor.rst
deleted file mode 100644
index 5eb3bdef..00000000
--- a/docs/source/fastNLP.modules.interactor.rst
+++ /dev/null
@@ -1,5 +0,0 @@
-fastNLP.modules.interactor
-===========================
-
-.. automodule:: fastNLP.modules.interactor
- :members:
diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst
index eda85aa7..965fb27d 100644
--- a/docs/source/fastNLP.modules.rst
+++ b/docs/source/fastNLP.modules.rst
@@ -6,7 +6,12 @@ fastNLP.modules
fastNLP.modules.aggregator
fastNLP.modules.decoder
fastNLP.modules.encoder
- fastNLP.modules.interactor
+
+fastNLP.modules.dropout
+------------------------
+
+.. automodule:: fastNLP.modules.dropout
+ :members:
fastNLP.modules.other\_modules
-------------------------------
diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst
index bb5037ce..61882359 100644
--- a/docs/source/fastNLP.rst
+++ b/docs/source/fastNLP.rst
@@ -3,18 +3,11 @@ fastNLP
.. toctree::
+ fastNLP.api
fastNLP.core
- fastNLP.loader
+ fastNLP.io
fastNLP.models
fastNLP.modules
- fastNLP.saver
-
-fastNLP.fastnlp
-----------------
-
-.. automodule:: fastNLP.fastnlp
- :members:
-
.. automodule:: fastNLP
:members:
diff --git a/docs/source/fastNLP.saver.rst b/docs/source/fastNLP.saver.rst
deleted file mode 100644
index 1a02572d..00000000
--- a/docs/source/fastNLP.saver.rst
+++ /dev/null
@@ -1,24 +0,0 @@
-fastNLP.saver
-==============
-
-fastNLP.saver.config\_saver
-----------------------------
-
-.. automodule:: fastNLP.saver.config_saver
- :members:
-
-fastNLP.saver.logger
----------------------
-
-.. automodule:: fastNLP.saver.logger
- :members:
-
-fastNLP.saver.model\_saver
----------------------------
-
-.. automodule:: fastNLP.saver.model_saver
- :members:
-
-
-.. automodule:: fastNLP.saver
- :members:
diff --git a/docs/source/index.rst b/docs/source/index.rst
index b58f712a..9f410f41 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,33 +1,35 @@
fastNLP documentation
=====================
-fastNLP,目前仍在孵化中。
+A Modularized and Extensible Toolkit for Natural Language Processing. Currently still in incubation.
Introduction
------------
-fastNLP是一个基于PyTorch的模块化自然语言处理系统,用于快速开发NLP工具。
-它将基于深度学习的NLP模型划分为不同的模块。
-这些模块分为4类:encoder(编码),interaction(交互), aggregration(聚合) and decoder(解码),
-而每个类别包含不同的实现模块。
+FastNLP is a modular Natural Language Processing system based on
+PyTorch, built for fast development of NLP models.
-大多数当前的NLP模型可以构建在这些模块上,这极大地简化了开发NLP模型的过程。
-fastNLP的架构如图所示:
+A deep learning NLP model is the composition of three types of modules:
-.. image:: figures/procedures.PNG
++-----------------------+-----------------------+-----------------------+
+| module type | functionality | example |
++=======================+=======================+=======================+
+| encoder | encode the input into | embedding, RNN, CNN, |
+| | some abstract | transformer |
+| | representation | |
++-----------------------+-----------------------+-----------------------+
+| aggregator | aggregate and reduce | self-attention, |
+| | information | max-pooling |
++-----------------------+-----------------------+-----------------------+
+| decoder | decode the | MLP, CRF |
+| | representation into | |
+| | the output | |
++-----------------------+-----------------------+-----------------------+
-在constructing model部分,以序列标注和文本分类为例进行说明:
-.. image:: figures/text_classification.png
-.. image:: figures/sequence_labeling.PNG
- :width: 400
-
-* encoder module:将输入编码为一些抽象表示,输入的是单词序列,输出向量序列。
-* interaction module:使表示中的信息相互交互,输入的是向量序列,输出的也是向量序列。
-* aggregation module:聚合和减少信息,输入向量序列,输出一个向量。
-* decoder module:将表示解码为输出,输出一个label(文本分类)或者输出label序列(序列标注)
+For example:
-其中interaction module和aggregation module在模型中不一定存在,例如上面的序列标注模型。
+.. image:: figures/text_classification.png
diff --git a/docs/source/tutorials/fastnlp_10tmin_tutorial.rst b/docs/source/tutorials/fastnlp_10tmin_tutorial.rst
new file mode 100644
index 00000000..30293796
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_10tmin_tutorial.rst
@@ -0,0 +1,375 @@
+
+fastNLP上手教程
+===============
+
+fastNLP提供方便的数据预处理,训练和测试模型的功能
+
+DataSet & Instance
+------------------
+
+fastNLP用DataSet和Instance保存和处理数据。每个DataSet表示一个数据集,每个Instance表示一个数据样本。一个DataSet存有多个Instance,每个Instance可以自定义存哪些内容。
+
+有一些read\_\*方法,可以轻松从文件读取数据,存成DataSet。
+
+.. code:: ipython3
+
+ from fastNLP import DataSet
+ from fastNLP import Instance
+
+ # 从csv读取数据到DataSet
+ win_path = "C:\\Users\zyfeng\Desktop\FudanNLP\\fastNLP\\test\\data_for_tests\\tutorial_sample_dataset.csv"
+ dataset = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\t')
+ print(dataset[0])
+
+
+.. parsed-literal::
+
+ {'raw_sentence': A series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .,
+ 'label': 1}
+
+
+.. code:: ipython3
+
+ # DataSet.append(Instance)加入新数据
+
+ dataset.append(Instance(raw_sentence='fake data', label='0'))
+ dataset[-1]
+
+
+
+
+.. parsed-literal::
+
+ {'raw_sentence': fake data,
+ 'label': 0}
+
+
+
+.. code:: ipython3
+
+ # DataSet.apply(func, new_field_name)对数据预处理
+
+ # 将所有数字转为小写
+ dataset.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
+ # label转int
+ dataset.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)
+ # 使用空格分割句子
+ dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)
+ def split_sent(ins):
+ return ins['raw_sentence'].split()
+ dataset.apply(split_sent, new_field_name='words', is_input=True)
+
+.. code:: ipython3
+
+ # DataSet.drop(func)筛除数据
+ # 删除低于某个长度的词语
+ dataset.drop(lambda x: len(x['words']) <= 3)
+
+.. code:: ipython3
+
+ # 分出测试集、训练集
+
+ test_data, train_data = dataset.split(0.3)
+ print("Train size: ", len(test_data))
+ print("Test size: ", len(train_data))
+
+
+.. parsed-literal::
+
+ Train size: 54
+ Test size:
+
+Vocabulary
+----------
+
+fastNLP中的Vocabulary轻松构建词表,将词转成数字
+
+.. code:: ipython3
+
+ from fastNLP import Vocabulary
+
+ # 构建词表, Vocabulary.add(word)
+ vocab = Vocabulary(min_freq=2)
+ train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
+ vocab.build_vocab()
+
+ # index句子, Vocabulary.to_index(word)
+ train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)
+ test_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)
+
+
+ print(test_data[0])
+
+
+.. parsed-literal::
+
+ {'raw_sentence': the plot is romantic comedy boilerplate from start to finish .,
+ 'label': 2,
+ 'label_seq': 2,
+ 'words': ['the', 'plot', 'is', 'romantic', 'comedy', 'boilerplate', 'from', 'start', 'to', 'finish', '.'],
+ 'word_seq': [2, 13, 9, 24, 25, 26, 15, 27, 11, 28, 3]}
+
+
+.. code:: ipython3
+
+ # 假设你们需要做强化学习或者gan之类的项目,也许你们可以使用这里的dataset
+ from fastNLP.core.batch import Batch
+ from fastNLP.core.sampler import RandomSampler
+
+ batch_iterator = Batch(dataset=train_data, batch_size=2, sampler=RandomSampler())
+ for batch_x, batch_y in batch_iterator:
+ print("batch_x has: ", batch_x)
+ print("batch_y has: ", batch_y)
+ break
+
+
+.. parsed-literal::
+
+ batch_x has: {'words': array([list(['this', 'kind', 'of', 'hands-on', 'storytelling', 'is', 'ultimately', 'what', 'makes', 'shanghai', 'ghetto', 'move', 'beyond', 'a', 'good', ',', 'dry', ',', 'reliable', 'textbook', 'and', 'what', 'allows', 'it', 'to', 'rank', 'with', 'its', 'worthy', 'predecessors', '.']),
+ list(['the', 'entire', 'movie', 'is', 'filled', 'with', 'deja', 'vu', 'moments', '.'])],
+ dtype=object), 'word_seq': tensor([[ 19, 184, 6, 1, 481, 9, 206, 50, 91, 1210, 1609, 1330,
+ 495, 5, 63, 4, 1269, 4, 1, 1184, 7, 50, 1050, 10,
+ 8, 1611, 16, 21, 1039, 1, 2],
+ [ 3, 711, 22, 9, 1282, 16, 2482, 2483, 200, 2, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0]])}
+ batch_y has: {'label_seq': tensor([3, 2])}
+
+
+Model
+-----
+
+.. code:: ipython3
+
+ # 定义一个简单的Pytorch模型
+
+ from fastNLP.models import CNNText
+ model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
+ model
+
+
+
+
+.. parsed-literal::
+
+ CNNText(
+ (embed): Embedding(
+ (embed): Embedding(77, 50, padding_idx=0)
+ (dropout): Dropout(p=0.0)
+ )
+ (conv_pool): ConvMaxpool(
+ (convs): ModuleList(
+ (0): Conv1d(50, 3, kernel_size=(3,), stride=(1,), padding=(2,))
+ (1): Conv1d(50, 4, kernel_size=(4,), stride=(1,), padding=(2,))
+ (2): Conv1d(50, 5, kernel_size=(5,), stride=(1,), padding=(2,))
+ )
+ )
+ (dropout): Dropout(p=0.1)
+ (fc): Linear(
+ (linear): Linear(in_features=12, out_features=5, bias=True)
+ )
+ )
+
+
+
+Trainer & Tester
+----------------
+
+使用fastNLP的Trainer训练模型
+
+.. code:: ipython3
+
+ from fastNLP import Trainer
+ from copy import deepcopy
+ from fastNLP import CrossEntropyLoss
+ from fastNLP import AccuracyMetric
+
+.. code:: ipython3
+
+ # 进行overfitting测试
+ copy_model = deepcopy(model)
+ overfit_trainer = Trainer(model=copy_model,
+ train_data=test_data,
+ dev_data=test_data,
+ loss=CrossEntropyLoss(pred="output", target="label_seq"),
+ metrics=AccuracyMetric(),
+ n_epochs=10,
+ save_path=None)
+ overfit_trainer.train()
+
+
+.. parsed-literal::
+
+ training epochs started 2018-12-07 14:07:20
+
+
+
+
+.. parsed-literal::
+
+ HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=20), HTML(value='')), layout=Layout(display='…
+
+
+
+.. parsed-literal::
+
+ Epoch 1/10. Step:2/20. AccuracyMetric: acc=0.037037
+ Epoch 2/10. Step:4/20. AccuracyMetric: acc=0.296296
+ Epoch 3/10. Step:6/20. AccuracyMetric: acc=0.333333
+ Epoch 4/10. Step:8/20. AccuracyMetric: acc=0.555556
+ Epoch 5/10. Step:10/20. AccuracyMetric: acc=0.611111
+ Epoch 6/10. Step:12/20. AccuracyMetric: acc=0.481481
+ Epoch 7/10. Step:14/20. AccuracyMetric: acc=0.62963
+ Epoch 8/10. Step:16/20. AccuracyMetric: acc=0.685185
+ Epoch 9/10. Step:18/20. AccuracyMetric: acc=0.722222
+ Epoch 10/10. Step:20/20. AccuracyMetric: acc=0.777778
+
+
+.. code:: ipython3
+
+ # 实例化Trainer,传入模型和数据,进行训练
+ trainer = Trainer(model=model,
+ train_data=train_data,
+ dev_data=test_data,
+ loss=CrossEntropyLoss(pred="output", target="label_seq"),
+ metrics=AccuracyMetric(),
+ n_epochs=5)
+ trainer.train()
+ print('Train finished!')
+
+
+.. parsed-literal::
+
+ training epochs started 2018-12-07 14:08:10
+
+
+
+
+.. parsed-literal::
+
+ HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=5), HTML(value='')), layout=Layout(display='i…
+
+
+
+.. parsed-literal::
+
+ Epoch 1/5. Step:1/5. AccuracyMetric: acc=0.037037
+ Epoch 2/5. Step:2/5. AccuracyMetric: acc=0.037037
+ Epoch 3/5. Step:3/5. AccuracyMetric: acc=0.037037
+ Epoch 4/5. Step:4/5. AccuracyMetric: acc=0.185185
+ Epoch 5/5. Step:5/5. AccuracyMetric: acc=0.240741
+ Train finished!
+
+
+.. code:: ipython3
+
+ from fastNLP import Tester
+
+ tester = Tester(data=test_data, model=model, metrics=AccuracyMetric())
+ acc = tester.test()
+
+
+.. parsed-literal::
+
+ [tester]
+ AccuracyMetric: acc=0.240741
+
+
+In summary
+----------
+
+fastNLP Trainer的伪代码逻辑
+---------------------------
+
+1. 准备DataSet,假设DataSet中共有如下的fields
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+::
+
+ ['raw_sentence', 'word_seq1', 'word_seq2', 'raw_label','label']
+ 通过
+ DataSet.set_input('word_seq1', word_seq2', flag=True)将'word_seq1', 'word_seq2'设置为input
+ 通过
+ DataSet.set_target('label', flag=True)将'label'设置为target
+
+2. 初始化模型
+~~~~~~~~~~~~~
+
+::
+
+ class Model(nn.Module):
+ def __init__(self):
+ xxx
+ def forward(self, word_seq1, word_seq2):
+ # (1) 这里使用的形参名必须和DataSet中的input field的名称对应。因为我们是通过形参名, 进行赋值的
+ # (2) input field的数量可以多于这里的形参数量。但是不能少于。
+ xxxx
+ # 输出必须是一个dict
+
+3. Trainer的训练过程
+~~~~~~~~~~~~~~~~~~~~
+
+::
+
+ (1) 从DataSet中按照batch_size取出一个batch,调用Model.forward
+ (2) 将 Model.forward的结果 与 标记为target的field 传入Losser当中。
+ 由于每个人写的Model.forward的output的dict可能key并不一样,比如有人是{'pred':xxx}, {'output': xxx};
+ 另外每个人将target可能也会设置为不同的名称, 比如有人是label, 有人设置为target;
+ 为了解决以上的问题,我们的loss提供映射机制
+ 比如CrossEntropyLosser的需要的输入是(prediction, target)。但是forward的output是{'output': xxx}; 'label'是target
+ 那么初始化losser的时候写为CrossEntropyLosser(prediction='output', target='label')即可
+ (3) 对于Metric是同理的
+ Metric计算也是从 forward的结果中取值 与 设置target的field中取值。 也是可以通过映射找到对应的值
+
+一些问题.
+---------
+
+1. DataSet中为什么需要设置input和target
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+::
+
+ 只有被设置为input或者target的数据才会在train的过程中被取出来
+ (1.1) 我们只会在设置为input的field中寻找传递给Model.forward的参数。
+ (1.2) 我们在传递值给losser或者metric的时候会使用来自:
+ (a)Model.forward的output
+ (b)被设置为target的field
+
+
+2. 我们是通过forwad中的形参名将DataSet中的field赋值给对应的参数
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+::
+
+ (1.1) 构建模型过程中,
+ 例如:
+ DataSet中x,seq_lens是input,那么forward就应该是
+ def forward(self, x, seq_lens):
+ pass
+ 我们是通过形参名称进行匹配的field的
+
+
+1. 加载数据到DataSet
+~~~~~~~~~~~~~~~~~~~~
+
+2. 使用apply操作对DataSet进行预处理
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+::
+
+ (2.1) 处理过程中将某些field设置为input,某些field设置为target
+
+3. 构建模型
+~~~~~~~~~~~
+
+::
+
+ (3.1) 构建模型过程中,需要注意forward函数的形参名需要和DataSet中设置为input的field名称是一致的。
+ 例如:
+ DataSet中x,seq_lens是input,那么forward就应该是
+ def forward(self, x, seq_lens):
+ pass
+ 我们是通过形参名称进行匹配的field的
+ (3.2) 模型的forward的output需要是dict类型的。
+ 建议将输出设置为{"pred": xx}.
+
diff --git a/docs/source/tutorials/fastnlp_1_minute_tutorial.rst b/docs/source/tutorials/fastnlp_1_minute_tutorial.rst
new file mode 100644
index 00000000..b4471e00
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_1_minute_tutorial.rst
@@ -0,0 +1,111 @@
+
+FastNLP 1分钟上手教程
+=====================
+
+step 1
+------
+
+读取数据集
+
+.. code:: ipython3
+
+ from fastNLP import DataSet
+ # linux_path = "../test/data_for_tests/tutorial_sample_dataset.csv"
+ win_path = "C:\\Users\zyfeng\Desktop\FudanNLP\\fastNLP\\test\\data_for_tests\\tutorial_sample_dataset.csv"
+ ds = DataSet.read_csv(win_path, headers=('raw_sentence', 'label'), sep='\t')
+
+step 2
+------
+
+数据预处理 1. 类型转换 2. 切分验证集 3. 构建词典
+
+.. code:: ipython3
+
+ # 将所有数字转为小写
+ ds.apply(lambda x: x['raw_sentence'].lower(), new_field_name='raw_sentence')
+ # label转int
+ ds.apply(lambda x: int(x['label']), new_field_name='label_seq', is_target=True)
+
+ def split_sent(ins):
+ return ins['raw_sentence'].split()
+ ds.apply(split_sent, new_field_name='words', is_input=True)
+
+
+.. code:: ipython3
+
+ # 分割训练集/验证集
+ train_data, dev_data = ds.split(0.3)
+ print("Train size: ", len(train_data))
+ print("Test size: ", len(dev_data))
+
+
+.. parsed-literal::
+
+ Train size: 54
+ Test size: 23
+
+
+.. code:: ipython3
+
+ from fastNLP import Vocabulary
+ vocab = Vocabulary(min_freq=2)
+ train_data.apply(lambda x: [vocab.add(word) for word in x['words']])
+
+ # index句子, Vocabulary.to_index(word)
+ train_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)
+ dev_data.apply(lambda x: [vocab.to_index(word) for word in x['words']], new_field_name='word_seq', is_input=True)
+
+
+step 3
+------
+
+定义模型
+
+.. code:: ipython3
+
+ from fastNLP.models import CNNText
+ model = CNNText(embed_num=len(vocab), embed_dim=50, num_classes=5, padding=2, dropout=0.1)
+
+
+step 4
+------
+
+开始训练
+
+.. code:: ipython3
+
+ from fastNLP import Trainer, CrossEntropyLoss, AccuracyMetric
+ trainer = Trainer(model=model,
+ train_data=train_data,
+ dev_data=dev_data,
+ loss=CrossEntropyLoss(),
+ metrics=AccuracyMetric()
+ )
+ trainer.train()
+ print('Train finished!')
+
+
+
+.. parsed-literal::
+
+ training epochs started 2018-12-07 14:03:41
+
+
+
+
+.. parsed-literal::
+
+ HBox(children=(IntProgress(value=0, layout=Layout(flex='2'), max=6), HTML(value='')), layout=Layout(display='i…
+
+
+
+.. parsed-literal::
+
+ Epoch 1/3. Step:2/6. AccuracyMetric: acc=0.26087
+ Epoch 2/3. Step:4/6. AccuracyMetric: acc=0.347826
+ Epoch 3/3. Step:6/6. AccuracyMetric: acc=0.608696
+ Train finished!
+
+
+本教程结束。更多操作请参考进阶教程。
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
diff --git a/docs/source/user/installation.rst b/docs/source/user/installation.rst
index 0655041b..7dc39b3b 100644
--- a/docs/source/user/installation.rst
+++ b/docs/source/user/installation.rst
@@ -6,26 +6,11 @@ Installation
:local:
-Cloning From GitHub
-~~~~~~~~~~~~~~~~~~~
-
-If you just want to use fastNLP, use:
+Run the following commands to install fastNLP package:
.. code:: shell
- git clone https://github.com/fastnlp/fastNLP
- cd fastNLP
+ pip install fastNLP
-PyTorch Installation
-~~~~~~~~~~~~~~~~~~~~
-
-Visit the [PyTorch official website] for installation instructions based
-on your system. In general, you could use:
-
-.. code:: shell
- # using conda
- conda install pytorch torchvision -c pytorch
- # or using pip
- pip3 install torch torchvision
diff --git a/docs/source/user/quickstart.rst b/docs/source/user/quickstart.rst
index 24c7363c..baa49eef 100644
--- a/docs/source/user/quickstart.rst
+++ b/docs/source/user/quickstart.rst
@@ -1,84 +1,9 @@
-==========
Quickstart
==========
-Example
--------
-
-Basic Usage
-~~~~~~~~~~~
-
-A typical fastNLP routine is composed of four phases: loading dataset,
-pre-processing data, constructing model and training model.
-
-.. code:: python
-
- from fastNLP.models.base_model import BaseModel
- from fastNLP.modules import encoder
- from fastNLP.modules import aggregation
- from fastNLP.modules import decoder
-
- from fastNLP.loader.dataset_loader import ClassDataSetLoader
- from fastNLP.loader.preprocess import ClassPreprocess
- from fastNLP.core.trainer import ClassificationTrainer
- from fastNLP.core.inference import ClassificationInfer
-
-
- class ClassificationModel(BaseModel):
- """
- Simple text classification model based on CNN.
- """
-
- def __init__(self, num_classes, vocab_size):
- super(ClassificationModel, self).__init__()
-
- self.emb = encoder.Embedding(nums=vocab_size, dims=300)
- self.enc = encoder.Conv(
- in_channels=300, out_channels=100, kernel_size=3)
- self.agg = aggregation.MaxPool()
- self.dec = decoder.MLP([100, num_classes])
-
- def forward(self, x):
- x = self.emb(x) # [N,L] -> [N,L,C]
- x = self.enc(x) # [N,L,C_in] -> [N,L,C_out]
- x = self.agg(x) # [N,L,C] -> [N,C]
- x = self.dec(x) # [N,C] -> [N, N_class]
- return x
-
-
- data_dir = 'data' # directory to save data and model
- train_path = 'test/data_for_tests/text_classify.txt' # training set file
-
- # load dataset
- ds_loader = ClassDataSetLoader("train", train_path)
- data = ds_loader.load()
-
- # pre-process dataset
- pre = ClassPreprocess(data_dir)
- vocab_size, n_classes = pre.process(data, "data_train.pkl")
-
- # construct model
- model_args = {
- 'num_classes': n_classes,
- 'vocab_size': vocab_size
- }
- model = ClassificationModel(num_classes=n_classes, vocab_size=vocab_size)
+.. toctree::
+ :maxdepth: 1
- # train model
- train_args = {
- "epochs": 20,
- "batch_size": 50,
- "pickle_path": data_dir,
- "validate": False,
- "save_best_dev": False,
- "model_saved_path": None,
- "use_cuda": True,
- "learn_rate": 1e-3,
- "momentum": 0.9}
- trainer = ClassificationTrainer(train_args)
- trainer.train(model)
+ ../tutorials/fastnlp_1_minute_tutorial
+ ../tutorials/fastnlp_10tmin_tutorial
- # predict using model
- seqs = [x[0] for x in data]
- infer = ClassificationInfer(data_dir)
- labels_pred = infer.predict(model, seqs)
\ No newline at end of file
diff --git a/fastNLP/io/dataset_loader.py b/fastNLP/io/dataset_loader.py
index 641a631e..76b9584d 100644
--- a/fastNLP/io/dataset_loader.py
+++ b/fastNLP/io/dataset_loader.py
@@ -417,6 +417,55 @@ class PeopleDailyCorpusLoader(DataSetLoader):
data_set.set_input("seq_len")
return data_set
+
+class Conll2003Loader(DataSetLoader):
+ """Self-defined loader of conll2003 dataset
+
+ More information about the given dataset cound be found on
+ https://sites.google.com/site/ermasoftware/getting-started/ne-tagging-conll2003-data
+
+ """
+
+ def __init__(self):
+ super(Conll2003Loader, self).__init__()
+
+ def load(self, dataset_path):
+ with open(dataset_path, "r", encoding="utf-8") as f:
+ lines = f.readlines()
+
+ ##Parse the dataset line by line
+ parsed_data = []
+ sentence = []
+ tokens = []
+ for line in lines:
+ if '-DOCSTART- -X- -X- O' in line or line == '\n':
+ if sentence != []:
+ parsed_data.append((sentence, tokens))
+ sentence = []
+ tokens = []
+ continue
+
+ temp = line.strip().split(" ")
+ sentence.append(temp[0])
+ tokens.append(temp[1:4])
+
+ return self.convert(parsed_data)
+
+ def convert(self, parsed_data):
+ dataset = DataSet()
+ for sample in parsed_data:
+ label0_list = list(map(
+ lambda labels: labels[0], sample[1]))
+ label1_list = list(map(
+ lambda labels: labels[1], sample[1]))
+ label2_list = list(map(
+ lambda labels: labels[2], sample[1]))
+ dataset.append(Instance(token_list=sample[0],
+ label0_list=label0_list,
+ label1_list=label1_list,
+ label2_list=label2_list))
+
+ return dataset
class SNLIDataSetLoader(DataSetLoader):
"""A data set loader for SNLI data set.
diff --git a/readthedocs.yml b/readthedocs.yml
new file mode 100644
index 00000000..9b172987
--- /dev/null
+++ b/readthedocs.yml
@@ -0,0 +1,6 @@
+build:
+ image: latest
+
+python:
+ version: 3.6
+ setup_py_install: true
\ No newline at end of file
diff --git a/test/data_for_tests/conll_2003_example.txt b/test/data_for_tests/conll_2003_example.txt
new file mode 100644
index 00000000..d11a8264
--- /dev/null
+++ b/test/data_for_tests/conll_2003_example.txt
@@ -0,0 +1,442 @@
+-DOCSTART- -X- -X- O
+
+SOCCER NN B-NP O
+- : O O
+JAPAN NNP B-NP B-LOC
+GET VB B-VP O
+LUCKY NNP B-NP O
+WIN NNP I-NP O
+, , O O
+CHINA NNP B-NP B-PER
+IN IN B-PP O
+SURPRISE DT B-NP O
+DEFEAT NN I-NP O
+. . O O
+
+Nadim NNP B-NP B-PER
+Ladki NNP I-NP I-PER
+
+AL-AIN NNP B-NP B-LOC
+, , O O
+United NNP B-NP B-LOC
+Arab NNP I-NP I-LOC
+Emirates NNPS I-NP I-LOC
+1996-12-06 CD I-NP O
+
+Japan NNP B-NP B-LOC
+began VBD B-VP O
+the DT B-NP O
+defence NN I-NP O
+of IN B-PP O
+their PRP$ B-NP O
+Asian JJ I-NP B-MISC
+Cup NNP I-NP I-MISC
+title NN I-NP O
+with IN B-PP O
+a DT B-NP O
+lucky JJ I-NP O
+2-1 CD I-NP O
+win VBP B-VP O
+against IN B-PP O
+Syria NNP B-NP B-LOC
+in IN B-PP O
+a DT B-NP O
+Group NNP I-NP O
+C NNP I-NP O
+championship NN I-NP O
+match NN I-NP O
+on IN B-PP O
+Friday NNP B-NP O
+. . O O
+
+But CC O O
+China NNP B-NP B-LOC
+saw VBD B-VP O
+their PRP$ B-NP O
+luck NN I-NP O
+desert VB B-VP O
+them PRP B-NP O
+in IN B-PP O
+the DT B-NP O
+second NN I-NP O
+match NN I-NP O
+of IN B-PP O
+the DT B-NP O
+group NN I-NP O
+, , O O
+crashing VBG B-VP O
+to TO B-PP O
+a DT B-NP O
+surprise NN I-NP O
+2-0 CD I-NP O
+defeat NN I-NP O
+to TO B-PP O
+newcomers NNS B-NP O
+Uzbekistan NNP I-NP B-LOC
+. . O O
+
+China NNP B-NP B-LOC
+controlled VBD B-VP O
+most JJS B-NP O
+of IN B-PP O
+the DT B-NP O
+match NN I-NP O
+and CC O O
+saw VBD B-VP O
+several JJ B-NP O
+chances NNS I-NP O
+missed VBD B-VP O
+until IN B-SBAR O
+the DT B-NP O
+78th JJ I-NP O
+minute NN I-NP O
+when WRB B-ADVP O
+Uzbek NNP B-NP B-MISC
+striker NN I-NP O
+Igor JJ B-NP B-PER
+Shkvyrin NNP I-NP I-PER
+took VBD B-VP O
+advantage NN B-NP O
+of IN B-PP O
+a DT B-NP O
+misdirected JJ I-NP O
+defensive JJ I-NP O
+header NN I-NP O
+to TO B-VP O
+lob VB I-VP O
+the DT B-NP O
+ball NN I-NP O
+over IN B-PP O
+the DT B-NP O
+advancing VBG I-NP O
+Chinese JJ I-NP B-MISC
+keeper NN I-NP O
+and CC O O
+into IN B-PP O
+an DT B-NP O
+empty JJ I-NP O
+net NN I-NP O
+. . O O
+
+Oleg NNP B-NP B-PER
+Shatskiku NNP I-NP I-PER
+made VBD B-VP O
+sure JJ B-ADJP O
+of IN B-PP O
+the DT B-NP O
+win VBP B-VP O
+in IN B-PP O
+injury NN B-NP O
+time NN I-NP O
+, , O O
+hitting VBG B-VP O
+an DT B-NP O
+unstoppable JJ I-NP O
+left VBD B-VP O
+foot NN B-NP O
+shot NN I-NP O
+from IN B-PP O
+just RB B-NP O
+outside IN B-PP O
+the DT B-NP O
+area NN I-NP O
+. . O O
+
+The DT B-NP O
+former JJ I-NP O
+Soviet JJ I-NP B-MISC
+republic NN I-NP O
+was VBD B-VP O
+playing VBG I-VP O
+in IN B-PP O
+an DT B-NP O
+Asian NNP I-NP B-MISC
+Cup NNP I-NP I-MISC
+finals NNS I-NP O
+tie NN I-NP O
+for IN B-PP O
+the DT B-NP O
+first JJ I-NP O
+time NN I-NP O
+. . O O
+
+Despite IN B-PP O
+winning VBG B-VP O
+the DT B-NP O
+Asian JJ I-NP B-MISC
+Games NNPS I-NP I-MISC
+title NN I-NP O
+two CD B-NP O
+years NNS I-NP O
+ago RB B-ADVP O
+, , O O
+Uzbekistan NNP B-NP B-LOC
+are VBP B-VP O
+in IN B-PP O
+the DT B-NP O
+finals NNS I-NP O
+as IN B-SBAR O
+outsiders NNS B-NP O
+. . O O
+
+Two CD B-NP O
+goals NNS I-NP O
+from IN B-PP O
+defensive JJ B-NP O
+errors NNS I-NP O
+in IN B-PP O
+the DT B-NP O
+last JJ I-NP O
+six CD I-NP O
+minutes NNS I-NP O
+allowed VBD B-VP O
+Japan NNP B-NP B-LOC
+to TO B-VP O
+come VB I-VP O
+from IN B-PP O
+behind NN B-NP O
+and CC O O
+collect VB B-VP O
+all DT B-NP O
+three CD I-NP O
+points NNS I-NP O
+from IN B-PP O
+their PRP$ B-NP O
+opening NN I-NP O
+meeting NN I-NP O
+against IN B-PP O
+Syria NNP B-NP B-LOC
+. . O O
+
+Takuya NNP B-NP B-PER
+Takagi NNP I-NP I-PER
+scored VBD B-VP O
+the DT B-NP O
+winner NN I-NP O
+in IN B-PP O
+the DT B-NP O
+88th JJ I-NP O
+minute NN I-NP O
+, , O O
+rising VBG B-VP O
+to TO I-VP O
+head VB I-VP O
+a DT B-NP O
+Hiroshige NNP I-NP B-PER
+Yanagimoto NNP I-NP I-PER
+cross VB B-VP O
+towards IN B-PP O
+the DT B-NP O
+Syrian JJ I-NP B-MISC
+goal NN I-NP O
+which WDT B-NP O
+goalkeeper VBD B-VP O
+Salem NNP B-NP B-PER
+Bitar NNP I-NP I-PER
+appeared VBD B-VP O
+to TO I-VP O
+have VB I-VP O
+covered VBN I-VP O
+but CC O O
+then RB B-VP O
+allowed VBN I-VP O
+to TO I-VP O
+slip VB I-VP O
+into IN B-PP O
+the DT B-NP O
+net NN I-NP O
+. . O O
+
+It PRP B-NP O
+was VBD B-VP O
+the DT B-NP O
+second JJ I-NP O
+costly JJ I-NP O
+blunder NN I-NP O
+by IN B-PP O
+Syria NNP B-NP B-LOC
+in IN B-PP O
+four CD B-NP O
+minutes NNS I-NP O
+. . O O
+
+Defender NNP B-NP O
+Hassan NNP I-NP B-PER
+Abbas NNP I-NP I-PER
+rose VBD B-VP O
+to TO I-VP O
+intercept VB I-VP O
+a DT B-NP O
+long JJ I-NP O
+ball NN I-NP O
+into IN B-PP O
+the DT B-NP O
+area NN I-NP O
+in IN B-PP O
+the DT B-NP O
+84th JJ I-NP O
+minute NN I-NP O
+but CC O O
+only RB B-ADVP O
+managed VBD B-VP O
+to TO I-VP O
+divert VB I-VP O
+it PRP B-NP O
+into IN B-PP O
+the DT B-NP O
+top JJ I-NP O
+corner NN I-NP O
+of IN B-PP O
+Bitar NN B-NP B-PER
+'s POS B-NP O
+goal NN I-NP O
+. . O O
+
+Nader NNP B-NP B-PER
+Jokhadar NNP I-NP I-PER
+had VBD B-VP O
+given VBN I-VP O
+Syria NNP B-NP B-LOC
+the DT B-NP O
+lead NN I-NP O
+with IN B-PP O
+a DT B-NP O
+well-struck NN I-NP O
+header NN I-NP O
+in IN B-PP O
+the DT B-NP O
+seventh JJ I-NP O
+minute NN I-NP O
+. . O O
+
+Japan NNP B-NP B-LOC
+then RB B-ADVP O
+laid VBD B-VP O
+siege NN B-NP O
+to TO B-PP O
+the DT B-NP O
+Syrian JJ I-NP B-MISC
+penalty NN I-NP O
+area NN I-NP O
+for IN B-PP O
+most JJS B-NP O
+of IN B-PP O
+the DT B-NP O
+game NN I-NP O
+but CC O O
+rarely RB B-VP O
+breached VBD I-VP O
+the DT B-NP O
+Syrian JJ I-NP B-MISC
+defence NN I-NP O
+. . O O
+
+Bitar NN B-NP B-PER
+pulled VBD B-VP O
+off RP B-PRT O
+fine JJ B-NP O
+saves VBZ B-VP O
+whenever WRB B-ADVP O
+they PRP B-NP O
+did VBD B-VP O
+. . O O
+
+Japan NNP B-NP B-LOC
+coach NN I-NP O
+Shu NNP I-NP B-PER
+Kamo NNP I-NP I-PER
+said VBD B-VP O
+: : O O
+' '' O O
+' POS B-NP O
+The DT I-NP O
+Syrian JJ I-NP B-MISC
+own JJ I-NP O
+goal NN I-NP O
+proved VBD B-VP O
+lucky JJ B-ADJP O
+for IN B-PP O
+us PRP B-NP O
+. . O O
+
+The DT B-NP O
+Syrians NNPS I-NP B-MISC
+scored VBD B-VP O
+early JJ B-NP O
+and CC O O
+then RB B-VP O
+played VBN I-VP O
+defensively RB B-ADVP O
+and CC O O
+adopted VBD B-VP O
+long RB I-VP O
+balls VBZ I-VP O
+which WDT B-NP O
+made VBD B-VP O
+it PRP B-NP O
+hard JJ B-ADJP O
+for IN B-PP O
+us PRP B-NP O
+. . O O
+' '' O O
+
+' '' O O
+
+Japan NNP B-NP B-LOC
+, , O O
+co-hosts VBZ B-VP O
+of IN B-PP O
+the DT B-NP O
+World NNP I-NP B-MISC
+Cup NNP I-NP I-MISC
+in IN B-PP O
+2002 CD B-NP O
+and CC O O
+ranked VBD B-VP O
+20th JJ B-NP O
+in IN B-PP O
+the DT B-NP O
+world NN I-NP O
+by IN B-PP O
+FIFA NNP B-NP B-ORG
+, , O O
+are VBP B-VP O
+favourites JJ B-ADJP O
+to TO B-VP O
+regain VB I-VP O
+their PRP$ B-NP O
+title NN I-NP O
+here RB B-ADVP O
+. . O O
+
+Hosts NNPS B-NP O
+UAE NNP I-NP B-LOC
+play NN I-NP O
+Kuwait NNP I-NP B-LOC
+and CC O O
+South NNP B-NP B-LOC
+Korea NNP I-NP I-LOC
+take VBP B-VP O
+on IN B-PP O
+Indonesia NNP B-NP B-LOC
+on IN B-PP O
+Saturday NNP B-NP O
+in IN B-PP O
+Group NNP B-NP O
+A NNP I-NP O
+matches VBZ B-VP O
+. . O O
+
+All DT B-NP O
+four CD I-NP O
+teams NNS I-NP O
+are VBP B-VP O
+level NN B-NP O
+with IN B-PP O
+one CD B-NP O
+point NN I-NP O
+each DT B-NP O
+from IN B-PP O
+one CD B-NP O
+game NN I-NP O
+. . O O
\ No newline at end of file
diff --git a/test/io/test_dataset_loader.py b/test/io/test_dataset_loader.py
new file mode 100644
index 00000000..9bee175b
--- /dev/null
+++ b/test/io/test_dataset_loader.py
@@ -0,0 +1,23 @@
+import os
+import unittest
+
+from fastNLP.io.dataset_loader import Conll2003Loader
+class TestDatasetLoader(unittest.TestCase):
+
+ def test_case_1(self):
+ '''
+ Test the the loader of Conll2003 dataset
+ '''
+
+ dataset_path = "test/data_for_tests/conll_2003_example.txt"
+ loader = Conll2003Loader()
+ dataset_2003 = loader.load(dataset_path)
+
+ for item in dataset_2003:
+ len0 = len(item["label0_list"])
+ len1 = len(item["label1_list"])
+ len2 = len(item["label2_list"])
+ lentoken = len(item["token_list"])
+ self.assertNotEqual(len0, 0)
+ self.assertEqual(len0, len1)
+ self.assertEqual(len1, len2)
\ No newline at end of file