diff --git a/docs/source/conf.py b/docs/source/conf.py index 01884ef7..812fb0ec 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,9 +24,9 @@ copyright = '2022, fastNLP' author = 'fastNLP' # The short X.Y version -version = '0.8' +version = '1.0' # The full version, including alpha/beta/rc tags -release = '0.8.0' +release = '1.0.0-alpha' # -- General configuration --------------------------------------------------- @@ -45,6 +45,7 @@ extensions = [ 'sphinx.ext.todo', 'sphinx_autodoc_typehints', 'sphinx_multiversion', + 'nbsphinx', ] autodoc_default_options = { @@ -169,7 +170,7 @@ man_pages = [ # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'fastNLP', 'fastNLP Documentation', - author, 'fastNLP', 'One line description of project.', + author, 'fastNLP', 'A fast NLP tool for programming.', 'Miscellaneous'), ] diff --git a/docs/source/fastNLP.core.callbacks.rst b/docs/source/fastNLP.core.callbacks.rst index 89d85f52..d0f3d210 100644 --- a/docs/source/fastNLP.core.callbacks.rst +++ b/docs/source/fastNLP.core.callbacks.rst @@ -31,5 +31,6 @@ Submodules fastNLP.core.callbacks.lr_scheduler_callback fastNLP.core.callbacks.more_evaluate_callback fastNLP.core.callbacks.progress_callback + fastNLP.core.callbacks.timer_callback fastNLP.core.callbacks.topk_saver fastNLP.core.callbacks.utils diff --git a/docs/source/fastNLP.core.callbacks.timer_callback.rst b/docs/source/fastNLP.core.callbacks.timer_callback.rst new file mode 100644 index 00000000..884fa604 --- /dev/null +++ b/docs/source/fastNLP.core.callbacks.timer_callback.rst @@ -0,0 +1,7 @@ +fastNLP.core.callbacks.timer\_callback module +============================================= + +.. automodule:: fastNLP.core.callbacks.timer_callback + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst b/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst new file mode 100644 index 00000000..ced75ccb --- /dev/null +++ b/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst @@ -0,0 +1,7 @@ +fastNLP.core.collators.padders.oneflow\_padder module +===================================================== + +.. automodule:: fastNLP.core.collators.padders.oneflow_padder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.collators.padders.rst b/docs/source/fastNLP.core.collators.padders.rst index 6f40becb..0c50dd4c 100644 --- a/docs/source/fastNLP.core.collators.padders.rst +++ b/docs/source/fastNLP.core.collators.padders.rst @@ -16,6 +16,7 @@ Submodules fastNLP.core.collators.padders.get_padder fastNLP.core.collators.padders.jittor_padder fastNLP.core.collators.padders.numpy_padder + fastNLP.core.collators.padders.oneflow_padder fastNLP.core.collators.padders.padder fastNLP.core.collators.padders.paddle_padder fastNLP.core.collators.padders.raw_padder diff --git a/docs/source/fastNLP.core.dataloaders.mix_dataloader.rst b/docs/source/fastNLP.core.dataloaders.mix_dataloader.rst deleted file mode 100644 index d2ffa234..00000000 --- a/docs/source/fastNLP.core.dataloaders.mix_dataloader.rst +++ /dev/null @@ -1,7 +0,0 @@ -fastNLP.core.dataloaders.mix\_dataloader module -=============================================== - -.. automodule:: fastNLP.core.dataloaders.mix_dataloader - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst new file mode 100644 index 00000000..5a8939b0 --- /dev/null +++ b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst @@ -0,0 +1,7 @@ +fastNLP.core.dataloaders.oneflow\_dataloader.fdl module +======================================================= + +.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader.fdl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst new file mode 100644 index 00000000..2b2081e5 --- /dev/null +++ b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst @@ -0,0 +1,15 @@ +fastNLP.core.dataloaders.oneflow\_dataloader package +==================================================== + +.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.core.dataloaders.oneflow_dataloader.fdl diff --git a/docs/source/fastNLP.core.dataloaders.rst b/docs/source/fastNLP.core.dataloaders.rst index e8c6b799..db53dbe0 100644 --- a/docs/source/fastNLP.core.dataloaders.rst +++ b/docs/source/fastNLP.core.dataloaders.rst @@ -13,6 +13,7 @@ Subpackages :maxdepth: 4 fastNLP.core.dataloaders.jittor_dataloader + fastNLP.core.dataloaders.oneflow_dataloader fastNLP.core.dataloaders.paddle_dataloader fastNLP.core.dataloaders.torch_dataloader @@ -22,6 +23,5 @@ Submodules .. toctree:: :maxdepth: 4 - fastNLP.core.dataloaders.mix_dataloader fastNLP.core.dataloaders.prepare_dataloader fastNLP.core.dataloaders.utils diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst new file mode 100644 index 00000000..cd8bd865 --- /dev/null +++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst @@ -0,0 +1,7 @@ +fastNLP.core.dataloaders.torch\_dataloader.mix\_dataloader module +================================================================= + +.. automodule:: fastNLP.core.dataloaders.torch_dataloader.mix_dataloader + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst index c9acca23..a3aeb1bf 100644 --- a/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst +++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst @@ -13,3 +13,4 @@ Submodules :maxdepth: 4 fastNLP.core.dataloaders.torch_dataloader.fdl + fastNLP.core.dataloaders.torch_dataloader.mix_dataloader diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst new file mode 100644 index 00000000..c7618619 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.ddp module +=============================================== + +.. automodule:: fastNLP.core.drivers.oneflow_driver.ddp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst new file mode 100644 index 00000000..9eae5d19 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.dist\_utils module +======================================================= + +.. automodule:: fastNLP.core.drivers.oneflow_driver.dist_utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst new file mode 100644 index 00000000..d7272c8e --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.initialize\_oneflow\_driver module +======================================================================= + +.. automodule:: fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst new file mode 100644 index 00000000..1f5d159e --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.oneflow\_driver module +=========================================================== + +.. automodule:: fastNLP.core.drivers.oneflow_driver.oneflow_driver + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.rst new file mode 100644 index 00000000..213dd24b --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.rst @@ -0,0 +1,20 @@ +fastNLP.core.drivers.oneflow\_driver package +============================================ + +.. automodule:: fastNLP.core.drivers.oneflow_driver + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.core.drivers.oneflow_driver.ddp + fastNLP.core.drivers.oneflow_driver.dist_utils + fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver + fastNLP.core.drivers.oneflow_driver.oneflow_driver + fastNLP.core.drivers.oneflow_driver.single_device + fastNLP.core.drivers.oneflow_driver.utils diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst new file mode 100644 index 00000000..a54e74ec --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.single\_device module +========================================================== + +.. automodule:: fastNLP.core.drivers.oneflow_driver.single_device + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst new file mode 100644 index 00000000..1eda7794 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.utils module +================================================= + +.. automodule:: fastNLP.core.drivers.oneflow_driver.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.rst b/docs/source/fastNLP.core.drivers.rst index bb168c76..30652fec 100644 --- a/docs/source/fastNLP.core.drivers.rst +++ b/docs/source/fastNLP.core.drivers.rst @@ -13,6 +13,7 @@ Subpackages :maxdepth: 4 fastNLP.core.drivers.jittor_driver + fastNLP.core.drivers.oneflow_driver fastNLP.core.drivers.paddle_driver fastNLP.core.drivers.torch_driver diff --git a/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst b/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst new file mode 100644 index 00000000..2944ffec --- /dev/null +++ b/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.torch\_driver.deepspeed module +=================================================== + +.. automodule:: fastNLP.core.drivers.torch_driver.deepspeed + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst b/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst new file mode 100644 index 00000000..e68972a7 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.torch\_driver.fairscale module +=================================================== + +.. automodule:: fastNLP.core.drivers.torch_driver.fairscale + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.torch_driver.fairscale_sharded.rst b/docs/source/fastNLP.core.drivers.torch_driver.fairscale_sharded.rst deleted file mode 100644 index 765ac4ae..00000000 --- a/docs/source/fastNLP.core.drivers.torch_driver.fairscale_sharded.rst +++ /dev/null @@ -1,7 +0,0 @@ -fastNLP.core.drivers.torch\_driver.fairscale\_sharded module -============================================================ - -.. automodule:: fastNLP.core.drivers.torch_driver.fairscale_sharded - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.torch_driver.rst b/docs/source/fastNLP.core.drivers.torch_driver.rst index 9a0109a2..c9080a86 100644 --- a/docs/source/fastNLP.core.drivers.torch_driver.rst +++ b/docs/source/fastNLP.core.drivers.torch_driver.rst @@ -13,9 +13,11 @@ Submodules :maxdepth: 4 fastNLP.core.drivers.torch_driver.ddp + fastNLP.core.drivers.torch_driver.deepspeed fastNLP.core.drivers.torch_driver.dist_utils - fastNLP.core.drivers.torch_driver.fairscale_sharded + fastNLP.core.drivers.torch_driver.fairscale fastNLP.core.drivers.torch_driver.initialize_torch_driver fastNLP.core.drivers.torch_driver.single_device fastNLP.core.drivers.torch_driver.torch_driver + fastNLP.core.drivers.torch_driver.torch_fsdp fastNLP.core.drivers.torch_driver.utils diff --git a/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst b/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst new file mode 100644 index 00000000..a799b7fc --- /dev/null +++ b/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.torch\_driver.torch\_fsdp module +===================================================== + +.. automodule:: fastNLP.core.drivers.torch_driver.torch_fsdp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst new file mode 100644 index 00000000..2389250b --- /dev/null +++ b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst @@ -0,0 +1,7 @@ +fastNLP.core.metrics.backend.oneflow\_backend.backend module +============================================================ + +.. automodule:: fastNLP.core.metrics.backend.oneflow_backend.backend + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst new file mode 100644 index 00000000..cb9e9653 --- /dev/null +++ b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst @@ -0,0 +1,15 @@ +fastNLP.core.metrics.backend.oneflow\_backend package +===================================================== + +.. automodule:: fastNLP.core.metrics.backend.oneflow_backend + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.core.metrics.backend.oneflow_backend.backend diff --git a/docs/source/fastNLP.core.metrics.backend.rst b/docs/source/fastNLP.core.metrics.backend.rst index 5a8cf4ad..4466a54a 100644 --- a/docs/source/fastNLP.core.metrics.backend.rst +++ b/docs/source/fastNLP.core.metrics.backend.rst @@ -13,6 +13,7 @@ Subpackages :maxdepth: 4 fastNLP.core.metrics.backend.jittor_backend + fastNLP.core.metrics.backend.oneflow_backend fastNLP.core.metrics.backend.paddle_backend fastNLP.core.metrics.backend.torch_backend diff --git a/docs/source/fastNLP.core.utils.oneflow_utils.rst b/docs/source/fastNLP.core.utils.oneflow_utils.rst new file mode 100644 index 00000000..f9d11510 --- /dev/null +++ b/docs/source/fastNLP.core.utils.oneflow_utils.rst @@ -0,0 +1,7 @@ +fastNLP.core.utils.oneflow\_utils module +======================================== + +.. automodule:: fastNLP.core.utils.oneflow_utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.utils.rst b/docs/source/fastNLP.core.utils.rst index 2d682010..9bf76a23 100644 --- a/docs/source/fastNLP.core.utils.rst +++ b/docs/source/fastNLP.core.utils.rst @@ -16,7 +16,10 @@ Submodules fastNLP.core.utils.dummy_class fastNLP.core.utils.exceptions fastNLP.core.utils.jittor_utils + fastNLP.core.utils.oneflow_utils fastNLP.core.utils.paddle_utils fastNLP.core.utils.rich_progress + fastNLP.core.utils.seq_len_to_mask fastNLP.core.utils.torch_utils + fastNLP.core.utils.tqdm_progress fastNLP.core.utils.utils diff --git a/docs/source/fastNLP.core.utils.seq_len_to_mask.rst b/docs/source/fastNLP.core.utils.seq_len_to_mask.rst new file mode 100644 index 00000000..55188a65 --- /dev/null +++ b/docs/source/fastNLP.core.utils.seq_len_to_mask.rst @@ -0,0 +1,7 @@ +fastNLP.core.utils.seq\_len\_to\_mask module +============================================ + +.. automodule:: fastNLP.core.utils.seq_len_to_mask + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.utils.tqdm_progress.rst b/docs/source/fastNLP.core.utils.tqdm_progress.rst new file mode 100644 index 00000000..cfcdc655 --- /dev/null +++ b/docs/source/fastNLP.core.utils.tqdm_progress.rst @@ -0,0 +1,7 @@ +fastNLP.core.utils.tqdm\_progress module +======================================== + +.. automodule:: fastNLP.core.utils.tqdm_progress + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.rst b/docs/source/fastNLP.embeddings.rst new file mode 100644 index 00000000..1b220f59 --- /dev/null +++ b/docs/source/fastNLP.embeddings.rst @@ -0,0 +1,15 @@ +fastNLP.embeddings package +========================== + +.. automodule:: fastNLP.embeddings + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.embeddings.torch diff --git a/docs/source/fastNLP.embeddings.torch.char_embedding.rst b/docs/source/fastNLP.embeddings.torch.char_embedding.rst new file mode 100644 index 00000000..f0d1dad7 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.char_embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.char\_embedding module +=============================================== + +.. automodule:: fastNLP.embeddings.torch.char_embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.embedding.rst b/docs/source/fastNLP.embeddings.torch.embedding.rst new file mode 100644 index 00000000..1804a70e --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.embedding module +========================================= + +.. automodule:: fastNLP.embeddings.torch.embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.rst b/docs/source/fastNLP.embeddings.torch.rst new file mode 100644 index 00000000..6294e8a2 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.rst @@ -0,0 +1,19 @@ +fastNLP.embeddings.torch package +================================ + +.. automodule:: fastNLP.embeddings.torch + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.embeddings.torch.char_embedding + fastNLP.embeddings.torch.embedding + fastNLP.embeddings.torch.stack_embedding + fastNLP.embeddings.torch.static_embedding + fastNLP.embeddings.torch.utils diff --git a/docs/source/fastNLP.embeddings.torch.stack_embedding.rst b/docs/source/fastNLP.embeddings.torch.stack_embedding.rst new file mode 100644 index 00000000..dab50088 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.stack_embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.stack\_embedding module +================================================ + +.. automodule:: fastNLP.embeddings.torch.stack_embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.static_embedding.rst b/docs/source/fastNLP.embeddings.torch.static_embedding.rst new file mode 100644 index 00000000..fc1a2bb9 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.static_embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.static\_embedding module +================================================= + +.. automodule:: fastNLP.embeddings.torch.static_embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.utils.rst b/docs/source/fastNLP.embeddings.torch.utils.rst new file mode 100644 index 00000000..9d1fc5b5 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.utils.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.utils module +===================================== + +.. automodule:: fastNLP.embeddings.torch.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst index bd91b795..20be532a 100644 --- a/docs/source/fastNLP.io.loader.rst +++ b/docs/source/fastNLP.io.loader.rst @@ -14,7 +14,6 @@ Submodules fastNLP.io.loader.classification fastNLP.io.loader.conll - fastNLP.io.loader.coreference fastNLP.io.loader.csv fastNLP.io.loader.cws fastNLP.io.loader.json diff --git a/docs/source/fastNLP.io.model_io.rst b/docs/source/fastNLP.io.model_io.rst deleted file mode 100644 index fe13d1d7..00000000 --- a/docs/source/fastNLP.io.model_io.rst +++ /dev/null @@ -1,7 +0,0 @@ -fastNLP.io.model\_io module -=========================== - -.. automodule:: fastNLP.io.model_io - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst index 9ad7e539..53a62918 100644 --- a/docs/source/fastNLP.io.pipe.rst +++ b/docs/source/fastNLP.io.pipe.rst @@ -15,7 +15,6 @@ Submodules fastNLP.io.pipe.classification fastNLP.io.pipe.conll fastNLP.io.pipe.construct_graph - fastNLP.io.pipe.coreference fastNLP.io.pipe.cws fastNLP.io.pipe.matching fastNLP.io.pipe.pipe diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index 5f025bba..7e1a5a67 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -25,5 +25,4 @@ Submodules fastNLP.io.embed_loader fastNLP.io.file_reader fastNLP.io.file_utils - fastNLP.io.model_io fastNLP.io.utils diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst new file mode 100644 index 00000000..eaef5a5e --- /dev/null +++ b/docs/source/fastNLP.models.rst @@ -0,0 +1,15 @@ +fastNLP.models package +====================== + +.. automodule:: fastNLP.models + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.models.torch diff --git a/docs/source/fastNLP.models.torch.biaffine_parser.rst b/docs/source/fastNLP.models.torch.biaffine_parser.rst new file mode 100644 index 00000000..c75d7079 --- /dev/null +++ b/docs/source/fastNLP.models.torch.biaffine_parser.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.biaffine\_parser module +============================================ + +.. automodule:: fastNLP.models.torch.biaffine_parser + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.cnn_text_classification.rst b/docs/source/fastNLP.models.torch.cnn_text_classification.rst new file mode 100644 index 00000000..a0b4e1bd --- /dev/null +++ b/docs/source/fastNLP.models.torch.cnn_text_classification.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.cnn\_text\_classification module +===================================================== + +.. automodule:: fastNLP.models.torch.cnn_text_classification + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.rst b/docs/source/fastNLP.models.torch.rst new file mode 100644 index 00000000..7196f3f7 --- /dev/null +++ b/docs/source/fastNLP.models.torch.rst @@ -0,0 +1,19 @@ +fastNLP.models.torch package +============================ + +.. automodule:: fastNLP.models.torch + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.models.torch.biaffine_parser + fastNLP.models.torch.cnn_text_classification + fastNLP.models.torch.seq2seq_generator + fastNLP.models.torch.seq2seq_model + fastNLP.models.torch.sequence_labeling diff --git a/docs/source/fastNLP.models.torch.seq2seq_generator.rst b/docs/source/fastNLP.models.torch.seq2seq_generator.rst new file mode 100644 index 00000000..bc1e4ca0 --- /dev/null +++ b/docs/source/fastNLP.models.torch.seq2seq_generator.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.seq2seq\_generator module +============================================== + +.. automodule:: fastNLP.models.torch.seq2seq_generator + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.seq2seq_model.rst b/docs/source/fastNLP.models.torch.seq2seq_model.rst new file mode 100644 index 00000000..802b8793 --- /dev/null +++ b/docs/source/fastNLP.models.torch.seq2seq_model.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.seq2seq\_model module +========================================== + +.. automodule:: fastNLP.models.torch.seq2seq_model + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.sequence_labeling.rst b/docs/source/fastNLP.models.torch.sequence_labeling.rst new file mode 100644 index 00000000..af834f53 --- /dev/null +++ b/docs/source/fastNLP.models.torch.sequence_labeling.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.sequence\_labeling module +============================================== + +.. automodule:: fastNLP.models.torch.sequence_labeling + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst index fa1d95de..b686105d 100644 --- a/docs/source/fastNLP.modules.rst +++ b/docs/source/fastNLP.modules.rst @@ -13,3 +13,4 @@ Subpackages :maxdepth: 4 fastNLP.modules.mix_modules + fastNLP.modules.torch diff --git a/docs/source/fastNLP.modules.torch.attention.rst b/docs/source/fastNLP.modules.torch.attention.rst new file mode 100644 index 00000000..52b7bf8c --- /dev/null +++ b/docs/source/fastNLP.modules.torch.attention.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.attention module +====================================== + +.. automodule:: fastNLP.modules.torch.attention + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.crf.rst b/docs/source/fastNLP.modules.torch.decoder.crf.rst new file mode 100644 index 00000000..2d9e3460 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.crf.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.crf module +======================================== + +.. automodule:: fastNLP.modules.torch.decoder.crf + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.mlp.rst b/docs/source/fastNLP.modules.torch.decoder.mlp.rst new file mode 100644 index 00000000..6bb9cc5c --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.mlp.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.mlp module +======================================== + +.. automodule:: fastNLP.modules.torch.decoder.mlp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.rst b/docs/source/fastNLP.modules.torch.decoder.rst new file mode 100644 index 00000000..999ab01d --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.rst @@ -0,0 +1,18 @@ +fastNLP.modules.torch.decoder package +===================================== + +.. automodule:: fastNLP.modules.torch.decoder + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.decoder.crf + fastNLP.modules.torch.decoder.mlp + fastNLP.modules.torch.decoder.seq2seq_decoder + fastNLP.modules.torch.decoder.seq2seq_state diff --git a/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst b/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst new file mode 100644 index 00000000..43c77fea --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.seq2seq\_decoder module +===================================================== + +.. automodule:: fastNLP.modules.torch.decoder.seq2seq_decoder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst b/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst new file mode 100644 index 00000000..05f730e4 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.seq2seq\_state module +=================================================== + +.. automodule:: fastNLP.modules.torch.decoder.seq2seq_state + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.loader.coreference.rst b/docs/source/fastNLP.modules.torch.dropout.rst similarity index 52% rename from docs/source/fastNLP.io.loader.coreference.rst rename to docs/source/fastNLP.modules.torch.dropout.rst index 58dfb880..8e4b591b 100644 --- a/docs/source/fastNLP.io.loader.coreference.rst +++ b/docs/source/fastNLP.modules.torch.dropout.rst @@ -1,7 +1,7 @@ -fastNLP.io.loader.coreference module +fastNLP.modules.torch.dropout module ==================================== -.. automodule:: fastNLP.io.loader.coreference +.. automodule:: fastNLP.modules.torch.dropout :members: :undoc-members: :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst b/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst new file mode 100644 index 00000000..438ec076 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.conv\_maxpool module +================================================== + +.. automodule:: fastNLP.modules.torch.encoder.conv_maxpool + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.lstm.rst b/docs/source/fastNLP.modules.torch.encoder.lstm.rst new file mode 100644 index 00000000..918e13cb --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.lstm.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.lstm module +========================================= + +.. automodule:: fastNLP.modules.torch.encoder.lstm + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.rst b/docs/source/fastNLP.modules.torch.encoder.rst new file mode 100644 index 00000000..14120ed1 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.rst @@ -0,0 +1,20 @@ +fastNLP.modules.torch.encoder package +===================================== + +.. automodule:: fastNLP.modules.torch.encoder + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.encoder.conv_maxpool + fastNLP.modules.torch.encoder.lstm + fastNLP.modules.torch.encoder.seq2seq_encoder + fastNLP.modules.torch.encoder.star_transformer + fastNLP.modules.torch.encoder.transformer + fastNLP.modules.torch.encoder.variational_rnn diff --git a/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst b/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst new file mode 100644 index 00000000..152fc091 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.seq2seq\_encoder module +===================================================== + +.. automodule:: fastNLP.modules.torch.encoder.seq2seq_encoder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst b/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst new file mode 100644 index 00000000..3257cf13 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.star\_transformer module +====================================================== + +.. automodule:: fastNLP.modules.torch.encoder.star_transformer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.transformer.rst b/docs/source/fastNLP.modules.torch.encoder.transformer.rst new file mode 100644 index 00000000..0a3c893f --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.transformer.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.transformer module +================================================ + +.. automodule:: fastNLP.modules.torch.encoder.transformer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst b/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst new file mode 100644 index 00000000..71a70c3a --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.variational\_rnn module +===================================================== + +.. automodule:: fastNLP.modules.torch.encoder.variational_rnn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.generator.rst b/docs/source/fastNLP.modules.torch.generator.rst new file mode 100644 index 00000000..783db61d --- /dev/null +++ b/docs/source/fastNLP.modules.torch.generator.rst @@ -0,0 +1,15 @@ +fastNLP.modules.torch.generator package +======================================= + +.. automodule:: fastNLP.modules.torch.generator + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.generator.seq2seq_generator diff --git a/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst b/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst new file mode 100644 index 00000000..4abc102f --- /dev/null +++ b/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.generator.seq2seq\_generator module +========================================================= + +.. automodule:: fastNLP.modules.torch.generator.seq2seq_generator + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.rst b/docs/source/fastNLP.modules.torch.rst new file mode 100644 index 00000000..8e1fb0f5 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.rst @@ -0,0 +1,26 @@ +fastNLP.modules.torch package +============================= + +.. automodule:: fastNLP.modules.torch + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.decoder + fastNLP.modules.torch.encoder + fastNLP.modules.torch.generator + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.attention + fastNLP.modules.torch.dropout diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index 89c8e058..f3e245fe 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -13,6 +13,9 @@ Subpackages :maxdepth: 4 fastNLP.core + fastNLP.embeddings fastNLP.envs fastNLP.io + fastNLP.models fastNLP.modules + fastNLP.transformers \ No newline at end of file diff --git a/docs/source/fastNLP.transformers.rst b/docs/source/fastNLP.transformers.rst new file mode 100644 index 00000000..023da63d --- /dev/null +++ b/docs/source/fastNLP.transformers.rst @@ -0,0 +1,14 @@ +fastNLP.transformers package +============================ +.. automodule:: fastNLP.transformers + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.transformers.torch diff --git a/docs/source/fastNLP.io.pipe.coreference.rst b/docs/source/fastNLP.transformers.torch.rst similarity index 53% rename from docs/source/fastNLP.io.pipe.coreference.rst rename to docs/source/fastNLP.transformers.torch.rst index bccdb0a7..9d5f0d65 100644 --- a/docs/source/fastNLP.io.pipe.coreference.rst +++ b/docs/source/fastNLP.transformers.torch.rst @@ -1,7 +1,7 @@ -fastNLP.io.pipe.coreference module +fastNLP.transformers.torch package ================================== -.. automodule:: fastNLP.io.pipe.coreference +.. automodule:: fastNLP.transformers.torch :members: :undoc-members: :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst index 10b0d74c..abe73d8a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -2,18 +2,18 @@ fastNLP 中文文档 ===================== -用户手册 +快速上手 ---------------- .. toctree:: - :maxdepth: 1 + :maxdepth: 2 - 语法样例 + tutorials API 文档 ------------- -除了用户手册之外,你还可以通过查阅 API 文档来找到你所需要的工具。 +您可以通过查阅 API 文档来找到你所需要的工具。 .. toctree:: :titlesonly: diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst new file mode 100644 index 00000000..bc32d10f --- /dev/null +++ b/docs/source/tutorials.rst @@ -0,0 +1,8 @@ +fastNLP 教程系列 +================ + +.. toctree:: + :maxdepth: 1 + :glob: + + tutorials/* diff --git a/docs/source/tutorials/fastnlp_torch_tutorial.ipynb b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb new file mode 100644 index 00000000..9633ac7f --- /dev/null +++ b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb @@ -0,0 +1,869 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6011adf8", + "metadata": {}, + "source": [ + "# 10 分钟快速上手 fastNLP torch\n", + "\n", + "在这个例子中,我们将使用BERT来解决conll2003数据集中的命名实体识别任务。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e166c051", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-07-07 10:12:29-- https://data.deepai.org/conll2003.zip\n", + "Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n", + "Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n", + "WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n", + " Issued certificate has expired.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 982975 (960K) [application/x-zip-compressed]\n", + "Saving to: ‘conll2003.zip’\n", + "\n", + "conll2003.zip 100%[===================>] 959.94K 653KB/s in 1.5s \n", + "\n", + "2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n", + "\n", + "Archive: conll2003.zip\n", + " inflating: conll2003/metadata \n", + " inflating: conll2003/test.txt \n", + " inflating: conll2003/train.txt \n", + " inflating: conll2003/valid.txt \n" + ] + } + ], + "source": [ + "# Linux/Mac 下载数据,并解压\n", + "import platform\n", + "if platform.system() != \"Windows\":\n", + " !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n", + " !unzip conll2003.zip -d conll2003\n", + "# Windows用户请通过复制该url到浏览器下载该数据并解压" + ] + }, + { + "cell_type": "markdown", + "id": "f7acbf1f", + "metadata": {}, + "source": [ + "## 目录\n", + "接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n", + "- 1. 数据加载\n", + "- 2. 数据预处理、数据缓存\n", + "- 3. DataLoader\n", + "- 4. 模型准备\n", + "- 5. Trainer的使用\n", + "- 6. Evaluator的使用\n", + "- 7. 其它【待补充】\n", + " - 7.1 使用多卡进行训练、评测\n", + " - 7.2 使用ZeRO优化\n", + " - 7.3 通过overfit测试快速验证模型\n", + " - 7.4 复杂Monitor的使用\n", + " - 7.5 训练过程中,使用不同的测试函数\n", + " - 7.6 更有效率的Sampler\n", + " - 7.7 保存模型\n", + " - 7.8 断点重训\n", + " - 7.9 使用huggingface datasets\n", + " - 7.10 使用torchmetrics来作为metric\n", + " - 7.11 将预测结果写出到文件\n", + " - 7.12 混合 dataset 训练\n", + " - 7.13 logger的使用\n", + " - 7.14 自定义分布式 Metric 。\n", + " - 7.15 通过batch_step_fn实现R-Drop" + ] + }, + { + "cell_type": "markdown", + "id": "0657dfba", + "metadata": {}, + "source": [ + "#### 1. 数据加载\n", + "目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件,文件的格式为[conll格式](https://universaldependencies.org/format.html),其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程,继承了 Loader 函数后,只需要在实现读取单个文件 _load() 函数即可。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c557f0ba", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6f59e438", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 3 datasets:\n", + "\ttrain has 14987 instances.\n", + "\ttest has 3684 instances.\n", + "\tdev has 3466 instances.\n", + "\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet, Instance\n", + "from fastNLP.io import Loader\n", + "\n", + "\n", + "# 继承Loader之后,我们只需要实现其中_load()方法,_load()方法传入一个文件路径,返回一个fastNLP DataSet对象,其目的是读取一个文件。\n", + "class ConllLoader(Loader):\n", + " def _load(self, path):\n", + " ds = DataSet()\n", + " with open(path, 'r') as f:\n", + " segments = []\n", + " for line in f:\n", + " line = line.strip()\n", + " if line == '': # 如果为空行,说明需要切换到下一句了。\n", + " if segments:\n", + " raw_words = [s[0] for s in segments]\n", + " raw_target = [s[1] for s in segments]\n", + " # 将一个 sample 插入到 DataSet中\n", + " ds.append(Instance(raw_words=raw_words, raw_target=raw_target)) \n", + " segments = []\n", + " else:\n", + " parts = line.split()\n", + " assert len(parts)==4\n", + " segments.append([parts[0], parts[-1]])\n", + " return ds\n", + " \n", + "\n", + "# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象,该对象相当于将多个 dataset 放置在一起,\n", + "# 可以方便之后的预处理,DataBundle 支持的接口可以在 !!! 查看。\n", + "data_bundle = ConllLoader().load({\n", + " 'train': 'conll2003/train.txt',\n", + " 'test': 'conll2003/test.txt',\n", + " 'dev': 'conll2003/valid.txt'\n", + "})\n", + "\"\"\"\n", + "也可以通过 ConllLoader().load('conll2003/') 来读取,其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n", + "'train'、'test'和'dev'的文件,并分别读取将其命名为'train'、'test'和'dev'(如文件夹中同一个关键字出现在了多个文件名中将导致报错,\n", + "此时请通过dict的方式传入路径信息)。但在我们这里的数据里,没有文件包含dev,所以无法直接使用文件夹读取,转而通过dict的方式传入读取的路径,\n", + "该dict的key也将作为读取的数据集的名称,value即对应的文件路径。\n", + "\"\"\"\n", + "\n", + "print(data_bundle) # 打印 data_bundle 可以查看包含的 DataSet \n", + "# data_bundle.get_dataset('train') # 可以获取单个 dataset" + ] + }, + { + "cell_type": "markdown", + "id": "57ae314d", + "metadata": {}, + "source": [ + "#### 2. 数据预处理\n", + "接下来,我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有: \n", + "(1)使用BertTokenizer将文本转换为index;同时记录每个word被bpe之后第一个bpe的index,用于得到word的hidden state; \n", + "(2)使用[Vocabulary](../fastNLP)来将raw_target转换为序号。 " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "96389988", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "c3bd41a323c94a41b409d29a5d4079b6",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "IOPub message rate exceeded.\n",
+      "The notebook server will temporarily stop sending output\n",
+      "to the client in order to avoid crashing it.\n",
+      "To change this limit, set the config variable\n",
+      "`--NotebookApp.iopub_msg_rate_limit`.\n",
+      "\n",
+      "Current values:\n",
+      "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
+      "NotebookApp.rate_limit_window=3.0 (secs)\n",
+      "\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
[10:48:13] INFO     Save cache to /remote-home/hyan01/exps/fastNLP/fastN cache_results.py:332\n",
+       "                    LP/demo/torch_tutorial/caches/c7f74559_cache.pkl.                        \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型,更多的预训练模型请直接使用transformers\n", + "from fastNLP.transformers.torch import BertTokenizer\n", + "from fastNLP import cache_results, Vocabulary\n", + "\n", + "# 使用cache_results来装饰函数,会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中(其中{param_hash_id}是根据\n", + "# 传递给 process_data 函数参数决定的,因此当函数的参数变化时,会再生成新的缓存文件。如果需要重新生成新的缓存,(a) 可以在调用process_data\n", + "# 函数时,额外传入一个_refresh=True的参数; 或者(b)删除相应的缓存文件。此外,保存结果时,cache_results默认还会\n", + "# 记录 process_data 函数源码的hash值,当其源码发生了变动,直接读取缓存会发出警告,以防止在修改预处理代码之后,忘记刷新缓存。)\n", + "@cache_results('caches/cache.pkl')\n", + "def process_data(data_bundle, model_name):\n", + " tokenizer = BertTokenizer.from_pretrained(model_name)\n", + " def bpe(raw_words):\n", + " bpes = [tokenizer.cls_token_id]\n", + " first = [0]\n", + " first_index = 1 # 记录第一个bpe的位置\n", + " for word in raw_words:\n", + " bpe = tokenizer.encode(word, add_special_tokens=False)\n", + " bpes.extend(bpe)\n", + " first.append(first_index)\n", + " first_index += len(bpe)\n", + " bpes.append(tokenizer.sep_token_id)\n", + " first.append(first_index)\n", + " return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n", + " # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数,并且将返回的结果加入到每条数据中。\n", + " data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n", + " # 对应我们还有 apply_field() 函数,该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n", + " # 内容(即不需要用dict包裹了)。此外,我们还提供了 data_bundle.apply() ,传入 apply() 的函数需要支持传入一个Instance对象,\n", + " # 更多信息可以参考对应的文档。\n", + " \n", + " # tag的词表,由于这是词表,所以不需要有padding和unk\n", + " tag_vocab = Vocabulary(padding=None, unknown=None)\n", + " # 从 train 数据的 raw_target 中获取建立词表\n", + " tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n", + " # 使用词表将每个 dataset 中的raw_target转为数字,并且将写入到target这个field中\n", + " tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n", + " \n", + " # 可以将 vocabulary 绑定到 data_bundle 上,方便之后使用。\n", + " data_bundle.set_vocab(tag_vocab, field_name='target')\n", + " \n", + " return data_bundle, tokenizer\n", + "\n", + "data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True) # 第一次调用耗时较长,第二次调用则会直接读取缓存的文件\n", + "# data_bundle = process_data(data_bundle, 'bert-base-uncased') # 由于参数变化,fastNLP 会再次生成新的缓存文件。 " + ] + }, + { + "cell_type": "markdown", + "id": "80036fcd", + "metadata": {}, + "source": [ + "### 3. DataLoader \n", + "由于现在的深度学习算法大都基于 mini-batch 进行优化,因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中,不同的 sample 往往长度不一致,需要进行 padding 操作。在fastNLP中,我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ,我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ,Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ,可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "09494695", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_dataloader\n", + "\n", + "# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ,包含了 'train', 'test', 'dev' 三个\n", + "# fastNLP.TorchDataLoader 对象。\n", + "dls = prepare_dataloader(data_bundle, batch_size=24) \n", + "\n", + "\n", + "# fastNLP 将默认尝试对所有 field 都进行 pad ,如果当前 field 是不可 pad 的类型,则不进行pad;如果是可以 pad 的类型\n", + "# 默认使用 0 进行 pad 。\n", + "for dl in dls.values():\n", + " # 可以通过 set_pad 修改 padding 的行为。\n", + " dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n", + " # 如果希望忽略某个 field ,可以通过 set_ignore 方法。\n", + " dl.set_ignore('raw_target')\n", + " dl.set_pad('target', pad_val=-100)\n", + "# 另一种设置的方法是,可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n", + "# data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n", + "# DataSet 也支持这两个方法。\n", + "# 若此时调用 batch = next(dls['train']),则 batch 是一个 dict ,其中包含了\n", + "# 'input_ids': torch.LongTensor([batch_size, max_len])\n", + "# 'input_len': torch.LongTensor([batch_size])\n", + "# 'first': torch.LongTensor([batch_size, max_len'])\n", + "# 'first_len': torch.LongTensor([batch_size])\n", + "# 'target': torch.LongTensor([batch_size, max_len'-2])\n", + "# 'raw_words': List[List[str]] # 因为无法判断,所以 Collator 不会做任何处理" + ] + }, + { + "cell_type": "markdown", + "id": "3583df6d", + "metadata": {}, + "source": [ + "### 4. 模型准备\n", + "传入给fastNLP的模型,需要有两个特殊的方法``train_step``、``evaluate_step``,前者默认在 fastNLP.Trainer 中进行调用,后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法,则Trainer会直接使用模型的``forward``函数;如果模型没有``evaluate_step``方法,则Evaluator会直接使用模型的``forward``函数。``train_step``方法(或当其不存在时,``forward``方法)的返回值必须为 dict 类型,并且必须包含``loss``这个 key 。\n", + "\n", + "此外fastNLP会使用形参名匹配的方式进行参数传递,例如以下模型\n", + "```python\n", + "class Model(nn.Module):\n", + " def train_step(self, x, y):\n", + " return {'loss': (x-y).abs().mean()}\n", + "```\n", + "fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ,如果没有找到则会报错。有以下的方法可以解决报错\n", + "- 修改 train_step 的参数为(input_ids, target),以保证和 DataLoader 返回的 batch 中的 key 匹配\n", + "- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n", + "- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射,train_input_mapping 也可以是一个函数,更多 train_input_mapping 的介绍可以参考文档。\n", + "\n", + "``evaluate_step``也是使用同样的匹配方式,前两条解决方法是一致的,第三种解决方案中,需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f131c1a3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[10:48:21] WARNING  Some weights of the model checkpoint at            modeling_utils.py:1490\n",
+       "                    bert-base-uncased were not used when initializing                        \n",
+       "                    BertModel: ['cls.predictions.bias',                                      \n",
+       "                    'cls.predictions.transform.LayerNorm.weight',                            \n",
+       "                    'cls.seq_relationship.weight',                                           \n",
+       "                    'cls.predictions.decoder.weight',                                        \n",
+       "                    'cls.predictions.transform.dense.weight',                                \n",
+       "                    'cls.predictions.transform.LayerNorm.bias',                              \n",
+       "                    'cls.predictions.transform.dense.bias',                                  \n",
+       "                    'cls.seq_relationship.bias']                                             \n",
+       "                    - This IS expected if you are initializing                               \n",
+       "                    BertModel from the checkpoint of a model trained                         \n",
+       "                    on another task or with another architecture (e.g.                       \n",
+       "                    initializing a BertForSequenceClassification model                       \n",
+       "                    from a BertForPreTraining model).                                        \n",
+       "                    - This IS NOT expected if you are initializing                           \n",
+       "                    BertModel from the checkpoint of a model that you                        \n",
+       "                    expect to be exactly identical (initializing a                           \n",
+       "                    BertForSequenceClassification model from a                               \n",
+       "                    BertForSequenceClassification model).                                    \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m bert-base-uncased were not used when initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.decoder.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m - This IS expected if you are initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model trained \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m initializing a BertForSequenceClassification model \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m from a BertForPreTraining model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m - This IS NOT expected if you are initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model that you \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m expect to be exactly identical \u001b[1m(\u001b[0minitializing a \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertForSequenceClassification model from a \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertForSequenceClassification model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     All the weights of BertModel were initialized from modeling_utils.py:1507\n",
+       "                    the model checkpoint at bert-base-uncased.                               \n",
+       "                    If your task is similar to the task the model of                         \n",
+       "                    the checkpoint was trained on, you can already use                       \n",
+       "                    BertModel for predictions without further                                \n",
+       "                    training.                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m the model checkpoint at bert-base-uncased. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m If your task is similar to the task the model of \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m the checkpoint was trained on, you can already use \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel for predictions without further \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m training. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "from fastNLP.transformers.torch import BertModel\n", + "from fastNLP import seq_len_to_mask\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "class BertNER(nn.Module):\n", + " def __init__(self, model_name, num_class, tag_vocab=None):\n", + " super().__init__()\n", + " self.bert = BertModel.from_pretrained(model_name)\n", + " self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n", + " nn.Dropout(0.3),\n", + " nn.Linear(self.bert.config.hidden_size, num_class))\n", + " self.tag_vocab = tag_vocab # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n", + " if tag_vocab is not None:\n", + " self._init_constrained_transition()\n", + " \n", + " def forward(self, input_ids, input_len, first):\n", + " attention_mask = seq_len_to_mask(input_len)\n", + " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", + " last_hidden_state = outputs.last_hidden_state\n", + " first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n", + " first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n", + " first_bpe_state = first_bpe_state[:, 1:-1] # 删除 cls 和 sep\n", + " \n", + " pred = self.mlp(first_bpe_state)\n", + " return {'pred': pred}\n", + " \n", + " def train_step(self, input_ids, input_len, first, target):\n", + " pred = self(input_ids, input_len, first)['pred']\n", + " loss = F.cross_entropy(pred.transpose(1, 2), target)\n", + " return {'loss': loss}\n", + " \n", + " def evaluate_step(self, input_ids, input_len, first):\n", + " pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n", + " return {'pred': pred}\n", + " \n", + " def constrained_decode(self, input_ids, input_len, first, first_len):\n", + " # 这个函数在推理时,将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n", + " # 本身这个需求可以在 Metric 中实现,这里在模型中实现的目的是为了方便演示:如何在fastNLP中使用不同的评测函数\n", + " pred = self(input_ids, input_len, first)['pred']\n", + " cons_pred = []\n", + " for _pred, _len in zip(pred, first_len):\n", + " _pred = _pred[:_len]\n", + " tags = [_pred[0].argmax(dim=-1).item()] # 这里就不考虑第一个位置非法的情况了\n", + " for i in range(1, _len):\n", + " tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n", + " cons_pred.append(torch.LongTensor(tags))\n", + " cons_pred = pad_sequence(cons_pred, batch_first=True)\n", + " return {'pred': cons_pred}\n", + " \n", + " def _init_constrained_transition(self):\n", + " from fastNLP.modules.torch import allowed_transitions\n", + " allowed_trans = allowed_transitions(self.tag_vocab)\n", + " transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n", + " for s, e in allowed_trans:\n", + " transition[s, e] = 0\n", + " self.register_buffer('transition', transition)\n", + "\n", + "model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))" + ] + }, + { + "cell_type": "markdown", + "id": "5aeee1e9", + "metadata": {}, + "source": [ + "### Trainer 的使用\n", + "fastNLP 的 Trainer 是用于对模型进行训练的部件。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f4250f0b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[10:49:22] INFO     Running evaluator sanity check for 2 batches.              trainer.py:661\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
+++++++++++++++++++++++++++++ Eval. results on Epoch:1, Batch:0 +++++++++++++++++++++++++++++\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#f\": 0.402447,\n",
+       "  \"pre#f\": 0.447906,\n",
+       "  \"rec#f\": 0.365365\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n", + " \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n", + " \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10:51:15] INFO     The best performance for monitor f#f:0.402447 was progress_callback.py:37\n",
+       "                    achieved in Epoch:1, Global Batch:625. The                               \n",
+       "                    evaluation result:                                                       \n",
+       "                    {'f#f': 0.402447, 'pre#f': 0.447906, 'rec#f':                            \n",
+       "                    0.365365}                                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m evaluation result: \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     Loading best model from buffer with f#f:  load_best_model_callback.py:115\n",
+       "                    0.402447...                                                              \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from buffer with f#f: \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from torch import optim\n", + "from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n", + "from fastNLP import SpanFPreRecMetric\n", + "\n", + "optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n", + "callbacks = [\n", + " LoadBestModelCallback(), # 用于在训练结束之后加载性能最好的model的权重\n", + " TorchWarmupCallback()\n", + "] \n", + "\n", + "trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n", + " evaluate_dataloaders=dls['dev'], \n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n", + " n_epochs=1, callbacks=callbacks, \n", + " # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n", + " evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n", + " device=0, monitor='f#f', fp16=False) # fp16 为 True 的话,将使用 float16 进行训练。\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c600a450", + "metadata": {}, + "source": [ + "### Evaluator的使用\n", + "fastNLP中用于评测数据的对象。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b19f0ba", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP import SpanFPreRecMetric\n", + "\n", + "evaluator = Evaluator(model=model, dataloaders=dls['test'], \n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n", + " evaluate_input_mapping={'first_len': 'seq_len'}, \n", + " device=0)\n", + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52f87770", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f723fe399df34917875ad74c2542508c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# 如果想评测一下使用 constrained decoding的性能,则可以通过传入 evaluate_fn 指定使用的函数\n", + "def input_mapping(x):\n", + " x['seq_len'] = x['first_len']\n", + " return x\n", + "evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n", + " evaluate_fn='constrained_decode',\n", + " # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数,因此\n", + " # 额外重复一下 'first_len': 'first_len',使得这个参数不会消失。\n", + " evaluate_input_mapping=input_mapping)\n", + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "419e718b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_0.ipynb b/docs/source/tutorials/fastnlp_tutorial_0.ipynb new file mode 100644 index 00000000..09667794 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_0.ipynb @@ -0,0 +1,1352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aec0fde7", + "metadata": {}, + "source": [ + "# T0. trainer 和 evaluator 的基本使用\n", + "\n", + "  1   trainer 和 evaluator 的基本关系\n", + " \n", + "    1.1   trainer 和 evaluater 的初始化\n", + "\n", + "    1.2   driver 的含义与使用要求\n", + "\n", + "    1.3   trainer 内部初始化 evaluater\n", + "\n", + "  2   使用 fastNLP 搭建 argmax 模型\n", + "\n", + "    2.1   trainer_step 和 evaluator_step\n", + "\n", + "    2.2   trainer 和 evaluator 的参数匹配\n", + "\n", + "    2.3   示例:argmax 模型的搭建\n", + "\n", + "  3   使用 fastNLP 训练 argmax 模型\n", + " \n", + "    3.1   trainer 外部初始化的 evaluator\n", + "\n", + "    3.2   trainer 内部初始化的 evaluator " + ] + }, + { + "cell_type": "markdown", + "id": "09ea669a", + "metadata": {}, + "source": [ + "## 1. trainer 和 evaluator 的基本关系\n", + "\n", + "### 1.1 trainer 和 evaluator 的初始化\n", + "\n", + "在`fastNLP 1.0`中,`Trainer`模块和`Evaluator`模块分别表示 **“训练器”和“评测器”**\n", + "\n", + "  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n", + "\n", + "在`fastNLP 1.0`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n", + "\n", + "  非常关键的问题在于**如何正确设置二者的 driver**。这就引入了另一个问题:什么是 `driver`?\n", + "\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", + " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", + " ...\n", + " driver=\"torch\", # 使用 pytorch 模块进行训练 \n", + " device='cuda', # 使用 GPU:0 显卡执行训练\n", + " ...\n", + " )\n", + "...\n", + "evaluator = Evaluator(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", + " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", + " ...\n", + " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", + " device=None,\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "3c11fe1a", + "metadata": {}, + "source": [ + "### 1.2 driver 的含义与使用要求\n", + "\n", + "在`fastNLP 1.0`中,**driver**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n", + "\n", + "  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n", + "\n", + "在`fastNLP 1.0`中,**Trainer 和 Evaluator 都依赖于具体的 driver 来完成整体的工作流程**\n", + "\n", + "  具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n", + "\n", + "注:这里给出一条建议:**在同一脚本中**,**所有的** Trainer **和** Evaluator **使用的** driver **应当保持一致**\n", + "\n", + "  尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n", + "\n", + "  多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦" + ] + }, + { + "cell_type": "markdown", + "id": "2cac4a1a", + "metadata": {}, + "source": [ + "### 1.3 Trainer 内部初始化 Evaluator\n", + "\n", + "在`fastNLP 1.0`中,如果在**初始化 Trainer 时**,**传入参数 evaluator_dataloaders 和 metrics **\n", + "\n", + "  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", + " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0c9c7dda", + "metadata": {}, + "source": [ + "## 2. argmax 模型的搭建实例" + ] + }, + { + "cell_type": "markdown", + "id": "524ac200", + "metadata": {}, + "source": [ + "### 2.1 trainer_step 和 evaluator_step\n", + "\n", + "在`fastNLP 1.0`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", + "\n", + "  添加`pytorch`要求的`forward`方法外,还需要添加 `train_step` 和 `evaluate_step` 这两个方法\n", + "\n", + "```python\n", + "class Model(torch.nn.Module):\n", + " def __init__(self):\n", + " super(Model, self).__init__()\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + " pass\n", + "\n", + " def forward(self, x):\n", + " pass\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}\n", + "```\n", + "***\n", + "在`fastNLP 1.0`中,**函数 train_step 是 Trainer 中参数 train_fn 的默认值**\n", + "\n", + "  由于,在`Trainer`训练时,**Trainer 通过参数 train_fn 对应的模型方法获得当前数据批次的损失值**\n", + "\n", + "  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n", + "\n", + "    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n", + "\n", + "注:在`fastNLP 1.0`中,**Trainer 要求模型通过 train_step 来返回一个字典**,**满足如 {\"loss\": loss} 的形式**\n", + "\n", + "  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换,详见(trainer的详细讲解,待补充)\n", + "\n", + "同样,在`fastNLP 1.0`中,**函数 evaluate_step 是 Evaluator 中参数 evaluate_fn 的默认值**\n", + "\n", + "  在`Evaluator`测试时,**Evaluator 通过参数 evaluate_fn 对应的模型方法获得当前数据批次的评测结果**\n", + "\n", + "  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n", + "\n", + "  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "fb3272eb", + "metadata": {}, + "source": [ + "### 2.2 trainer 和 evaluator 的参数匹配\n", + "\n", + "在`fastNLP 1.0`中,参数匹配涉及到两个方面,分别是在\n", + "\n", + "  一方面,**在模型的前向传播中**,**dataloader 向 train_step 或 evaluate_step 函数传递 batch**\n", + "\n", + "  另方面,**在模型的评测过程中**,**evaluate_dataloader 向 metric 的 update 函数传递 batch**\n", + "\n", + "对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n", + "\n", + "    **fastNLP 1.0 要求 dataloader 生成的每个 batch **,**满足如 {\"x\": x, \"y\": y} 的形式**\n", + "\n", + "  同时,`fastNLP 1.0`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n", + "\n", + "    **字典形式的定义**,**对应在 Dataset 定义的 \\_\\_getitem\\_\\_ 方法中**,例如下方的`ArgMaxDatset`\n", + "\n", + "  而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", + "\n", + "    `fastNLP 1.0`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", + "\n", + "```python\n", + "class Dataset(torch.utils.data.Dataset):\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " def __len__(self):\n", + " return len(self.x)\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "f5f1a6aa", + "metadata": {}, + "source": [ + "对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n", + "\n", + "    **update 函数**,**针对一个 batch 的预测结果**,计算其累计的评价指标\n", + "\n", + "    **get_metric 函数**,**统计 update 函数累计的评价指标**,来计算最终的评价结果\n", + "\n", + "  例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n", + "\n", + "    而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n", + "\n", + "  在此基础上,**fastNLP 1.0 要求 evaluate_dataloader 生成的每个 batch 传递给对应的 metric**\n", + "\n", + "    **以 {\"pred\": y_pred, \"target\": y_true} 的形式**,对应其`update`函数的函数签名\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "f62b7bb1", + "metadata": {}, + "source": [ + "### 2.3 示例:argmax 模型的搭建\n", + "\n", + "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", + "\n", + "  首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5314482b", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class ArgMaxModel(nn.Module):\n", + " def __init__(self, num_labels, feature_dimension):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + "\n", + " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", + " self.ac1 = nn.ReLU()\n", + " self.linear2 = nn.Linear(in_features=10, out_features=10)\n", + " self.ac2 = nn.ReLU()\n", + " self.output = nn.Linear(in_features=10, out_features=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " pred = self.ac1(self.linear1(x))\n", + " pred = self.ac2(self.linear2(pred))\n", + " pred = self.output(pred)\n", + " return pred\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}" + ] + }, + { + "cell_type": "markdown", + "id": "71f3fa6b", + "metadata": {}, + "source": [ + "  接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n", + "\n", + "    数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n", + "\n", + "    数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fe612e61", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "class ArgMaxDataset(Dataset):\n", + " def __init__(self, feature_dimension, data_num=1000, seed=0):\n", + " self.num_labels = feature_dimension\n", + " self.feature_dimension = feature_dimension\n", + " self.data_num = data_num\n", + " self.seed = seed\n", + "\n", + " g = torch.Generator()\n", + " g.manual_seed(1000)\n", + " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n", + " self.y = torch.max(self.x, dim=-1)[1]\n", + "\n", + " def __len__(self):\n", + " return self.data_num\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}" + ] + }, + { + "cell_type": "markdown", + "id": "2cb96332", + "metadata": {}, + "source": [ + "  然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n", + "\n", + "    再根据`ArgMaxDataset`类初始化两个数据集实例,分别用来模型测试和模型评测,数据量各1000笔" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76172ef8", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n", + "\n", + "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n", + "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)" + ] + }, + { + "cell_type": "markdown", + "id": "4e7d25ee", + "metadata": {}, + "source": [ + "  此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块,批量大小同为8,分别用于训练和测评" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "363b5b09", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n", + "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)" + ] + }, + { + "cell_type": "markdown", + "id": "c8d4443f", + "metadata": {}, + "source": [ + "  最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dc28a2d9", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.optim import SGD\n", + "\n", + "optimizer = SGD(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "eb8ca6cf", + "metadata": {}, + "source": [ + "## 3. 使用 fastNLP 1.0 训练 argmax 模型\n", + "\n", + "### 3.1 trainer 外部初始化的 evaluator" + ] + }, + { + "cell_type": "markdown", + "id": "55145553", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n", + "\n", + "  需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n", + "\n", + "  通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n", + "\n", + "    但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n", + "\n", + "  通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b51b7a2d", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 设定迭代轮数 \n", + " progress_bar=\"auto\" # 设定进度条格式\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6e202d6e", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + "  其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n", + "\n", + "  `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ba047ead", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c16c5fa4", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n", + "\n", + "  需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n", + "\n", + "  需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n", + "\n", + "  类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1c6b6b36", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP import Accuracy\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n", + " device=None,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "8157bb9b", + "metadata": {}, + "source": [ + "通过使用`Evaluator`类的`run`函数,进行训练\n", + "\n", + "  其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n", + "\n", + "  最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f7cb0165", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "dd9f68fa", + "metadata": {}, + "source": [ + "### 3.2 trainer 内部初始化的 evaluator \n", + "\n", + "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n", + "\n", + "  通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n", + "\n", + "  但是中间的评估结果仍会保留;**通过 evaluate_every 设定评估频率**,可以为负数、正数或者函数:\n", + "\n", + "    **为负数时**,**表示每隔几个 epoch 评估一次**;**为正数时**,**则表示每隔几个 batch 评估一次**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "183c7d19", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " optimizers=optimizer,\n", + " n_epochs=10, \n", + " evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "714cc404", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + "  还可以通过**参数 num_eval_sanity_batch 决定每次训练前运行多少个 evaluate_batch 进行评测**,**默认为 2 **\n", + "\n", + "  之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2e4daa2c", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[18:28:25] INFO     Running evaluator sanity check for 2 batches.              trainer.py:592\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.31,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 31.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.33,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 33.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.34,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 34.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.36,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 36.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.37,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 37.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.4,\n",
+       "  \"total#acc\": 100.0,\n",
+       "  \"correct#acc\": 40.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c4e9c619", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}"
+      ]
+     },
+     "execution_count": 12,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1bc7cb4a",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "metadata": {
+     "collapsed": false
+    },
+    "source": []
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_1.ipynb b/docs/source/tutorials/fastnlp_tutorial_1.ipynb
new file mode 100644
index 00000000..cff81a21
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_1.ipynb
@@ -0,0 +1,1333 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "cdc25fcd",
+   "metadata": {},
+   "source": [
+    "# T1. dataset 和 vocabulary 的基本使用\n",
+    "\n",
+    "  1   dataset 的使用与结构\n",
+    " \n",
+    "    1.1   dataset 的结构与创建\n",
+    "\n",
+    "    1.2   dataset 的数据预处理\n",
+    "\n",
+    "    1.3   延伸:instance 和 field\n",
+    "\n",
+    "  2   vocabulary 的结构与使用\n",
+    "\n",
+    "    2.1   vocabulary 的创建与修改\n",
+    "\n",
+    "    2.2   vocabulary 与 OOV 问题\n",
+    "\n",
+    "  3   dataset 和 vocabulary 的组合使用\n",
+    " \n",
+    "    3.1   从 dataframe 中加载 dataset\n",
+    "\n",
+    "    3.2   从 dataset 中获取 vocabulary"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0eb18a22",
+   "metadata": {},
+   "source": [
+    "## 1. dataset 的基本使用\n",
+    "\n",
+    "### 1.1  dataset 的结构与创建\n",
+    "\n",
+    "在`fastNLP 1.0`中,使用`DataSet`模块表示数据集,**dataset 类似于关系型数据库中的数据表**(下文统一为小写 `dataset`)\n",
+    "\n",
+    "  **主要包含 field 字段和 instance 实例两个元素**,对应 table 中的 field 字段和`record`记录\n",
+    "\n",
+    "在`fastNLP 1.0`中,`DataSet`模块被定义在`fastNLP.core.dataset`路径下,导入该模块后,最简单的\n",
+    "\n",
+    "  初始化方法,即将字典形式的表格 **{'field1': column1, 'field2': column2, ...}** 传入构造函数"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "a1d69ad2",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"],\n", + " 'words': [['This', 'is', 'an', 'apple', '.'], \n", + " ['I', 'like', 'apples', '.'], \n", + " ['Apples', 'are', 'good', 'for', 'our', 'health', '.']],\n", + " 'num': [5, 4, 7]}\n", + "\n", + "dataset = DataSet(data)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9260fdc6", + "metadata": {}, + "source": [ + "  在`dataset`的实例中,字段`field`的名称和实例`instance`中的字符串也可以中文" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3d72ef00", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------+--------------------+------------------------+------+\n", + "| 序号 | 句子 | 字符 | 长度 |\n", + "+------+--------------------+------------------------+------+\n", + "| 0 | 生活就像海洋, | ['生', '活', '就', ... | 7 |\n", + "| 1 | 只有意志坚强的人, | ['只', '有', '意', ... | 9 |\n", + "| 2 | 才能到达彼岸。 | ['才', '能', '到', ... | 7 |\n", + "+------+--------------------+------------------------+------+\n" + ] + } + ], + "source": [ + "temp = {'序号': [0, 1, 2], \n", + " '句子':[\"生活就像海洋,\", \"只有意志坚强的人,\", \"才能到达彼岸。\"],\n", + " '字符': [['生', '活', '就', '像', '海', '洋', ','], \n", + " ['只', '有', '意', '志', '坚', '强', '的', '人', ','], \n", + " ['才', '能', '到', '达', '彼', '岸', '。']],\n", + " '长度': [7, 9, 7]}\n", + "\n", + "chinese = DataSet(temp)\n", + "print(chinese)" + ] + }, + { + "cell_type": "markdown", + "id": "202e5490", + "metadata": {}, + "source": [ + "在`dataset`中,使用`drop`方法可以删除满足条件的实例,这里使用了python中的`lambda`表达式\n", + "\n", + "  注一:在`drop`方法中,通过设置`inplace`参数将删除对应实例后的`dataset`作为一个新的实例生成" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "09b478f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2492313174344 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped = dropped.drop(lambda ins:ins['num'] < 5, inplace=False)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "aa277674", + "metadata": {}, + "source": [ + "  注二:**对对象使用等号一般表示传引用**,所以对`dataset`使用等号,是传引用而不是赋值\n", + "\n", + "    如下所示,**dropped 和 dataset 具有相同 id**,**对 dropped 执行删除操作 dataset 同时会被修改**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "77c8583a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2491986424200 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped.drop(lambda ins:ins['num'] < 5)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "a76199dc", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_instance`方法可以删除对应序号的`instance`实例,序号从0开始" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d8824b40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+--------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "+-----+--------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.delete_instance(2)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "f4fa9f33", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_field`方法可以删除对应名称的`field`字段" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f68ddb40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+--------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "+-----+--------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset.delete_field('num')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "b1e9d42c", + "metadata": {}, + "source": [ + "### 1.2 dataset 的数据预处理\n", + "\n", + "在`dataset`模块中,`apply`、`apply_field`、`apply_more`和`apply_field_more`函数可以进行简单的数据预处理\n", + "\n", + "  **apply 和 apply_more 输入整条实例**,**apply_field 和 apply_field_more 仅输入实例的部分字段**\n", + "\n", + "  **apply 和 apply_field 仅输出单个字段**,**apply_more 和 apply_field_more 则是输出多个字段**\n", + "\n", + "  **apply 和 apply_field 返回的是个列表**,**apply_more 和 apply_field_more 返回的是个字典**\n", + "\n", + "    预处理过程中,通过`progress_bar`参数设置显示进度条类型,通过`num_proc`设置多进程\n", + "***\n", + "\n", + "`apply`的参数包括一个函数`func`和一个新字段名`new_field_name`,函数`func`的处理对象是`dataset`模块中\n", + "\n", + "  的每个`instance`实例,函数`func`的处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "72a0b5f9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00,\n", + " 'words': ,\n", + " 'num': }" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.get_all_fields()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "5433815c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['num', 'sentence', 'words']" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.get_field_names()" + ] + }, + { + "cell_type": "markdown", + "id": "4964eeed", + "metadata": {}, + "source": [ + "其他`dataset`的基本使用:通过`in`或者`has_field`方法可以判断`dataset`的是否包含某种字段\n", + "\n", + "  通过`rename_field`方法可以更改`dataset`中的字段名称;通过`concat`方法可以实现两个`dataset`中的拼接\n", + "\n", + "  通过`len`可以统计`dataset`中的实例数目;`dataset`的全部变量与函数可以通过`dir(dataset)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "25ce5488", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3 False\n", + "6 True\n", + "+------------------------------+------------------------------+--------+\n", + "| sentence | words | length |\n", + "+------------------------------+------------------------------+--------+\n", + "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", + "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", + "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", + "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", + "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", + "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", + "+------------------------------+------------------------------+--------+\n" + ] + } + ], + "source": [ + "print(len(dataset), dataset.has_field('length')) \n", + "if 'num' in dataset:\n", + " dataset.rename_field('num', 'length')\n", + "elif 'length' in dataset:\n", + " dataset.rename_field('length', 'num')\n", + "dataset.concat(dataset)\n", + "print(len(dataset), dataset.has_field('length')) \n", + "print(dataset) " + ] + }, + { + "cell_type": "markdown", + "id": "e30a6cd7", + "metadata": {}, + "source": [ + "## 2. vocabulary 的结构与使用\n", + "\n", + "### 2.1 vocabulary 的创建与修改\n", + "\n", + "在`fastNLP 1.0`中,使用`Vocabulary`模块表示词汇表,**vocabulary 的核心是从单词到序号的映射**\n", + "\n", + "  可以直接通过构造函数实例化,通过查找`word2idx`属性,可以找到`vocabulary`映射对应的字典实现\n", + "\n", + "  **默认补零 padding 用 \\ 表示**,**对应序号为0**;**未知单词 unknown 用 \\ 表示**,**对应序号1**\n", + "\n", + "  通过打印`vocabulary`可以看到词汇表中的单词列表,其中,`padding`和`unknown`不会显示" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3515e096", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vocabulary([]...)\n", + "{'': 0, '': 1}\n", + " 0\n", + " 1\n" + ] + } + ], + "source": [ + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "print(vocab)\n", + "print(vocab.word2idx)\n", + "print(vocab.padding, vocab.padding_idx)\n", + "print(vocab.unknown, vocab.unknown_idx)" + ] + }, + { + "cell_type": "markdown", + "id": "640be126", + "metadata": {}, + "source": [ + "在`vocabulary`中,通过`add_word`方法或`add_word_lst`方法,可以单独或批量添加单词\n", + "\n", + "  通过`len`或`word_count`属性,可以显示`vocabulary`的单词量和每个单词添加的次数" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "88c7472a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5 Counter({'生活': 1, '就像': 1, '海洋': 1})\n", + "6 Counter({'生活': 1, '就像': 1, '海洋': 1, '只有': 1})\n", + "6 {'': 0, '': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5}\n" + ] + } + ], + "source": [ + "vocab.add_word_lst(['生活', '就像', '海洋'])\n", + "print(len(vocab), vocab.word_count)\n", + "vocab.add_word('只有')\n", + "print(len(vocab), vocab.word_count)\n", + "print(len(vocab), vocab.word2idx)" + ] + }, + { + "cell_type": "markdown", + "id": "f9ec8b28", + "metadata": {}, + "source": [ + "  **通过 to_word 方法可以找到单词对应的序号**,**通过 to_index 方法可以找到序号对应的单词**\n", + "\n", + "    由于序号0和序号1已经被占用,所以**新加入的词的序号从2开始计数**,如`'生活'`对应2\n", + "\n", + "    通过`has_word`方法可以判断单词是否在词汇表中,没有的单词被判做``" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "3447acde", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 0\n", + " 1\n", + "生活 2\n", + "彼岸 1 False\n" + ] + } + ], + "source": [ + "print(vocab.to_word(0), vocab.to_index(''))\n", + "print(vocab.to_word(1), vocab.to_index(''))\n", + "print(vocab.to_word(2), vocab.to_index('生活'))\n", + "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))" + ] + }, + { + "cell_type": "markdown", + "id": "b4e36850", + "metadata": {}, + "source": [ + "**vocabulary 允许反复添加相同单词**,**可以通过 word_count 方法看到相应单词被添加的次数**\n", + "\n", + "  但其中没有``和``,`vocabulary`的全部变量与函数可以通过`dir(vocabulary)`查询\n", + "\n", + "  注:**使用 add_word_lst 添加单词**,**单词对应序号不会动态调整**,**使用 dataset 添加单词的情况不同**" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "490b101c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "生活 2\n", + "彼岸 12 True\n", + "13 Counter({'人': 4, '生活': 2, '就像': 2, '海洋': 2, '只有': 2, '意志': 1, '坚强的': 1, '才': 1, '能': 1, '到达': 1, '彼岸': 1})\n", + "13 {'': 0, '': 1, '生活': 2, '就像': 3, '海洋': 4, '只有': 5, '人': 6, '意志': 7, '坚强的': 8, '才': 9, '能': 10, '到达': 11, '彼岸': 12}\n" + ] + } + ], + "source": [ + "vocab.add_word_lst(['生活', '就像', '海洋', '只有', '意志', '坚强的', '人', '人', '人', '人', '才', '能', '到达', '彼岸'])\n", + "print(vocab.to_word(2), vocab.to_index('生活'))\n", + "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))\n", + "print(len(vocab), vocab.word_count)\n", + "print(len(vocab), vocab.word2idx)" + ] + }, + { + "cell_type": "markdown", + "id": "23e32a63", + "metadata": {}, + "source": [ + "### 2.2 vocabulary 与 OOV 问题\n", + "\n", + "在`vocabulary`模块初始化的时候,可以通过指定`unknown`和`padding`为`None`,限制其存在\n", + "\n", + "  此时添加单词直接从0开始标号,如果遇到未知单词会直接报错,即 out of vocabulary" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "a99ff909", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'positive': 0, 'negative': 1}\n", + "ValueError: word `neutral` not in vocabulary\n" + ] + } + ], + "source": [ + "vocab = Vocabulary(unknown=None, padding=None)\n", + "\n", + "vocab.add_word_lst(['positive', 'negative'])\n", + "print(vocab.word2idx)\n", + "\n", + "try:\n", + " print(vocab.to_index('neutral'))\n", + "except ValueError:\n", + " print(\"ValueError: word `neutral` not in vocabulary\")" + ] + }, + { + "cell_type": "markdown", + "id": "618da6bd", + "metadata": {}, + "source": [ + "  相应的,如果只指定其中的`unknown`,则编号会后移一个,同时遇到未知单词全部当做``" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "432f74c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'': 0, 'positive': 1, 'negative': 2}\n", + "0 \n" + ] + } + ], + "source": [ + "vocab = Vocabulary(unknown='', padding=None)\n", + "\n", + "vocab.add_word_lst(['positive', 'negative'])\n", + "print(vocab.word2idx)\n", + "\n", + "print(vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral')))" + ] + }, + { + "cell_type": "markdown", + "id": "b6263f73", + "metadata": {}, + "source": [ + "## 3 dataset 和 vocabulary 的组合使用\n", + " \n", + "### 3.1 从 dataframe 中加载 dataset\n", + "\n", + "以下通过 [NLP-beginner](https://github.com/FudanNLP/nlp-beginner) 实践一中 [Rotten Tomatoes 影评数据集](https://www.kaggle.com/c/sentiment-analysis-on-movie-reviews) 的部分训练数据组成`test4dataset.tsv`文件\n", + "\n", + "  介绍如何使用`dataset`、`vocabulary`简单加载并处理数据集,首先使用`pandas`模块,读取原始数据的`dataframe`" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "3dbd985d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
SentenceIdSentenceSentiment
01A series of escapades demonstrating the adage ...negative
12This quiet , introspective and entertaining in...positive
23Even fans of Ismail Merchant 's work , I suspe...negative
34A positively thrilling combination of ethnogra...neutral
45A comedy-drama of nearly epic proportions root...positive
56The Importance of Being Earnest , so thick wit...neutral
\n", + "
" + ], + "text/plain": [ + " SentenceId Sentence Sentiment\n", + "0 1 A series of escapades demonstrating the adage ... negative\n", + "1 2 This quiet , introspective and entertaining in... positive\n", + "2 3 Even fans of Ismail Merchant 's work , I suspe... negative\n", + "3 4 A positively thrilling combination of ethnogra... neutral\n", + "4 5 A comedy-drama of nearly epic proportions root... positive\n", + "5 6 The Importance of Being Earnest , so thick wit... neutral" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "df = pd.read_csv('./data/test4dataset.tsv', sep='\\t')\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "919ab350", + "metadata": {}, + "source": [ + "接着,通过`dataset`中的`from_pandas`方法填充数据集,并使用`apply_more`方法对文本进行分词操作" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "4f634586", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6 [00:00': 0, '': 1, 'a': 2, 'of': 3, ',': 4, 'the': 5, '.': 6, 'is': 7, 'and': 8, 'good': 9, 'for': 10, 'which': 11, 'this': 12, \"'s\": 13, 'series': 14, 'escapades': 15, 'demonstrating': 16, 'adage': 17, 'that': 18, 'what': 19, 'goose': 20, 'also': 21, 'gander': 22, 'some': 23, 'occasionally': 24, 'amuses': 25, 'but': 26, 'none': 27, 'amounts': 28, 'to': 29, 'much': 30, 'story': 31, 'quiet': 32, 'introspective': 33, 'entertaining': 34, 'independent': 35, 'worth': 36, 'seeking': 37, 'even': 38, 'fans': 39, 'ismail': 40, 'merchant': 41, 'work': 42, 'i': 43, 'suspect': 44, 'would': 45, 'have': 46, 'hard': 47, 'time': 48, 'sitting': 49, 'through': 50, 'one': 51, 'positively': 52, 'thrilling': 53, 'combination': 54, 'ethnography': 55, 'all': 56, 'intrigue': 57, 'betrayal': 58, 'deceit': 59, 'murder': 60, 'shakespearean': 61, 'tragedy': 62, 'or': 63, 'juicy': 64, 'soap': 65, 'opera': 66, 'comedy-drama': 67, 'nearly': 68, 'epic': 69, 'proportions': 70, 'rooted': 71, 'in': 72, 'sincere': 73, 'performance': 74, 'by': 75, 'title': 76, 'character': 77, 'undergoing': 78, 'midlife': 79, 'crisis': 80, 'importance': 81, 'being': 82, 'earnest': 83, 'so': 84, 'thick': 85, 'with': 86, 'wit': 87, 'it': 88, 'plays': 89, 'like': 90, 'reading': 91, 'from': 92, 'bartlett': 93, 'familiar': 94, 'quotations': 95} \n", + "\n", + "Vocabulary(['a', 'series', 'of', 'escapades', 'demonstrating']...)\n" + ] + } + ], + "source": [ + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab = vocab.from_dataset(dataset, field_name='Sentence')\n", + "print(vocab.word_count, '\\n')\n", + "print(vocab.word2idx, '\\n')\n", + "print(vocab)" + ] + }, + { + "cell_type": "markdown", + "id": "f0857ccb", + "metadata": {}, + "source": [ + "之后,**通过 vocabulary 的 index_dataset 方法**,**调整 dataset 中指定字段的元素**,**使用编号将之代替**\n", + "\n", + "  使用上述方法,可以将影评数据集中的单词序列转化为词编号序列,为接下来转化为词嵌入序列做准备" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "2f9a04b2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------+------------------------------+-----------+\n", + "| SentenceId | Sentence | Sentiment |\n", + "+------------+------------------------------+-----------+\n", + "| 1 | [2, 14, 3, 15, 16, 5, 17,... | negative |\n", + "| 2 | [12, 32, 4, 33, 8, 34, 35... | positive |\n", + "| 3 | [38, 39, 3, 40, 41, 13, 4... | negative |\n", + "| 4 | [2, 52, 53, 54, 3, 55, 8,... | neutral |\n", + "| 5 | [2, 67, 3, 68, 69, 70, 71... | positive |\n", + "| 6 | [5, 81, 3, 82, 83, 4, 84,... | neutral |\n", + "+------------+------------------------------+-----------+\n" + ] + } + ], + "source": [ + "vocab.index_dataset(dataset, field_name='Sentence')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "6b26b707", + "metadata": {}, + "source": [ + "最后,使用相同方法,再将`dataset`中`Sentiment`字段中的`negative`、`neutral`、`positive`转化为数字编号" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "5f5eed18", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'negative': 0, 'positive': 1, 'neutral': 2}\n", + "+------------+------------------------------+-----------+\n", + "| SentenceId | Sentence | Sentiment |\n", + "+------------+------------------------------+-----------+\n", + "| 1 | [2, 14, 3, 15, 16, 5, 17,... | 0 |\n", + "| 2 | [12, 32, 4, 33, 8, 34, 35... | 1 |\n", + "| 3 | [38, 39, 3, 40, 41, 13, 4... | 0 |\n", + "| 4 | [2, 52, 53, 54, 3, 55, 8,... | 2 |\n", + "| 5 | [2, 67, 3, 68, 69, 70, 71... | 1 |\n", + "| 6 | [5, 81, 3, 82, 83, 4, 84,... | 2 |\n", + "+------------+------------------------------+-----------+\n" + ] + } + ], + "source": [ + "target_vocab = Vocabulary(padding=None, unknown=None)\n", + "\n", + "target_vocab.from_dataset(dataset, field_name='Sentiment')\n", + "print(target_vocab.word2idx)\n", + "target_vocab.index_dataset(dataset, field_name='Sentiment')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "eed7ea64", + "metadata": {}, + "source": [ + "在最后的最后,通过以下的一张图,来总结本章关于`dataset`和`vocabulary`主要知识点的讲解,以及两者的联系\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35b4f0f7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_2.ipynb b/docs/source/tutorials/fastnlp_tutorial_2.ipynb new file mode 100644 index 00000000..546e471d --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_2.ipynb @@ -0,0 +1,884 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# T2. databundle 和 tokenizer 的基本使用\n", + "\n", + "  1   fastNLP 中 dataset 的延伸\n", + "\n", + "    1.1   databundle 的概念与使用\n", + "\n", + "  2   fastNLP 中的 tokenizer\n", + " \n", + "    2.1   PreTrainedTokenizer 的概念\n", + "\n", + "    2.2   BertTokenizer 的基本使用\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. fastNLP 中 dataset 的延伸\n", + "\n", + "### 1.1 databundle 的概念与使用\n", + "\n", + "在`fastNLP 1.0`中,在常用的数据加载模块`DataLoader`和数据集`DataSet`模块之间,还存在\n", + "\n", + "  一个中间模块,即 **数据包 DataBundle 模块**,可以从`fastNLP.io`路径中导入该模块\n", + "\n", + "在`fastNLP 1.0`中,**一个 databundle 数据包包含若干 dataset 数据集和 vocabulary 词汇表**\n", + "\n", + "  分别存储在`datasets`和`vocabs`两个变量中,所以了解`databundle`数据包之前\n", + "\n", + "需要首先**复习 dataset 数据集和 vocabulary 词汇表**,**下面的一串代码**,**你知道其大概含义吗?**\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6 [00:00': 0, '': 1, 'negative': 2, 'positive': 3, 'neutral': 4}\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "from fastNLP import DataSet\n", + "from fastNLP import Vocabulary\n", + "from fastNLP.io import DataBundle\n", + "\n", + "datasets = DataSet.from_pandas(pd.read_csv('./data/test4dataset.tsv', sep='\\t'))\n", + "datasets.rename_field('Sentence', 'text')\n", + "datasets.rename_field('Sentiment', 'label')\n", + "datasets.apply_more(lambda ins:{'label': ins['label'].lower(), \n", + " 'text': ins['text'].lower().split()},\n", + " progress_bar='tqdm')\n", + "datasets.delete_field('SentenceId')\n", + "train_ds, test_ds = datasets.split(ratio=0.7)\n", + "datasets = {'train': train_ds, 'test': test_ds}\n", + "print(datasets['train'])\n", + "print(datasets['test'])\n", + "\n", + "vocabs = {}\n", + "vocabs['label'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='label')\n", + "vocabs['text'] = Vocabulary().from_dataset(datasets['train'].concat(datasets['test'], inplace=False), field_name='text')\n", + "print(vocabs['label'].word2idx)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "上述代码的含义是:从`test4dataset`的 6 条数据中,划分 4 条训练集(`int(6*0.7) = 4`),2 条测试集\n", + "\n", + "    修改相关字段名称,删除序号字段,同时将标签都设为小写,对文本进行分词\n", + "\n", + "  接着通过`concat`方法拼接测试集训练集,注意设置`inplace=False`,生成临时的新数据集\n", + "\n", + "  使用`from_dataset`方法从拼接的数据集中抽取词汇表,为将数据集中的单词替换为序号做准备\n", + "\n", + "由此就可以得到**数据集字典 datasets**(**对应训练集、测试集**)和**词汇表字典 vocabs**(**对应数据集各字段**)\n", + "\n", + "  然后就可以初始化`databundle`了,通过`print`可以观察其大致结构,效果如下" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 2 datasets:\n", + "\ttrain has 4 instances.\n", + "\ttest has 2 instances.\n", + "In total 2 vocabs:\n", + "\tlabel has 5 entries.\n", + "\ttext has 96 entries.\n", + "\n", + "['train', 'test']\n", + "['label', 'text']\n" + ] + } + ], + "source": [ + "data_bundle = DataBundle(datasets=datasets, vocabs=vocabs)\n", + "print(data_bundle)\n", + "print(data_bundle.get_dataset_names())\n", + "print(data_bundle.get_vocab_names())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "此外,也可以通过`data_bundle`的`num_dataset`和`num_vocab`返回数据表和词汇表个数\n", + "\n", + "  通过`data_bundle`的`iter_datasets`和`iter_vocabs`遍历数据表和词汇表" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 2 datasets:\n", + "\ttrain has 4 instances.\n", + "\ttest has 2 instances.\n", + "In total 2 datasets:\n", + "\tlabel has 5 entries.\n", + "\ttext has 96 entries.\n" + ] + } + ], + "source": [ + "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n", + "for name, dataset in data_bundle.iter_datasets():\n", + " print(\"\\t%s has %d instances.\" % (name, len(dataset)))\n", + "print(\"In total %d datasets:\" % data_bundle.num_dataset)\n", + "for name, vocab in data_bundle.iter_vocabs():\n", + " print(\"\\t%s has %d entries.\" % (name, len(vocab)))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在数据包`databundle`中,也有和数据集`dataset`类似的四个`apply`函数,即\n", + "\n", + "  `apply`函数、`apply_field`函数、`apply_field_more`函数和`apply_more`函数\n", + "\n", + "  负责对数据集进行预处理,如下所示是`apply_more`函数的示例,其他函数类似\n", + "\n", + "此外,通过`get_dataset`函数,可以通过数据表名`name`称找到对应数据表\n", + "\n", + "  通过`get_vocab`函数,可以通过词汇表名`field_name`称找到对应词汇表" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00\n", + "在`fastNLP 1.0`中,**使用 PreTrainedTokenizer 模块来为数据集中的词语进行词向量的标注**\n", + "\n", + "  需要注意的是,`PreTrainedTokenizer`模块的下载和导入**需要确保环境安装了 transformers 模块**\n", + "\n", + "  这是因为 `fastNLP 1.0`中`PreTrainedTokenizer`模块的实现基于`Huggingface Transformers`库\n", + "\n", + "**Huggingface Transformers 是一个开源的**,**基于 transformer 模型结构提供的预训练语言库**\n", + "\n", + "  包含了多种经典的基于`transformer`的预训练模型,如`BERT`、`BART`、`RoBERTa`、`GPT2`、`CPT`\n", + "\n", + "  更多相关内容可以参考`Huggingface Transformers`的[相关论文](https://arxiv.org/pdf/1910.03771.pdf)、[官方文档](https://huggingface.co/transformers/)以及[的代码仓库](https://github.com/huggingface/transformers)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2.2 BertTokenizer 的基本使用\n", + "\n", + "在`fastNLP 1.0`中,以`PreTrainedTokenizer`为基类,泛化出多个子类,实现基于`BERT`等模型的标注\n", + "\n", + "  本节以`BertTokenizer`模块为例,展示`PreTrainedTokenizer`模块的使用方法与应用实例\n", + "\n", + "**BertTokenizer 的初始化包括 导入模块和导入数据 两步**,先通过从`fastNLP.transformers.torch`中\n", + "\n", + "  导入`BertTokenizer`模块,再**通过 from_pretrained 方法指定 tokenizer 参数类型下载**\n", + "\n", + "  其中,**'bert-base-uncased' 指定 tokenizer 使用的预训练 BERT 类型**:单词不区分大小写\n", + "\n", + "    **模块层数 L=12**,**隐藏层维度 H=768**,**自注意力头数 A=12**,**总参数量 110M**\n", + "\n", + "  另外,模型参数自动下载至 home 目录下的`~\\.cache\\huggingface\\transformers`文件夹中" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "from fastNLP.transformers.torch import BertTokenizer\n", + "\n", + "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "通过变量`vocab_size`和`vocab_files_names`可以查看`BertTokenizer`的词汇表的大小和对应文件\n", + "\n", + "  通过变量`vocab`可以访问`BertTokenizer`预训练的词汇表(由于内容过大就不演示了" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "30522 {'vocab_file': 'vocab.txt'}\n" + ] + } + ], + "source": [ + "print(tokenizer.vocab_size, tokenizer.vocab_files_names)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "通过变量`all_special_tokens`或通过变量`special_tokens_map`可以**查看 BertTokenizer 内置的特殊词素**\n", + "\n", + "  包括**未知符 '[UNK]'**, **断句符 '[SEP]'**, **补零符 '[PAD]'**, **分类符 '[CLS]'**, **掩码 '[MASK]'**\n", + "\n", + "通过变量`all_special_ids`可以**查看 BertTokenizer 内置的特殊词素对应的词汇表编号**,相同功能\n", + "\n", + "  也可以直接通过查看`pad_token`,值为`'[UNK]'`,和`pad_token_id`,值为`0`,等变量来实现" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pad_token [PAD] 0\n", + "unk_token [UNK] 100\n", + "cls_token [CLS] 101\n", + "sep_token [SEP] 102\n", + "msk_token [MASK] 103\n", + "all_tokens ['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 102, 0, 101, 103]\n", + "{'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n" + ] + } + ], + "source": [ + "print('pad_token', tokenizer.pad_token, tokenizer.pad_token_id) \n", + "print('unk_token', tokenizer.unk_token, tokenizer.unk_token_id) \n", + "print('cls_token', tokenizer.cls_token, tokenizer.cls_token_id) \n", + "print('sep_token', tokenizer.sep_token, tokenizer.sep_token_id)\n", + "print('msk_token', tokenizer.mask_token, tokenizer.mask_token_id)\n", + "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n", + "print(tokenizer.special_tokens_map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "此外,还可以添加其他特殊字符,例如起始符`[BOS]`、终止符`[EOS]`,添加后词汇表编号也会相应改变\n", + "\n", + "  *但是如何添加这两个之外的字符,并且如何将这两个的编号设置为 [UNK] 之外的编号???*" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "bos_token [BOS] 100\n", + "eos_token [EOS] 100\n", + "all_tokens ['[BOS]', '[EOS]', '[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]'] [100, 100, 100, 102, 0, 101, 103]\n", + "{'bos_token': '[BOS]', 'eos_token': '[EOS]', 'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}\n" + ] + } + ], + "source": [ + "tokenizer.bos_token = '[BOS]'\n", + "tokenizer.eos_token = '[EOS]'\n", + "# tokenizer.bos_token_id = 104\n", + "# tokenizer.eos_token_id = 105\n", + "print('bos_token', tokenizer.bos_token, tokenizer.bos_token_id)\n", + "print('eos_token', tokenizer.eos_token, tokenizer.eos_token_id)\n", + "print('all_tokens', tokenizer.all_special_tokens, tokenizer.all_special_ids)\n", + "print(tokenizer.special_tokens_map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`BertTokenizer`中,**使用 tokenize 函数和 convert_tokens_to_string 函数可以实现文本和词素列表的互转**\n", + "\n", + "  此外,**使用 convert_tokens_to_ids 函数和 convert_ids_to_tokens 函数则可以实现词素和词素编号的互转**\n", + "\n", + "  上述四个函数的使用效果如下所示,此处可以明显看出,`tokenizer`分词和传统分词的不同效果,例如`'##cap'`" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012]\n", + "['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n", + "a series of escapades demonstrating the adage that what is good for the goose is also good for the gander , some of which occasionally amuses but none of which amounts to much of a story .\n" + ] + } + ], + "source": [ + "text = \"a series of escapades demonstrating the adage that what is \" \\\n", + " \"good for the goose is also good for the gander , some of which \" \\\n", + " \"occasionally amuses but none of which amounts to much of a story .\" \n", + "tks = ['a', 'series', 'of', 'es', '##cap', '##ades', 'demonstrating', 'the', \n", + " 'ada', '##ge', 'that', 'what', 'is', 'good', 'for', 'the', 'goose', \n", + " 'is', 'also', 'good', 'for', 'the', 'gan', '##der', ',', 'some', 'of', \n", + " 'which', 'occasionally', 'am', '##uses', 'but', 'none', 'of', 'which', \n", + " 'amounts', 'to', 'much', 'of', 'a', 'story', '.']\n", + "ids = [ 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, \n", + " 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204,\n", + " 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572,\n", + " 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037,\n", + " 2466, 1012]\n", + "\n", + "tokens = tokenizer.tokenize(text)\n", + "print(tokenizer.convert_tokens_to_ids(tokens))\n", + "\n", + "ids = tokenizer.convert_tokens_to_ids(tokens)\n", + "print(tokenizer.convert_ids_to_tokens(ids))\n", + "\n", + "print(tokenizer.convert_tokens_to_string(tokens))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`BertTokenizer`中,还有另外两个函数可以实现分词标注,分别是 **encode 和 decode 函数**,**可以直接实现**\n", + "\n", + "  **文本字符串和词素编号列表的互转**,但是编码过程中会按照`BERT`的规则,**在句子首末加入 [CLS] 和 [SEP]**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 1012, 102]\n", + "[CLS] a series of escapades demonstrating the adage that what is good for the goose is also good for the gander, some of which occasionally amuses but none of which amounts to much of a story. [SEP]\n" + ] + } + ], + "source": [ + "enc = tokenizer.encode(text)\n", + "print(tokenizer.encode(text))\n", + "dec = tokenizer.decode(enc)\n", + "print(tokenizer.decode(enc))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`encode`函数之上,还有`encode_plus`函数,这也是在数据预处理中,`BertTokenizer`模块最常用到的函数\n", + "\n", + "  **encode 函数的参数**,**encode_plus 函数都有**;**encode 函数词素编号列表**,**encode_plus 函数返回字典**\n", + "\n", + "在`encode_plus`函数的返回值中,字段`input_ids`表示词素编号,其余两个字段后文有详细解释\n", + "\n", + "  **字段 token_type_ids 详见 text_pairs 的示例**,**字段 attention_mask 详见 batch_text 的示例**\n", + "\n", + "在`encode_plus`函数的参数中,参数`add_special_tokens`表示是否按照`BERT`的规则,加入相关特殊字符\n", + "\n", + "  参数`max_length`表示句子截取最大长度(算特殊字符),在参数`truncation=True`时会自动截取\n", + "\n", + "  参数`return_attention_mask`约定返回的字典中是否包括`attention_mask`字段,以上案例如下" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681, 2572, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + ] + } + ], + "source": [ + "text = \"a series of escapades demonstrating the adage that what is good for the goose is also good for \"\\\n", + " \"the gander , some of which occasionally amuses but none of which amounts to much of a story .\" \n", + "\n", + "encoded = tokenizer.encode_plus(text=text, add_special_tokens=True, max_length=32, \n", + " truncation=True, return_attention_mask=True)\n", + "print(encoded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在`encode_plus`函数之上,还有`batch_encode_plus`函数(类似地,在`decode`之上,还有`batch_decode`\n", + "\n", + "  两者参数类似,**batch_encode_plus 函数针对批量文本 batch_text**,**或者批量句对 text_pairs**\n", + "\n", + "在针对批量文本`batch_text`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n", + "\n", + "  可以发现,**attention_mask 字段通过 01 标注出词素序列中该位置是否为补零**,可以用做自注意力的掩模" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 0, 0], [101, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 0, 0, 0, 0, 0, 0, 0], [101, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0]]}\n" + ] + } + ], + "source": [ + "batch_text = [\"a series of escapades demonstrating the adage that\",\n", + " \"what is good for the goose is also good for the gander\",\n", + " \"some of which occasionally amuses\",\n", + " \"but none of which amounts to much of a story\" ]\n", + "\n", + "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=batch_text, padding=True,\n", + " add_special_tokens=True, max_length=16, truncation=True, \n", + " return_attention_mask=True)\n", + "print(encoded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "而在针对批量句对`text_pairs`的示例中,注意`batch_encode_plus`函数返回字典中的`attention_mask`字段\n", + "\n", + "  可以发现,**token_type_ids 字段通过 01 标注出词素序列中该位置为句对中的第几句**,句对用 [SEP] 分割" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [[101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262, 3351, 2008, 102, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036, 2204, 2005, 1996, 25957, 4063, 102], [101, 2070, 1997, 2029, 5681, 2572, 25581, 102, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997, 1037, 2466, 102, 0, 0, 0, 0, 0, 0, 0, 0]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0]]}\n" + ] + } + ], + "source": [ + "text_pairs = [(\"a series of escapades demonstrating the adage that\",\n", + " \"what is good for the goose is also good for the gander\"),\n", + " (\"some of which occasionally amuses\",\n", + " \"but none of which amounts to much of a story\")]\n", + "\n", + "encoded = tokenizer.batch_encode_plus(batch_text_or_text_pairs=text_pairs, padding=True,\n", + " add_special_tokens=True, max_length=32, truncation=True, \n", + " return_attention_mask=True)\n", + "print(encoded)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "回到`encode_plus`上,在接下来的示例中,**使用内置的 functools.partial 模块构造 encode 函数**\n", + "\n", + "  接着**使用该函数对 databundle 进行数据预处理**,由于`tokenizer.encode_plus`返回的是一个字典\n", + "\n", + "  读入的是一个字段,所以此处使用`apply_field_more`方法,得到结果自动并入`databundle`中如下" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "functools.partial(, max_length=32, truncation=True, return_attention_mask=True)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00\n", + "\n", + "在接下来的`tutorial 3.`中,将会介绍`fastNLP v1.0`中的`dataloader`模块,会涉及本章中\n", + "\n", + "  提到的`collator`模块,`fastNLP`的多框架适应以及完整的数据加载过程,敬请期待" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_3.ipynb b/docs/source/tutorials/fastnlp_tutorial_3.ipynb new file mode 100644 index 00000000..4100105a --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_3.ipynb @@ -0,0 +1,621 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "213d538c", + "metadata": {}, + "source": [ + "# T3. dataloader 的内部结构和基本使用\n", + "\n", + "  1   fastNLP 中的 dataloader\n", + " \n", + "    1.1   dataloader 的基本介绍\n", + "\n", + "    1.2   dataloader 的函数创建\n", + "\n", + "  2   fastNLP 中 dataloader 的延伸\n", + "\n", + "    2.1   collator 的概念与使用\n", + "\n", + "    2.2   结合 datasets 框架" + ] + }, + { + "cell_type": "markdown", + "id": "85857115", + "metadata": {}, + "source": [ + "## 1. fastNLP 中的 dataloader\n", + "\n", + "### 1.1 dataloader 的基本介绍\n", + "\n", + "在`fastNLP 1.0`的开发中,最关键的开发目标就是**实现 fastNLP 对当前主流机器学习框架**,例如\n", + "\n", + "  **当下流行的 pytorch**,以及**国产的 paddle 、jittor 和 oneflow 的兼容**,扩大受众的同时,也是助力国产\n", + "\n", + "本着分而治之的思想,我们可以将`fastNLP 1.0`对`pytorch`、`paddle`、`jittor`、`oneflow`框架的兼容,划分为\n", + "\n", + "    **对数据预处理**、**批量 batch 的划分与补齐**、**模型训练**、**模型评测**,**四个部分的兼容**\n", + "\n", + "  针对数据预处理,我们已经在`tutorial-1`中介绍了`dataset`和`vocabulary`的使用\n", + "\n", + "    而结合`tutorial-0`,我们可以发现**数据预处理环节本质上是框架无关的**\n", + "\n", + "    因为在不同框架下,读取的原始数据格式都差异不大,彼此也很容易转换\n", + "\n", + "只有涉及到张量、模型,不同框架才展现出其各自的特色:**pytorch 和 oneflow 中的 tensor 和 nn.Module**\n", + "\n", + "    **在 paddle 中称为 tensor 和 nn.Layer**,**在 jittor 中则称为 Var 和 Module**\n", + "\n", + "    因此,**模型训练、模型评测**,**是兼容的重难点**,我们将会在`tutorial-5`中详细介绍\n", + "\n", + "  针对批量`batch`的处理,作为`fastNLP 1.0`中框架无关部分想框架相关部分的过渡\n", + "\n", + "    就是`dataloader`模块的职责,这也是本篇教程`tutorial-3`讲解的重点\n", + "\n", + "**dataloader 模块的职责**,详细划分可以包含以下三部分,**采样划分、补零对齐、框架匹配**\n", + "\n", + "    第一,确定`batch`大小,确定采样方式,划分后通过迭代器即可得到`batch`序列\n", + "\n", + "    第二,对于序列处理,这也是`fastNLP`主要针对的,将同个`batch`内的数据对齐\n", + "\n", + "    第三,**batch 内数据格式要匹配框架**,**但 batch 结构需保持一致**,**参数匹配机制**\n", + "\n", + "  对此,`fastNLP 1.0`给出了 **TorchDataLoader 、 PaddleDataLoader 、 JittorDataLoader 和 OneflowDataLoader**\n", + "\n", + "    分别针对并匹配不同框架,但彼此之间参数名、属性、方法仍然类似,前两者大致如下表所示\n", + "\n", + "名称|参数|属性|功能|内容\n", + "----|----|----|----|----|\n", + " `dataset` | √ | √ | 指定`dataloader`的数据内容 | |\n", + " `batch_size` | √ | √ | 指定`dataloader`的`batch`大小 | 默认`16` |\n", + " `shuffle` | √ | √ | 指定`dataloader`的数据是否打乱 | 默认`False` |\n", + " `collate_fn` | √ | √ | 指定`dataloader`的`batch`打包方法 | 视框架而定 |\n", + " `sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", + " `batch_sampler` | √ | √ | 指定`dataloader`的`__len__`和`__iter__`函数的实现 | 默认`None` |\n", + " `drop_last` | √ | √ | 指定`dataloader`划分`batch`时是否丢弃剩余的 | 默认`False` |\n", + " `cur_batch_indices` | | √ | 记录`dataloader`当前遍历批量序号 | |\n", + " `num_workers` | √ | √ | 指定`dataloader`开启子进程数量 | 默认`0` |\n", + " `worker_init_fn` | √ | √ | 指定`dataloader`子进程初始方法 | 默认`None` |\n", + " `generator` | √ | √ | 指定`dataloader`子进程随机种子 | 默认`None` |\n", + " `prefetch_factor` | | √ | 指定为每个`worker`装载的`sampler`数量 | 默认`2` |" + ] + }, + { + "cell_type": "markdown", + "id": "60a8a224", + "metadata": {}, + "source": [ + "  论及`dataloader`的函数,其中,`get_batch_indices`用来获取当前遍历到的`batch`序号,其他函数\n", + "\n", + "    包括`set_ignore`、`set_pad`和`databundle`类似,请参考`tutorial-2`,此处不做更多介绍\n", + "\n", + "    以下是`tutorial-2`中已经介绍过的数据预处理流程,接下来是对相关数据进行`dataloader`处理" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aca72b49", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;5;2m[i 0604 15:44:29.773860 92 log.cc:351] Load log_sync: 1\u001b[m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00\n", + " ['input_ids', 'token_type_ids', 'attention_mask', 'target']\n", + "{'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]),\n", + " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", + " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", + " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", + " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", + " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", + " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", + " 1037, 2466, 1012, 102],\n", + " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", + " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", + " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", + " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", + " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", + " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0]]),\n", + " 'target': tensor([0, 1, 1, 2]),\n", + " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" + ] + } + ], + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataset = data_bundle.get_dataset('train')\n", + "evaluate_dataset = data_bundle.get_dataset('dev')\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "print(type(train_dataloader))\n", + "\n", + "import pprint\n", + "\n", + "for batch in train_dataloader:\n", + " print(type(batch), type(batch['input_ids']), list(batch))\n", + " pprint.pprint(batch, width=1)" + ] + }, + { + "cell_type": "markdown", + "id": "9f457a6e", + "metadata": {}, + "source": [ + "之所以说`prepare_xx_dataloader`函数更方便,是因为其**导入对象不仅可也是 DataSet 类型**,**还可以**\n", + "\n", + "  **是 DataBundle 类型**,不过数据集名称需要是`'train'`、`'dev'`、`'test'`供`fastNLP`识别\n", + "\n", + "例如下方就是**直接通过 prepare_paddle_dataloader 函数生成基于 PaddleDataLoader 的字典**\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7827557d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "from fastNLP import prepare_paddle_dataloader\n", + "\n", + "dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)\n", + "\n", + "print(type(dl_bundle['train']))" + ] + }, + { + "cell_type": "markdown", + "id": "d898cf40", + "metadata": {}, + "source": [ + "  而在接下来`trainer`的初始化过程中,按如下方式使用即可,除了初始化时`driver='paddle'`外\n", + "\n", + "  这里也可以看出`trainer`模块中,**evaluate_dataloaders 的设计允许评测可以针对多个数据集**\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=dl_bundle['train'],\n", + " optimizers=optimizer,\n", + "\t...\n", + "\tdriver='paddle',\n", + "\tdevice='gpu',\n", + "\t...\n", + " evaluate_dataloaders={'dev': dl_bundle['dev'], 'test': dl_bundle['test']}, \n", + " metrics={'acc': Accuracy()},\n", + "\t...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "d74d0523", + "metadata": {}, + "source": [ + "## 2. fastNLP 中 dataloader 的延伸\n", + "\n", + "### 2.1 collator 的概念与使用\n", + "\n", + "在`fastNLP 1.0`中,在数据加载模块`dataloader`内部,如之前表格所列举的,还存在其他的一些模块\n", + "\n", + "  例如,**实现序列的补零对齐的核对器 collator 模块**;注:`collate vt. 整理(文件或书等);核对,校勘`\n", + "\n", + "在`fastNLP 1.0`中,虽然`dataloader`随框架不同,但`collator`模块却是统一的,主要属性、方法如下表所示\n", + "\n", + "名称|属性|方法|功能|内容\n", + " ----|----|----|----|----|\n", + " `backend` | √ | | 记录`collator`对应框架 | 字符串型,如`'torch'` |\n", + " `padders` | √ | | 记录各字段对应的`padder`,每个负责具体补零对齐  | 字典类型 |\n", + " `ignore_fields` | √ | | 记录`dataloader`采样`batch`时不予考虑的字段 | 集合类型 |\n", + " `input_fields` | √ | | 记录`collator`每个字段的补零值、数据类型等 | 字典类型 |\n", + " `set_backend` | | √ | 设置`collator`对应框架 | 字符串型,如`'torch'` |\n", + " `set_ignore` | | √ | 设置`dataloader`采样`batch`时不予考虑的字段 | 字符串型,表示`field_name`  |\n", + " `set_pad` | | √ | 设置`collator`每个字段的补零值、数据类型等 | |" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d0795b3e", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "train_dataloader.collate_fn\n", + "\n", + "print(type(train_dataloader.collate_fn))" + ] + }, + { + "cell_type": "markdown", + "id": "5f816ef5", + "metadata": {}, + "source": [ + "此外,还可以 **手动定义 dataloader 中的 collate_fn**,而不是使用`fastNLP 1.0`中自带的`collator`模块\n", + "\n", + "  该函数的定义可以大致如下,需要注意的是,**定义 collate_fn 之前需要了解 batch 作为字典的格式**\n", + "\n", + "  该函数通过`collate_fn`参数传入`dataloader`,**在 batch 分发**(**而不是 batch 划分**)**时调用**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ff8e405e", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "def collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item['input_ids'])\n", + " max_length[0] = max(len(each_item['input_ids']), max_length[0])\n", + " atten_mask.append(each_item['token_type_ids'])\n", + " max_length[1] = max(len(each_item['token_type_ids']), max_length[1])\n", + " labels.append(each_item['attention_mask'])\n", + " max_length[2] = max(len(each_item['attention_mask']), max_length[2])\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'token_type_ids': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "id": "487b75fb", + "metadata": {}, + "source": [ + "注意:使用自定义的`collate_fn`函数,`trainer`的`collate_fn`变量也会自动调整为`function`类型" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "e916d1ac", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n", + "{'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0]),\n", + " 'input_ids': tensor([[ 101, 1037, 4038, 1011, 3689, 1997, 3053, 8680, 19173, 15685,\n", + " 1999, 1037, 18006, 2836, 2011, 1996, 2516, 2839, 14996, 3054,\n", + " 15509, 5325, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 2186, 1997, 9686, 17695, 18673, 14313, 1996, 15262,\n", + " 3351, 2008, 2054, 2003, 2204, 2005, 1996, 13020, 2003, 2036,\n", + " 2204, 2005, 1996, 25957, 4063, 1010, 2070, 1997, 2029, 5681,\n", + " 2572, 25581, 2021, 3904, 1997, 2029, 8310, 2000, 2172, 1997,\n", + " 1037, 2466, 1012, 102],\n", + " [ 101, 2130, 4599, 1997, 19214, 6432, 1005, 1055, 2147, 1010,\n", + " 1045, 8343, 1010, 2052, 2031, 1037, 2524, 2051, 3564, 2083,\n", + " 2023, 2028, 1012, 102, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0],\n", + " [ 101, 1037, 13567, 26162, 5257, 1997, 3802, 7295, 9888, 1998,\n", + " 2035, 1996, 20014, 27611, 1010, 14583, 1010, 11703, 20175, 1998,\n", + " 4028, 1997, 1037, 8101, 2319, 10576, 2030, 1037, 28900, 7815,\n", + " 3850, 1012, 102, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0]]),\n", + " 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n", + " [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])}\n" + ] + } + ], + "source": [ + "train_dataloader = prepare_torch_dataloader(train_dataset, collate_fn=collate_fn, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, collate_fn=collate_fn, shuffle=True)\n", + "\n", + "print(type(train_dataloader))\n", + "print(type(train_dataloader.collate_fn))\n", + "\n", + "for batch in train_dataloader:\n", + " pprint.pprint(batch, width=1)" + ] + }, + { + "cell_type": "markdown", + "id": "0bd98365", + "metadata": {}, + "source": [ + "### 2.2 fastNLP 与 datasets 的结合\n", + "\n", + "从`tutorial-1`至`tutorial-3`,我们已经完成了对`fastNLP v1.0`数据读取、预处理、加载,整个流程的介绍\n", + "\n", + "  不过在实际使用中,我们往往也会采取更为简便的方法读取数据,例如使用`huggingface`的`datasets`模块\n", + "\n", + "**使用 datasets 模块中的 load_dataset 函数**,通过指定数据集两级的名称,示例中即是**GLUE 标准中的 SST-2 数据集**\n", + "\n", + "  即可以快速从网上下载好`SST-2`数据集读入,之后以`pandas.DataFrame`作为中介,再转化成`fastNLP.DataSet`\n", + "\n", + "  之后的步骤就和其他关于`dataset`、`databundle`、`vocabulary`、`dataloader`中介绍的相关使用相同了" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "91879c30", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "639a0ad3c63944c6abef4e8ee1f7bf7c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00[16:20:10] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.525,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 84.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.54375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 87.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.55,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 88.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 100.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.65,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 104.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.69375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 111.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 108.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.66875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 107.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 108.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.68125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 109.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8bc4bfb2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}"
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "07538876",
+   "metadata": {},
+   "source": [
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "id": "1b52eafd",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "383"
+      ]
+     },
+     "execution_count": 9,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gc\n",
+    "\n",
+    "del model\n",
+    "del trainer\n",
+    "del dataset\n",
+    "del sst2data\n",
+    "\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "d9443213",
+   "metadata": {},
+   "source": [
+    "## 2. fastNLP 中 models 模块的介绍\n",
+    "\n",
+    "### 2.1  示例一:models 实现 CNN 分类\n",
+    "\n",
+    "  本示例使用`fastNLP 1.0`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n",
+    "\n",
+    "数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n",
+    "\n",
+    "模型使用方面,如上所述,这里使用**基于卷积神经网络 CNN 的预定义文本分类模型 CNNText**,结构如下所示\n",
+    "\n",
+    "  首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n",
+    "\n",
+    "    **感受野为 1 、 3 、 5 的卷积算子变换至 30 维、 40 维、 50 维的卷积特征**,再将三者拼接\n",
+    "\n",
+    "  最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n",
+    "\n",
+    "```\n",
+    "CNNText(\n",
+    "  (embed): Embedding(\n",
+    "    (embed): Embedding(5194, 100)\n",
+    "    (dropout): Dropout(p=0.0, inplace=False)\n",
+    "  )\n",
+    "  (conv_pool): ConvMaxpool(\n",
+    "    (convs): ModuleList(\n",
+    "      (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n",
+    "      (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n",
+    "      (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n",
+    "    )\n",
+    "  )\n",
+    "  (dropout): Dropout(p=0.1, inplace=False)\n",
+    "  (fc): Linear(in_features=120, out_features=2, bias=True)\n",
+    ")\n",
+    "```\n",
+    "\n",
+    "对应到代码上,**从 fastNLP.models.torch 路径下导入 CNNText**,初始化`CNNText`和`optimizer`实例\n",
+    "\n",
+    "  注意:初始化`CNNText`时,**二元组参数 embed 、分类数量 num_classes 是必须传入的**,其中\n",
+    "\n",
+    "    **embed 表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100` 维"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "id": "f6e76e2e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from fastNLP.models.torch import CNNText\n",
+    "\n",
+    "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n",
+    "\n",
+    "from torch.optim import AdamW\n",
+    "\n",
+    "optimizers = AdamW(params=model.parameters(), lr=5e-4)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "0cc5ca10",
+   "metadata": {},
+   "source": [
+    "  最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "id": "50a13ee5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from fastNLP import Trainer, Accuracy\n",
+    "\n",
+    "trainer = Trainer(\n",
+    "    model=model,\n",
+    "    driver='torch',\n",
+    "    device=0,  # 'cuda'\n",
+    "    n_epochs=10,\n",
+    "    optimizers=optimizers,\n",
+    "    train_dataloader=train_dataloader,\n",
+    "    evaluate_dataloaders=evaluate_dataloader,\n",
+    "    metrics={'acc': Accuracy()}\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "id": "28903a7d",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
[16:21:57] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.654444,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 589.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.767778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 691.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.797778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 718.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.803333,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 723.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.807778,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 727.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.812222,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 731.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.804444,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 724.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.811111,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 730.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.811111,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 730.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.806667,\n",
+       "  \"total#acc\": 900.0,\n",
+       "  \"correct#acc\": 726.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f47a6a35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}"
+      ]
+     },
+     "execution_count": 13,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "5b5c0446",
+   "metadata": {},
+   "source": [
+    "  注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 14,
+   "id": "e9e70f88",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "344"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "import gc\n",
+    "\n",
+    "del model\n",
+    "del trainer\n",
+    "\n",
+    "gc.collect()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "6aec2a19",
+   "metadata": {},
+   "source": [
+    "### 2.2  示例二:models 实现 BiLSTM 标注\n",
+    "\n",
+    "  通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n",
+    "\n",
+    "    针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n",
+    "\n",
+    "  避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n",
+    "\n",
+    "模型使用方面,如上所述,这里使用**基于双向 LSTM +条件随机场 CRF 的标注模型 BiLSTMCRF**,结构如下所示\n",
+    "\n",
+    "  其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n",
+    "\n",
+    "```\n",
+    "BiLSTMCRF(\n",
+    "  (embed): Embedding(7590, 100)\n",
+    "  (lstm): LSTM(\n",
+    "    (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n",
+    "  )\n",
+    "  (dropout): Dropout(p=0.1, inplace=False)\n",
+    "  (fc): Linear(in_features=200, out_features=9, bias=True)\n",
+    "  (crf): ConditionalRandomField()\n",
+    ")\n",
+    "```\n",
+    "\n",
+    "数据使用方面,此处仍然**使用 datasets 模块中的 load_dataset 函数**,以如下形式,加载`CoNLL-2003`数据集\n",
+    "\n",
+    "  首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 15,
+   "id": "03e66686",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n"
+     ]
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "593bc03ed5914953ab94268ff2f01710",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "  0%|          | 0/3 [00:00[16:23:41] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "\n"
+      ],
+      "text/plain": [
+       "\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO    \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches.              \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.169014,\n",
+       "  \"pre#F1\": 0.170732,\n",
+       "  \"rec#F1\": 0.167331\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.361809,\n",
+       "  \"pre#F1\": 0.312139,\n",
+       "  \"rec#F1\": 0.430279\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.525,\n",
+       "  \"pre#F1\": 0.475728,\n",
+       "  \"rec#F1\": 0.585657\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.627306,\n",
+       "  \"pre#F1\": 0.584192,\n",
+       "  \"rec#F1\": 0.677291\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.710937,\n",
+       "  \"pre#F1\": 0.697318,\n",
+       "  \"rec#F1\": 0.7251\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.739563,\n",
+       "  \"pre#F1\": 0.738095,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.748491,\n",
+       "  \"pre#F1\": 0.756098,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.716763,\n",
+       "  \"pre#F1\": 0.69403,\n",
+       "  \"rec#F1\": 0.741036\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.768293,\n",
+       "  \"pre#F1\": 0.784232,\n",
+       "  \"rec#F1\": 0.752988\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"f#F1\": 0.757692,\n",
+       "  \"pre#F1\": 0.732342,\n",
+       "  \"rec#F1\": 0.784861\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "37871d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}"
+      ]
+     },
+     "execution_count": 21,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "96bae094",
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_5.ipynb b/docs/source/tutorials/fastnlp_tutorial_5.ipynb
new file mode 100644
index 00000000..ab759feb
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_5.ipynb
@@ -0,0 +1,1242 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "fdd7ff16",
+   "metadata": {},
+   "source": [
+    "# T5. trainer 和 evaluator 的深入介绍\n",
+    "\n",
+    "  1   fastNLP 中 driver 的补充介绍\n",
+    " \n",
+    "    1.1   trainer 和 driver 的构想 \n",
+    "\n",
+    "    1.2   device 与 多卡训练\n",
+    "\n",
+    "  2   fastNLP 中的更多 metric 类型\n",
+    "\n",
+    "    2.1   预定义的 metric 类型\n",
+    "\n",
+    "    2.2   自定义的 metric 类型\n",
+    "\n",
+    "  3   fastNLP 中 trainer 的补充介绍\n",
+    "\n",
+    "    3.1   trainer 的内部结构"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "08752c5a",
+   "metadata": {
+    "pycharm": {
+     "name": "#%% md\n"
+    }
+   },
+   "source": [
+    "## 1. fastNLP 中 driver 的补充介绍\n",
+    "\n",
+    "### 1.1  trainer 和 driver 的构想\n",
+    "\n",
+    "在`fastNLP 1.0`中,模型训练最关键的模块便是**训练模块 trainer 、评测模块 evaluator 、驱动模块 driver**,\n",
+    "\n",
+    "  在`tutorial 0`中,已经简单介绍过上述三个模块:**driver 用来控制训练评测中的 model 的最终运行**\n",
+    "\n",
+    "    **evaluator 封装评测的 metric**,**trainer 封装训练的 optimizer**,**也可以包括 evaluator**\n",
+    "\n",
+    "之所以做出上述的划分,其根本目的在于要**达成对于多个 python 学习框架**,**例如 pytorch 、 paddle 、 jittor 的兼容**\n",
+    "\n",
+    "  对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n",
+    "\n",
+    "    划分为**框架无关的循环控制、批量分发部分**,**由 trainer 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
+    "\n",
+    "    以及**随框架不同的模型调用、数值优化部分**,**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
+    "\n",
+    "|训练过程|框架无关 对应`Trainer`|框架相关 对应`Driver`\n",
+    "|----|----|----|\n",
+    "| try: | try: |  |\n",
+    "| for epoch in 1:n_eoochs: | for epoch in 1:n_eoochs: |  |\n",
+    "| for step in 1:total_steps: | for step in 1:total_steps: |  |\n",
+    "| batch = fetch_batch() | batch = fetch_batch() |  |\n",
+    "| loss = model.forward(batch)  |  | loss = model.forward(batch)  |\n",
+    "| loss.backward() |  | loss.backward() |\n",
+    "| model.clear_grad() |  | model.clear_grad() |\n",
+    "| model.update() |  | model.update() |\n",
+    "| if need_save: | if need_save: |  |\n",
+    "| model.save() |  | model.save() |\n",
+    "| except: | except: |  |\n",
+    "| process_exception() | process_exception() |  |"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "3e55f07b",
+   "metadata": {},
+   "source": [
+    "  对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n",
+    "\n",
+    "    划分为**框架无关的循环控制、分发汇总部分**,**由 evaluator 模块负责**实现,对应的伪代码如下方中间一栏所示\n",
+    "\n",
+    "    以及**随框架不同的模型调用、评测计算部分**,同样**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n",
+    "\n",
+    "|评测过程|框架无关 对应`Evaluator`|框架相关 对应`Driver`\n",
+    "|----|----|----|\n",
+    "| try: | try: |  |\n",
+    "| model.set_eval() | model.set_eval() |  |\n",
+    "| for step in 1:total_steps: | for step in 1:total_steps: |  |\n",
+    "| batch = fetch_batch() | batch = fetch_batch() |  |\n",
+    "| outputs = model.evaluate(batch)  |  | outputs = model.evaluate(batch)  |\n",
+    "| metric.compute(batch, outputs) |  | metric.compute(batch, outputs) |\n",
+    "| results = metric.get_metric() | results = metric.get_metric() |  |\n",
+    "| except: | except: |  |\n",
+    "| process_exception() | process_exception() |  |"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "94ba11c6",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
+    }
+   },
+   "source": [
+    "由此,从程序员的角度,`fastNLP v1.0` **通过一个 driver 让基于 pytorch 、 paddle 、 jittor 、 oneflow 框架的模型**\n",
+    "\n",
+    "    **都能在相同的 trainer 和 evaluator 上运行**,这也**是 fastNLP v1.0 相比于之前版本的一大亮点**\n",
+    "\n",
+    "  而从`driver`的角度,`fastNLP v1.0`通过定义一个`driver`基类,**将所有张量转化为 numpy.tensor**\n",
+    "\n",
+    "    并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n",
+    "\n",
+    "    对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "ab1cea7d",
+   "metadata": {},
+   "source": [
+    "### 1.2  device 与 多卡训练\n",
+    "\n",
+    "**fastNLP v1.0 支持多卡训练**,实现方法则是**通过将 trainer 中的 device 设置为对应显卡的序号列表**\n",
+    "\n",
+    "  由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v1.0`保证:\n",
+    "\n",
+    "    数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n",
+    "\n",
+    "    模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n",
+    "\n",
+    "  例如,在评测计算运行`get_metric`函数时,`fastNLP v1.0`将自动按照`self.right`和`self.total`\n",
+    "\n",
+    "    指定的 **aggregate_method 方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n",
+    "\n",
+    "    在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n",
+    "    \n",
+    "```python\n",
+    "trainer = Trainer(\n",
+    "        model=model,                                # model 基于 pytorch 实现 \n",
+    "        train_dataloader=train_dataloader,\n",
+    "        optimizers=optimizer,\n",
+    "        ...\n",
+    "        driver='torch',                             # driver 使用 torch_driver \n",
+    "        device=[0, 1],                              # gpu 选择 cuda:0 + cuda:1\n",
+    "        ...\n",
+    "        evaluate_dataloaders=evaluate_dataloader,\n",
+    "        metrics={'acc': Accuracy()},\n",
+    "        ...\n",
+    "    )\n",
+    "\n",
+    "class Accuracy(Metric):\n",
+    "    def __init__(self):\n",
+    "        super().__init__()\n",
+    "        self.register_element(name='total', value=0, aggregate_method='sum')\n",
+    "        self.register_element(name='right', value=0, aggregate_method='sum')\n",
+    "```\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "e2e0a210",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
+    }
+   },
+   "source": [
+    "注:`fastNLP v1.0`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8d19220c",
+   "metadata": {},
+   "source": [
+    "## 2. fastNLP 中的更多 metric 类型\n",
+    "\n",
+    "### 2.1  预定义的 metric 类型\n",
+    "\n",
+    "在`fastNLP 1.0`中,除了前几篇`tutorial`中经常见到的**正确率 Accuracy**,还有其他**预定义的评测标准 metric**\n",
+    "\n",
+    "  包括**所有 metric 的基类 Metric**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n",
+    "\n",
+    "    **适用于分类语境下的 F1 值 ClassifyFPreRecMetric**(其中也包括召回率`Pre`、精确率`Rec`\n",
+    "\n",
+    "    **适用于抽取语境下的 F1 值 SpanFPreRecMetric**;相关基本信息内容见下表,之后是详细分析\n",
+    "\n",
+    "代码名称|简要介绍|代码路径\n",
+    "----|----|----|\n",
+    " `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n",
+    " `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n",
+    " `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n",
+    " `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n",
+    " `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "fdc083a3",
+   "metadata": {
+    "pycharm": {
+     "name": "#%%\n"
+    }
+   },
+   "source": [
+    "  如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n",
+    "\n",
+    "    **update 函数更新单个 batch 的统计量**,**get_metric 函数返回最终结果**,并打印显示\n",
+    "\n",
+    "\n",
+    "### 2.1.1  Accuracy 与 TransformersAccuracy\n",
+    "\n",
+    "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n",
+    "\n",
+    "  `get_metric`函数打印格式为 **{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}**\n",
+    "\n",
+    "  一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n",
+    "\n",
+    "  **update 函数的参数包括 pred 、 target 、 seq_len**,**后者用来标记批次中每笔数据的长度**\n",
+    "\n",
+    "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n",
+    "\n",
+    "  在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n",
+    "\n",
+    "\n",
+    "### 2.1.2  ClassifyFPreRecMetric 与 SpanFPreRecMetric\n",
+    "\n",
+    "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n",
+    "\n",
+    "  两者的相同之处在于:**第一**,**都包括召回率/查全率 ec**、**精确率/查准率 Pre**、**F1 值**这三个指标\n",
+    "\n",
+    "    `get_metric`函数打印格式为 **{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}**\n",
+    "\n",
+    "    三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n",
+    "\n",
+    "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n",
+    "\n",
+    "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n",
+    "\n",
+    "  **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n",
+    "\n",
+    "    **micro F1**(**直接统计所有类别的 Rec-Pre-F1**)、**macro F1**(**统计各类别的 Rec-Pre-F1 再算术平均**)\n",
+    "\n",
+    "  **第三**,两者在初始化时还可以**传入基于 fastNLP.Vocabulary 的 tag_vocab 参数记录数据集中的标签序号**\n",
+    "\n",
+    "    **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n",
+    "\n",
+    "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n",
+    "\n",
+    "    **SpanFPreRecMetric 针对更复杂的抽取问题**,**规定标签 B-xx 和 I-xx 或 B-xx 和 E-xx 构成标签对**\n",
+    "\n",
+    "  在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n",
+    "\n",
+    "    对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n",
+    "\n",
+    "    因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n",
+    "\n",
+    "      或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n",
+    "\n",
+    "    最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n",
+    "\n",
+    "```python\n",
+    "from fastNLP import Vocabulary\n",
+    "from fastNLP import ClassifyFPreRecMetric\n",
+    "\n",
+    "tag_vocab = Vocabulary(padding=None, unknown=None)            # 记录序号与标签之间的映射\n",
+    "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n",
+    "                        'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n",
+    "                        'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n",
+    "                        'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n",
+    "                        'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ])  # CoNLL-2003 中的 pos_tags\n",
+    "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n",
+    "\n",
+    "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab,          \n",
+    "                                ignore_labels=ignore_labels,  # 表示评测/优化中不考虑上述标签的正误/损失\n",
+    "                                only_gross=True,              # 默认为 True 表示输出所有类别的综合统计结果\n",
+    "                                f_type='micro')               # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n",
+    "metrics = {'F1': FPreRec}\n",
+    "```"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8a22f522",
+   "metadata": {},
+   "source": [
+    "### 2.2  自定义的 metric 类型\n",
+    "\n",
+    "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的 metric 类型**\n",
+    "\n",
+    "    也**需要继承自 Metric 类**,同时**内部自定义好 __init__ 、 update 和 get_metric 函数**\n",
+    "\n",
+    "  在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n",
+    "\n",
+    "  在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**update`的参数名**\n",
+    "\n",
+    "    **需要待评估模型在 evaluate_step 中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n",
+    "\n",
+    "    在`fastNLP v1.0`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n",
+    "\n",
+    "    此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n",
+    "\n",
+    "  在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n",
+    "\n",
+    "    其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n",
+    "\n",
+    "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "08a872e9",
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Metric\n", + "\n", + "class MyMetric(Metric):\n", + "\n", + " def __init__(self):\n", + " Metric.__init__(self)\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + "\n", + " def update(self, pred, target):\n", + " self.total_num += target.size(0)\n", + " self.right_num += target.eq(pred).sum().item()\n", + "\n", + " def get_metric(self, reset=True):\n", + " acc = self.right_num / self.total_num\n", + " if reset:\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + " return {'prefix': acc}" + ] + }, + { + "cell_type": "markdown", + "id": "0155f447", + "metadata": {}, + "source": [ + "  数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5ad81ac7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ef923b90b19847f4916cccda5d33fc36", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00 0: # 如果设置了 num_eval_sanity_batch\n", + "\t\ton_sanity_check_begin(trainer)\n", + "\t\ton_sanity_check_end(trainer, sanity_check_res)\n", + "\ttry:\n", + "\t\ton_train_begin(trainer)\n", + "\t\twhile cur_epoch_idx < n_epochs:\n", + "\t\t\ton_train_epoch_begin(trainer)\n", + "\t\t\twhile batch_idx_in_epoch<=num_batches_per_epoch:\n", + "\t\t\t\ton_fetch_data_begin(trainer)\n", + "\t\t\t\tbatch = next(dataloader)\n", + "\t\t\t\ton_fetch_data_end(trainer)\n", + "\t\t\t\ton_train_batch_begin(trainer, batch, indices)\n", + "\t\t\t\ton_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping 后的\n", + "\t\t\t\ton_after_backward(trainer)\n", + "\t\t\t\ton_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响\n", + "\t\t\t\ton_train_batch_end(trainer)\n", + "\t\t\ton_train_epoch_end(trainer)\n", + "\texcept BaseException:\n", + "\t\tself.on_exception(trainer, exception)\n", + "\tfinally:\n", + "\t\ton_train_end(trainer)\n", + "``` -->" + ] + }, + { + "cell_type": "markdown", + "id": "1e21df35", + "metadata": {}, + "source": [ + "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n", + "\n", + "  字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "926a9c50", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " input_mapping=input_mapping,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'suffix': MyMetric()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b1b2e8b7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n", + "\n", + "|名称|功能|默认值|\n", + "|----|----|----|\n", + "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n", + "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n", + "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n", + "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n", + "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43be274f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[09:30:35] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.6875\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8125\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.825\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8125\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.8\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"prefix#suffix\": 0.80625\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1abfa0a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_6.ipynb b/docs/source/tutorials/fastnlp_tutorial_6.ipynb new file mode 100644 index 00000000..63f7481e --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_6.ipynb @@ -0,0 +1,1646 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7ff16", + "metadata": {}, + "source": [ + "# T6. fastNLP 与 paddle 或 jittor 的结合\n", + "\n", + "  1   fastNLP 结合 paddle 训练模型\n", + " \n", + "    1.1   关于 paddle 的简单介绍\n", + "\n", + "    1.2   使用 paddle 搭建并训练模型\n", + "\n", + "  2   fastNLP 结合 jittor 训练模型\n", + "\n", + "    2.1   关于 jittor 的简单介绍\n", + "\n", + "    2.2   使用 jittor 搭建并训练模型\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08752c5a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6b13d42c39ba455eb370bf2caaa3a264", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00 True\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", + "print(type(train_dataset), isinstance(train_dataset, DataSet))\n", + "\n", + "from fastNLP.io import DataBundle\n", + "\n", + "data_bundle = DataBundle(datasets={'train': train_dataset, 'dev': evaluate_dataset})" + ] + }, + { + "cell_type": "markdown", + "id": "57a3272f", + "metadata": {}, + "source": [ + "## 1. fastNLP 结合 paddle 训练模型\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "e31b3198", + "metadata": {}, + "outputs": [], + "source": [ + "import paddle\n", + "import paddle.nn as nn\n", + "import paddle.nn.functional as F\n", + "\n", + "\n", + "class ClsByPaddle(nn.Layer):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, dropout=0.5):\n", + " nn.Layer.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n", + " \n", + " self.conv1 = nn.Sequential(nn.Conv1D(embedding_dim, 30, 1, padding=0), nn.ReLU())\n", + " self.conv2 = nn.Sequential(nn.Conv1D(embedding_dim, 40, 3, padding=1), nn.ReLU())\n", + " self.conv3 = nn.Sequential(nn.Conv1D(embedding_dim, 50, 5, padding=2), nn.ReLU())\n", + "\n", + " self.mlp = nn.Sequential(('dropout', nn.Dropout(p=dropout)),\n", + " ('linear_1', nn.Linear(120, hidden_dim)),\n", + " ('activate', nn.ReLU()),\n", + " ('linear_2', nn.Linear(hidden_dim, output_dim)))\n", + " \n", + " self.loss_fn = nn.MSELoss()\n", + "\n", + " def forward(self, words):\n", + " output = self.embedding(words).transpose([0, 2, 1])\n", + " conv1, conv2, conv3 = self.conv1(output), self.conv2(output), self.conv3(output)\n", + "\n", + " pool1 = F.max_pool1d(conv1, conv1.shape[-1]).squeeze(2)\n", + " pool2 = F.max_pool1d(conv2, conv2.shape[-1]).squeeze(2)\n", + " pool3 = F.max_pool1d(conv3, conv3.shape[-1]).squeeze(2)\n", + "\n", + " pool = paddle.concat([pool1, pool2, pool3], axis=1)\n", + " output = self.mlp(pool)\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " target = paddle.stack((1 - target, target), axis=1).cast(pred.dtype)\n", + " return {'loss': self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = paddle.argmax(pred, axis=-1)\n", + " return {'pred': pred, 'target': target}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "c63b030f", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "W0604 21:02:25.453869 19014 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 6.1, Driver API Version: 11.1, Runtime API Version: 10.2\n", + "W0604 21:02:26.061690 19014 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.\n" + ] + }, + { + "data": { + "text/plain": [ + "ClsByPaddle(\n", + " (embedding): Embedding(8458, 100, sparse=False)\n", + " (conv1): Sequential(\n", + " (0): Conv1D(100, 30, kernel_size=[1], data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (conv2): Sequential(\n", + " (0): Conv1D(100, 40, kernel_size=[3], padding=1, data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (conv3): Sequential(\n", + " (0): Conv1D(100, 50, kernel_size=[5], padding=2, data_format=NCL)\n", + " (1): ReLU()\n", + " )\n", + " (mlp): Sequential(\n", + " (dropout): Dropout(p=0.5, axis=None, mode=upscale_in_train)\n", + " (linear_1): Linear(in_features=120, out_features=64, dtype=float32)\n", + " (activate): ReLU()\n", + " (linear_2): Linear(in_features=64, out_features=2, dtype=float32)\n", + " )\n", + " (loss_fn): MSELoss()\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = ClsByPaddle(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "2997c0aa", + "metadata": {}, + "outputs": [], + "source": [ + "from paddle.optimizer import AdamW\n", + "\n", + "optimizers = AdamW(parameters=model.parameters(), learning_rate=5e-4)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ead35fb8", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_paddle_dataloader\n", + "\n", + "train_dataloader = prepare_paddle_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_paddle_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "# dl_bundle = prepare_paddle_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "25e8da83", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='paddle',\n", + " device='gpu', # 'cpu', 'gpu', 'gpu:x'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'], \n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d63c5d74", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[21:03:08] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:03:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=894986;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=567751;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n",
+       "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n",
+       ".get_parent()\n",
+       "  self.msg_id = ip.kernel._parent_header['header']['msg_id']\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n",
+       "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n",
+       "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n",
+       "safe. \n",
+       "Deprecated in NumPy 1.20; for more details and guidance: \n",
+       "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n",
+       "  if data.dtype == np.object:\n",
+       "
\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.78125,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 126.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.81875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 131.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.80625,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 129.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.79375,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 127.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7875,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 126.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8,\n",
+       "  \"total#acc\": 160.0,\n",
+       "  \"correct#acc\": 128.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10) " + ] + }, + { + "cell_type": "markdown", + "id": "cb9a0b3c", + "metadata": {}, + "source": [ + "## 2. fastNLP 结合 jittor 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c600191d", + "metadata": {}, + "outputs": [], + "source": [ + "import jittor\n", + "import jittor.nn as nn\n", + "\n", + "from jittor import Module\n", + "\n", + "\n", + "class ClsByJittor(Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " Module.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", + " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, # 默认 batch_first=False\n", + " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", + " self.mlp = nn.Sequential([nn.Dropout(p=dropout),\n", + " nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim * 2, output_dim),\n", + " nn.Sigmoid(),])\n", + "\n", + " self.loss_fn = nn.MSELoss()\n", + "\n", + " def execute(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), dim=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " target = jittor.stack((1 - target, target), dim=1)\n", + " return {'loss': self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = jittor.argmax(pred, dim=-1)[0]\n", + " return {'pred': pred, 'target': target}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a94ed8c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ClsByJittor(\n", + " embedding: Embedding(8458, 100)\n", + " lstm: LSTM(100, 64, 2, bias=True, batch_first=True, dropout=0.5, bidirectional=True, proj_size=0)\n", + " mlp: Sequential(\n", + " 0: Dropout(0.5, is_train=False)\n", + " 1: Linear(128, 128, float32[128,], None)\n", + " 2: relu()\n", + " 3: Linear(128, 2, float32[2,], None)\n", + " 4: Sigmoid()\n", + " )\n", + " loss_fn: MSELoss(mean)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6d15ebc1", + "metadata": {}, + "outputs": [], + "source": [ + "from jittor.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "95d8d09e", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_jittor_dataloader\n", + "\n", + "train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "# dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "917eab81", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='jittor',\n", + " device='gpu', # 'cpu', 'gpu', 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'],\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f7c4ac5a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[21:05:51] INFO     Running evaluator sanity check for 2 batches.              trainer.py:596\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:05:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=69759;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=202322;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\n",
+      "Compiling Operators(5/6) used: 8.31s eta: 1.66s 6/6) used: 9.33s eta:    0s \n",
+      "\n",
+      "Compiling Operators(31/31) used: 7.31s eta:    0s \n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.61875,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 99\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.61875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m99\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 112\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m112\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.725,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 116\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.725\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m116\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.74375,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 119\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 121\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.75625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 121\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.73125,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 117\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.73125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m117\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 122\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.74375,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 119\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.7625,\n",
+       "  \"total#acc\": 160,\n",
+       "  \"correct#acc\": 122\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3df5f425", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb new file mode 100644 index 00000000..af8e60a0 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb @@ -0,0 +1,1280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "  从这篇开始,我们将开启 **fastNLP v1.0 tutorial 的 example 系列**,在接下来的\n", + "\n", + "  每篇`tutorial`里,我们将会介绍`fastNLP v1.0`在自然语言处理任务上的应用实例" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[, , ]\n" + ] + } + ], + "source": [ + "from pygments.plugin import find_plugin_lexers\n", + "print(list(find_plugin_lexers()))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E1. 使用 Bert + fine-tuning 完成 SST-2 分类\n", + "\n", + "  1   基础介绍:`GLUE`通用语言理解评估、`SST-2`文本情感二分类数据集 \n", + "\n", + "  2   准备工作:加载`tokenizer`、预处理`dataset`、`dataloader`使用\n", + "\n", + "  3   模型训练:加载`distilbert-base`、`fastNLP`参数匹配、`fine-tuning`" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP import Accuracy\n", + "\n", + "print(transformers.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n", + "\n", + "  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n", + "\n", + "    调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n", + "\n", + "**GLUE**,**全称 General Language Understanding Evaluation**,**通用语言理解评估**,\n", + "\n", + "  包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", + "\n", + "    **CoLA**,文本分类任务,预测单句语法正误分类;**SST-2**,文本分类任务,预测单句情感二分类\n", + "\n", + "    **MRPC**,句对分类任务,预测句对语义一致性;**STS-B**,相似度打分任务,预测句对语义相似度回归\n", + "\n", + "    **QQP**,句对分类任务,预测问题对语义一致性;**MNLI**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", + "\n", + "    **QNLI / RTE / WNLI**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", + "\n", + "  诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", + "\n", + "    此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", + "\n", + "task = 'sst2'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "**SST**,**全称`Stanford Sentiment Treebank**,**斯坦福情感树库**,**单句情感分类**数据集\n", + "\n", + "  包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", + "\n", + "  数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", + "\n", + "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n", + "\n", + "  首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5915debacf9443986b5b3b34870b303", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00[09:12:45] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.878125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 281.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.9,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 288.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.8875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 284.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.88125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 282.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 280.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.865625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 277.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.884375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 283.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.878125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 281.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 附:`DistilBertForSequenceClassification`模块结构\n",
+    "\n",
+    "```\n",
+    "\n",
+    "```"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.7.13 ('fnlp-paddle')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "metadata": {
+     "collapsed": false
+    },
+    "source": []
+   }
+  },
+  "vscode": {
+   "interpreter": {
+    "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_e2.ipynb
new file mode 100644
index 00000000..588ee8c3
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_e2.ipynb
@@ -0,0 +1,1082 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# E2. 使用 Bert + prompt 完成 SST-2 分类\n",
+    "\n",
+    "  1   基础介绍:`prompt-based model`简介、与`fastNLP`的结合\n",
+    "\n",
+    "  2   准备工作:`P-Tuning v2`原理概述、`P-Tuning v2`模型搭建\n",
+    "\n",
+    "  3   模型训练:加载`tokenizer`、预处理`dataset`、模型训练与分析"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. 基础介绍:prompt-based model 简介、与 fastNLP 的结合\n",
+    "\n",
+    "  本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`prompt-based tuning`方式\n",
+    "\n",
+    "    微调`bert-base-uncased`模型,实现文本情感的二分类,在此之前本示例\n",
+    "\n",
+    "    将首先简单介绍提示学习模型的研究,以及与`fastNLP v1.0`结合的优势\n",
+    "\n",
+    "**prompt**,**提示词**,最早出自论文[Exploiting Cloze Questions for Few Shot TC and NLI](https://arxiv.org/pdf/2001.07676.pdf)中的 **PET 模型**\n",
+    "\n",
+    "    全称 **Pattern-Exploiting Training**,虽然文中并没有提到`prompt`的说法,但仍被视为开山之作\n",
+    "\n",
+    "  其大致思路包括,对于文本分类任务,假定输入文本为`\" X . \"`,设计**输入模板 template**,**后来被称为 prompt**\n",
+    "\n",
+    "    将输入重构为`\" X . It is [MASK] . \"`,**诱导或刺激语言模型在 [MASK] 位置生成含有情感倾向的词汇**\n",
+    "\n",
+    "    接着将该词汇**输入分类器中**,**后来被称为 verbalizer**,从而得到该语句对应的情感倾向,实现文本分类\n",
+    "\n",
+    "  其主要贡献在于,通过构造`prompt`,诱导/刺激预训练模型生成期望适应下游任务特征,适合少样本学习的需求\n",
+    "\n",
+    "\n",
+    "\n",
+    "**prompt-based tuning**,**基于提示的微调**,将`prompt`应用于**参数高效微调**,**parameter-efficient tuning**\n",
+    "\n",
+    "  通过**设计模板调整模型输入**或者**调整模型内部状态**,**固定预训练模型**,**诱导/刺激模型**调整输出以适应\n",
+    "\n",
+    "  当前任务,极大降低了训练开销,也省去了`verbalizer`的构造,更多参考[prompt综述](https://arxiv.org/pdf/2107.13586.pdf)、[DeltaTuning综述](https://arxiv.org/pdf/2203.06904.pdf)\n",
+    "\n",
+    "    以下列举些经典的`prompt-based tuning`案例,简单地介绍下`prompt-based tuning`的脉络\n",
+    "\n",
+    "  **案例一**:**PrefixTuning**,详细内容参考[PrefixTuning论文](https://arxiv.org/pdf/2101.00190.pdf)\n",
+    "\n",
+    "    其主要贡献在于,**提出连续的、非人工构造的、任务导向的 prompt**,即**前缀 prefix**,**调整**\n",
+    "\n",
+    "      **模型内部更新状态**,诱导模型在特定任务下生成期望目标,降低优化难度,提升微调效果\n",
+    "\n",
+    "    其主要研究对象,是`GPT2`和`BART`,主要面向生成任务`NLG`,如`table-to-text`和摘要\n",
+    "\n",
+    "  **案例二**:**P-Tuning v1**,详细内容参考[P-Tuning-v1论文](https://arxiv.org/pdf/2103.10385.pdf)\n",
+    "\n",
+    "    其主要贡献在于,**通过连续的、非人工构造的 prompt 调整模型输入**,取代原先基于单词设计的\n",
+    "\n",
+    "      但离散且不易于优化的`prompt`;同时也**证明了 GPT2 在语言理解任务上仍然是可以胜任的**\n",
+    "\n",
+    "    其主要研究对象,是`GPT2`,主要面向知识探测`knowledge probing`和自然语言理解`NLU`\n",
+    "\n",
+    "  **案例三**:**PromptTuning**,详细内容参考[PromptTuning论文](https://arxiv.org/pdf/2104.08691.pdf)\n",
+    "\n",
+    "    其主要贡献在于,通过连续的`prompt`调整模型输入,**证明了 prompt-based tuning 的效果**\n",
+    "\n",
+    "      **随模型参数量的增加而提升**,最终**在 10B 左右追上了全参数微调 fine-tuning 的效果**\n",
+    "\n",
+    "    其主要面向自然语言理解`NLU`,通过为每个任务定义不同的`prompt`,从而支持多任务语境\n",
+    "\n",
+    "通过上述介绍可以发现`prompt-based tuning`只是模型微调方式,独立于预训练模型基础`backbone`\n",
+    "\n",
+    "  目前,加载预训练模型的主流方法是使用**transformers 模块**,而实现微调的框架则\n",
+    "\n",
+    "    可以是`pytorch`、`paddle`、`jittor`等,而不同框架间又存在不兼容的问题\n",
+    "\n",
+    "  因此,**使用 fastNLP v1.0 实现 prompt-based tuning**,可以**很好地解决 paddle 等框架**\n",
+    "\n",
+    "    **和 transformers 模块之间的桥接**(`transformers`模块基于`pytorch`实现)\n",
+    "\n",
+    "本示例仍使用了`tutorial-E1`的`SST-2`数据集、`distilbert-base-uncased`模型(便于比较\n",
+    "\n",
+    "  使用`pytorch`框架,通过将连续的`prompt`与`model`拼接,解决`SST-2`二分类任务"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "print(transformers.__version__)\n", + "\n", + "task = 'sst2'\n", + "model_checkpoint = 'distilbert-base-uncased' # 'bert-base-uncased'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:P-Tuning v2 原理概述、P-Tuning v2 模型搭建\n", + "\n", + "  本示例使用`P-Tuning v2`作为`prompt-based tuning`与`fastNLP v1.0`结合的案例\n", + "\n", + "    以下首先简述`P-Tuning v2`的论文原理,并由此引出`fastNLP v1.0`的代码实践\n", + "\n", + "**P-Tuning v2**出自论文[Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n", + "\n", + "  其主要贡献在于,**在 PrefixTuning 等深度提示学习基础上**,**提升了其在分类标注等 NLU 任务的表现**\n", + "\n", + "    并使之在中等规模模型,主要是**参数量在 100M-1B 区间的模型上**,**获得与全参数微调相同的效果**\n", + "\n", + "  其结构如图所示,通过**在输入序列的分类符 [CLS] 之前**,**加入前缀序列**(**序号对应嵌入是待训练的连续值向量**\n", + "\n", + "    **刺激模型在新任务下**,从`[CLS]`对应位置,**输出符合微调任务的输出**,从而达到适应微调任务的目的\n", + "\n", + "\n", + "\n", + "本示例使用`bert-base-uncased`模型,作为`P-Tuning v2`的基础`backbone`,设置`requires_grad=False`\n", + "\n", + "    固定其参数不参与训练,**设置 pre_seq_len 长的 prefix_tokens 作为输入的提示前缀序列**\n", + "\n", + "  **使用基于 nn.Embedding 的 prefix_encoder 为提示前缀嵌入**,通过`get_prompt`函数获取,再将之\n", + "\n", + "    拼接至批量内每笔数据前得到`inputs_embeds`,同时更新自注意力掩码`attention_mask`\n", + "\n", + "  将`inputs_embeds`、`attention_mask`和`labels`输入`backbone`,**得到输出包括 loss 和 logits**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + " self.embeddings = self.back_bone.get_input_embeddings()\n", + "\n", + " for param in self.back_bone.parameters():\n", + " param.requires_grad = False\n", + " \n", + " self.pre_seq_len = pre_seq_len\n", + " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", + " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", + " \n", + " def get_prompt(self, batch_size):\n", + " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", + " prompts = self.prefix_encoder(prefix_tokens)\n", + " return prompts\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " \n", + " batch_size = input_ids.shape[0]\n", + " raw_embedding = self.embeddings(input_ids)\n", + " \n", + " prompts = self.get_prompt(batch_size=batch_size)\n", + " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", + " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", + " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", + "\n", + " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return outputs\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器\n", + "\n", + "  根据`P-Tuning v2`论文:`Generally, simple classification tasks prefer shorter prompts (less than 20)`\n", + "\n", + "  此处`pre_seq_len`参数设定为`20`,学习率相应做出调整,其他内容和`tutorial-E1`中的内容一致" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = SeqClsModel(model_checkpoint=model_checkpoint, num_labels=2, pre_seq_len=20)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n", + "\n", + "  本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST-2`数据集\n", + "\n", + "    以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n", + "\n", + "    数据集加载相关代码流程见下,内容和`tutorial-E1`中的内容基本一致\n", + "\n", + "首先,使用`datasets.load_dataset`加载数据集,使用`transformers.AutoTokenizer`\n", + "\n", + "  构建`tokenizer`实例,通过`dataset.map`使用`tokenizer`将文本替换为词素序号序列" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "21cbd92c3397497d84dc10f017ec96f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00[22:53:00] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[22:53:00]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=406635;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=951504;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "application/vnd.jupyter.widget-view+json": {
+       "model_id": "",
+       "version_major": 2,
+       "version_minor": 0
+      },
+      "text/plain": [
+       "Output()"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.540625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 173.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.540625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m173.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.5,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 160.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.5\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.509375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 163.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.509375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m163.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.634375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 203.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.634375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m203.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.6125,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 196.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.6125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m196.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.675,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 216.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m216.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.64375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 206.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.64375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m206.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.665625,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 213.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.665625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m213.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.659375,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 211.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.659375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m211.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#acc\": 0.696875,\n",
+       "  \"total#acc\": 320.0,\n",
+       "  \"correct#acc\": 223.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.696875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m223.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以发现,其效果远远逊色于`fine-tuning`,这是因为`P-Tuning v2`虽然能够适应参数量\n", + "\n", + "  在`100M-1B`区间的模型,但是,**distilbert-base 的参数量仅为 66M**,无法触及其下限\n", + "\n", + "另一方面,**fastNLP v1.0 不支持 jupyter 多卡**,所以无法在笔者的电脑/服务器上,完成\n", + "\n", + "  合适规模模型的学习,例如`110M`的`bert-base`模型,以及`340M`的`bert-large`模型" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{'acc#acc': 0.737385, 'total#acc': 872.0, 'correct#acc': 643.0}"
+      ]
+     },
+     "execution_count": 10,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "trainer.evaluator.run()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "pycharm": {
+   "stem_cell": {
+    "cell_type": "raw",
+    "metadata": {
+     "collapsed": false
+    },
+    "source": []
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb
new file mode 100644
index 00000000..a5883416
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb
@@ -0,0 +1,1086 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# E3. 使用 paddlenlp 和 fastNLP 实现中文文本情感分析\n",
+    "\n",
+    "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何使用 `paddlenlp` 自然语言处理库和 `fastNLP` 来完成比较简单的情感分析任务。\n",
+    "\n",
+    "1. 基础介绍:飞桨自然语言处理库 ``paddlenlp`` 和语义理解框架 ``ERNIE``\n",
+    "\n",
+    "2. 准备工作:使用 ``tokenizer`` 处理数据并构造 ``dataloader``\n",
+    "\n",
+    "3. 模型训练:加载 ``ERNIE`` 预训练模型,使用 ``fastNLP`` 进行训练"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. 基础介绍:飞桨自然语言处理库 paddlenlp 和语义理解框架 ERNIE\n",
+    "\n",
+    "#### 1.1 飞桨自然语言处理库 paddlenlp\n",
+    "\n",
+    "``paddlenlp`` 是由百度以飞桨 ``PaddlePaddle`` 为核心开发的自然语言处理库,集成了多个数据集和 NLP 模型,包括百度自研的语义理解框架 ``ERNIE`` 。在本篇教程中,我们会以 ``paddlenlp`` 为基础,使用模型 ``ERNIE`` 完成中文情感分析任务。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2.3.3\n"
+     ]
+    }
+   ],
+   "source": [
+    "import sys\n",
+    "sys.path.append(\"../\")\n",
+    "\n",
+    "import paddle\n",
+    "import paddlenlp\n",
+    "from paddlenlp.transformers import AutoTokenizer\n",
+    "from paddlenlp.transformers import AutoModelForSequenceClassification\n",
+    "\n",
+    "print(paddlenlp.__version__)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 1.2 语义理解框架 ERNIE\n",
+    "\n",
+    "``ERNIE(Enhanced Representation from kNowledge IntEgration)`` 是百度提出的基于知识增强的持续学习语义理解框架,至今已有 ``ERNIE 2.0``、``ERNIE 3.0``、``ERNIE-M``、``ERNIE-tiny`` 等多种预训练模型。``ERNIE 1.0`` 采用``Transformer Encoder`` 作为其语义表示的骨架,并改进了两种 ``mask`` 策略,分别为基于**短语**和**实体**(人名、组织等)的策略。在 ``ERNIE`` 中,由多个字组成的短语或者实体将作为一个统一单元,在训练的时候被统一地 ``mask`` 掉,这样可以潜在地学习到知识的依赖以及更长的语义依赖来让模型更具泛化性。\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "\n",
+    "``ERNIE 2.0`` 则提出了连续学习(``Continual Learning``)的概念,即首先用一个简单的任务来初始化模型,在更新时用前一个任务训练好的参数作为下一个任务模型初始化的参数。这样在训练新的任务时,模型便可以记住之前学习到的知识,使得模型在新任务上获得更好的表现。``ERNIE 2.0`` 分别构建了词法、语法、语义不同级别的预训练任务,并使用不同的 task id 来标示不同的任务,在共计16个中英文任务上都取得了SOTA效果。\n",
+    "\n",
+    "\n",
+    "\n",
+    "``ERNIE 3.0`` 将自回归和自编码网络融合在一起进行预训练,其中自编码网络采用 ``ERNIE 2.0`` 的多任务学习增量式构建预训练任务,持续进行语义理解学习。其中自编码网络增加了知识增强的预训练任务。自回归网络则基于 ``Tranformer-XL`` 结构,支持长文本语言模型建模,并在多个自然语言处理任务中取得了SOTA的效果。\n",
+    "\n",
+    "\n",
+    "\n",
+    "接下来,我们将展示如何在 ``fastNLP`` 中使用基于 ``paddle`` 的 ``ERNIE 1.0`` 框架进行中文情感分析。"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2. 使用 tokenizer 处理数据并构造 dataloader\n",
+    "\n",
+    "#### 2.1 加载中文数据集 ChnSentiCorp\n",
+    "\n",
+    "``ChnSentiCorp`` 数据集是由中国科学院发布的中文句子级情感分析数据集,包含了从网络上获取的酒店、电影、书籍等多个领域的评论,每条评论都被划分为两个标签:消极(``0``)和积极(``1``),可以用于二分类的中文情感分析任务。通过 ``paddlenlp.datasets.load_dataset`` 函数,我们可以加载并查看 ``ChnSentiCorp`` 数据集的内容。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "训练集大小: 9600\n",
+      "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': ''}\n",
+      "{'text': '15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错', 'label': 1, 'qid': ''}\n",
+      "{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0, 'qid': ''}\n"
+     ]
+    }
+   ],
+   "source": [
+    "from paddlenlp.datasets import load_dataset\n",
+    "\n",
+    "train_dataset, val_dataset, test_dataset = load_dataset(\"chnsenticorp\", splits=[\"train\", \"dev\", \"test\"])\n",
+    "print(\"训练集大小:\", len(train_dataset))\n",
+    "for i in range(3):\n",
+    "    print(train_dataset[i])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.2 处理数据\n",
+    "\n",
+    "可以看到,原本的数据集仅包含中文的文本和标签,这样的数据是无法被模型识别的。同英文文本分类任务一样,我们需要使用 ``tokenizer`` 对文本进行分词并转换为数字形式的结果。我们可以加载已经预训练好的中文分词模型 ``ernie-1.0-base-zh``,将分词的过程写在函数 ``_process`` 中,然后调用数据集的 ``map`` 函数对每一条数据进行分词。其中:\n",
+    "- 参数 ``max_length`` 代表句子的最大长度;\n",
+    "- ``padding=\"max_length\"`` 表示将长度不足的结果 padding 至和最大长度相同;\n",
+    "- ``truncation=True`` 表示将长度过长的句子进行截断。\n",
+    "\n",
+    "至此,我们得到了每条数据长度均相同的数据集。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\u001b[32m[2022-06-22 21:31:04,168] [    INFO]\u001b[0m - We are using  to load 'ernie-1.0-base-zh'.\u001b[0m\n",
+      "\u001b[32m[2022-06-22 21:31:04,171] [    INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': '', 'input_ids': [1, 352, 790, 1252, 409, 283, 509, 5, 250, 196, 113, 10, 58, 518, 4, 9, 128, 70, 1495, 1855, 339, 293, 45, 302, 233, 554, 4, 544, 637, 1134, 774, 6, 494, 2068, 6, 278, 191, 6, 634, 99, 6, 2678, 144, 7, 149, 1573, 62, 12043, 661, 737, 371, 435, 7, 689, 4, 255, 201, 559, 407, 1308, 12043, 2275, 1110, 11, 19, 842, 5, 1207, 878, 4, 196, 198, 321, 96, 4, 16, 93, 291, 464, 1099, 10, 692, 811, 12043, 392, 5, 748, 1134, 10, 213, 220, 5, 4, 201, 559, 723, 595, 12043, 231, 112, 1114, 4, 7, 689, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}\n"
+     ]
+    }
+   ],
+   "source": [
+    "max_len = 128\n",
+    "model_checkpoint = \"ernie-1.0-base-zh\"\n",
+    "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)\n",
+    "def _process(data):\n",
+    "    data.update(tokenizer(\n",
+    "        data[\"text\"],\n",
+    "        max_length=max_len,\n",
+    "        padding=\"max_length\",\n",
+    "        truncation=True,\n",
+    "        return_attention_mask=True,\n",
+    "    ))\n",
+    "    return data\n",
+    "\n",
+    "train_dataset.map(_process, num_workers=5)\n",
+    "val_dataset.map(_process, num_workers=5)\n",
+    "test_dataset.map(_process, num_workers=5)\n",
+    "\n",
+    "print(train_dataset[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "得到数据集之后,我们便可以将数据集包裹在 ``PaddleDataLoader`` 中,用于之后的训练。``fastNLP`` 提供的 ``PaddleDataLoader`` 拓展了 ``paddle.io.DataLoader`` 的功能,详情可以查看相关的文档。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 11,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from fastNLP.core import PaddleDataLoader\n",
+    "import paddle.nn as nn\n",
+    "\n",
+    "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n",
+    "val_dataloader = PaddleDataLoader(val_dataset, batch_size=32, shuffle=False)\n",
+    "test_dataloader = PaddleDataLoader(test_dataset, batch_size=1, shuffle=False)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 3. 模型训练:加载 ERNIE 预训练模型,使用 fastNLP 进行训练\n",
+    "\n",
+    "#### 3.1 使用 ERNIE 预训练模型\n",
+    "\n",
+    "为了实现文本分类,我们首先需要定义文本分类的模型。``paddlenlp.transformers`` 提供了模型 ``AutoModelForSequenceClassification``,我们可以利用它来加载不同权重的文本分类模型。在 ``fastNLP`` 中,我们可以定义 ``train_step`` 和 ``evaluate_step`` 函数来实现训练和验证过程中的不同行为。\n",
+    "\n",
+    "- ``train_step`` 函数在获得返回值 ``logits`` (大小为 ``(batch_size, num_labels)``)后计算交叉熵损失 ``CrossEntropyLoss``,然后将 ``loss`` 放在字典中返回。``fastNLP`` 也支持返回 ``dataclass`` 类型的训练结果,但二者都需要包含名为 **loss** 的键或成员。\n",
+    "- ``evaluate_step`` 函数在获得返回值 ``logits`` 后,将 ``logits`` 和标签 ``label`` 放在字典中返回。\n",
+    "\n",
+    "这两个函数的参数均为数据集中字典**键**的子集,``fastNLP`` 会自动进行参数匹配然后输入到模型中。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 12,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "\u001b[32m[2022-06-22 21:31:15,577] [    INFO]\u001b[0m - We are using  to load 'ernie-1.0-base-zh'.\u001b[0m\n",
+      "\u001b[32m[2022-06-22 21:31:15,580] [    INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n"
+     ]
+    }
+   ],
+   "source": [
+    "import paddle.nn as nn\n",
+    "\n",
+    "class SeqClsModel(nn.Layer):\n",
+    "    def __init__(self, model_checkpoint, num_labels):\n",
+    "        super(SeqClsModel, self).__init__()\n",
+    "        self.model = AutoModelForSequenceClassification.from_pretrained(\n",
+    "            model_checkpoint,\n",
+    "            num_classes=num_labels,\n",
+    "        )\n",
+    "\n",
+    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
+    "        logits = self.model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)\n",
+    "        return logits\n",
+    "\n",
+    "    def train_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+    "        logits = self(input_ids, attention_mask, token_type_ids)\n",
+    "        loss = nn.CrossEntropyLoss()(logits, label)\n",
+    "        return {\"loss\": loss}\n",
+    "\n",
+    "    def evaluate_step(self, input_ids, attention_mask, token_type_ids, label):\n",
+    "        logits = self(input_ids, attention_mask, token_type_ids)\n",
+    "        return {'pred': logits, 'target': label}\n",
+    "\n",
+    "model = SeqClsModel(model_checkpoint, num_labels=2)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 3.2 设置参数并使用 Trainer 开始训练\n",
+    "\n",
+    "现在我们可以着手使用 ``fastNLP.Trainer`` 进行训练了。\n",
+    "\n",
+    "首先,为了高效地训练 ``ERNIE`` 模型,我们最好为学习率指定一定的策略。``paddlenlp`` 提供的 ``LinearDecayWithWarmup`` 可以令学习率在一段时间内从 0 开始线性地增长(预热),然后再线性地衰减至 0 。在本篇教程中,我们将学习率设置为 ``5e-5``,预热时间为 ``0.1``,然后将得到的的 ``lr_scheduler`` 赋值给 ``AdamW`` 优化器。\n",
+    "\n",
+    "其次,我们还可以为 ``Trainer`` 指定多个 ``Callback`` 来在基础的训练过程之外进行额外的定制操作。在本篇教程中,我们使用的 ``Callback`` 有以下三种:\n",
+    "\n",
+    "- ``LRSchedCallback`` - 由于我们使用了 ``Scheduler``,因此需要将 ``lr_scheduler`` 传给该 ``Callback`` 以在训练中进行更新。\n",
+    "- ``LoadBestModelCallback`` - 该 ``Callback`` 会评估结果中的 ``'acc#accuracy'`` 值,保存训练中出现的正确率最高的模型,并在训练结束时加载到模型上,方便对模型进行测试和评估。\n",
+    "\n",
+    "在 ``Trainer`` 中,我们还可以设置 ``metrics`` 来衡量模型的表现。``Accuracy`` 能够根据传入的预测值和真实值计算出模型预测的正确率。还记得模型中 ``evaluate_step`` 函数的返回值吗?键 ``pred`` 和 ``target`` 分别为 ``Accuracy.update`` 的参数名,在验证过程中 ``fastNLP`` 会自动将键和参数名匹配从而计算出正确率,这也是我们规定模型需要返回字典类型数据的原因。\n",
+    "\n",
+    "``Accuracy`` 的返回值包含三个部分:``acc``、``total`` 和 ``correct``,分别代表 ``正确率``、 ``数据总数`` 和 ``预测正确的数目``,这让您能够直观地知晓训练中模型的变化,``LoadBestModelCallback`` 的参数 ``'acc#accuracy'`` 也正是代表了 ``accuracy`` 指标的 ``acc`` 结果。\n",
+    "\n",
+    "在设定好参数之后,调用 ``run`` 函数便可以进行训练和验证了。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 13,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
[21:31:16] INFO     Running evaluator sanity check for 2 batches.              trainer.py:631\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:31:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=4641;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=822054;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:60 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.895833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1075.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.895833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1075.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:120 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.8975,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1077.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.8975\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1077.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:180 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.911667,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1094.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.911667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1094.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:240 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9225,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1107.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9225\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9275,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1113.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9275\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1113.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:60 -----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.930833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1117.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.930833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1117.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:120 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.935833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1123.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:180 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.935833,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1123.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:240 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.9375,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1125.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9375\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"acc#accuracy\": 0.941667,\n",
+       "  \"total#accuracy\": 1200.0,\n",
+       "  \"correct#accuracy\": 1130.0\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.941667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1130.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:28] INFO     Loading best model from fnlp-ernie/2022-0 load_best_model_callback.py:111\n",
+       "                    6-22-21_29_12_898095/best_so_far with                                    \n",
+       "                    acc#accuracy: 0.941667...                                                \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m0\u001b[0m \u001b]8;id=340364;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=763898;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m6\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_898095/best_so_far with \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m acc#accuracy: \u001b[1;36m0.941667\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:34] INFO     Deleting fnlp-ernie/2022-06-22-21_29_12_8 load_best_model_callback.py:131\n",
+       "                    98095/best_so_far...                                                     \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_8 \u001b]8;id=430330;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=508566;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 98095/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import LRSchedCallback, LoadBestModelCallback\n", + "from fastNLP import Trainer, Accuracy\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 2\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks = [\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"acc#accuracy\", larger_better=True, save_folder=\"fnlp-ernie\"),\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"paddle\",\n", + " optimizers=optimizer,\n", + " device=0,\n", + " n_epochs=n_epochs,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " evaluate_every=60,\n", + " metrics={\"accuracy\": Accuracy()},\n", + " callbacks=callbacks,\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 测试和评估\n", + "\n", + "现在我们已经得到了一个表现良好的 ``ERNIE`` 模型,接下来可以在测试集上测试模型的效果了。``fastNLP.Evaluator`` 提供了定制函数的功能。我们以 ``test_dataloader`` 初始化一个 ``Evaluator``,然后将写好的测试函数 ``test_batch_step_fn`` 传给参数 ``evaluate_batch_step_fn``,``Evaluate`` 在对每个 batch 进行评估时就会调用我们自定义的 ``test_batch_step_fn`` 函数而不是 ``evaluate_step`` 函数。在这里,我们仅测试 5 条数据并输出文本和对应的标签。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n",
+       "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n",
+       "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n",
+       "集算什么??简直是画蛇添足!!']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n", + "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n", + "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n", + "集算什么??简直是画蛇添足!!']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n",
+       "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n",
+       "]\n",
+       "
\n" + ], + "text/plain": [ + "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n", + "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n", + "]\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n",
+       "
\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['交通方便;环境很好;服务态度很好 房间较小']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['交通方便;环境很好;服务态度很好 房间较小']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n",
+       "
\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n",
+       "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n",
+       "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n",
+       "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n",
+       "。']\n",
+       "
\n" + ], + "text/plain": [ + "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n", + "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n", + "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n", + "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n", + "。']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n",
+       "
\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": [
+       "{}"
+      ]
+     },
+     "execution_count": 14,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from fastNLP import Evaluator\n",
+    "def test_batch_step_fn(evaluator, batch):\n",
+    "    input_ids = batch[\"input_ids\"]\n",
+    "    attention_mask = batch[\"attention_mask\"]\n",
+    "    token_type_ids = batch[\"token_type_ids\"]\n",
+    "    logits = model(input_ids, attention_mask, token_type_ids)\n",
+    "    predict = logits.argmax().item()\n",
+    "    print(\"text:\", batch['text'])\n",
+    "    print(\"labels:\", predict)\n",
+    "\n",
+    "evaluator = Evaluator(\n",
+    "    model=model,\n",
+    "    dataloaders=test_dataloader,\n",
+    "    driver=\"paddle\",\n",
+    "    device=0,\n",
+    "    evaluate_batch_step_fn=test_batch_step_fn,\n",
+    ")\n",
+    "evaluator.run(5)    "
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3.7.13 ('fnlp-paddle')",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.7.13"
+  },
+  "orig_nbformat": 4,
+  "vscode": {
+   "interpreter": {
+    "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720"
+   }
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb
new file mode 100644
index 00000000..439d7f9f
--- /dev/null
+++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb
@@ -0,0 +1,1510 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# E4. 使用 paddlenlp 和 fastNLP 训练中文阅读理解任务\n",
+    "\n",
+    "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何在 `fastNLP` 中通过自定义 `Metric` 和 损失函数来完成进阶的问答任务。\n",
+    "\n",
+    "1. 基础介绍:自然语言处理中的阅读理解任务\n",
+    "\n",
+    "2. 准备工作:加载 `DuReader-robust` 数据集,并使用 `tokenizer` 处理数据\n",
+    "\n",
+    "3. 模型训练:自己定义评测用的 `Metric` 实现更加自由的任务评测"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. 基础介绍:自然语言处理中的阅读理解任务\n",
+    "\n",
+    "阅读理解任务,顾名思义,就是给出一段文字,然后让模型理解这段文字所含的语义。大部分机器阅读理解任务都采用问答式测评,即设计与文章内容相关的自然语言式问题,让模型理解问题并根据文章作答。与文本分类任务不同的是,在阅读理解任务中我们有时需要需要输入“一对”句子,分别代表问题和上下文;答案的格式也分为多种:\n",
+    "\n",
+    "- 多项选择:让模型从多个答案选项中选出正确答案\n",
+    "- 区间答案:答案为上下文的一段子句,需要模型给出答案的起始位置\n",
+    "- 自由回答:不做限制,让模型自行生成答案\n",
+    "- 完形填空:在原文中挖空部分关键词,让模型补全;这类答案往往不需要问题\n",
+    "\n",
+    "如果您对 `transformers` 有所了解的话,其中的 `ModelForQuestionAnswering` 系列模型就可以用于这项任务。阅读理解模型的泛用性是衡量该技术能否在实际应用中大规模落地的重要指标之一,随着当前技术的进步,许多模型虽然能够在一些测试集上取得较好的性能,但在实际应用中,这些模型仍然难以让人满意。在本篇教程中,我们将会为您展示如何训练一个问答模型。\n",
+    "\n",
+    "在这一领域,`SQuAD` 数据集是一个影响深远的数据集。它的全称是斯坦福问答数据集(Stanford Question Answering Dataset),每条数据包含 `(问题,上下文,答案)` 三部分,规模大(约十万条,2.0又新增了五万条),在提出之后很快成为训练问答任务的经典数据集之一。`SQuAD` 数据集有两个指标来衡量模型的表现:`EM`(Exact Match,精确匹配)和 `F1`(模糊匹配)。前者反应了模型给出的答案中有多少和正确答案完全一致,后者则反应了模型给出的答案中与正确答案重叠的部分,均为越高越好。"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2. 准备工作:加载 DuReader-robust 数据集,并使用 tokenizer 处理数据"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/remote-home/shxing/anaconda3/envs/fnlp-paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+      "  from .autonotebook import tqdm as notebook_tqdm\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2.3.3\n"
+     ]
+    }
+   ],
+   "source": [
+    "import sys\n",
+    "sys.path.append(\"../\")\n",
+    "import paddle\n",
+    "import paddlenlp\n",
+    "\n",
+    "print(paddlenlp.__version__)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "在数据集方面,我们选用 `DuReader-robust` 中文数据集作为训练数据。它是一种抽取式问答数据集,采用 `SQuAD` 数据格式,能够评估真实应用场景下模型的泛用性。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 17,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+      "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n",
+      "\u001b[32m[2022-06-27 19:22:46,998] [    INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'id': '0a25cb4bc1ab6f474c699884e04601e4', 'title': '', 'context': '第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。', 'question': '仙剑奇侠传3第几集上天界', 'answers': {'text': ['第35集'], 'answer_start': [0]}}\n",
+      "{'id': '7de192d6adf7d60ba73ba25cf590cc1e', 'title': '', 'context': '选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。', 'question': '燃气热水器哪个牌子好', 'answers': {'text': ['方太'], 'answer_start': [110]}}\n",
+      "{'id': 'b9e74d4b9228399b03701d1fe6d52940', 'title': '', 'context': '迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役。迈克尔·乔丹(Michael Jordan),1963年2月17日生于纽约布鲁克林,美国著名篮球运动员,司职得分后卫,历史上最伟大的篮球运动员。1984年的NBA选秀大会,乔丹在首轮第3顺位被芝加哥公牛队选中。 1986-87赛季,乔丹场均得到37.1分,首次获得分王称号。1990-91赛季,乔丹连夺常规赛MVP和总决赛MVP称号,率领芝加哥公牛首次夺得NBA总冠军。 1997-98赛季,乔丹获得个人职业生涯第10个得分王,并率领公牛队第六次夺得总冠军。2009年9月11日,乔丹正式入选NBA名人堂。', 'question': '乔丹打了多少个赛季', 'answers': {'text': ['15个'], 'answer_start': [12]}}\n",
+      "训练集大小: 14520\n",
+      "验证集大小: 1417\n"
+     ]
+    }
+   ],
+   "source": [
+    "from paddlenlp.datasets import load_dataset\n",
+    "train_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"train\")\n",
+    "val_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"validation\")\n",
+    "for i in range(3):\n",
+    "    print(train_dataset[i])\n",
+    "print(\"训练集大小:\", len(train_dataset))\n",
+    "print(\"验证集大小:\", len(val_dataset))\n",
+    "\n",
+    "MODEL_NAME = \"ernie-1.0-base-zh\"\n",
+    "from paddlenlp.transformers import ErnieTokenizer\n",
+    "tokenizer =ErnieTokenizer.from_pretrained(MODEL_NAME)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.1 处理训练集\n",
+    "\n",
+    "对于阅读理解任务,数据处理的方式较为麻烦。接下来我们会为您详细讲解处理函数 `_process_train` 的功能,同时也将通过实践展示关于 `tokenizer` 的更多功能,让您更加深入地了解自然语言处理任务。首先让我们向 `tokenizer` 输入一条数据(以列表的形式):"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "2\n",
+      "dict_keys(['offset_mapping', 'input_ids', 'token_type_ids', 'overflow_to_sample'])\n"
+     ]
+    }
+   ],
+   "source": [
+    "result = tokenizer(\n",
+    "    [train_dataset[0][\"question\"]],\n",
+    "    [train_dataset[0][\"context\"]],\n",
+    "    stride=128,\n",
+    "    max_length=256,\n",
+    "    padding=\"max_length\",\n",
+    "    return_dict=False\n",
+    ")\n",
+    "\n",
+    "print(len(result))\n",
+    "print(result[0].keys())"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "首先不难理解的是,模型必须要同时接受问题(`question`)和上下文(`context`)才能够进行阅读理解,因此我们需要将二者同时进行分词(`tokenize`)。所幸,`Tokenizer` 提供了这一功能,当我们调用 `tokenizer` 的时候,其第一个参数名为 `text`,第二个参数名为 `text_pair`,这使得我们可以同时对一对文本进行分词。同时,`tokenizer` 还需要标记出一条数据中哪些属于问题,哪些属于上下文,这一功能则由 `token_type_ids` 完成。`token_type_ids` 会将输入的第一个文本(问题)标记为 `0`,第二个文本(上下文)标记为 `1`,这样模型在训练时便可以将问题和上下文区分开来:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2]\n",
+      "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓', '缓', '张', '开', '眼', '睛', ',', '景', '天', '又', '惊', '又', '喜', '之', '际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众', '人', '无', '恙', ',', '也', '十', '分', '高', '兴', '。', '众', '人', '登', '船', ',', '用', '尽', '合', '力', '把', '自', '身', '的', '真', '气', '和', '水', '分', '输', '给', '她', '。', '雪', '见', '终', '于', '醒', '过', '来', '了', ',', '但', '却', '一', '脸', '木', '然', ',', '全', '无', '反', '应', '。', '众', '人', '向', '常', '胤', '求', '助', ',', '却', '发', '现', '人', '世', '界', '竟', '没', '有', '雪', '见', '的', '身', '世', '纪', '录', '。', '长', '卿', '询', '问', '清', '微', '的', '身', '世', ',', '清', '微', '语', '带', '双', '关', '说', '一', '切', '上', '了', '天', '界', '便', '有', '答', '案', '。', '长', '卿', '驾', '驶', '仙', '船', ',', '众', '人', '决', '定', '立', '马', '动', '身', ',', '往', '天', '界', '而', '去', '。', '众', '人', '来', '到', '一', '荒', '山', ',', '长', '卿', '指', '出', ',', '魔', '界', '和', '天', '界', '相', '连', '。', '由', '魔', '界', '进', '入', '通', '过', '神', '魔', '之', '井', ',', '便', '可', '登', '天', '。', '众', '人', '至', '魔', '界', '入', '口', ',', '仿', '若', '一', '黑', '色', '的', '蝙', '蝠', '洞', ',', '但', '始', '终', '无', '法', '进', '入', '。', '后', '来', '花', '楹', '发', '现', '只', '要', '有', '翅', '膀', '便', '能', '飞', '入', '[SEP]']\n",
+      "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(result[0][\"input_ids\"])\n",
+    "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"]))\n",
+    "print(result[0][\"token_type_ids\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "根据上面的输出我们可以看出,`tokenizer` 会将数据开头用 `[CLS]` 标记,用 `[SEP]` 来分割句子。同时,根据 `token_type_ids` 得到的 0、1 串,我们也很容易将问题和上下文区分开。顺带一提,如果一条数据进行了 `padding`,那么这部分会被标记为 `0` 。\n",
+    "\n",
+    "在输出的 `keys` 中还有一项名为 `offset_mapping` 的键。该项数据能够表示分词后的每个 `token` 在原文中对应文字或词语的位置。比如我们可以像下面这样将数据打印出来:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n",
+      "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427]\n",
+      "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓']\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(result[0][\"offset_mapping\"][:20])\n",
+    "print(result[0][\"input_ids\"][:20])\n",
+    "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"])[:20])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "`[CLS]` 由于是 `tokenizer` 自己添加进去用于标记数据的 `token`,因此它在原文中找不到任何对应的词语,所以给出的位置范围就是 `(0, 0)`;第二个 `token` 对应第一个 `“仙”` 字,因此映射的位置就是 `(0, 1)`;同理,后面的 `[SEP]` 也不对应任何文字,映射的位置为 `(0, 0)`;而接下来的 `token` 对应 **上下文** 中的第一个字 `“第”`,映射出的位置为 `(0, 1)`;再后面的 `token` 对应原文中的两个字符 `35`,因此其位置映射为 `(1, 3)` 。通过这种手段,我们可以更方便地获取 `token` 与原文的对应关系。\n",
+    "\n",
+    "最后,您也许会注意到我们获取的 `result` 长度为 2 。这是文本在分词后长度超过了 `max_length` 256 ,`tokenizer` 将数据分成了两部分所致。在阅读理解任务中,我们不可能像文本分类那样轻易地将一条数据截断,因为答案很可能就出现在后面被丢弃的那部分数据中,因此,我们需要保留所有的数据(当然,您也可以直接丢弃这些超长的数据)。`overflow_to_sample` 则可以标识当前数据在原数据的索引:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "[CLS]仙剑奇侠传3第几集上天界[SEP]第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入[SEP]\n",
+      "overflow_to_sample:  0\n",
+      "[CLS]仙剑奇侠传3第几集上天界[SEP]说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]\n",
+      "overflow_to_sample:  0\n"
+     ]
+    }
+   ],
+   "source": [
+    "for res in result:\n",
+    "    tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"])\n",
+    "    print(\"\".join(tokens))\n",
+    "    print(\"overflow_to_sample: \", res[\"overflow_to_sample\"])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "将两条数据均输出之后可以看到,它们都出自我们传入的数据,并且存在一部分重合。`tokenizer` 的 `stride` 参数可以设置重合部分的长度,这也可以帮助模型识别被分割开的两条数据;`overflow_to_sample` 的 `0` 则代表它们来自于第 `0` 条数据。\n",
+    "\n",
+    "基于以上信息,我们处理训练集的思路如下:\n",
+    "\n",
+    "1. 通过 `overflow_to_sample` 来获取原来的数据\n",
+    "2. 通过原数据的 `answers` 找到答案的起始位置\n",
+    "3. 通过 `offset_mapping` 给出的映射关系在分词处理后的数据中找到答案的起始位置,分别记录在 `start_pos` 和 `end_pos` 中;如果没有找到答案(比如答案被截断了),那么答案的起始位置就被标记为 `[CLS]` 的位置。\n",
+    "\n",
+    "这样 `_process_train` 函数就呼之欲出了,我们调用 `train_dataset.map` 函数,并将 `batched` 参数设置为 `True` ,将所有数据批量地进行更新。有一点需要注意的是,**在处理过后数据量会增加**。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 18,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "{'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (0, 0)], 'input_ids': [1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'overflow_to_sample': 0, 'start_pos': 14, 'end_pos': 16}\n",
+      "处理后的训练集大小: 26198\n"
+     ]
+    }
+   ],
+   "source": [
+    "max_length = 256\n",
+    "doc_stride = 128\n",
+    "def _process_train(data):\n",
+    "\n",
+    "    contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+    "    questions = [data[i][\"question\"] for i in range(len(data))]\n",
+    "\n",
+    "    tokenized_data_list = tokenizer(\n",
+    "        questions,\n",
+    "        contexts,\n",
+    "        stride=doc_stride,\n",
+    "        max_length=max_length,\n",
+    "        padding=\"max_length\",\n",
+    "        return_dict=False\n",
+    "    )\n",
+    "\n",
+    "    for i, tokenized_data in enumerate(tokenized_data_list):\n",
+    "        # 获取 [CLS] 对应的位置\n",
+    "        input_ids = tokenized_data[\"input_ids\"]\n",
+    "        cls_index = input_ids.index(tokenizer.cls_token_id)\n",
+    "\n",
+    "        # 在 tokenize 的过程中,汉字和 token 在位置上并非一一对应的\n",
+    "        # 而 offset mapping 记录了每个 token 在原文中对应的起始位置\n",
+    "        offsets = tokenized_data[\"offset_mapping\"]\n",
+    "        # token_type_ids 记录了一条数据中哪些是问题,哪些是上下文\n",
+    "        token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+    "\n",
+    "        # 一条数据可能因为长度过长而在 tokenized_data 中存在多个结果\n",
+    "        # overflow_to_sample 表示了当前 tokenize_example 属于 data 中的哪一条数据\n",
+    "        sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+    "        answers = data[sample_index][\"answers\"]\n",
+    "\n",
+    "        # answers 和 answer_starts 均为长度为 1 的 list\n",
+    "        # 我们可以计算出答案的结束位置\n",
+    "        start_char = answers[\"answer_start\"][0]\n",
+    "        end_char = start_char + len(answers[\"text\"][0])\n",
+    "\n",
+    "        token_start_index = 0\n",
+    "        while token_type_ids[token_start_index] != 1:\n",
+    "            token_start_index += 1\n",
+    "\n",
+    "        token_end_index = len(input_ids) - 1\n",
+    "        while token_type_ids[token_end_index] != 1:\n",
+    "            token_end_index -= 1\n",
+    "        # 分词后一条数据的结尾一定是 [SEP],因此还需要减一\n",
+    "        token_end_index -= 1\n",
+    "\n",
+    "        if not (offsets[token_start_index][0] <= start_char and\n",
+    "                offsets[token_end_index][1] >= end_char):\n",
+    "            # 如果答案不在这条数据中,则将答案位置标记为 [CLS] 的位置\n",
+    "            tokenized_data_list[i][\"start_pos\"] = cls_index\n",
+    "            tokenized_data_list[i][\"end_pos\"] = cls_index\n",
+    "        else:\n",
+    "            # 否则,我们可以找到答案对应的 token 的起始位置,记录在 start_pos 和 end_pos 中\n",
+    "            while token_start_index < len(offsets) and offsets[\n",
+    "                    token_start_index][0] <= start_char:\n",
+    "                token_start_index += 1\n",
+    "            tokenized_data_list[i][\"start_pos\"] = token_start_index - 1\n",
+    "            while offsets[token_end_index][1] >= end_char:\n",
+    "                token_end_index -= 1\n",
+    "            tokenized_data_list[i][\"end_pos\"] = token_end_index + 1\n",
+    "\n",
+    "    return tokenized_data_list\n",
+    "\n",
+    "train_dataset.map(_process_train, batched=True, num_workers=5)\n",
+    "print(train_dataset[0])\n",
+    "print(\"处理后的训练集大小:\", len(train_dataset))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.2 处理验证集\n",
+    "\n",
+    "对于验证集的处理则简单得多,我们只需要保存原数据的 `id` 并将 `offset_mapping` 中不属于上下文的部分设置为 `None` 即可。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 8,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       ""
+      ]
+     },
+     "execution_count": 8,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "def _process_val(data):\n",
+    "\n",
+    "    contexts = [data[i][\"context\"] for i in range(len(data))]\n",
+    "    questions = [data[i][\"question\"] for i in range(len(data))]\n",
+    "\n",
+    "    tokenized_data_list = tokenizer(\n",
+    "        questions,\n",
+    "        contexts,\n",
+    "        stride=doc_stride,\n",
+    "        max_length=max_length,\n",
+    "        return_dict=False\n",
+    "    )\n",
+    "\n",
+    "    for i, tokenized_data in enumerate(tokenized_data_list):\n",
+    "        token_type_ids = tokenized_data[\"token_type_ids\"]\n",
+    "        # 保存数据对应的 id\n",
+    "        sample_index = tokenized_data[\"overflow_to_sample\"]\n",
+    "        tokenized_data_list[i][\"example_id\"] = data[sample_index][\"id\"]\n",
+    "\n",
+    "        # 将不属于 context 的 offset 设置为 None\n",
+    "        tokenized_data_list[i][\"offset_mapping\"] = [\n",
+    "            (o if token_type_ids[k] == 1 else None)\n",
+    "            for k, o in enumerate(tokenized_data[\"offset_mapping\"])\n",
+    "        ]\n",
+    "\n",
+    "    return tokenized_data_list\n",
+    "\n",
+    "val_dataset.map(_process_val, batched=True, num_workers=5)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### 2.3 DataLoader\n",
+    "\n",
+    "最后使用 `PaddleDataLoader` 将数据集包裹起来即可。"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 9,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP.core import PaddleDataLoader\n", + "\n", + "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = PaddleDataLoader(val_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:自己定义评测用的 Metric 实现更加自由的任务评测\n", + "\n", + "#### 3.1 损失函数\n", + "\n", + "对于阅读理解任务,我们使用的是 `ErnieForQuestionAnswering` 模型。该模型在接受输入后会返回两个值:`start_logits` 和 `end_logits` ,大小均为 `(batch_size, sequence_length)`,反映了每条数据每个词语为答案起始位置的可能性,因此我们需要自定义一个损失函数来计算 `loss`。 `CrossEntropyLossForSquad` 会分别对答案起始位置的预测值和真实值计算交叉熵,最后返回其平均值作为最终的损失。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class CrossEntropyLossForSquad(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(CrossEntropyLossForSquad, self).__init__()\n", + "\n", + " def forward(self, start_logits, end_logits, start_pos, end_pos):\n", + " start_pos = paddle.unsqueeze(start_pos, axis=-1)\n", + " end_pos = paddle.unsqueeze(end_pos, axis=-1)\n", + " start_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=start_logits, label=start_pos)\n", + " start_loss = paddle.mean(start_loss)\n", + " end_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=end_logits, label=end_pos)\n", + " end_loss = paddle.mean(end_loss)\n", + "\n", + " loss = (start_loss + end_loss) / 2\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.2 定义模型\n", + "\n", + "模型的核心则是 `ErnieForQuestionAnswering` 的 `ernie-1.0-base-zh` 预训练模型,同时按照 `fastNLP` 的规定定义 `train_step` 和 `evaluate_step` 函数。这里 `evaluate_step` 函数并没有像文本分类那样直接返回该批次数据的评测结果,这一点我们将在下面为您讲解。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-27 19:00:15,825] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n", + "W0627 19:00:15.831080 21543 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.2, Runtime API Version: 11.2\n", + "W0627 19:00:15.843276 21543 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.\n" + ] + } + ], + "source": [ + "from paddlenlp.transformers import ErnieForQuestionAnswering\n", + "\n", + "class QAModel(paddle.nn.Layer):\n", + " def __init__(self, model_checkpoint):\n", + " super(QAModel, self).__init__()\n", + " self.model = ErnieForQuestionAnswering.from_pretrained(model_checkpoint)\n", + " self.loss_func = CrossEntropyLossForSquad()\n", + "\n", + " def forward(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self.model(input_ids, token_type_ids)\n", + " return start_logits, end_logits\n", + "\n", + " def train_step(self, input_ids, token_type_ids, start_pos, end_pos):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " loss = self.loss_func(start_logits, end_logits, start_pos, end_pos)\n", + " return {\"loss\": loss}\n", + "\n", + " def evaluate_step(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " return {\"start_logits\": start_logits, \"end_logits\": end_logits}\n", + "\n", + "model = QAModel(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 自定义 Metric 进行数据的评估\n", + "\n", + "`paddlenlp` 为我们提供了评测 `SQuAD` 格式数据集的函数 `compute_prediction` 和 `squad_evaluate`:\n", + "- `compute_prediction` 函数要求传入原数据 `examples` 、处理后的数据 `features` 和 `features` 对应的结果 `predictions`(一个包含所有数据 `start_logits` 和 `end_logits` 的元组)\n", + "- `squad_evaluate` 要求传入原数据 `examples` 和预测结果 `all_predictions`(通常来自于 `compute_prediction`)\n", + "\n", + "在使用这两个函数的时候,我们需要向其中传入数据集,但显然根据 `fastNLP` 的设计,我们无法在 `evaluate_step` 里实现这一过程,并且 `fastNLP` 也并没有提供计算 `F1` 和 `EM` 的 `Metric`,故我们需要自己定义用于评测的 `Metric`。\n", + "\n", + "在初始化之外,一个 `Metric` 还需要实现三个函数:\n", + "\n", + "1. `reset` - 该函数会在验证数据集的迭代之前被调用,用于清空数据;在我们自定义的 `Metric` 中,我们需要将 `all_start_logits` 和 `all_end_logits` 清空,重新收集每个 `batch` 的结果。\n", + "2. `update` - 该函数会在在每个 `batch` 得到结果后被调用,用于更新 `Metric` 的状态;它的参数即为 `evaluate_step` 返回的内容。我们在这里将得到的 `start_logits` 和 `end_logits` 收集起来。\n", + "3. `get_metric` - 该函数会在数据集被迭代完毕后调用,用于计算评测的结果。现在我们有了整个验证集的 `all_start_logits` 和 `all_end_logits` ,将他们传入 `compute_predictions` 函数得到预测的结果,并继续使用 `squad_evaluate` 函数得到评测的结果。\n", + " - 注:`suqad_evaluate` 函数会自己输出评测结果,为了不让其干扰 `fastNLP` 输出,这里我们使用 `contextlib.redirect_stdout(None)` 将函数的标准输出屏蔽掉。\n", + "\n", + "综上,`SquadEvaluateMetric` 实现的评估过程是:将验证集中所有数据的 `logits` 收集起来,然后统一传入 `compute_prediction` 和 `squad_evaluate` 中进行评估。值得一提的是,`paddlenlp.datasets.load_dataset` 返回的结果是一个 `MapDataset` 类型,其 `data` 成员为加载时的数据,`new_data` 为经过 `map` 函数处理后更新的数据,因此可以分别作为 `examples` 和 `features` 传入。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core import Metric\n", + "from paddlenlp.metrics.squad import squad_evaluate, compute_prediction\n", + "import contextlib\n", + "\n", + "class SquadEvaluateMetric(Metric):\n", + " def __init__(self, examples, features, testing=False):\n", + " super(SquadEvaluateMetric, self).__init__(\"paddle\", False)\n", + " self.examples = examples\n", + " self.features = features\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + " self.testing = testing\n", + "\n", + " def reset(self):\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + "\n", + " def update(self, start_logits, end_logits):\n", + " for start, end in zip(start_logits, end_logits):\n", + " self.all_start_logits.append(start.numpy())\n", + " self.all_end_logits.append(end.numpy())\n", + "\n", + " def get_metric(self):\n", + " all_predictions, _, _ = compute_prediction(\n", + " self.examples, self.features[:len(self.all_start_logits)],\n", + " (self.all_start_logits, self.all_end_logits),\n", + " False, 20, 30\n", + " )\n", + " with contextlib.redirect_stdout(None):\n", + " result = squad_evaluate(\n", + " examples=self.examples,\n", + " preds=all_predictions,\n", + " is_whitespace_splited=False\n", + " )\n", + "\n", + " if self.testing:\n", + " self.print_predictions(all_predictions)\n", + " return result\n", + "\n", + " def print_predictions(self, preds):\n", + " for i, data in enumerate(self.examples):\n", + " if i >= 5:\n", + " break\n", + " print()\n", + " print(\"原文:\", data[\"context\"])\n", + " print(\"问题:\", data[\"question\"], \\\n", + " \"答案:\", preds[data[\"id\"]], \\\n", + " \"正确答案:\", data[\"answers\"][\"text\"])\n", + "\n", + "metric = SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.4 训练\n", + "\n", + "至此所有的准备工作已经完成,可以使用 `Trainer` 进行训练了。学习率我们依旧采用线性预热策略 `LinearDecayWithWarmup`,优化器为 `AdamW`;回调模块我们选择 `LRSchedCallback` 更新学习率和 `LoadBestModelCallback` 监视评测结果的 `f1` 分数。初始化好 `Trainer` 之后,就将训练的过程交给 `fastNLP` 吧。" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[19:04:54] INFO     Running evaluator sanity check for 2 batches.              trainer.py:631\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[19:04:54]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=367046;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96810;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:100 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m100\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 49.25899788285109,\n",
+       "  \"f1#squad\": 66.55559127349602,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 49.25899788285109,\n",
+       "  \"HasAns_f1#squad\": 66.55559127349602,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:200 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m200\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 57.37473535638673,\n",
+       "  \"f1#squad\": 70.93036525200617,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 57.37473535638673,\n",
+       "  \"HasAns_f1#squad\": 70.93036525200617,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 63.86732533521524,\n",
+       "  \"f1#squad\": 78.62546663568186,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 63.86732533521524,\n",
+       "  \"HasAns_f1#squad\": 78.62546663568186,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:400 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m400\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 64.92589978828511,\n",
+       "  \"f1#squad\": 79.36746074079691,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 64.92589978828511,\n",
+       "  \"HasAns_f1#squad\": 79.36746074079691,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:500 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m500\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.70218772053634,\n",
+       "  \"f1#squad\": 80.33295482054824,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.70218772053634,\n",
+       "  \"HasAns_f1#squad\": 80.33295482054824,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:600 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m600\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.41990119971771,\n",
+       "  \"f1#squad\": 79.7483487059053,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.41990119971771,\n",
+       "  \"HasAns_f1#squad\": 79.7483487059053,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:700 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m700\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 66.61961891319689,\n",
+       "  \"f1#squad\": 80.32432238994133,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 66.61961891319689,\n",
+       "  \"HasAns_f1#squad\": 80.32432238994133,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:800 ----------------------------\n",
+       "
\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m800\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n",
+       "  \"exact#squad\": 65.84333098094567,\n",
+       "  \"f1#squad\": 79.23169801265415,\n",
+       "  \"total#squad\": 1417,\n",
+       "  \"HasAns_exact#squad\": 65.84333098094567,\n",
+       "  \"HasAns_f1#squad\": 79.23169801265415,\n",
+       "  \"HasAns_total#squad\": 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[19:20:28] INFO     Loading best model from fnlp-ernie-squad/ load_best_model_callback.py:111\n",
+       "                    2022-06-27-19_00_15_388554/best_so_far                                   \n",
+       "                    with f1#squad: 80.33295482054824...                                      \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m[19:20:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie-squad/ \u001b]8;id=163935;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=31503;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_00_15_388554/best_so_far \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m with f1#squad: \u001b[1;36m80.33295482054824\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
           INFO     Deleting fnlp-ernie-squad/2022-06-27-19_0 load_best_model_callback.py:131\n",
+       "                    0_15_388554/best_so_far...                                               \n",
+       "
\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie-squad/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_0 \u001b]8;id=560859;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=573263;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 0_15_388554/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Trainer, LRSchedCallback, LoadBestModelCallback\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 1\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(3e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks=[\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"f1#squad\", larger_better=True, save_folder=\"fnlp-ernie-squad\")\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " device=1,\n", + " optimizers=optimizer,\n", + " n_epochs=n_epochs,\n", + " callbacks=callbacks,\n", + " evaluate_every=100,\n", + " metrics={\"squad\": metric},\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.5 测试\n", + "\n", + "最后,我们可以使用 `Evaluator` 查看我们训练的结果。我们在之前为 `SquadEvaluateMetric` 设置了 `testing` 参数来在测试阶段进行输出,可以看到,训练的结果还是比较不错的。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n",
+       "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n",
+       "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n",
+       "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n",
+       "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n",
+       "行垫,油墨外露容易脱落。 \n",
+       "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n", + "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n", + "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n", + "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n", + "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n", + "行垫,油墨外露容易脱落。 \n", + "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n",
+       "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n",
+       "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n",
+       "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n",
+       "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n",
+       "10厘米。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n", + "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n", + "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n", + "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n", + "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n", + "10厘米。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n",
+       "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n",
+       "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n",
+       "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n",
+       "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n", + "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n", + "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n", + "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n", + "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n",
+       "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n",
+       "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n",
+       "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n",
+       "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n",
+       "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n",
+       "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n",
+       "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n",
+       "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n",
+       "
\n" + ], + "text/plain": [ + "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n", + "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n", + "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n", + "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n", + "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n", + "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n", + "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n", + "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n", + "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n",
+       "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n",
+       "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n",
+       "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n",
+       "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n",
+       ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n",
+       "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n",
+       "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n",
+       "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n",
+       "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n",
+       "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n",
+       "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n",
+       "
\n" + ], + "text/plain": [ + "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n", + "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n", + "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n", + "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n", + "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n", + ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n", + "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n", + "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n", + "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n", + "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n", + "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n", + "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n",
+       "
\n" + ], + "text/plain": [ + "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
{\n",
+       "    'exact#squad': 65.70218772053634,\n",
+       "    'f1#squad': 80.33295482054824,\n",
+       "    'total#squad': 1417,\n",
+       "    'HasAns_exact#squad': 65.70218772053634,\n",
+       "    'HasAns_f1#squad': 80.33295482054824,\n",
+       "    'HasAns_total#squad': 1417\n",
+       "}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[32m'HasAns_exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'HasAns_f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'HasAns_total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=val_dataloader,\n", + " device=1,\n", + " metrics={\n", + " \"squad\": SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + " testing=True,\n", + " ),\n", + " },\n", + ")\n", + "result = evaluator.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('fnlp-paddle')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/tutorials/figures/E1-fig-glue-benchmark.png b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png new file mode 100644 index 00000000..515db700 Binary files /dev/null and b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png differ diff --git a/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png new file mode 100644 index 00000000..b5a9c1b8 Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png differ diff --git a/docs/source/tutorials/figures/E2-fig-pet-model.png b/docs/source/tutorials/figures/E2-fig-pet-model.png new file mode 100644 index 00000000..c3c377c0 Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-pet-model.png differ diff --git a/docs/source/tutorials/figures/T0-fig-parameter-matching.png b/docs/source/tutorials/figures/T0-fig-parameter-matching.png new file mode 100644 index 00000000..24013cc1 Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-parameter-matching.png differ diff --git a/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png new file mode 100644 index 00000000..38222ee8 Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png differ diff --git a/docs/source/tutorials/figures/T0-fig-training-structure.png b/docs/source/tutorials/figures/T0-fig-training-structure.png new file mode 100644 index 00000000..edc2e2ff Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-training-structure.png differ diff --git a/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png new file mode 100644 index 00000000..803cf34a Binary files /dev/null and b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png new file mode 100644 index 00000000..ff2519c4 Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png new file mode 100644 index 00000000..ed003a2f Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png new file mode 100644 index 00000000..d45f65d8 Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png new file mode 100644 index 00000000..f50ddb1c Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png differ diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 9885a175..31249c80 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -2,4 +2,4 @@ from fastNLP.envs import * from fastNLP.core import * -__version__ = '0.8.0beta' +__version__ = '1.0.0alpha' diff --git a/fastNLP/core/callbacks/callback_event.py b/fastNLP/core/callbacks/callback_event.py index 8a51b6de..f632cf3c 100644 --- a/fastNLP/core/callbacks/callback_event.py +++ b/fastNLP/core/callbacks/callback_event.py @@ -35,14 +35,14 @@ class Event: :param value: Trainer 的 callback 时机; :param every: 每触发多少次才真正运行一次; - :param once: 在第一次运行后时候再次执行; + :param once: 是否仅运行一次; :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; """ every: Optional[int] - once: Optional[int] + once: Optional[bool] - def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, + def __init__(self, value: str, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): self.every = every self.once = once @@ -68,7 +68,6 @@ class Event: return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_sanity_check_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_sanity_check_begin` 时触发; @@ -85,7 +84,6 @@ class Event: return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_sanity_check_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_sanity_check_end` 时触发; @@ -101,7 +99,6 @@ class Event: return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_begin` 时触发; @@ -117,7 +114,6 @@ class Event: return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_end` 时触发; @@ -133,7 +129,6 @@ class Event: return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_epoch_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_epoch_begin` 时触发; @@ -149,7 +144,6 @@ class Event: return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_epoch_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_epoch_end` 时触发; @@ -165,7 +159,6 @@ class Event: return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_fetch_data_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_fetch_data_begin` 时触发; @@ -181,7 +174,6 @@ class Event: return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_fetch_data_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_fetch_data_end` 时触发; @@ -197,7 +189,6 @@ class Event: return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_batch_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_batch_begin` 时触发; @@ -213,7 +204,6 @@ class Event: return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_batch_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_batch_end` 时触发; @@ -229,7 +219,6 @@ class Event: return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_exception(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_exception` 时触发; @@ -245,7 +234,6 @@ class Event: return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_save_model(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_save_model` 时触发; @@ -261,7 +249,6 @@ class Event: return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_load_model(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_load_model` 时触发; @@ -277,7 +264,6 @@ class Event: return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_save_checkpoint(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_save_checkpoint` 时触发; @@ -293,7 +279,6 @@ class Event: return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_load_checkpoint(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; @@ -309,7 +294,6 @@ class Event: return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_load_checkpoint(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; @@ -325,7 +309,6 @@ class Event: return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_before_backward(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_before_backward` 时触发; @@ -341,7 +324,6 @@ class Event: return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_after_backward(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_after_backward` 时触发; @@ -357,7 +339,6 @@ class Event: return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_before_optimizers_step(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_before_optimizers_step` 时触发; @@ -373,7 +354,6 @@ class Event: return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_after_optimizers_step(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_after_optimizers_step` 时触发; @@ -389,7 +369,6 @@ class Event: return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_before_zero_grad(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_before_zero_grad` 时触发; @@ -405,7 +384,6 @@ class Event: return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_after_zero_grad(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_after_zero_grad` 时触发; @@ -421,7 +399,6 @@ class Event: return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_evaluate_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_evaluate_begin` 时触发; @@ -437,7 +414,6 @@ class Event: return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_evaluate_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_evaluate_end` 时触发; diff --git a/fastNLP/core/collators/padders/oneflow_padder.py b/fastNLP/core/collators/padders/oneflow_padder.py index 5e235a0f..30d73e26 100644 --- a/fastNLP/core/collators/padders/oneflow_padder.py +++ b/fastNLP/core/collators/padders/oneflow_padder.py @@ -7,6 +7,7 @@ from inspect import isclass import numpy as np from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from fastNLP.envs.utils import _module_available if _NEED_IMPORT_ONEFLOW: import oneflow diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 14dad89b..454db3d0 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -84,13 +84,13 @@ class Trainer(TrainerEventTrigger): .. warning:: 当使用分布式训练时, **fastNLP** 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 epoch 中,不同卡 - 用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 - 你自身保证每张卡上所使用的数据是不同的。 + 用以训练的数据是不重叠的。如果您对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 + 您自身保证每张卡上所使用的数据是不同的。 :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 ``torch.distributed.launch/run`` 启动时可以为 ``None``, - 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 - 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 + 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是您可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 + 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时您也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); device 的可选输入如下所示: @@ -196,7 +196,7 @@ class Trainer(TrainerEventTrigger): 3. 如果此时 batch 此时是其它类型,那么我们将会直接报错; 2. 如果 ``input_mapping`` 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; - 注意该参数会被传进 ``Evaluator`` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时); + 注意该参数会被传进 ``Evaluator`` 中;因此您可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时); 如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``input_mapping``, 请使用 ``train_input_mapping`` 与 ``evaluate_input_mapping`` 分别进行设置。 :param output_mapping: 应当为一个字典或者函数。作用和 ``input_mapping`` 类似,区别在于其用于转换输出: @@ -367,7 +367,7 @@ class Trainer(TrainerEventTrigger): .. note:: ``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; - ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: + ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,您需要保证这几个参数得到正确的传入: 必须的参数:``metrics`` 与 ``evaluate_dataloaders``; @@ -898,7 +898,7 @@ class Trainer(TrainerEventTrigger): 这段代码意味着 ``fn1`` 和 ``fn2`` 会被加入到 ``trainer1``,``fn3`` 会被加入到 ``trainer2``; - 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; + 注意如果您使用该函数修饰器来为您的训练添加 callback,请务必保证您加入 callback 函数的代码在实例化 `Trainer` 之前; 补充性的解释见 :meth:`~fastNLP.core.controllers.Trainer.add_callback_fn`; diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 862c7ee7..8ef521eb 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -584,7 +584,7 @@ class DataSet: 将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 field 传给函数 ``func``,并写入到 ``new_field_name`` 中。 - :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; :param field_name: 传入 ``func`` 的 field 名称; :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; @@ -624,10 +624,9 @@ class DataSet: ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.core.dataset.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 - :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; - :param field_name: 传入 ``func`` 的 fiel` 名称; - :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 - 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param field_name: 传入 ``func`` 的 field 名称; + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 .. note:: @@ -751,8 +750,8 @@ class DataSet: 3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 - :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 .. note:: diff --git a/fastNLP/core/drivers/torch_driver/torch_fsdp.py b/fastNLP/core/drivers/torch_driver/torch_fsdp.py index 9359615a..0b1948e8 100644 --- a/fastNLP/core/drivers/torch_driver/torch_fsdp.py +++ b/fastNLP/core/drivers/torch_driver/torch_fsdp.py @@ -1,15 +1,17 @@ -from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12 +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12, _NEED_IMPORT_TORCH if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel + import os -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel from typing import Optional, Union, List, Dict, Mapping from pathlib import Path diff --git a/fastNLP/core/utils/cache_results.py b/fastNLP/core/utils/cache_results.py index d462c06c..dc114b48 100644 --- a/fastNLP/core/utils/cache_results.py +++ b/fastNLP/core/utils/cache_results.py @@ -1,3 +1,14 @@ +""" +:func:`cache_results` 函数是 **fastNLP** 中用于缓存数据的装饰器,通过该函数您可以省去调试代码过程中一些耗时过长程序 +带来的时间开销。比如在加载并处理较大的数据时,每次修改训练参数都需要从头开始执行处理数据的过程,那么 :func:`cache_results` +便可以跳过这部分漫长的时间。详细的使用方法和原理请参见下面的说明。 + +.. warning:: + + 如果您发现对代码进行修改之后程序执行的结果没有变化,很有可能是这个函数的原因;届时删除掉缓存数据即可。 + +""" + from datetime import datetime import hashlib import _pickle diff --git a/fastNLP/embeddings/torch/static_embedding.py b/fastNLP/embeddings/torch/static_embedding.py index 12e7294c..6980c851 100644 --- a/fastNLP/embeddings/torch/static_embedding.py +++ b/fastNLP/embeddings/torch/static_embedding.py @@ -86,7 +86,7 @@ class StaticEmbedding(TokenEmbedding): :param requires_grad: 是否需要梯度。 :param init_method: 如何初始化没有找到的值。可以使用 :mod:`torch.nn.init` 中的各种方法,传入的方法应该接受一个 tensor,并 inplace 地修改其值。 - :param lower: 是否将 ``vocab`` 中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 + :param lower: 是否将 ``vocab`` 中的词语小写后再和预训练的词表进行匹配。如果您的词表中包含大写的词语,或者就是需要单独 为大写的词语开辟一个 vector 表示,则将 ``lower`` 设置为 ``False``。 :param dropout: 以多大的概率对 embedding 的表示进行 Dropout。0.1 即随机将 10% 的值置为 0。 :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ```` 这个 token 得到足够的训练, diff --git a/fastNLP/io/__init__.py b/fastNLP/io/__init__.py index 3897cb0d..c4fa600e 100644 --- a/fastNLP/io/__init__.py +++ b/fastNLP/io/__init__.py @@ -1,13 +1,13 @@ r""" -用于IO的模块, 具体包括: +用于 **IO** 的模块, 具体包括: -1. 用于读入 embedding 的 :mod:`EmbedLoader ` 类, +1. 用于读入 embedding 的 :mod:`EmbedLoader ` 类 2. 用于读入不同格式数据的 :mod:`Loader ` 类 3. 用于处理读入数据的 :mod:`Pipe ` 类 -4. 用于保存和载入模型的类, 参考 :mod:`model_io模块 ` +4. 用于管理数据集的类 :mod:`DataBundle ` 类 这些类的使用方法如下: """ @@ -105,7 +105,7 @@ __all__ = [ "BQCorpusPipe", "RenamePipe", "GranularizePipe", - "MachingTruncatePipe", + "TruncateBertPipe", "CMRC2018BertPipe", diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index a53f00a5..03e58e56 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -1,7 +1,7 @@ -r""" -.. todo:: - doc """ +:class:`DataBundle` 是 **fastNLP** 提供的用于方便快捷地管理多个数据集的工具,并有诸多接口来进行批量的数据处理。 +""" + __all__ = [ 'DataBundle', ] @@ -15,25 +15,20 @@ from fastNLP.core import logger class DataBundle: r""" - 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个field对应的vocabulary。该对象一般由fastNLP中各种 - Loader的load函数生成,可以通过以下的方法获取里面的内容 - - Example:: + 经过处理的数据信息,包括一系列数据集(比如:分开的训练集、验证集和测试集)以及各个 field 对应的 vocabulary。该对象一般由 + **fastNLP** 中各种 :class:`~fastNLP.io.loader.Loader` 的 :meth:`load` 函数生成,可以通过以下的方法获取里面的内容:: data_bundle = YelpLoader().load({'train':'/path/to/train', 'dev': '/path/to/dev'}) train_vocabs = data_bundle.vocabs['train'] train_data = data_bundle.datasets['train'] dev_data = data_bundle.datasets['train'] + :param vocabs: 从名称(字符串)到 :class:`~fastNLP.core.Vocabulary` 类型的字典 + :param datasets: 从名称(字符串)到 :class:`~fastNLP.core.dataset.DataSet` 类型的字典。建议不要将相同的 ``DataSet`` 对象重复传入, + 否则可能会在使用 :class:`~fastNLP.io.pipe.Pipe` 处理数据的时候遇到问题,若多个数据集确需一致,请手动 ``deepcopy`` 后传入。 """ def __init__(self, vocabs=None, datasets=None): - r""" - - :param vocabs: 从名称(字符串)到 :class:`~fastNLP.Vocabulary` 类型的dict - :param datasets: 从名称(字符串)到 :class:`~fastNLP.DataSet` 类型的dict。建议不要将相同的DataSet对象重复传入,可能会在 - 使用Pipe处理数据的时候遇到问题,若多个数据集确需一致,请手动deepcopy后传入。 - """ self._vocabs = vocabs or {} self._datasets = datasets or {} @@ -47,21 +42,21 @@ class DataBundle: def set_vocab(self, vocab: Vocabulary, field_name: str): r""" - 向DataBunlde中增加vocab + 向 :class:`DataBunlde` 中增加 ``vocab`` - :param ~fastNLP.Vocabulary vocab: 词表 - :param str field_name: 这个vocab对应的field名称 + :param vocab: :class:`~fastNLP.core.Vocabulary` 类型的词表 + :param field_name: 这个 vocab 对应的 field 名称 :return: self """ - assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary supports." + assert isinstance(vocab, Vocabulary), "Only fastNLP.core.Vocabulary supports." self.vocabs[field_name] = vocab return self def set_dataset(self, dataset: DataSet, name: str): r""" - :param ~fastNLP.DataSet dataset: 传递给DataBundle的DataSet - :param str name: dataset的名称 + :param dataset: 传递给 :class:`DataBundle` 的 :class:`~fastNLP.core.dataset.DataSet` + :param name: ``dataset`` 的名称 :return: self """ assert isinstance(dataset, DataSet), "Only fastNLP.DataSet supports." @@ -70,10 +65,10 @@ class DataBundle: def get_dataset(self, name: str) -> DataSet: r""" - 获取名为name的dataset + 获取名为 ``name`` 的 dataset - :param str name: dataset的名称,一般为'train', 'dev', 'test' - :return: DataSet + :param name: dataset的名称,一般为 'train', 'dev', 'test' 。 + :return: """ if name in self.datasets.keys(): return self.datasets[name] @@ -85,33 +80,34 @@ class DataBundle: def delete_dataset(self, name: str): r""" - 删除名为name的DataSet + 删除名为 ``name`` 的 dataset - :param str name: + :param name: :return: self """ self.datasets.pop(name, None) return self - def get_vocab(self, field_name: str) -> Vocabulary: + def get_vocab(self, name: str) -> Vocabulary: r""" - 获取field名为field_name对应的vocab + 获取 field 名为 ``field_name`` 对应的词表 - :param str field_name: 名称 - :return: Vocabulary + :param field_name: 名称 + :return: :class:`~fastNLP.core.Vocabulary` """ - if field_name in self.vocabs.keys(): - return self.vocabs[field_name] + if name in self.vocabs.keys(): + return self.vocabs[name] else: - error_msg = f'DataBundle do NOT have Vocabulary named {field_name}. ' \ + error_msg = f'DataBundle do NOT have Vocabulary named {name}. ' \ f'It should be one of {self.vocabs.keys()}.' logger.error(error_msg) raise KeyError(error_msg) def delete_vocab(self, field_name: str): r""" - 删除vocab - :param str field_name: + 删除名为 ``field_name`` 的 vocab + + :param field_name: :return: self """ self.vocabs.pop(field_name, None) @@ -125,14 +121,14 @@ class DataBundle: def num_vocab(self): return len(self.vocabs) - def copy_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True): + def copy_field(self, field_name: str, new_field_name: str, ignore_miss_dataset: bool=True): r""" - 将DataBundle中所有的DataSet中名为field_name的Field复制一份并命名为叫new_field_name. + 将所有的 dataset 中名为 ``field_name`` 的 Field 复制一份并命名为 ``new_field_name``。 - :param str field_name: - :param str new_field_name: - :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; - 如果为False,则报错 + :param field_name: + :param new_field_name: + :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, + 如果为 ``False`` 则会报错。 :return: self """ for name, dataset in self.datasets.items(): @@ -142,15 +138,15 @@ class DataBundle: raise KeyError(f"{field_name} not found DataSet:{name}.") return self - def rename_field(self, field_name: str, new_field_name: str, ignore_miss_dataset=True, rename_vocab=True): + def rename_field(self, field_name: str, new_field_name: str, ignore_miss_dataset: bool=True, rename_vocab: bool=True): r""" - 将DataBundle中所有DataSet中名为field_name的field重命名为new_field_name. + 将所有的 dataset 中名为 ``field_name`` 的 Field 重命名为 ``new_field_name``。 - :param str field_name: - :param str new_field_name: - :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; - 如果为False,则报错 - :param bool rename_vocab: 如果该field同时也存在于vocabs中,会将该field的名称对应修改 + :param field_name: + :param new_field_name: + :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, + 如果为 ``False`` 则会报错。 + :param rename_vocab: 如果该 ``field_name`` 同时也存在于 vocabs 中,则也会进行重命名 :return: self """ for name, dataset in self.datasets.items(): @@ -164,14 +160,14 @@ class DataBundle: return self - def delete_field(self, field_name: str, ignore_miss_dataset=True, delete_vocab=True): + def delete_field(self, field_name: str, ignore_miss_dataset: bool=True, delete_vocab: bool=True): r""" - 将DataBundle中所有DataSet中名为field_name的field删除掉. + 将所有的 dataset 中名为 ``field_name`` 的 Field 删除。 - :param str field_name: - :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; - 如果为False,则报错 - :param bool delete_vocab: 如果该field也在vocabs中存在,将该值也一并删除 + :param field_name: + :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, + 如果为 ``False`` 则会报错。 + :param delete_vocab: 如果该 ``field_name`` 也在 vocabs 中存在,则也会删除。 :return: self """ for name, dataset in self.datasets.items(): @@ -186,44 +182,38 @@ class DataBundle: def iter_datasets(self) -> Union[str, DataSet]: r""" - 迭代data_bundle中的DataSet + 迭代 dataset Example:: for name, dataset in data_bundle.iter_datasets(): pass - :return: """ for name, dataset in self.datasets.items(): yield name, dataset def get_dataset_names(self) -> List[str]: r""" - 返回DataBundle中DataSet的名称 - - :return: + :return: 所有 dataset 的名称 """ return list(self.datasets.keys()) def get_vocab_names(self) -> List[str]: r""" - 返回DataBundle中Vocabulary的名称 - - :return: + :return: 所有词表的名称 """ return list(self.vocabs.keys()) def iter_vocabs(self): r""" - 迭代data_bundle中的DataSet + 迭代词表 - Example: + Example:: for field_name, vocab in data_bundle.iter_vocabs(): pass - :return: """ for field_name, vocab in self.vocabs.items(): yield field_name, vocab @@ -231,25 +221,24 @@ class DataBundle: def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, ignore_miss_dataset: bool = True, progress_desc: str = '', progress_bar: str = 'rich'): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 - - :param callable func: input是instance中名为 `field_name` 的field的内容。 - :param str field_name: 传入func的是哪个field。 - :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 - 盖之前的field。如果为None则不创建新的field。 - :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; - 如果为False,则报错 - :param num_proc: 使用进程的数量。 + 对 :class:`DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.core.dataset.DataSet.apply_field` 方法 + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param field_name: 传入 ``func`` 的 field 名称; + :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 + 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; + :param num_proc: 使用进程的数量。 + .. note:: - + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 - :param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 - + :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, + 如果为 ``False`` 则会报错。 + :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 + :return: self """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): @@ -263,32 +252,30 @@ class DataBundle: raise KeyError(f"{field_name} not found DataSet:{name}.") return self - def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, + def apply_field_more(self, func: Callable, field_name: str, modify_fields: str=True, num_proc: int = 0, ignore_miss_dataset=True, progress_bar: str = 'rich', progress_desc: str = ''): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 + 对 :class:`DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.core.DataSet.apply_field_more` 方法 .. note:: ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 - :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 - :param str field_name: 传入func的是哪个field。 - :param bool modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param field_name: 传入 ``func`` 的 field 名称; + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 - + .. note:: - + 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; - 如果为False,则报错 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 - :param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。 - - :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 - + :param ignore_miss_dataset: 如果为 ``True`` ,则当 ``field_name`` 在某个 dataset 内不存在时,直接忽略该 dataset, + 如果为 ``False`` 则会报错。 + :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称; + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 + :return: 一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 """ res = {} _progress_desc = progress_desc @@ -307,13 +294,11 @@ class DataBundle: def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, progress_desc: str = '', progress_bar: bool = True): r""" - 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 - - 对DataBundle中所有的dataset使用apply方法 + 对 :class:`~DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.core.DataSet.apply` 方法 - :param callable func: input是instance中名为 `field_name` 的field的内容。 - :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 - 盖之前的field。如果为None则不创建新的field。 + :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param new_field_name: 将 ``func`` 返回的内容放入到 ``new_field_name`` 这个 field中 ,如果名称与已有的 field 相同,则覆 + 盖之前的 field。如果为 ``None`` 则不创建新的 field。 :param num_proc: 使用进程的数量。 .. note:: @@ -321,9 +306,9 @@ class DataBundle: 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 - :param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称 - + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 + :param progress_desc: 如果不为 ``None``,则会显示当前正在处理的进度条的名称。 + :return: self """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): @@ -334,7 +319,7 @@ class DataBundle: progress_desc=progress_desc) return self - def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, + def apply_more(self, func: Callable, modify_fields: bool=True, num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): r""" 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 @@ -343,8 +328,8 @@ class DataBundle: ``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 - :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 - :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True + :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 .. note:: @@ -352,10 +337,10 @@ class DataBundle: 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。另外,使用多进程时, ``func`` 函数中的打印将不会输出。 - :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 - :param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称 + :param progress_desc: 当 progress_bar 不为 ``None`` 时,可以显示当前正在处理的进度条名称 + :param progress_bar: 显示进度条的方式,支持 ``["rich", "tqdm", None]``。 - :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 + :return: 一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 """ res = {} _progress_desc = progress_desc @@ -371,19 +356,19 @@ class DataBundle: """ 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 - :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组表示多层次的 key,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); - 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 - 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 - :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 - field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。如果 backend 为 None ,该值 - 无意义。 - :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 - :param backend: 可选['raw', 'numpy', 'torch', 'paddle', 'jittor', 'auto'],分别代表,输出为 list, numpy.ndarray, - torch.Tensor, paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值无意义 。 - :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 - batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch - 形式,输出将被直接作为结果输出。 + :param field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 + 如果该 field 在数据中没有找到,则报错;如果 :meth:`Dataset.__getitem__` 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 ``None``,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 ``None`` 。如果 ``backend`` 为 ``None``, + 该值无意义。 + :param dtype: 对于需要 pad 的 field ,该 field 数据的 ``dtype`` 。 + :param backend: 可选 ``['raw', 'numpy', 'torch', 'paddle', 'jittor', 'oneflow', 'auto']`` ,分别代表,输出为 :class:`list`, + :class:`numpy.ndarray`, :class:`torch.Tensor`, :class:`paddle.Tensor`, :class:`jittor.Var`, :class:`oneflow.Tensor` 类型。 + 若 ``pad_val`` 为 ``None`` ,该值无意义 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 ``pad_val``, ``dtype``, ``backend`` 等参数失效。``pad_fn`` 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。 :return: self """ for _, ds in self.iter_datasets(): @@ -393,14 +378,14 @@ class DataBundle: def set_ignore(self, *field_names) -> "DataBundle": """ - 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 - Example:: + ``DataSet`` 中想要对绑定的 collator 进行调整可以调用此函数。 ``collator`` 为 :class:`~fastNLP.core.collators.Collator` + 时该函数才有效。调用该函数可以设置忽略输出某些 field 的内容,被设置的 field 将在 batch 的输出中被忽略:: - collator.set_ignore('field1', 'field2') + databundle.set_ignore('field1', 'field2') - :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 - field 的 key 来表示,如果是 nested 的 dict,可以使用元组来表示,例如 {'a': {'b': 1}} 中的使用 ('a', 'b'); 如果 - __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :param field_names: field_name: 需要调整的 field 的名称。如果 :meth:`Dataset.__getitem__` 方法返回的是字典类型,则可以直接使用对应的 + field 的 key 来表示,如果是嵌套字典,可以使用元组表示多层次的 key,例如 ``{'a': {'b': 1}}`` 中可以使用 ``('a', 'b')``; + 如果 :meth:`Dataset.__getitem__` 返回的是 Sequence 类型,则可以使用 ``'_0'``, ``'_1'`` 表示序列中第 **0** 或 **1** 个元素。 :return: self """ for _, ds in self.iter_datasets(): diff --git a/fastNLP/io/embed_loader.py b/fastNLP/io/embed_loader.py index df82643b..650869e2 100644 --- a/fastNLP/io/embed_loader.py +++ b/fastNLP/io/embed_loader.py @@ -1,7 +1,3 @@ -r""" -.. todo:: - doc -""" __all__ = [ "EmbedLoader", "EmbeddingOption", @@ -9,6 +5,7 @@ __all__ = [ import logging import os +from typing import Callable import numpy as np @@ -33,30 +30,30 @@ class EmbeddingOption(Option): class EmbedLoader: r""" - 用于读取预训练的embedding, 读取结果可直接载入为模型参数。 + 用于读取预训练的 embedding, 读取结果可直接载入为模型参数。 """ def __init__(self): super(EmbedLoader, self).__init__() @staticmethod - def load_with_vocab(embed_filepath, vocab, dtype=np.float32, padding='', unknown='', normalize=True, - error='ignore', init_method=None): + def load_with_vocab(embed_filepath: str, vocab, dtype=np.float32, padding: str='', unknown: str='', normalize: bool=True, + error: str='ignore', init_method: Callable=None): r""" - 从embed_filepath这个预训练的词向量中抽取出vocab这个词表的词的embedding。EmbedLoader将自动判断embed_filepath是 - word2vec(第一行只有两个元素)还是glove格式的数据。 + 从 ``embed_filepath`` 这个预训练的词向量中抽取出 ``vocab`` 这个词表的词的 embedding。 :class:`EmbedLoader` 将自动判断 ``embed_filepath`` + 是 **word2vec** (第一行只有两个元素) 还是 **glove** 格式的数据。 - :param str embed_filepath: 预训练的embedding的路径。 - :param vocab: 词表 :class:`~fastNLP.Vocabulary` 类型,读取出现在vocab中的词的embedding。 - 没有出现在vocab中的词的embedding将通过找到的词的embedding的正态分布采样出来,以使得整个Embedding是同分布的。 - :param dtype: 读出的embedding的类型 - :param str padding: 词表中padding的token - :param str unknown: 词表中unknown的token - :param bool normalize: 是否将每个vector归一化到norm为1 - :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。 + :param embed_filepath: 预训练的 embedding 的路径。 + :param vocab: 词表 :class:`~fastNLP.core.Vocabulary` 类型,读取出现在 ``vocab`` 中的词的 embedding。 + 没有出现在 ``vocab`` 中的词的 embedding 将通过找到的词的 embedding 的 *正态分布* 采样出来,以使得整个 Embedding 是同分布的。 + :param dtype: 读出的 embedding 的类型 + :param padding: 词表中 *padding* 的 token + :param unknown: 词表中 *unknown* 的 token + :param normalize: 是否将每个 vector 归一化到 norm 为 1 + :param error: 可以为以下值之一: ``['ignore', 'strict']`` 。如果为 ``ignore`` ,错误将自动跳过;如果是 ``strict`` ,错误将抛出。 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 - :param callable init_method: 传入numpy.ndarray, 返回numpy.ndarray, 用以初始化embedding - :return numpy.ndarray: shape为 [len(vocab), dimension], dimension由pretrain的embedding决定。 + :param init_method: 用于初始化 embedding 的函数。该函数接受一个 :class:`numpy.ndarray` 类型,返回 :class:`numpy.ndarray`。 + :return: 返回类型为 :class:`numpy.ndarray`,形状为 ``[len(vocab), dimension]``,其中 *dimension*由预训练的 embedding 决定。 """ assert isinstance(vocab, Vocabulary), "Only fastNLP.Vocabulary is supported." if not os.path.exists(embed_filepath): @@ -112,20 +109,21 @@ class EmbedLoader: return matrix @staticmethod - def load_without_vocab(embed_filepath, dtype=np.float32, padding='', unknown='', normalize=True, - error='ignore'): + def load_without_vocab(embed_filepath: str, dtype=np.float32, padding: str='', unknown: str='', normalize: bool=True, + error: str='ignore'): r""" - 从embed_filepath中读取预训练的word vector。根据预训练的词表读取embedding并生成一个对应的Vocabulary。 + 从 ``embed_filepath`` 中读取预训练的 word vector。根据预训练的词表读取 embedding 并生成一个对应的 :class:`~fastNLP.core.Vocabulary` 。 - :param str embed_filepath: 预训练的embedding的路径。 - :param dtype: 读出的embedding的类型 - :param str padding: 词表中的padding的token. 并以此用做vocab的padding。 - :param str unknown: 词表中的unknown的token. 并以此用做vocab的unknown。 - :param bool normalize: 是否将每个vector归一化到norm为1 - :param str error: `ignore` , `strict` ; 如果 `ignore` ,错误将自动跳过; 如果 `strict` , 错误将抛出。这里主要可能出错的地 - 方在于词表有空行或者词表出现了维度不一致。 - :return (numpy.ndarray, Vocabulary): Embedding的shape是[词表大小+x, 词表维度], "词表大小+x"是由于最终的大小还取决与 - 是否使用padding, 以及unknown有没有在词表中找到对应的词。 Vocabulary中的词的顺序与Embedding的顺序是一一对应的。 + :param embed_filepath: 预训练的 embedding 的路径。 + :param dtype: 读出的 embedding 的类型 + :param padding: 词表中的 *padding* 的 token。 + :param unknown: 词表中的 *unknown* 的 token。 + :param normalize: 是否将每个 vector 归一化到 norm 为 1 + :param error: 可以为以下值之一: ``['ignore', 'strict']`` 。如果为 ``ignore`` ,错误将自动跳过;如果是 ``strict`` ,错误将抛出。 + 这里主要可能出错的地方在于词表有空行或者词表出现了维度不一致。 + :return: 返回两个结果,第一个返回值为 :class:`numpy.ndarray`,大小为 ``[词表大小+x, 词表维度]`` 。 ``词表大小+x`` 是由于最终的大小还取决于 + 是否使用 ``padding``,以及 ``unknown`` 有没有在词表中找到对应的词。 第二个返回值为 :class:`~fastNLP.core.Vocabulary` 类型的词表,其中 + 词的顺序与 embedding 的顺序是一一对应的。 """ vocab = Vocabulary(padding=padding, unknown=unknown) diff --git a/fastNLP/io/file_reader.py b/fastNLP/io/file_reader.py index 2df61e17..37bdbce9 100644 --- a/fastNLP/io/file_reader.py +++ b/fastNLP/io/file_reader.py @@ -1,7 +1,3 @@ -r"""undocumented -此模块用于给其它模块提供读取文件的函数,没有为用户提供 API -""" - __all__ = [] import json diff --git a/fastNLP/io/file_utils.py b/fastNLP/io/file_utils.py index b982af54..3f37ff19 100644 --- a/fastNLP/io/file_utils.py +++ b/fastNLP/io/file_utils.py @@ -1,8 +1,3 @@ -r""" -.. todo:: - doc -""" - __all__ = [ "cached_path", "get_filepath", @@ -160,21 +155,16 @@ FASTNLP_EXTEND_EMBEDDING_URL = {'elmo': 'fastnlp_elmo_url.txt', def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: r""" - 给定一个url,尝试通过url中的解析出来的文件名字filename到{cache_dir}/{name}/{filename}下寻找这个文件, - - 1. 如果cache_dir=None, 则cache_dir=~/.fastNLP/; 否则cache_dir=cache_dir - 2. 如果name=None, 则没有中间的{name}这一层结构;否者中间结构就为{name} - - 如果有该文件,就直接返回路径 + 给定一个 url,尝试通过 url 中解析出来的文件名字 filename 到 ``{cache_dir}/{name}/{filename}`` 下寻找这个文件或文件夹: - 如果没有该文件,则尝试用传入的url下载 + 1. 如果 ``cache_dir`` 为 ``None``,则默认为 ``~/.fastNLP/``; + 2. 如果 ``name=None`` ,则没有中间的 {name} 这一层结构。 - 或者文件名(可以是具体的文件名,也可以是文件夹),先在cache_dir下寻找该文件是否存在,如果不存在则去下载, 并 - 将文件放入到cache_dir中. + 如果有该文件,就直接返回路径;如果没有该文件,则尝试用传入的 url 下载,并将文件放入 ``cache_dir`` 中。 - :param str url_or_filename: 文件的下载url或者文件名称。 - :param str cache_dir: 文件的缓存文件夹。如果为None,将使用"~/.fastNLP"这个默认路径 - :param str name: 中间一层的名称。如embedding, dataset + :param url_or_filename: 文件的下载 url 或者文件名称。 + :param cache_dir: 文件的缓存文件夹。如果为 ``None``,将使 ``"~/.fastNLP`` 这个默认路径 + :param name: 中间一层的名称。如 embedding, dataset :return: """ if cache_dir is None: @@ -205,17 +195,10 @@ def cached_path(url_or_filename: str, cache_dir: str = None, name=None) -> Path: def get_filepath(filepath): r""" - 如果filepath为文件夹, - - 如果内含多个文件, 返回filepath - - 如果只有一个文件, 返回filepath + filename - - 如果filepath为文件 - - 返回filepath + 如果 ``filepath`` 为文件夹且包含多个文件,则直接返回 ``filepath``;反之若只包含一个文件,返回该文件的路径 ``{filepath}/{filename}``。 + 如果 ``filepath`` 为一个文件,也会直接返回 ``filepath``。 - :param str filepath: 路径 + :param filepath: 路径 :return: """ if os.path.isdir(filepath): @@ -232,9 +215,7 @@ def get_filepath(filepath): def get_cache_path(): r""" - 获取fastNLP默认cache的存放路径, 如果将FASTNLP_CACHE_PATH设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 - - :return str: 存放路径 + 获取 **fastNLP** 默认 cache 的存放路径, 如果将 ``FASTNLP_CACHE_PATH`` 设置在了环境变量中,将使用环境变量的值,使得不用每个用户都去下载。 """ if 'FASTNLP_CACHE_DIR' in os.environ: fastnlp_cache_dir = os.environ.get('FASTNLP_CACHE_DIR') @@ -342,10 +323,10 @@ def _get_dataset_url(name, dataset_dir: dict = None): def split_filename_suffix(filepath): r""" - 给定filepath 返回对应的name和suffix. 如果后缀是多个点,仅支持.tar.gz类型 + 给定 ``filepath`` 返回对应的文件名和后缀。如果后缀是多个点,仅支持 **.tar.gz** 类型 :param filepath: 文件路径 - :return: filename, suffix + :return: 文件名 ``filename`` 和后缀 ``suffix`` """ filename = os.path.basename(filepath) if filename.endswith('.tar.gz'): @@ -355,10 +336,10 @@ def split_filename_suffix(filepath): def get_from_cache(url: str, cache_dir: Path = None) -> Path: r""" - 尝试在cache_dir中寻找url定义的资源; 如果没有找到; 则从url下载并将结果放在cache_dir下,缓存的名称由url的结果推断而来。会将下载的 - 文件解压,将解压后的文件全部放在cache_dir文件夹中。 + 尝试在 ``cache_dir``中寻找 ``url`` 定义的资源,如果没有找到,则从 ``url`` 下载并将结果放在 ``cache_dir`` 下, + 缓存的名称由 ``url`` 的结果推断而来。会将下载的文件解压,将解压后的文件全部放在 ``cache_dir`` 文件夹中。 - 如果从url中下载的资源解压后有多个文件,则返回目录的路径; 如果只有一个资源文件,则返回具体的路径。 + 如果从 ``url``中下载的资源解压后有多个文件,则返回目录的路径;如果只有一个资源文件,则返回具体的路径。 :param url: 资源的 url :param cache_dir: cache 目录 diff --git a/fastNLP/io/loader/__init__.py b/fastNLP/io/loader/__init__.py index ebd56330..04752a96 100644 --- a/fastNLP/io/loader/__init__.py +++ b/fastNLP/io/loader/__init__.py @@ -1,8 +1,8 @@ r""" -Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的Loader都支持以下的 -三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该Loader支持的数据格式, -读取后 :class:`~fastNLP.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.DataSet` ; -``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: +**Loader** 用于读取数据,并将内容读取到 :class:`~fastNLP.core.DataSet` 或者 :class:`~fastNLP.io.DataBundle` 中。所有的 ``Loader`` 都支持以下的 +三个方法: ``__init__`` , ``_load`` , ``loads`` . 其中 ``__init__(...)`` 用于申明读取参数,以及说明该 ``Loader`` 支持的数据格式, +读取后 :class:`~fastNLP.core.DataSet` 中的 `field` ; ``_load(path)`` 方法传入文件路径读取单个文件,并返回 :class:`~fastNLP.core.DataSet` ; +``load(paths)`` 用于读取文件夹下的文件,并返回 :class:`~fastNLP.io.DataBundle` 类型的对象 , load()方法支持以下几种类型的参数: 0.传入None 将尝试自动下载数据集并缓存。但不是所有的数据都可以直接下载。 @@ -38,10 +38,7 @@ Loader用于读取数据,并将内容读取到 :class:`~fastNLP.DataSet` 或 在 Loader().load(paths) 返回的 `data_bundle` 中可以用 ``data_bundle.get_dataset('train')`` , ``data_bundle.get_dataset('dev')`` , ``data_bundle.get_dataset('test')`` 来获取对应的 `dataset` -fastNLP 目前提供了如下的 Loader - - - +**fastNLP** 目前提供了如下的 Loader: """ __all__ = [ diff --git a/fastNLP/io/loader/classification.py b/fastNLP/io/loader/classification.py index 2ae0b163..99c42041 100644 --- a/fastNLP/io/loader/classification.py +++ b/fastNLP/io/loader/classification.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "CLSBaseLoader", "YelpFullLoader", @@ -30,22 +28,18 @@ from .loader import Loader from fastNLP.core.dataset import Instance, DataSet from fastNLP.core.log import logger - -# from ...core._logger import log - - class CLSBaseLoader(Loader): r""" - 文本分类Loader的一个基类 + 文本分类 Loader 的一个基类 - 原始数据中内容应该为, 每一行为一个sample,第一个逗号之前为target,第一个逗号之后为文本内容。 + 原始数据中内容应该为:每一行为一个 sample ,第一个逗号之前为 **target** ,第一个逗号之后为 **文本内容** 。 Example:: "1","I got 'new' tires from the..." "1","Don't waste your time..." - 读取的DataSet将具备以下的数据结构 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_words", "target" @@ -124,32 +118,36 @@ def _split_dev(dataset_name, data_dir, dev_ratio=0.0, re_download=False, suffix= class AGsNewsLoader(CLSBaseLoader): - def download(self): - r""" - 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + """ + **AG's News** 数据集的 **Loader**,如果您使用了这个数据集,请引用以下的文章 Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015) + """ + def download(self): + r""" + 自动下载数据集。 - :return: str, 数据集的目录地址 + :return: 数据集的目录地址 """ return self._get_dataset_path(dataset_name='ag-news') class DBPediaLoader(CLSBaseLoader): - def download(self, dev_ratio: float = 0.0, re_download: bool = False): - r""" - 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + """ + **DBpedia** 数据集的 **Loader**。如果您使用了这个数据集,请引用以下的文章 Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015) + """ + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + r""" + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv - - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = 'dbpedia' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -163,15 +161,18 @@ class DBPediaLoader(CLSBaseLoader): class IMDBLoader(CLSBaseLoader): r""" - 原始数据中内容应该为, 每一行为一个sample,制表符之前为target,制表符之后为文本内容。 + **IMDb** 数据集的 **Loader** ,如果您使用了这个数据集,请引用以下的文章 + + http://www.aclweb.org/anthology/P11-1015。 + + 原始数据中内容应该为:每一行为一个 sample ,制表符之前为 **target** ,制表符之后为 **文本内容** 。 Example:: neg Alan Rickman & Emma... neg I have seen this... - IMDBLoader读取后的数据将具有以下两列内容: raw_words: str, 需要分类的文本; target: str, 文本的标签 - 读取的DataSet具备以下的结构: + **IMDBLoader** 读取后的 :class:`~fastNLP.core.DataSet` 将具有以下两列内容: ``raw_words`` 代表需要分类的文本,``target`` 代表文本的标签: .. csv-table:: :header: "raw_words", "target" @@ -187,15 +188,13 @@ class IMDBLoader(CLSBaseLoader): def download(self, dev_ratio: float = 0.0, re_download=False): r""" - 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + 自动下载数据集。 - http://www.aclweb.org/anthology/P11-1015 + 根据 ``dev_ratio`` 的值随机将 train 中的数据取出一部分作为 dev 数据。 - 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后不从train中切分dev - - :param float dev_ratio: 如果路径中没有dev.txt。从train划分多少作为dev的数据. 如果为0,则不划分dev - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有 ``dev.txt`` ,从 train 划分多少作为 dev 的数据。 如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = 'aclImdb' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -209,23 +208,25 @@ class IMDBLoader(CLSBaseLoader): class SSTLoader(Loader): r""" - 原始数据中内容应该为: + **SST** 数据集的 **Loader**,如果您使用了这个数据集,请引用以下的文章 - Example:: + https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf + + 原始数据中内容应该为:: (2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)... (3 (3 (2 If) (3 (2 you) (3 (2 sometimes)... - 读取之后的DataSet具有以下的结构 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: - .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + .. csv-table:: 下面是使用 SSTLoader 读取的 DataSet 所具备的 field :header: "raw_words" "(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." "(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." "..." - raw_words列是str。 + ``raw_words`` 列是 :class:`str` 。 """ @@ -249,30 +250,29 @@ class SSTLoader(Loader): def download(self): r""" - 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 - - https://nlp.stanford.edu/~socherr/EMNLP2013_RNTN.pdf + 自动下载数据集。 - :return: str, 数据集的目录地址 + :return: 数据集的目录地址 """ output_dir = self._get_dataset_path(dataset_name='sst') return output_dir class YelpFullLoader(CLSBaseLoader): - def download(self, dev_ratio: float = 0.0, re_download: bool = False): - r""" - 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + """ + **Yelp Review Full** 数据集的 **Loader**,如果您使用了这个数据集,请引用以下的文章 Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015) + """ + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + r""" + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv - - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = 'yelp-review-full' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -285,19 +285,20 @@ class YelpFullLoader(CLSBaseLoader): class YelpPolarityLoader(CLSBaseLoader): - def download(self, dev_ratio: float = 0.0, re_download: bool = False): - r""" - 自动下载数据集,如果你使用了这个数据集,请引用以下的文章 + """ + **Yelp Review Polarity** 数据集的 **Loader**,如果您使用了这个数据集,请引用以下的文章 Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015) + """ + def download(self, dev_ratio: float = 0.0, re_download: bool = False): + r""" + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv - - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = 'yelp-review-polarity' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -311,7 +312,12 @@ class YelpPolarityLoader(CLSBaseLoader): class SST2Loader(Loader): r""" - 原始数据中内容为:第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是句子,第一个制表符之后认为是label + **SST-2** 数据集的 **Loader**,如果您使用了该数据集,请引用以下的文章 + + https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf + + 原始数据中内容应该为:第一行为标题(具体内容会被忽略),之后每一行为一个 sample ,第一个制表符之前是 **句子** , + 第一个制表符之后认为是 **label** 。 Example:: @@ -319,7 +325,7 @@ class SST2Loader(Loader): it 's a charming and often affecting journey . 1 unflinchingly bleak and desperate 0 - 读取之后DataSet将如下所示 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_words", "target" @@ -328,7 +334,7 @@ class SST2Loader(Loader): "unflinchingly bleak and desperate", "0" "..." - test的DataSet没有target列。 + 测试集的 :class:`~fastNLP.core.DataSet` 没有 ``target`` 列。 """ def __init__(self): @@ -366,8 +372,8 @@ class SST2Loader(Loader): def download(self): r""" - 自动下载数据集,如果你使用了该数据集,请引用以下的文章 - https://nlp.stanford.edu/pubs/SocherBauerManningNg_ACL2013.pdf + 自动下载数据集。 + :return: """ output_dir = self._get_dataset_path(dataset_name='sst-2') @@ -376,8 +382,11 @@ class SST2Loader(Loader): class ChnSentiCorpLoader(Loader): r""" - 支持读取的数据的格式为,第一行为标题(具体内容会被忽略),之后一行为一个sample,第一个制表符之前被认为是label,第 - 一个制表符之后认为是句子 + **ChnSentiCorp** 数据集的 **Loader**,该数据取自 https://github.com/pengming617/bert_classification/tree/master/data,在 + https://arxiv.org/pdf/1904.09223.pdf 与 https://arxiv.org/pdf/1906.08101.pdf 有使用。 + + 支持读取的数据的格式为:第一行为标题(具体内容会被忽略),之后每一行为一个 sample,第一个制表符之前被认为是 **label** ,第 + 一个制表符之后认为是 **句子** 。 Example:: @@ -385,7 +394,7 @@ class ChnSentiCorpLoader(Loader): 1 基金痛所有投资项目一样,必须先要有所了解... 1 系统很好装,LED屏是不错,就是16比9的比例... - 读取后的DataSet具有以下的field + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_chars", "target" @@ -421,10 +430,9 @@ class ChnSentiCorpLoader(Loader): def download(self) -> str: r""" - 自动下载数据,该数据取自https://github.com/pengming617/bert_classification/tree/master/data,在 - https://arxiv.org/pdf/1904.09223.pdf与https://arxiv.org/pdf/1906.08101.pdf有使用 + 自动下载数据。 - :return: + :return: 数据集的目录地址 """ output_dir = self._get_dataset_path('chn-senti-corp') return output_dir @@ -432,14 +440,17 @@ class ChnSentiCorpLoader(Loader): class THUCNewsLoader(Loader): r""" - 数据集简介:document-level分类任务,新闻10分类 - 原始数据内容为:每行一个sample,第一个 "\\t" 之前为target,第一个 "\\t" 之后为raw_words + **THUCNews** 数据集的 **Loader**,该数据取自 + http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews + + 数据用于 document-level 分类任务,新闻 10 分类。 + 原始数据内容为:每行一个 sample,第一个 ``"\t"`` 之前为 **target** ,第一个 ``"\t"`` 之后为 **raw_words** 。 Example:: 体育 调查-您如何评价热火客场胜绿军总分3-1夺赛点?... - 读取后的Dataset将具有以下数据结构: + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_words", "target" @@ -466,11 +477,9 @@ class THUCNewsLoader(Loader): def download(self) -> str: r""" - 自动下载数据,该数据取自 - - http://thuctc.thunlp.org/#%E4%B8%AD%E6%96%87%E6%96%87%E6%9C%AC%E5%88%86%E7%B1%BB%E6%95%B0%E6%8D%AE%E9%9B%86THUCNews + 自动下载数据。 - :return: + :return: 数据集目录地址 """ output_dir = self._get_dataset_path('thuc-news') return output_dir @@ -478,8 +487,8 @@ class THUCNewsLoader(Loader): class WeiboSenti100kLoader(Loader): r""" - 别名: - 数据集简介:微博sentiment classification,二分类 + **WeiboSenti100k** 数据集的 **Loader**,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/, + 在 https://arxiv.org/abs/1906.08101 有使用。微博 sentiment classification,二分类。 Example:: @@ -487,7 +496,7 @@ class WeiboSenti100kLoader(Loader): 1 多谢小莲,好运满满[爱你] 1 能在他乡遇老友真不赖,哈哈,珠儿,我也要用... - 读取后的Dataset将具有以下数据结构: + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_chars", "target" @@ -515,28 +524,29 @@ class WeiboSenti100kLoader(Loader): def download(self) -> str: r""" - 自动下载数据,该数据取自 https://github.com/SophonPlus/ChineseNlpCorpus/ - 在 https://arxiv.org/abs/1906.08101 有使用 - :return: + 自动下载数据。 + + :return: 数据集目录地址 """ output_dir = self._get_dataset_path('weibo-senti-100k') return output_dir class MRLoader(CLSBaseLoader): + """ + **MR** 数据集的 **Loader** + """ def __init__(self): super(MRLoader, self).__init__() def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: r""" - 自动下载数据集 + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv - - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = r'mr' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -549,19 +559,20 @@ class MRLoader(CLSBaseLoader): class R8Loader(CLSBaseLoader): + """ + **R8** 数据集的 **Loader** + """ def __init__(self): super(R8Loader, self).__init__() def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: r""" - 自动下载数据集 - - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = r'R8' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -574,19 +585,20 @@ class R8Loader(CLSBaseLoader): class R52Loader(CLSBaseLoader): + """ + **R52** 数据集的 **Loader** + """ def __init__(self): super(R52Loader, self).__init__() def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: r""" - 自动下载数据集 + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv - - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = r'R52' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -599,19 +611,20 @@ class R52Loader(CLSBaseLoader): class NG20Loader(CLSBaseLoader): + """ + **NG20** 数据集的 **Loader** + """ def __init__(self): super(NG20Loader, self).__init__() def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: r""" - 自动下载数据集 - - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = r'20ng' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -624,19 +637,20 @@ class NG20Loader(CLSBaseLoader): class OhsumedLoader(CLSBaseLoader): + """ + **Ohsumed** 数据集的 **Loader** + """ def __init__(self): super(OhsumedLoader, self).__init__() def download(self, dev_ratio: float = 0.0, re_download: bool = False) -> str: r""" - 自动下载数据集 - - 如果dev_ratio不等于0,则根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。 - 下载完成后在output_dir中有train.csv, test.csv, dev.csv三个文件。否则只有train.csv和test.csv + 自动下载数据集。下载完成后在 ``output_dir`` 中有 ``train.csv`` , ``test.csv`` , ``dev.csv`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.csv`` 和 ``test.csv`` 。 - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ dataset_name = r'ohsumed' data_dir = self._get_dataset_path(dataset_name=dataset_name) diff --git a/fastNLP/io/loader/conll.py b/fastNLP/io/loader/conll.py index 0b597398..ef209fb4 100644 --- a/fastNLP/io/loader/conll.py +++ b/fastNLP/io/loader/conll.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "ConllLoader", "Conll2003Loader", @@ -17,6 +15,7 @@ import os import random import shutil import time +from typing import List from .loader import Loader from ..file_reader import _read_conll @@ -26,9 +25,7 @@ from fastNLP.core.dataset import DataSet, Instance class ConllLoader(Loader): r""" - ConllLoader支持读取的数据格式: 以空行隔开两个sample,除了分割行,每一行用空格或者制表符隔开不同的元素。如下例所示: - - Example:: + :class:`ConllLoader` 支持读取的数据格式:以空行隔开两个 sample,除了分割行之外的每一行用空格或者制表符隔开不同的元素。如下例所示:: # 文件中的内容 Nadim NNP B-NP B-PER @@ -48,19 +45,16 @@ class ConllLoader(Loader): # 如果用以下的参数读取,返回的DataSet将包含raw_words, pos和ner三个field dataset = ConllLoader(headers=['raw_words', 'pos', 'ner'], indexes=[0, 1, 3])._load('/path/to/train.conll') - ConllLoader返回的DataSet的field由传入的headers确定。 + :class:`ConllLoader` 返回的 :class:`~fastNLP.core.DataSet` 的 `field` 由传入的 ``headers`` 确定。 + :param headers: 每一列数据的名称, ``header`` 与 ``indexes`` 一一对应 + :param sep: 指定分隔符,默认为制表符 + :param indexes: 需要保留的数据列下标,从 **0** 开始。若为 ``None`` ,则所有列都保留。 + :param dropna: 是否忽略非法数据,若为 ``False`` ,则遇到非法数据时抛出 :class:`ValueError` 。 + :param drophashtag: 是否忽略以 ``#`` 开头的句子。 """ - def __init__(self, headers, sep=None, indexes=None, dropna=True, drophash=True): - r""" - - :param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 - :param str sep: 指定分隔符,默认为制表符 - :param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` - :param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` - :param bool drophashtag: 是否忽略以 ``#`` 开头的句子。 - """ + def __init__(self, headers: List[str], sep: str=None, indexes: List[int]=None, dropna: bool=True, drophash: bool=True): super(ConllLoader, self).__init__() if not isinstance(headers, (list, tuple)): raise TypeError( @@ -93,8 +87,9 @@ class ConllLoader(Loader): class Conll2003Loader(ConllLoader): r""" - 用于读取conll2003任务的数据。数据的内容应该类似与以下的内容, 第一列为raw_words, 第二列为pos, 第三列为chunking,第四列为ner。 - 数据中以"-DOCSTART-"开头的行将被忽略,因为该符号在conll 2003中被用为文档分割符。 + 用于读取 **conll2003** 任务的数据。数据的内容应该类似于以下的内容:第一列为 **raw_words** ,第二列为 **pos** , + 第三列为 **chunking** ,第四列为 **ner** 。 + 数据中以 ``"-DOCSTART-"`` 开头的行将被忽略,因为该符号在 **conll2003** 中被用为文档分割符。 Example:: @@ -108,9 +103,9 @@ class Conll2003Loader(ConllLoader): 1996-12-06 CD I-NP O ... - 返回的DataSet的内容为 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: - .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构。 + .. csv-table:: 下面是 Conll2003Loader 加载后数据具备的结构。 :header: "raw_words", "pos", "chunk", "ner" "[Nadim, Ladki]", "[NNP, NNP]", "[B-NP, I-NP]", "[B-PER, I-PER]" @@ -152,10 +147,9 @@ class Conll2003Loader(ConllLoader): class Conll2003NERLoader(ConllLoader): r""" - 用于读取conll2003任务的NER数据。每一行有4列内容,空行意味着隔开两个句子 + 用于读取 **conll2003** 任务的 NER 数据。每一行有 4 列内容,空行意味着隔开两个句子。 - 支持读取的内容如下 - Example:: + 支持读取的内容如下:: Nadim NNP B-NP B-PER Ladki NNP I-NP I-PER @@ -167,9 +161,9 @@ class Conll2003NERLoader(ConllLoader): 1996-12-06 CD I-NP O ... - 返回的DataSet的内容为 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: - .. csv-table:: 下面是Conll2003Loader加载后数据具备的结构, target是BIO2编码 + .. csv-table:: 下面是 Conll2003Loader 加载后数据具备的结构, target 是 BIO2 编码 :header: "raw_words", "target" "[Nadim, Ladki]", "[B-PER, I-PER]" @@ -213,18 +207,16 @@ class Conll2003NERLoader(ConllLoader): class OntoNotesNERLoader(ConllLoader): r""" - 用以读取OntoNotes的NER数据,同时也是Conll2012的NER任务数据。将OntoNote数据处理为conll格式的过程可以参考 - https://github.com/yhcc/OntoNotes-5.0-NER。OntoNoteNERLoader将取第4列和第11列的内容。 - - 读取的数据格式为: + 用以读取 **OntoNotes** 的 NER 数据,同时也是 **Conll2012** 的 NER 任务数据。将 **OntoNote** 数据处理为 conll 格式的过程可以参考 + https://github.com/yhcc/OntoNotes-5.0-NER。:class:`OntoNotesNERLoader` 将取第 **4** 列和第 **11** 列的内容。 - Example:: + 读取的数据格式为:: bc/msnbc/00/msnbc_0000 0 0 Hi UH (TOP(FRAG(INTJ*) - - - Dan_Abrams * - bc/msnbc/00/msnbc_0000 0 1 everyone NN (NP*) - - - Dan_Abrams * - ... - 返回的DataSet的内容为 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_words", "target" @@ -291,7 +283,8 @@ class OntoNotesNERLoader(ConllLoader): class CTBLoader(Loader): r""" - 支持加载的数据应该具备以下格式, 其中第二列为词语,第四列为pos tag,第七列为依赖树的head,第八列为依赖树的label + **CTB** 数据集的 **Loader**。支持加载的数据应该具备以下格式, 其中第二列为 **词语** ,第四列为 **pos tag** ,第七列为 **依赖树的 head** , + 第八列为 **依赖树的 label** 。 Example:: @@ -306,7 +299,7 @@ class CTBLoader(Loader): 3 12月 _ NT NT _ 7 dep _ _ ... - 读取之后DataSet具备的格式为 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_words", "pos", "dep_head", "dep_label" @@ -335,30 +328,30 @@ class CTBLoader(Loader): 由于版权限制,不能提供自动下载功能。可参考 https://catalog.ldc.upenn.edu/LDC2013T21 - - :return: """ raise RuntimeError("CTB cannot be downloaded automatically.") class CNNERLoader(Loader): - def _load(self, path: str): - r""" - 支持加载形如以下格式的内容,一行两列,以空格隔开两个sample + r""" + 支持加载形如以下格式的内容,一行两列,以空格隔开两个 sample - Example:: + Example:: - 我 O - 们 O - 变 O - 而 O - 以 O - 书 O - 会 O - ... + 我 O + 们 O + 变 O + 而 O + 以 O + 书 O + 会 O + ... - :param str path: 文件路径 - :return: DataSet,包含raw_words列和target列 + """ + def _load(self, path: str): + """ + :param path: 文件路径 + :return: :class:`~fastNLP.core.DataSet` ,包含 ``raw_words`` 列和 ``target`` 列 """ ds = DataSet() with open(path, 'r', encoding='utf-8') as f: @@ -382,9 +375,11 @@ class CNNERLoader(Loader): class MsraNERLoader(CNNERLoader): r""" - 读取MSRA-NER数据,数据中的格式应该类似与下列的内容 - - Example:: + 读取 **MSRA-NER** 数据,如果您要使用该数据,请引用以下的文章: + + Gina-Anne Levow, 2006, The Third International Chinese Language Processing Bakeoff: Word Segmentation and Named Entity Recognition. + + 数据中的格式应该类似于下列的内容:: 把 O 欧 B-LOC @@ -404,7 +399,7 @@ class MsraNERLoader(CNNERLoader): ... - 读取后的DataSet包含以下的field + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_chars", "target" @@ -420,15 +415,14 @@ class MsraNERLoader(CNNERLoader): def download(self, dev_ratio: float = 0.1, re_download: bool = False) -> str: r""" - 自动下载MSAR-NER的数据,如果你使用该数据,请引用 Gina-Anne Levow, 2006, The Third International Chinese Language - Processing Bakeoff: Word Segmentation and Named Entity Recognition. + 自动下载 **MSAR-NER** 的数据。 - 根据dev_ratio的值随机将train中的数据取出一部分作为dev数据。下载完成后在output_dir中有train.conll, test.conll, - dev.conll三个文件。 + 下载完成后在 ``output_dir`` 中有 ``train.conll`` , ``test.conll`` , ``dev.conll`` 三个文件。 + 如果 ``dev_ratio`` 为 0,则只有 ``train.conll`` 和 ``test.conll`` 。 - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str, 数据集的目录地址 + :param dev_ratio: 如果路径中没有验证集 ,从 train 划分多少作为 dev 的数据。如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 :return: """ dataset_name = 'msra-ner' @@ -470,9 +464,11 @@ class MsraNERLoader(CNNERLoader): class WeiboNERLoader(CNNERLoader): r""" - 读取WeiboNER数据,数据中的格式应该类似与下列的内容 - - Example:: + 读取 **WeiboNER** 数据,如果您要使用该数据,请引用以下的文章: + + Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for Chinese Social Media with Jointly Trained Embeddings. + + 数据中的格式应该类似与下列的内容:: 老 B-PER.NOM 百 I-PER.NOM @@ -482,7 +478,7 @@ class WeiboNERLoader(CNNERLoader): ... - 读取后的DataSet包含以下的field + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: @@ -498,10 +494,9 @@ class WeiboNERLoader(CNNERLoader): def download(self) -> str: r""" - 自动下载Weibo-NER的数据,如果你使用了该数据,请引用 Nanyun Peng and Mark Dredze, 2015, Named Entity Recognition for - Chinese Social Media with Jointly Trained Embeddings. + 自动下载 **Weibo-NER** 的数据。 - :return: str + :return: 数据集目录地址 """ dataset_name = 'weibo-ner' data_dir = self._get_dataset_path(dataset_name=dataset_name) @@ -511,9 +506,7 @@ class WeiboNERLoader(CNNERLoader): class PeopleDailyNERLoader(CNNERLoader): r""" - 支持加载的数据格式如下 - - Example:: + 加载 **People's Daily NER** 数据集的 **Loader** 。支持加载的数据格式如下:: 中 B-ORG 共 I-ORG @@ -524,9 +517,9 @@ class PeopleDailyNERLoader(CNNERLoader): 中 B-ORG ... - 读取后的DataSet包含以下的field + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: - .. csv-table:: target列是基于BIO的编码方式 + .. csv-table:: target 列是基于 BIO 的编码方式 :header: "raw_chars", "target" "['中', '共', '中', '央']", "['B-ORG', 'I-ORG', 'I-ORG', 'I-ORG']" @@ -538,6 +531,11 @@ class PeopleDailyNERLoader(CNNERLoader): super().__init__() def download(self) -> str: + """ + 自动下载数据集。 + + :return: 数据集目录地址 + """ dataset_name = 'peopledaily' data_dir = self._get_dataset_path(dataset_name=dataset_name) diff --git a/fastNLP/io/loader/csv.py b/fastNLP/io/loader/csv.py index debd5222..09a60694 100644 --- a/fastNLP/io/loader/csv.py +++ b/fastNLP/io/loader/csv.py @@ -1,9 +1,9 @@ -r"""undocumented""" - __all__ = [ "CSVLoader", ] +from typing import List + from .loader import Loader from ..file_reader import _read_csv from fastNLP.core.dataset import DataSet, Instance @@ -11,19 +11,15 @@ from fastNLP.core.dataset import DataSet, Instance class CSVLoader(Loader): r""" - 读取CSV格式的数据集, 返回 ``DataSet`` 。 + 读取CSV格式的数据集, 返回 :class:`~fastNLP.core.DataSet` 。 + :param headers: CSV文件的文件头,定义每一列的属性名称,即返回的 :class:`~fastNLP.core.DataSet` 中 ``field`` 的名称。 + 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` 。 + :param sep: CSV文件中列与列之间的分隔符。 + :param dropna: 是否忽略非法数据,若为 ``True`` 则忽略;若为 ``False`` 则在遇到非法数据时抛出 :class:`ValueError`。 """ - def __init__(self, headers=None, sep=",", dropna=False): - r""" - - :param List[str] headers: CSV文件的文件头.定义每一列的属性名称,即返回的DataSet中`field`的名称 - 若为 ``None`` ,则将读入文件的第一行视作 ``headers`` . Default: ``None`` - :param str sep: CSV文件中列与列之间的分隔符. Default: "," - :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . - Default: ``False`` - """ + def __init__(self, headers: List[str]=None, sep: str=",", dropna: bool=False): super().__init__() self.headers = headers self.sep = sep diff --git a/fastNLP/io/loader/cws.py b/fastNLP/io/loader/cws.py index d88d6a00..f7bdbcb5 100644 --- a/fastNLP/io/loader/cws.py +++ b/fastNLP/io/loader/cws.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "CWSLoader" ] @@ -16,15 +14,16 @@ from fastNLP.core.dataset import DataSet, Instance class CWSLoader(Loader): r""" - CWSLoader支持的数据格式为,一行一句话,不同词之间用空格隔开, 例如: + **Chinese word segmentor** 的 **Loader** 。如果您使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, + 2005. 更多信息可以在 http://sighan.cs.uchicago.edu/bakeoff2005/ 查看。 - Example:: + :class:`CWSLoader` 支持的数据格式为:一行一句话,不同词之间用空格隔开,例如:: 上海 浦东 开发 与 法制 建设 同步 新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 ) ... - 该Loader读取后的DataSet具有如下的结构 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_words" @@ -32,14 +31,11 @@ class CWSLoader(Loader): "上海 浦东 开发 与 法制 建设 同步" "新华社 上海 二月 十日 电 ( 记者 谢金虎 、 张持坚 )" "..." - + + :param dataset_name: data 的名称,支持 ``['pku', 'msra', 'cityu'(繁体), 'as'(繁体), None]`` """ def __init__(self, dataset_name: str = None): - r""" - - :param str dataset_name: data的名称,支持pku, msra, cityu(繁体), as(繁体), None - """ super().__init__() datanames = {'pku': 'cws-pku', 'msra': 'cws-msra', 'as': 'cws-as', 'cityu': 'cws-cityu'} if dataset_name in datanames: @@ -58,12 +54,11 @@ class CWSLoader(Loader): def download(self, dev_ratio=0.1, re_download=False) -> str: r""" - 如果你使用了该数据集,请引用以下的文章:Thomas Emerson, The Second International Chinese Word Segmentation Bakeoff, - 2005. 更多信息可以在http://sighan.cs.uchicago.edu/bakeoff2005/查看 + 自动下载数据集。 - :param float dev_ratio: 如果路径中没有dev集,从train划分多少作为dev的数据. 如果为0,则不划分dev。 - :param bool re_download: 是否重新下载数据,以重新切分数据。 - :return: str + :param dev_ratio: 如果路径中没有验证集,从 train 划分多少作为 dev 的数据。 如果为 **0** ,则不划分 dev + :param re_download: 是否重新下载数据,以重新切分数据。 + :return: 数据集的目录地址 """ if self.dataset_name is None: return '' diff --git a/fastNLP/io/loader/json.py b/fastNLP/io/loader/json.py index e5648a26..31dfce9d 100644 --- a/fastNLP/io/loader/json.py +++ b/fastNLP/io/loader/json.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "JsonLoader" ] @@ -11,19 +9,16 @@ from fastNLP.core.dataset import DataSet, Instance class JsonLoader(Loader): r""" - 别名::class:`fastNLP.io.JsonLoader` :class:`fastNLP.io.loader.JsonLoader` - - 读取json格式数据.数据必须按行存储,每行是一个包含各类属性的json对象 + 读取 *json* 格式数据,数据必须按行存储,每行是一个包含各类属性的 json 对象。 - :param dict fields: 需要读入的json属性名称, 和读入后在DataSet中存储的field_name - ``fields`` 的 `key` 必须是json对象的属性名. ``fields`` 的 `value` 为读入后在DataSet存储的 `field_name` , - `value` 也可为 ``None`` , 这时读入后的 `field_name` 与json对象对应属性同名 - ``fields`` 可为 ``None`` , 这时,json对象所有属性都保存在DataSet中. Default: ``None`` - :param bool dropna: 是否忽略非法数据,若 ``True`` 则忽略,若 ``False`` ,在遇到非法数据时,抛出 ``ValueError`` . - Default: ``False`` + :param fields: 需要读入的 json 属性名称,和读入后在 :class:`~fastNLP.core.DataSet` 中存储的 `field_name`。 + ``fields`` 的 `key` 必须是 json 对象的 **属性名**, ``fields`` 的 `value` 为读入后在 ``DataSet`` 存储的 `field_name` , + `value` 也可为 ``None`` ,这时读入后的 `field_name` 与 json 对象对应属性同名。 + ``fields`` 可为 ``None`` ,这时 json 对象所有属性都保存在 ``DataSet`` 中。 + :param dropna: 是否忽略非法数据,若为 ``True`` 则忽略;若为 ``False`` 则在遇到非法数据时抛出 :class:`ValueError`。 """ - def __init__(self, fields=None, dropna=False): + def __init__(self, fields: dict=None, dropna=False): super(JsonLoader, self).__init__() self.dropna = dropna self.fields = None diff --git a/fastNLP/io/loader/loader.py b/fastNLP/io/loader/loader.py index 135a9d74..2294f476 100644 --- a/fastNLP/io/loader/loader.py +++ b/fastNLP/io/loader/loader.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "Loader" ] @@ -14,24 +12,23 @@ from fastNLP.core.dataset import DataSet class Loader: r""" - 各种数据 Loader 的基类,提供了 API 的参考. - Loader支持以下的三个函数 + 各种数据 **Loader** 的基类,提供了 API 的参考。 + :class:`Loader` 支持以下的三个函数 - - download() 函数:自动将该数据集下载到缓存地址,默认缓存地址为~/.fastNLP/datasets/。由于版权等原因,不是所有的Loader都实现了该方法。该方法会返回下载后文件所处的缓存地址。 - - _load() 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.DataSet` 。返回的DataSet的内容可以通过每个Loader的文档判断出。 - - load() 函数:将文件分别读取为DataSet,然后将多个DataSet放入到一个DataBundle中并返回 - + - :meth:`download` 函数:自动将该数据集下载到缓存地址,默认缓存地址为 ``~/.fastNLP/datasets/`` 。由于版权等原因,不是所有的 ``Loader`` 都实现了该方法。 + 该方法会返回下载后文件所处的缓存地址。 + - :meth:`_load` 函数:从一个数据文件中读取数据,返回一个 :class:`~fastNLP.core.DataSet` 。返回的 DataSet 的内容可以通过每个 ``Loader`` 的文档判断出。 + - :meth:`load` 函数:将文件分别读取为 :class:`~fastNLP.core.DataSet` ,然后将多个 DataSet 放入到一个 :class:`~fastNLP.io.DataBundle` 中并返回 """ - def __init__(self): pass def _load(self, path: str) -> DataSet: r""" - 给定一个路径,返回读取的DataSet。 + 给定一个路径,返回读取的 :class:`~fastNLP.core.DataSet` 。 - :param str path: 路径 - :return: DataSet + :param path: 路径 + :return: :class:`~fastNLP.core.DataSet` """ raise NotImplementedError @@ -39,29 +36,29 @@ class Loader: r""" 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 - :param Union[str, Dict[str, str]] paths: 支持以下的几种输入方式: - - 0.如果为None,则先查看本地是否有缓存,如果没有则自动下载并缓存。 + :param paths: 支持以下的几种输入方式: - 1.传入一个目录, 该目录下名称包含train的被认为是train,包含test的被认为是test,包含dev的被认为是dev,如果检测到多个文件名包含'train'、 'dev'、 'test'则会报错:: + - ``None`` -- 先查看本地是否有缓存,如果没有则自动下载并缓存。 + - 一个目录,该目录下名称包含 ``'train'`` 的被认为是训练集,包含 ``'test'`` 的被认为是测试集,包含 ``'dev'`` 的被认为是验证集 / 开发集, + 如果检测到多个文件名包含 ``'train'``、 ``'dev'``、 ``'test'`` 则会报错:: data_bundle = xxxLoader().load('/path/to/dir') # 返回的DataBundle中datasets根据目录下是否检测到train - # dev、 test等有所变化,可以通过以下的方式取出DataSet + # dev、 test 等有所变化,可以通过以下的方式取出 DataSet tr_data = data_bundle.get_dataset('train') te_data = data_bundle.get_dataset('test') # 如果目录下有文件包含test这个字段 - 2.传入一个dict,比如train,dev,test不在同一个目录下,或者名称中不包含train, dev, test:: + - 传入一个 :class:`dict` ,比如训练集、验证集和测试集不在同一个目录下,或者名称中不包含 ``'train'``、 ``'dev'``、 ``'test'`` :: paths = {'train':"/path/to/tr.conll", 'dev':"/to/validate.conll", "test":"/to/te.conll"} data_bundle = xxxLoader().load(paths) # 返回的DataBundle中的dataset中包含"train", "dev", "test" dev_data = data_bundle.get_dataset('dev') - 3.传入文件路径:: + - 传入文件路径:: data_bundle = xxxLoader().load("/path/to/a/train.conll") # 返回DataBundle对象, datasets中仅包含'train' tr_data = data_bundle.get_dataset('train') # 取出DataSet - :return: 返回的 :class:`~fastNLP.io.DataBundle` + :return: :class:`~fastNLP.io.DataBundle` """ if paths is None: paths = self.download() diff --git a/fastNLP/io/loader/matching.py b/fastNLP/io/loader/matching.py index ed5c84c0..27689a18 100644 --- a/fastNLP/io/loader/matching.py +++ b/fastNLP/io/loader/matching.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "MNLILoader", "SNLILoader", @@ -26,17 +24,19 @@ from fastNLP.core.log import logger class MNLILoader(Loader): r""" - 读取的数据格式为: + **MNLI** 数据集的 **Loader**,如果您使用了这个数据,请引用 - Example:: + https://www.nyu.edu/projects/bowman/multinli/paper.pdf + + 读取的数据格式为:: index promptID pairID genre sentence1_binary_parse sentence2_binary_parse sentence1_parse sentence2_parse sentence1 sentence2 label1 gold_label 0 31193 31193n government ( ( Conceptually ( cream skimming ) ) ... 1 101457 101457e telephone ( you ( ( know ( during ( ( ( the season ) and ) ( i guess ) ) )... ... - 读取MNLI任务的数据,读取之后的DataSet中包含以下的内容,words0是sentence1, words1是sentence2, target是gold_label, 测试集中没 - 有target列。 + 读取之后的 :class:`~fastNLP.core.DataSet` 中包含以下的内容: ``raw_words`` 是 ``sentence1`` , ``raw_words2`` 是 ``sentence2`` , + ``target`` 是 ``gold_label``。测试集中没有 ``target`` 列。 .. csv-table:: :header: "raw_words1", "raw_words2", "target" @@ -80,10 +80,9 @@ class MNLILoader(Loader): def load(self, paths: str = None): r""" - - :param str paths: 传入数据所在目录,会在该目录下寻找dev_matched.tsv, dev_mismatched.tsv, test_matched.tsv, - test_mismatched.tsv, train.tsv文件夹 - :return: DataBundle + :param paths: 传入数据所在目录,会在该目录下寻找 ``dev_matched.tsv``, ``dev_mismatched.tsv``, ``test_matched.tsv``, + ``test_mismatched.tsv``, ``train.tsv`` 文件。 + :return: :class:`~fastNLP.io.DataBundle` """ if paths: paths = os.path.abspath(os.path.expanduser(paths)) @@ -112,10 +111,9 @@ class MNLILoader(Loader): def download(self): r""" - 如果你使用了这个数据,请引用 + 自动下载数据集。 - https://www.nyu.edu/projects/bowman/multinli/paper.pdf - :return: + :return: 数据集目录地址 """ output_dir = self._get_dataset_path('mnli') return output_dir @@ -123,9 +121,11 @@ class MNLILoader(Loader): class SNLILoader(JsonLoader): r""" - 文件每一行是一个sample,每一行都为一个json对象,其数据格式为: + **SNLI** 数据集的 **Loader**,如果您的文章使用了这份数据,请引用 - Example:: + http://nlp.stanford.edu/pubs/snli_paper.pdf + + 文件每一行是一个 sample,每一行都为一个 ``json`` 对象,其数据格式为:: {"annotator_labels": ["neutral", "entailment", "neutral", "neutral", "neutral"], "captionID": "4705552913.jpg#2", "gold_label": "neutral", "pairID": "4705552913.jpg#2r1n", @@ -137,7 +137,7 @@ class SNLILoader(JsonLoader): "sentence2_parse": "(ROOT (S (NP (DT The) (NNS sisters)) (VP (VBP are) (VP (VBG hugging) (NP (UH goodbye)) (PP (IN while) (S (VP (VBG holding) (S (VP (TO to) (VP (VB go) (NP (NNS packages)) (PP (IN after) (S (ADVP (RB just)) (VP (VBG eating) (NP (NN lunch))))))))))))) (. .)))" } - 读取之后的DataSet中的field情况为 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: 下面是使用SNLILoader加载的DataSet所具备的field :header: "target", "raw_words1", "raw_words2", @@ -158,13 +158,11 @@ class SNLILoader(JsonLoader): def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: r""" 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 + 读取的 field 根据 :class:`SNLILoader` 初始化时传入的 ``fields`` 决定。 - 读取的field根据Loader初始化时传入的field决定。 - - :param str paths: 传入一个目录, 将在该目录下寻找snli_1.0_train.jsonl, snli_1.0_dev.jsonl - 和snli_1.0_test.jsonl三个文件。 - - :return: :class:`~fastNLP.io.DataBundle` + :param str paths: 传入一个目录, 将在该目录下寻找 ``snli_1.0_train.jsonl``, ``snli_1.0_dev.jsonl`` + 和 ``snli_1.0_test.jsonl`` 三个文件。 + :return: """ _paths = {} if paths is None: @@ -187,26 +185,26 @@ class SNLILoader(JsonLoader): def download(self): r""" - 如果您的文章使用了这份数据,请引用 - - http://nlp.stanford.edu/pubs/snli_paper.pdf + 自动下载数据集。 - :return: str + :return: 数据集目录地址 """ return self._get_dataset_path('snli') class QNLILoader(JsonLoader): r""" - 第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、问题、句子和标签构成(以制表符分割),数据结构如下: + **QNLI** 数据集的 **Loader** ,如果您的实验使用到了该数据,请引用 - Example:: + https://arxiv.org/pdf/1809.05053.pdf + + 读取数据的格式为:第一行为标题(具体内容会被忽略),之后每一行是一个 sample,由 **index** 、**问题** 、**句子** 和 **标签** 构成(以制表符分割),数据结构如下:: index question sentence label 0 What came into force after the new constitution was herald? As of that day, the new constitution heralding the Second Republic came into force. entailment - QNLI数据集的Loader, - 加载的DataSet将具备以下的field, raw_words1是question, raw_words2是sentence, target是label + 加载的 :class:`~fastNLP.core.DataSet` 将具备以下的内容: ``raw_words1`` 是 ``question`` , ``raw_words2`` 是 ``sentence`` , ``target`` 是 ``label`` 。 + 测试集中没有 ``target`` 列。 .. csv-table:: :header: "raw_words1", "raw_words2", "target" @@ -214,8 +212,6 @@ class QNLILoader(JsonLoader): "What came into force after the new...", "As of that day...", "entailment" "...","." - test数据集没有target列 - """ def __init__(self): @@ -250,26 +246,27 @@ class QNLILoader(JsonLoader): def download(self): r""" - 如果您的实验使用到了该数据,请引用 + 自动下载数据集。 - https://arxiv.org/pdf/1809.05053.pdf - - :return: + :return: 数据集目录地址 """ return self._get_dataset_path('qnli') class RTELoader(Loader): r""" - 第一行为标题(具体内容会被忽略),之后每一行是一个sample,由index、句子1、句子2和标签构成(以制表符分割),数据结构如下: + **RTE** 数据集的 **Loader**,如果您使用了该数据,请引用 **GLUE Benchmark** : - Example:: + https://openreview.net/pdf?id=rJ4km2R5t7 + + 读取数据的格式为:第一行为标题(具体内容会被忽略),之后每一行是一个 sample,由 **index** 、**句子1** 、**句子2** 和 **标签** + 构成(以制表符分割),数据结构如下:: index sentence1 sentence2 label 0 Dana Reeve, the widow of the actor Christopher Reeve, has died of lung cancer at age 44, according to the Christopher Reeve Foundation. Christopher Reeve had an accident. not_entailment - RTE数据的loader - 加载的DataSet将具备以下的field, raw_words1是sentence0,raw_words2是sentence1, target是label + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的内容:``raw_words1`` 是 ``sentence1`` , ``raw_words2`` 是 ``sentence2`` , ``target`` 是 ``label`` 。 + 测试集中没有 ``target`` 列。 .. csv-table:: :header: "raw_words1", "raw_words2", "target" @@ -277,7 +274,6 @@ class RTELoader(Loader): "Dana Reeve, the widow of the actor...", "Christopher Reeve had an...", "not_entailment" "...","..." - test数据集没有target列 """ def __init__(self): @@ -312,20 +308,17 @@ class RTELoader(Loader): def download(self): r""" - 如果您的实验使用到了该数据,请引用GLUE Benchmark + 自动下载数据集。 - https://openreview.net/pdf?id=rJ4km2R5t7 - - :return: + :return: 数据集目录地址 """ return self._get_dataset_path('rte') class QuoraLoader(Loader): r""" - Quora matching任务的数据集Loader - - 支持读取的文件中的内容,应该有以下的形式, 以制表符分隔,且前三列的内容必须是:第一列是label,第二列和第三列是句子 + **Quora matching** 任务的数据集 **Loader**。 + 支持读取的文件中的内容应该有以下的形式:以制表符分隔,且前三列一定分别为 **label** , **句子1** , **句子2** 。 Example:: @@ -333,7 +326,7 @@ class QuoraLoader(Loader): 0 Is honey a viable alternative to sugar for diabetics ? How would you compare the United States ' euthanasia laws to Denmark ? 90348 ... - 加载的DataSet将具备以下的field + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_words1", "raw_words2", "target" @@ -370,21 +363,27 @@ class QuoraLoader(Loader): :return: """ + r""" + 自动下载数据集。 + + :return: 数据集目录地址 + """ raise RuntimeError("Quora cannot be downloaded automatically.") class CNXNLILoader(Loader): r""" - 数据集简介:中文句对NLI(本为multi-lingual的数据集,但是这里只取了中文的数据集)。原句子已被MOSES tokenizer处理,这里我们将其还原并重新按字tokenize - 原始数据数据为: - - Example:: + **XNLI Chinese** 数据集的 **Loader** ,该数据取自 https://arxiv.org/abs/1809.05053 ,在 https://arxiv.org/pdf/1905.05526.pdf + 、 https://arxiv.org/pdf/1901.10125.pdf 和 https://arxiv.org/pdf/1809.05053.pdf 有使用。 + + 该数据集为中文句对 NLI(本为 ``multi-lingual`` 的数据集,但是这里只取了中文的数据集)。原句子已被 + ``MOSES tokenizer`` 处理,这里我们将其还原并重新按字 tokenize 。原始数据为:: premise hypo label 我们 家里 有 一个 但 我 没 找到 我 可以 用 的 时间 我们 家里 有 一个 但 我 从来 没有 时间 使用 它 . entailment - dev和test中的数据为csv或json格式,包括十多个field,这里只取与以上三个field中的数据 - 读取后的Dataset将具有以下数据结构: + 验证集和测试集中的数据为 csv 或 json 格式,这里只取以上三个 field 中的数据。 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_chars1", "raw_chars2", "target" @@ -444,6 +443,13 @@ class CNXNLILoader(Loader): return ds def load(self, paths: Union[str, Dict[str, str]] = None) -> DataBundle: + """ + 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 + 读取的 field 根据 :class:`SNLILoader` 初始化时传入的 ``fields`` 决定。 + + :param paths: + :return: + """ if paths is None: paths = self.download() paths = check_loader_paths(paths) @@ -459,10 +465,9 @@ class CNXNLILoader(Loader): def download(self) -> str: r""" - 自动下载数据,该数据取自 https://arxiv.org/abs/1809.05053 - 在 https://arxiv.org/pdf/1905.05526.pdf https://arxiv.org/pdf/1901.10125.pdf - https://arxiv.org/pdf/1809.05053.pdf 有使用 - :return: + 自动下载数据集。 + + :return: 数据集目录地址 """ output_dir = self._get_dataset_path('cn-xnli') return output_dir @@ -470,23 +475,19 @@ class CNXNLILoader(Loader): class BQCorpusLoader(Loader): r""" - 别名: - 数据集简介:句子对二分类任务(判断是否具有相同的语义) - 原始数据结构为: - - Example:: + **BQ Corpus** 数据集的 **Loader** 。句子对二分类任务,判断是否具有相同的语义。原始数据结构为:: sentence1,sentence2,label 综合评分不足什么原因,综合评估的依据,0 - 什么时候我能使用微粒贷,你就赶快给我开通就行了,0 + 什么时候我能使用微粒贷,您就赶快给我开通就行了,0 - 读取后的Dataset将具有以下数据结构: + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_chars1", "raw_chars2", "target" "综合评分不足什么原因", "综合评估的依据", "0" - "什么时候我能使用微粒贷", "你就赶快给我开通就行了", "0" + "什么时候我能使用微粒贷", "您就赶快给我开通就行了", "0" "...", "...", "..." """ @@ -514,31 +515,25 @@ class BQCorpusLoader(Loader): 由于版权限制,不能提供自动下载功能。可参考 https://github.com/ymcui/Chinese-BERT-wwm - - :return: """ raise RuntimeError("BQCorpus cannot be downloaded automatically.") class LCQMCLoader(Loader): r""" - 数据集简介:句对匹配(question matching) - - 原始数据为: - - Example:: + **LCQMC** 数据集的 **Loader**,该数据集用于句对匹配(question matching)。原始数据为:: 喜欢打篮球的男生喜欢什么样的女生 爱打篮球的男生喜欢什么样的女生 1 - 你帮我设计小说的封面吧 谁能帮我给小说设计个封面? 0 + 您帮我设计小说的封面吧 谁能帮我给小说设计个封面? 0 - 读取后的Dataset将具有以下的数据结构 + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: .. csv-table:: :header: "raw_chars1", "raw_chars2", "target" "喜欢打篮球的男生喜欢什么样的女生", "爱打篮球的男生喜欢什么样的女生", "1" - "你帮我设计小说的封面吧", "妇可以戴耳机听音乐吗?", "0" + "您帮我设计小说的封面吧", "妇可以戴耳机听音乐吗?", "0" "...", "...", "..." @@ -569,8 +564,6 @@ class LCQMCLoader(Loader): 由于版权限制,不能提供自动下载功能。可参考 https://github.com/ymcui/Chinese-BERT-wwm - - :return: """ raise RuntimeError("LCQMC cannot be downloaded automatically.") diff --git a/fastNLP/io/loader/qa.py b/fastNLP/io/loader/qa.py index a3140b01..b6329daf 100644 --- a/fastNLP/io/loader/qa.py +++ b/fastNLP/io/loader/qa.py @@ -1,6 +1,5 @@ r""" -该文件中的Loader主要用于读取问答式任务的数据 - +该文件中的 **Loader** 主要用于读取问答式任务的数据 """ @@ -13,20 +12,23 @@ __all__ = ['CMRC2018Loader'] class CMRC2018Loader(Loader): r""" - 请直接使用从fastNLP下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测 + **CMRC2018** 数据集的 **Loader** ,如果您使用了本数据,请引用 + A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. + + 请直接使用从 **fastNLP** 下载的数据进行处理。该数据集未提供测试集,测试需要通过上传到对应的系统进行评测。 - 读取之后训练集DataSet将具备以下的内容,每个问题的答案只有一个 + 读取之后训练集 :class:`~fastNLP.core.DataSet` 将具备以下的内容,每个问题的答案只有一个: .. csv-table:: - :header:"title", "context", "question", "answers", "answer_starts", "id" + :header: "title", "context", "question", "answers", "answer_starts", "id" "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" "...", "...", "...","...", ".", "..." - 其中title是文本的标题,多条记录可能是相同的title;id是该问题的id,具备唯一性 + 其中 ``title`` 是文本的标题,多条记录可能是相同的 ``title`` ;``id`` 是该问题的 id,具备唯一性。 - 验证集DataSet将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案) + 验证集 :class:`~fastNLP.core.DataSet` 将具备以下的内容,每个问题的答案可能有三个(有时候只是3个重复的答案): .. csv-table:: :header: "title", "context", "question", "answers", "answer_starts", "id" @@ -35,8 +37,8 @@ class CMRC2018Loader(Loader): "战国无双3", "《战国无双3》()是由光荣和ω-force开发...", "男女主角亦有专属声优这一模式是由谁改编的?", "['村雨城', '村雨城', '任天堂游戏谜之村雨城']", "[226, 226, 219]", "DEV_0_QUERY_1" "...", "...", "...","...", ".", "..." - 其中answer_starts是从0开始的index。例如"我来自a复旦大学?",其中"复"的开始index为4。另外"Russell评价说"中的说的index为9, 因为 - 英文和数字都直接按照character计量的。 + 其中 ``answer_starts`` 是从 0 开始的 index。例如 ``"我来自a复旦大学?"`` ,其中 ``"复"`` 的开始 index 为 **4**。另外 ``"Russell评价说"`` + 中的 ``"说"`` 的 index 为 **9** , 因为英文和数字都直接按照 character 计量的。 """ def __init__(self): super().__init__() @@ -65,9 +67,9 @@ class CMRC2018Loader(Loader): def download(self) -> str: r""" - 如果您使用了本数据,请引用A Span-Extraction Dataset for Chinese Machine Reading Comprehension. Yiming Cui, Ting Liu, etc. + 自动下载数据集。 - :return: + :return: 数据集目录地址 """ output_dir = self._get_dataset_path('cmrc2018') return output_dir diff --git a/fastNLP/io/loader/summarization.py b/fastNLP/io/loader/summarization.py new file mode 100644 index 00000000..6dcb755f --- /dev/null +++ b/fastNLP/io/loader/summarization.py @@ -0,0 +1,64 @@ +__all__ = [ + "ExtCNNDMLoader" +] + +import os +from typing import Union, Dict + +from ..data_bundle import DataBundle +from ..utils import check_loader_paths +from .json import JsonLoader + + +class ExtCNNDMLoader(JsonLoader): + r""" + **CNN / Daily Mail** 数据集的 **Loader** ,用于 **extractive summarization task** 任务。 + 如果你使用了这个数据,请引用 https://arxiv.org/pdf/1506.03340.pdf + + 读取的 :class:`~fastNLP.core.DataSet` 将具备以下的数据结构: + + .. csv-table:: + :header: "text", "summary", "label", "publication" + + "['I got new tires from them and... ','...']", "['The new tires...','...']", "[0, 1]", "cnndm" + "['Don't waste your time. We had two...','...']", "['Time is precious','...']", "[1]", "cnndm" + "['...']", "['...']", "[]", "cnndm" + + :param fields: + """ + + def __init__(self, fields=None): + fields = fields or {"text": None, "summary": None, "label": None, "publication": None} + super(ExtCNNDMLoader, self).__init__(fields=fields) + + def load(self, paths: Union[str, Dict[str, str]] = None): + r""" + 从指定一个或多个路径中的文件中读取数据,返回 :class:`~fastNLP.io.DataBundle` 。 + + 读取的 field 根据 :class:`ExtCNNDMLoader` 初始化时传入的 ``fields`` 决定。 + + :param paths: 传入一个目录, 将在该目录下寻找 ``train.label.jsonl`` , ``dev.label.jsonl`` , + ``test.label.jsonl`` 三个文件(该目录还应该需要有一个名字为 ``vocab`` 的文件,在 :class:`~fastNLP.io.pipe.ExtCNNDMPipe` + 当中需要用到)。 + + :return: :class:`~fastNLP.io.DataBundle` + """ + if paths is None: + paths = self.download() + paths = check_loader_paths(paths) + if ('train' in paths) and ('test' not in paths): + paths['test'] = paths['train'] + paths.pop('train') + + datasets = {name: self._load(path) for name, path in paths.items()} + data_bundle = DataBundle(datasets=datasets) + return data_bundle + + def download(self): + r""" + 自动下载数据集。 + + :return: 数据集目录地址 + """ + output_dir = self._get_dataset_path('ext-cnndm') + return output_dir diff --git a/fastNLP/io/pipe/__init__.py b/fastNLP/io/pipe/__init__.py index 05a82806..e7314e00 100644 --- a/fastNLP/io/pipe/__init__.py +++ b/fastNLP/io/pipe/__init__.py @@ -52,7 +52,7 @@ __all__ = [ "BQCorpusPipe", "RenamePipe", "GranularizePipe", - "MachingTruncatePipe", + "TruncateBertPipe", "CMRC2018BertPipe", @@ -74,7 +74,7 @@ from .conll import Conll2003Pipe, iob2, iob2bioes from .cws import CWSPipe from .matching import MatchingBertPipe, RTEBertPipe, SNLIBertPipe, QuoraBertPipe, QNLIBertPipe, MNLIBertPipe, \ MatchingPipe, RTEPipe, SNLIPipe, QuoraPipe, QNLIPipe, MNLIPipe, CNXNLIBertPipe, CNXNLIPipe, BQCorpusBertPipe, \ - LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, MachingTruncatePipe + LCQMCPipe, BQCorpusPipe, LCQMCBertPipe, RenamePipe, GranularizePipe, TruncateBertPipe from .pipe import Pipe from .qa import CMRC2018BertPipe diff --git a/fastNLP/io/pipe/classification.py b/fastNLP/io/pipe/classification.py index a29de173..7527a3d6 100644 --- a/fastNLP/io/pipe/classification.py +++ b/fastNLP/io/pipe/classification.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "CLSBasePipe", "AGsNewsPipe", @@ -36,8 +34,17 @@ from fastNLP.core.log import logger class CLSBasePipe(Pipe): + """ + 处理分类数据集 **Pipe** 的基类。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw', 'cn-char']`` 。``'raw'`` 表示使用空格作为切分, ``'cn-char'`` 表示 + 按字符切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param lang: :mod:`spacy` 使用的语言,当前仅支持 ``'en'`` 。 + :param num_proc: 处理数据时使用的进程数目。 + """ - def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang='en', num_proc=0): + def __init__(self, lower: bool = False, tokenizer: str = 'raw', lang: str='en', num_proc: int=0): super().__init__() self.lower = lower self.tokenizer = get_tokenizer(tokenizer, lang=lang) @@ -61,7 +68,7 @@ class CLSBasePipe(Pipe): def process(self, data_bundle: DataBundle): r""" - 传入的DataSet应该具备如下的结构 + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备如下的结构: .. csv-table:: :header: "raw_words", "target" @@ -71,7 +78,7 @@ class CLSBasePipe(Pipe): "...", "..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ # 复制一列words data_bundle = _add_words_field(data_bundle, lower=self.lower) @@ -87,46 +94,32 @@ class CLSBasePipe(Pipe): def process_from_file(self, paths) -> DataBundle: r""" - 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` :param paths: - :return: DataBundle + :return: """ raise NotImplementedError class YelpFullPipe(CLSBasePipe): r""" - 处理YelpFull的数据, 处理之后DataSet中的内容如下 + 处理 **Yelp Review Full** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: - .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field + .. csv-table:: 下面是使用 YelpFullPipe 处理后的 DataSet 所具备的 field :header: "raw_words", "target", "words", "seq_len" "I got 'new' tires from them and within...", 0 ,"[7, 110, 22, 107, 22, 499, 59, 140, 3,...]", 160 " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 "...", ., "[...]", . - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | False | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + :param lower: 是否对输入进行小写化。 + :param granularity: 支持 ``[2, 3, 5]`` 。若为 ``2`` ,则认为是二分类问题,将 **1、2** 归为一类, **4、5** 归为一类, + 丢掉 3;若为 ``3`` ,则认为是三分类问题,将 **1、2** 归为一类, **3** 归为一类, **4、5** 归为一类;若为 ``5`` ,则认为是五分类问题。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ - - def __init__(self, lower: bool = False, granularity=5, tokenizer: str = 'spacy', num_proc=0): - r""" - - :param bool lower: 是否对输入进行小写化。 - :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将1、2归为1类,4、5归为一类,丢掉2;若为3, 则有3分类问题,将 - 1、2归为1类,3归为1类,4、5归为1类;若为5, 则有5分类问题。 - :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 - """ + def __init__(self, lower: bool = False, granularity: int=5, tokenizer: str = 'spacy', num_proc: int=0): super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) assert granularity in (2, 3, 5), "granularity can only be 2,3,5." self.granularity = granularity @@ -140,7 +133,7 @@ class YelpFullPipe(CLSBasePipe): def process(self, data_bundle): r""" - 传入的DataSet应该具备如下的结构 + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备如下的结构: .. csv-table:: :header: "raw_words", "target" @@ -150,7 +143,7 @@ class YelpFullPipe(CLSBasePipe): "...", "..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ if self.tag_map is not None: data_bundle = _granularize(data_bundle, self.tag_map) @@ -159,11 +152,12 @@ class YelpFullPipe(CLSBasePipe): return data_bundle - def process_from_file(self, paths=None): + def process_from_file(self, paths=None) -> DataBundle: r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` :param paths: - :return: DataBundle + :return: """ data_bundle = YelpFullLoader().load(paths) return self.process(data_bundle=data_bundle) @@ -171,7 +165,7 @@ class YelpFullPipe(CLSBasePipe): class YelpPolarityPipe(CLSBasePipe): r""" - 处理YelpPolarity的数据, 处理之后DataSet中的内容如下 + 处理 **Yelp Review Polarity** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: 下面是使用YelpFullPipe处理后的DataSet所具备的field :header: "raw_words", "target", "words", "seq_len" @@ -180,32 +174,20 @@ class YelpPolarityPipe(CLSBasePipe): " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 "...", ., "[...]", . - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | False | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ - def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): - r""" - - :param bool lower: 是否对输入进行小写化。 - :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 - """ + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int=0): super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param str paths: - :return: DataBundle + :param paths: + :return: """ data_bundle = YelpPolarityLoader().load(paths) return self.process(data_bundle=data_bundle) @@ -213,7 +195,7 @@ class YelpPolarityPipe(CLSBasePipe): class AGsNewsPipe(CLSBasePipe): r""" - 处理AG's News的数据, 处理之后DataSet中的内容如下 + 处理 **AG's News** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: 下面是使用AGsNewsPipe处理后的DataSet所具备的field :header: "raw_words", "target", "words", "seq_len" @@ -222,31 +204,20 @@ class AGsNewsPipe(CLSBasePipe): " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 "...", ., "[...]", . - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | False | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): - r""" - - :param bool lower: 是否对输入进行小写化。 - :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 - """ super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" - :param str paths: - :return: DataBundle + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: """ data_bundle = AGsNewsLoader().load(paths) return self.process(data_bundle=data_bundle) @@ -254,7 +225,7 @@ class AGsNewsPipe(CLSBasePipe): class DBPediaPipe(CLSBasePipe): r""" - 处理DBPedia的数据, 处理之后DataSet中的内容如下 + 处理 **DBPedia** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: 下面是使用DBPediaPipe处理后的DataSet所具备的field :header: "raw_words", "target", "words", "seq_len" @@ -263,31 +234,20 @@ class DBPediaPipe(CLSBasePipe): " Don't waste your time. We had two dif... ", 0, "[277, 17, 278, 38, 30, 112, 24, 85, 27...", 40 "...", ., "[...]", . - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | False | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ - def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): - r""" - - :param bool lower: 是否对输入进行小写化。 - :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 - """ + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int=0): super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" - :param str paths: - :return: DataBundle + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: """ data_bundle = DBPediaLoader().load(paths) return self.process(data_bundle=data_bundle) @@ -295,7 +255,7 @@ class DBPediaPipe(CLSBasePipe): class SSTPipe(CLSBasePipe): r""" - 经过该Pipe之后,DataSet中具备的field如下所示 + 处理 **SST** 的数据,处理之后, :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: 下面是使用SSTPipe处理后的DataSet所具备的field :header: "raw_words", "words", "target", "seq_len" @@ -304,29 +264,15 @@ class SSTPipe(CLSBasePipe): "No one goes unindicted here , which is...", 0, "[191, 126, 192, 193, 194, 4, 195, 17, ...", 13 "...", ., "[...]", . - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | False | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + :param subtree: 是否将训练集、测试集和验证集数据展开为子树,扩充数据量。 + :param train_subtree: 是否将训练集通过子树扩展数据。 + :param lower: 是否对输入进行小写化。 + :param granularity: 支持 ``[2, 3, 5]`` 。若为 ``2`` ,则认为是二分类问题,将 **1、2** 归为一类, **4、5** 归为一类, + 丢掉 3;若为 ``3`` ,则认为是三分类问题,将 **1、2** 归为一类, **3** 归为一类, **4、5** 归为一类;若为 ``5`` ,则认为是五分类问题。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ - - def __init__(self, subtree=False, train_subtree=True, lower=False, granularity=5, tokenizer='spacy', num_proc=0): - r""" - - :param bool subtree: 是否将train, test, dev数据展开为子树,扩充数据量。 Default: ``False`` - :param bool train_subtree: 是否将train集通过子树扩展数据。 - :param bool lower: 是否对输入进行小写化。 - :param int granularity: 支持2, 3, 5。若为2, 则认为是2分类问题,将0、1归为1类,3、4归为一类,丢掉2;若为3, 则有3分类问题,将 - 0、1归为1类,2归为1类,3、4归为1类;若为5, 则有5分类问题。 - :param str tokenizer: 使用哪种tokenize方式将数据切成单词。支持'spacy'和'raw'。raw使用空格作为切分。 - """ + def __init__(self, subtree: bool=False, train_subtree: bool=True, lower: bool=False, granularity: int=5, tokenizer: int='spacy', num_proc: int=0): super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.subtree = subtree self.train_tree = train_subtree @@ -341,19 +287,19 @@ class SSTPipe(CLSBasePipe): else: self.tag_map = None - def process(self, data_bundle: DataBundle): + def process(self, data_bundle: DataBundle) -> DataBundle: r""" - 对DataBundle中的数据进行预处理。输入的DataSet应该至少拥有raw_words这一列,且内容类似与 + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` ` 应该至少拥有 ``raw_words`` 列,内容类似于: - .. csv-table:: 下面是使用SSTLoader读取的DataSet所具备的field + .. csv-table:: 下面是使用 SSTLoader 读取的 DataSet 所具备的 field :header: "raw_words" "(2 (3 (3 Effective) (2 but)) (1 (1 too-tepid)..." "(3 (3 (2 If) (3 (2 you) (3 (2 sometimes) ..." "..." - :param ~fastNLP.io.DataBundle data_bundle: 需要处理的DataBundle对象 - :return: + :param data_bundle: 需要处理的 :class:`~fastNLP.io.DataBundle` 对象 + :return: 处理后的 ``data_bundle`` """ # 先取出subtree for name in list(data_bundle.datasets.keys()): @@ -381,13 +327,19 @@ class SSTPipe(CLSBasePipe): return data_bundle def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = SSTLoader().load(paths) return self.process(data_bundle=data_bundle) class SST2Pipe(CLSBasePipe): r""" - 加载SST2的数据, 处理完成之后DataSet将拥有以下的field + 处理 **SST-2** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_words", "target", "words", "seq_len" @@ -396,32 +348,20 @@ class SST2Pipe(CLSBasePipe): "unflinchingly bleak and desperate", 0, "[115, 116, 5, 117]", 4 "...", "...", ., . - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | False | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ def __init__(self, lower=False, tokenizer='raw', num_proc=0): - r""" - - :param bool lower: 是否对输入进行小写化。 - :param str tokenizer: 使用哪种tokenize方式将数据切成单词。 - """ super().__init__(lower=lower, tokenizer=tokenizer, lang='en', num_proc=num_proc) def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param str paths: 如果为None,则自动下载并缓存到fastNLP的缓存地址。 - :return: DataBundle + :param paths: + :return: """ data_bundle = SST2Loader().load(paths) return self.process(data_bundle) @@ -429,43 +369,31 @@ class SST2Pipe(CLSBasePipe): class IMDBPipe(CLSBasePipe): r""" - 经过本Pipe处理后DataSet将如下 + 处理 **IMDb** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: - .. csv-table:: 输出DataSet的field + .. csv-table:: 输出 DataSet 的 field :header: "raw_words", "target", "words", "seq_len" "Bromwell High is a cartoon ... ", 0, "[3, 5, 6, 9, ...]", 20 "Story of a man who has ...", 1, "[20, 43, 9, 10, ...]", 31 "...", ., "[...]", . - 其中raw_words为str类型,是原文; words是转换为index的输入; target是转换为index的目标值; - words列被设置为input; target列被设置为target。 - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | False | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ + 其中 ``raw_words`` 为 :class:`str` 类型,是原文; ``words`` 是转换为 index 的输入; ``target`` 是转换为 index 的目标值。 + ``words`` 列被设置为 input, ``target`` 列被设置为 target。 + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): - r""" - - :param bool lower: 是否将words列的数据小写。 - :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - """ super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process(self, data_bundle: DataBundle): r""" - 期待的DataBunlde中输入的DataSet应该类似于如下,有两个field,raw_words和target,且均为str类型 + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备如下的结构:有两个 field , ``raw_words`` 和 ``target`` , + 且均为 :class:`str` 类型。 .. csv-table:: 输入DataSet的field :header: "raw_words", "target" @@ -476,7 +404,7 @@ class IMDBPipe(CLSBasePipe): :param DataBunlde data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和target两个field,且raw_words列应该为str, target列应该为str。 - :return: DataBundle + :return: 处理后的 ``data_bundle`` """ # 替换
@@ -493,9 +421,10 @@ class IMDBPipe(CLSBasePipe): def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = IMDBLoader().load(paths) @@ -506,7 +435,7 @@ class IMDBPipe(CLSBasePipe): class ChnSentiCorpPipe(Pipe): r""" - 处理之后的DataSet有以下的结构 + 处理 **ChnSentiCorp** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容为: .. csv-table:: :header: "raw_chars", "target", "chars", "seq_len" @@ -515,30 +444,18 @@ class ChnSentiCorpPipe(Pipe): "<荐书> 推荐所有喜欢<红楼>...", 1, "[10, 21, ....]", 25 "..." - 其中chars, seq_len是input,target是target - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_chars | target | chars | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ + 其中 ``chars`` , ``seq_len`` 是 input, ``target`` 是 target。 + :param bigrams: 是否增加一列 ``bigrams`` 。 ``bigrams`` 会对原文进行如下转化: ``['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]`` 。如果 + 设置为 ``True`` ,返回的 `~fastNLP.core.DataSet` 将有一列名为 ``bigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('bigrams')`` 获取。 + :param trigrams: 是否增加一列 ``trigrams`` 。 ``trigrams`` 会对原文进行如下转化 ``['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]`` 。 + 如果设置为 ``True`` ,返回的 `~fastNLP.core.DataSet` 将有一列名为 ``trigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('trigrams')`` 获取。 + :param num_proc: 处理数据时使用的进程数目。 """ - def __init__(self, bigrams=False, trigrams=False, num_proc: int = 0): - r""" - - :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 - 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 - data_bundle.get_vocab('bigrams')获取. - :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] - 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 - data_bundle.get_vocab('trigrams')获取. - """ + def __init__(self, bigrams: bool=False, trigrams: bool=False, num_proc: int = 0): super().__init__() self.bigrams = bigrams @@ -557,7 +474,7 @@ class ChnSentiCorpPipe(Pipe): def process(self, data_bundle: DataBundle): r""" - 可以处理的DataSet应该具备以下的field + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备如下的结构: .. csv-table:: :header: "raw_chars", "target" @@ -567,7 +484,7 @@ class ChnSentiCorpPipe(Pipe): "..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ _add_chars_field(data_bundle, lower=False) @@ -601,9 +518,10 @@ class ChnSentiCorpPipe(Pipe): def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = ChnSentiCorpLoader().load(paths) @@ -614,7 +532,7 @@ class ChnSentiCorpPipe(Pipe): class THUCNewsPipe(CLSBasePipe): r""" - 处理之后的DataSet有以下的结构 + 处理 **THUCNews** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容为: .. csv-table:: :header: "raw_chars", "target", "chars", "seq_len" @@ -622,27 +540,18 @@ class THUCNewsPipe(CLSBasePipe): "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", 0, "[409, 1197, 2146, 213, ...]", 746 "..." - 其中chars, seq_len是input,target是target - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_chars | target | chars | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - - :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 - 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 - data_bundle.get_vocab('bigrams')获取. - :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] - 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 - data_bundle.get_vocab('trigrams')获取. + 其中 ``chars`` , ``seq_len`` 是 input, ``target`` 是target + + :param bigrams: 是否增加一列 ``bigrams`` 。 ``bigrams`` 会对原文进行如下转化: ``['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]`` 。如果 + 设置为 ``True`` ,返回的 `~fastNLP.core.DataSet` 将有一列名为 ``bigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('bigrams')`` 获取。 + :param trigrams: 是否增加一列 ``trigrams`` 。 ``trigrams`` 会对原文进行如下转化 ``['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]`` 。 + 如果设置为 ``True`` ,返回的 `~fastNLP.core.DataSet` 将有一列名为 ``trigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('trigrams')`` 获取。 + :param num_proc: 处理数据时使用的进程数目。 """ - def __init__(self, bigrams=False, trigrams=False, num_proc=0): + def __init__(self, bigrams: int=False, trigrams: int=False, num_proc: int=0): super().__init__(num_proc=num_proc) self.bigrams = bigrams @@ -663,7 +572,7 @@ class THUCNewsPipe(CLSBasePipe): def process(self, data_bundle: DataBundle): r""" - 可处理的DataSet应具备如下的field + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备如下的结构: .. csv-table:: :header: "raw_words", "target" @@ -672,7 +581,7 @@ class THUCNewsPipe(CLSBasePipe): "...", "..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ # 根据granularity设置tag tag_map = {'体育': 0, '财经': 1, '房产': 2, '家居': 3, '教育': 4, '科技': 5, '时尚': 6, '时政': 7, '游戏': 8, '娱乐': 9} @@ -713,8 +622,10 @@ class THUCNewsPipe(CLSBasePipe): def process_from_file(self, paths=None): r""" - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: """ data_loader = THUCNewsLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None data_bundle = data_loader.load(paths) @@ -724,32 +635,23 @@ class THUCNewsPipe(CLSBasePipe): class WeiboSenti100kPipe(CLSBasePipe): r""" - 处理之后的DataSet有以下的结构 + 处理 **WeiboSenti100k** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容为: .. csv-table:: :header: "raw_chars", "target", "chars", "seq_len" - "六一出生的?好讽刺…… //@祭春姬:他爸爸是外星人吧 //@面孔小高:现在的孩子都怎么了 [怒][怒][怒]", 0, "[0, 690, 18, ...]", 56 + "马晓旭意外受伤让国奥警惕 无奈大雨格外青睐殷家军记者傅亚雨沈阳报道...", 0, "[409, 1197, 2146, 213, ...]", 746 "..." - 其中chars, seq_len是input,target是target - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_chars | target | chars | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | False | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - - :param bool bigrams: 是否增加一列bigrams. bigrams的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]。如果 - 设置为True,返回的DataSet将有一列名为bigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 - data_bundle.get_vocab('bigrams')获取. - :param bool trigrams: 是否增加一列trigrams. trigrams的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] - 。如果设置为True,返回的DataSet将有一列名为trigrams, 且已经转换为了index并设置为input,对应的vocab可以通过 - data_bundle.get_vocab('trigrams')获取. + 其中 ``chars`` , ``seq_len`` 是 input, ``target`` 是target + + :param bigrams: 是否增加一列 ``bigrams`` 。 ``bigrams`` 会对原文进行如下转化: ``['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]`` 。如果 + 设置为 ``True`` ,返回的 `~fastNLP.core.DataSet` 将有一列名为 ``bigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('bigrams')`` 获取。 + :param trigrams: 是否增加一列 ``trigrams`` 。 ``trigrams`` 会对原文进行如下转化 ``['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]`` 。 + 如果设置为 ``True`` ,返回的 `~fastNLP.core.DataSet` 将有一列名为 ``trigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('trigrams')`` 获取。 + :param num_proc: 处理数据时使用的进程数目。 """ def __init__(self, bigrams=False, trigrams=False, num_proc=0): @@ -770,7 +672,7 @@ class WeiboSenti100kPipe(CLSBasePipe): def process(self, data_bundle: DataBundle): r""" - 可处理的DataSet应具备以下的field + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备如下的结构: .. csv-table:: :header: "raw_chars", "target" @@ -779,7 +681,7 @@ class WeiboSenti100kPipe(CLSBasePipe): "...", "..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ # clean,lower @@ -811,8 +713,10 @@ class WeiboSenti100kPipe(CLSBasePipe): def process_from_file(self, paths=None): r""" - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: """ data_loader = WeiboSenti100kLoader() # 此处需要实例化一个data_loader,否则传入load()的参数为None data_bundle = data_loader.load(paths) @@ -820,20 +724,23 @@ class WeiboSenti100kPipe(CLSBasePipe): return data_bundle class MRPipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): - r""" + """ + 加载 **MR** 的数据。 - :param bool lower: 是否将words列的数据小写。 - :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - """ + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc=0): super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = MRLoader().load(paths) @@ -843,20 +750,23 @@ class MRPipe(CLSBasePipe): class R8Pipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc = 0): - r""" + """ + 加载 **R8** 的数据。 - :param bool lower: 是否将words列的数据小写。 - :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - """ + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc = 0): super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = R8Loader().load(paths) @@ -866,20 +776,23 @@ class R8Pipe(CLSBasePipe): class R52Pipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): - r""" + """ + 加载 **R52** 的数据。 - :param bool lower: 是否将words列的数据小写。 - :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - """ + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = R52Loader().load(paths) @@ -889,20 +802,23 @@ class R52Pipe(CLSBasePipe): class OhsumedPipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): - r""" + """ + 加载 **Ohsumed** 的数据。 - :param bool lower: 是否将words列的数据小写。 - :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - """ + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = OhsumedLoader().load(paths) @@ -912,20 +828,23 @@ class OhsumedPipe(CLSBasePipe): class NG20Pipe(CLSBasePipe): - def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): - r""" + """ + 加载 **NG20** 的数据。 - :param bool lower: 是否将words列的数据小写。 - :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - """ + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ + def __init__(self, lower: bool = False, tokenizer: str = 'spacy', num_proc: int = 0): super().__init__(tokenizer=tokenizer, lang='en', num_proc=num_proc) self.lower = lower def process_from_file(self, paths=None): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.Loader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = NG20Loader().load(paths) diff --git a/fastNLP/io/pipe/conll.py b/fastNLP/io/pipe/conll.py index efe05de0..d2d4730e 100644 --- a/fastNLP/io/pipe/conll.py +++ b/fastNLP/io/pipe/conll.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "Conll2003NERPipe", "Conll2003Pipe", @@ -60,7 +58,7 @@ class _NERPipe(Pipe): "[...]", "[...]" :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]在传入DataBundle基础上原位修改。 - :return DataBundle: + :return: 处理后的 ``data_bundle`` """ # 转换tag for name, dataset in data_bundle.iter_datasets(): @@ -79,10 +77,14 @@ class _NERPipe(Pipe): class Conll2003NERPipe(_NERPipe): r""" - Conll2003的NER任务的处理Pipe, 该Pipe会(1)复制raw_words列,并命名为words; (2)在words, target列建立词表 - (创建 :class:`fastNLP.Vocabulary` 对象,所以在返回的DataBundle中将有两个Vocabulary); (3)将words,target列根据相应的 - Vocabulary转换为index。 - 经过该Pipe过后,DataSet中的内容如下所示 + **Conll2003** 的 **NER** 任务的处理 **Pipe** , 该Pipe会: + + 1. 复制 ``raw_words`` 列,并命名为 ``words`` ; + 2. 在 ``words`` , ``target`` 列建立词表,即创建 :class:`~fastNLP.core.Vocabulary` 对象,所以在返回的 + :class:`~fastNLP.io.DataBundle` 中将有两个 ``Vocabulary`` ; + 3. 将 ``words`` , ``target`` 列根据相应的词表转换为 index。 + + 处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: Following is a demo layout of DataSet returned by Conll2003Loader :header: "raw_words", "target", "words", "seq_len" @@ -91,27 +93,21 @@ class Conll2003NERPipe(_NERPipe): "[AL-AIN, United, Arab, ...]", "[3, 4,...]", "[4, 5, 6,...]", 6 "[...]", "[...]", "[...]", . - raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 - target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | True | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ + ``raw_words`` 列为 :class:`List` [ :class:`str` ], 是未转换的原始数据; ``words`` 列为 :class:`List` [ :class:`int` ], + 是转换为 index 的输入数据; ``target`` 列是 :class:`List` [ :class:`int` ] ,是转换为 index 的 target。返回的 :class:`~fastNLP.core.DataSet` + 中被设置为 input 有 ``words`` , ``target``, ``seq_len``;target 有 ``target`` 。 + :param encoding_type: ``target`` 列使用什么类型的 encoding 方式,支持 ``['bioes', 'bio']`` 两种。 + :param lower: 是否将 ``words`` 小写化后再建立词表,绝大多数情况都不需要设置为 ``True`` 。 + :param num_proc: 处理数据时使用的进程数目。 """ def process_from_file(self, paths) -> DataBundle: r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param paths: 支持路径类型参见 :class:`fastNLP.io.loader.ConllLoader` 的load函数。 - :return: DataBundle + :param paths: + :return: """ # 读取数据 data_bundle = Conll2003NERLoader().load(paths) @@ -122,7 +118,7 @@ class Conll2003NERPipe(_NERPipe): class Conll2003Pipe(Pipe): r""" - 经过该Pipe后,DataSet中的内容如下 + 处理 **Conll2003** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_words" , "pos", "chunk", "ner", "words", "seq_len" @@ -131,27 +127,14 @@ class Conll2003Pipe(Pipe): "[AL-AIN, United, Arab, ...]", "[1, 2...]", "[3, 4...]", "[3, 4...]", "[4, 5, 6,...]", 6 "[...]", "[...]", "[...]", "[...]", "[...]", . - 其中words, seq_len是input; pos, chunk, ner, seq_len是target - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+-------+-------+-------+-------+---------+ - | field_names | raw_words | pos | chunk | ner | words | seq_len | - +-------------+-----------+-------+-------+-------+-------+---------+ - | is_input | False | False | False | False | True | True | - | is_target | False | True | True | True | False | True | - | ignore_type | | False | False | False | False | False | - | pad_value | | 0 | 0 | 0 | 0 | 0 | - +-------------+-----------+-------+-------+-------+-------+---------+ - + 其中``words``, ``seq_len`` 是 input; ``pos``, ``chunk``, ``ner``, ``seq_len`` 是 target + :param chunk_encoding_type: ``chunk`` 列使用什么类型的 encoding 方式,支持 ``['bioes', 'bio']`` 两种。 + :param ner_encoding_type: ``ner`` 列使用什么类型的 encoding 方式,支持 ``['bioes', 'bio']`` 两种。 + :param lower: 是否将 ``words`` 小写化后再建立词表,绝大多数情况都不需要设置为 ``True`` 。 + :param num_proc: 处理数据时使用的进程数目。 """ - def __init__(self, chunk_encoding_type='bioes', ner_encoding_type='bioes', lower: bool = False, num_proc: int = 0): - r""" - - :param str chunk_encoding_type: 支持bioes, bio。 - :param str ner_encoding_type: 支持bioes, bio。 - :param bool lower: 是否将words列小写化后再建立词表 - """ + def __init__(self, chunk_encoding_type: str='bioes', ner_encoding_type: str='bioes', lower: bool = False, num_proc: int = 0): if chunk_encoding_type == 'bio': self.chunk_convert_tag = iob2 elif chunk_encoding_type == 'bioes': @@ -175,7 +158,7 @@ class Conll2003Pipe(Pipe): def process(self, data_bundle) -> DataBundle: r""" - 输入的DataSet应该类似于如下的形式 + 输入的 `~fastNLP.core.DataSet` 应该类似于如下的形式: .. csv-table:: :header: "raw_words", "pos", "chunk", "ner" @@ -185,7 +168,7 @@ class Conll2003Pipe(Pipe): "[...]", "[...]", "[...]", "[...]", . :param data_bundle: - :return: 传入的DataBundle + :return: 处理后的 ``data_bundle`` """ # 转换tag for name, dataset in data_bundle.datasets.items(): @@ -210,6 +193,7 @@ class Conll2003Pipe(Pipe): def process_from_file(self, paths): r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` :param paths: :return: @@ -220,7 +204,7 @@ class Conll2003Pipe(Pipe): class OntoNotesNERPipe(_NERPipe): r""" - 处理OntoNotes的NER数据,处理之后DataSet中的field情况为 + 处理 **OntoNotes** 的 **NER** 数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_words", "target", "words", "seq_len" @@ -229,23 +213,19 @@ class OntoNotesNERPipe(_NERPipe): "[AL-AIN, United, Arab, ...]", "[3, 4]", "[4, 5, 6,...]", 6 "[...]", "[...]", "[...]", . - raw_words列为List[str], 是未转换的原始数据; words列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 - target。返回的DataSet中被设置为input有words, target, seq_len; 设置为target有target。 - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_words | target | words | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | True | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ + ``raw_words`` 列为 :class:`List` [ :class:`str` ], 是未转换的原始数据; ``words`` 列为 :class:`List` [ :class:`int` ], + 是转换为 index 的输入数据; ``target`` 列是 :class:`List` [ :class:`int` ] ,是转换为 index 的 target。返回的 :class:`~fastNLP.core.DataSet` + 中被设置为 input 有 ``words`` , ``target``, ``seq_len``;target 有 ``target`` 。 """ def process_from_file(self, paths): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = OntoNotesNERLoader().load(paths) return self.process(data_bundle) @@ -301,7 +281,7 @@ class _CNNERPipe(Pipe): 是转换为index的target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 :param ~fastNLP.DataBundle data_bundle: 传入的DataBundle中的DataSet必须包含raw_words和ner两个field,且两个field的内容均为List[str]。在传入DataBundle基础上原位修改。 - :return: DataBundle + :return: 处理后的 ``data_bundle`` """ # 转换tag for name, dataset in data_bundle.datasets.items(): @@ -338,7 +318,7 @@ class _CNNERPipe(Pipe): class MsraNERPipe(_CNNERPipe): r""" - 处理MSRA-NER的数据,处理之后的DataSet的field情况为 + 处理 **MSRA-NER** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_chars", "target", "chars", "seq_len" @@ -347,30 +327,34 @@ class MsraNERPipe(_CNNERPipe): "[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 "[...]", "[...]", "[...]", . - raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 - target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_chars | target | chars | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | True | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + ``raw_chars`` 列为 :class:`List` [ :class:`str` ], 是未转换的原始数据; ``chars`` 列为 :class:`List` [ :class:`int` ], + 是转换为 index 的输入数据; ``target`` 列是 :class:`List` [ :class:`int` ] ,是转换为 index 的 target。返回的 :class:`~fastNLP.core.DataSet` + 中被设置为 input 有 ``chars`` , ``target``, ``seq_len``;target 有 ``target`` 。 + + :param encoding_type: ``target`` 列使用什么类型的 encoding 方式,支持 ``['bioes', 'bio']`` 两种。 + :param bigrams: 是否增加一列 ``bigrams`` 。 ``bigrams`` 会对原文进行如下转化: ``['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]`` 。如果 + 设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``bigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('bigrams')`` 获取。 + :param trigrams: 是否增加一列 ``trigrams`` 。 ``trigrams`` 会对原文进行如下转化 ``['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]`` 。 + 如果设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``trigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('trigrams')`` 获取。 + :param num_proc: 处理数据时使用的进程数目。 """ def process_from_file(self, paths=None) -> DataBundle: + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = MsraNERLoader().load(paths) return self.process(data_bundle) class PeopleDailyPipe(_CNNERPipe): r""" - 处理people daily的ner的数据,处理之后的DataSet的field情况为 + 处理 **People's Daily NER** 的 **ner** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_chars", "target", "chars", "seq_len" @@ -379,30 +363,34 @@ class PeopleDailyPipe(_CNNERPipe): "[青, 岛, 海, 牛, 队, 和, ...]", "[1, 2, 3, ...]", "[10, 21, ....]", 21 "[...]", "[...]", "[...]", . - raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 - target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_chars | target | chars | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | True | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + ``raw_chars`` 列为 :class:`List` [ :class:`str` ], 是未转换的原始数据; ``chars`` 列为 :class:`List` [ :class:`int` ], + 是转换为 index 的输入数据; ``target`` 列是 :class:`List` [ :class:`int` ] ,是转换为 index 的 target。返回的 :class:`~fastNLP.core.DataSet` + 中被设置为 input 有 ``chars`` , ``target``, ``seq_len``;target 有 ``target`` 。 + + :param encoding_type: ``target`` 列使用什么类型的 encoding 方式,支持 ``['bioes', 'bio']`` 两种。 + :param bigrams: 是否增加一列 ``bigrams`` 。 ``bigrams`` 会对原文进行如下转化: ``['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]`` 。如果 + 设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``bigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('bigrams')`` 获取。 + :param trigrams: 是否增加一列 ``trigrams`` 。 ``trigrams`` 会对原文进行如下转化 ``['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]`` 。 + 如果设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``trigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('trigrams')`` 获取。 + :param num_proc: 处理数据时使用的进程数目。 """ def process_from_file(self, paths=None) -> DataBundle: + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = PeopleDailyNERLoader().load(paths) return self.process(data_bundle) class WeiboNERPipe(_CNNERPipe): r""" - 处理weibo的ner的数据,处理之后的DataSet的field情况为 + 处理 **Weibo** 的 **BER** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_chars", "chars", "target", "seq_len" @@ -411,22 +399,26 @@ class WeiboNERPipe(_CNNERPipe): "['心']", "[0]", "[41]", 1 "[...]", "[...]", "[...]", . - raw_chars列为List[str], 是未转换的原始数据; chars列为List[int],是转换为index的输入数据; target列是List[int],是转换为index的 - target。返回的DataSet中被设置为input有chars, target, seq_len; 设置为target有target。 - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+--------+-------+---------+ - | field_names | raw_chars | target | chars | seq_len | - +-------------+-----------+--------+-------+---------+ - | is_input | False | True | True | True | - | is_target | False | True | False | True | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+--------+-------+---------+ - + ``raw_chars`` 列为 :class:`List` [ :class:`str` ], 是未转换的原始数据; ``chars`` 列为 :class:`List` [ :class:`int` ], + 是转换为 index 的输入数据; ``target`` 列是 :class:`List` [ :class:`int` ] ,是转换为 index 的 target。返回的 :class:`~fastNLP.core.DataSet` + 中被设置为 input 有 ``chars`` , ``target``, ``seq_len``;target 有 ``target`` 。 + + :param encoding_type: ``target`` 列使用什么类型的 encoding 方式,支持 ``['bioes', 'bio']`` 两种。 + :param bigrams: 是否增加一列 ``bigrams`` 。 ``bigrams`` 会对原文进行如下转化: ``['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]`` 。如果 + 设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``bigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('bigrams')`` 获取。 + :param trigrams: 是否增加一列 ``trigrams`` 。 ``trigrams`` 会对原文进行如下转化 ``['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]`` 。 + 如果设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``trigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('trigrams')`` 获取。 + :param num_proc: 处理数据时使用的进程数目。 """ def process_from_file(self, paths=None) -> DataBundle: + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = WeiboNERLoader().load(paths) return self.process(data_bundle) diff --git a/fastNLP/io/pipe/construct_graph.py b/fastNLP/io/pipe/construct_graph.py index 1448765e..26846002 100644 --- a/fastNLP/io/pipe/construct_graph.py +++ b/fastNLP/io/pipe/construct_graph.py @@ -164,7 +164,7 @@ class GraphBuilderBase: def build_graph_from_file(self, path: str): r""" - 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load` :param path: :return: scipy_sparse_matrix @@ -173,14 +173,20 @@ class GraphBuilderBase: class MRPmiGraphPipe(GraphBuilderBase): + """ + 构建 **MR** 数据集的 **Graph** 。 + :param graph_type: + :param widow_size: + :param threshold: + """ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) def build_graph(self, data_bundle: DataBundle): r""" - params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. - return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + :param data_bundle: 需要处理的 :class:`~fastNLP.io.DataBundle` 对象。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 """ self._get_doc_edge(data_bundle) self._get_word_edge() @@ -190,19 +196,31 @@ class MRPmiGraphPipe(GraphBuilderBase): self.tr_doc_index, self.dev_doc_index, self.te_doc_index) def build_graph_from_file(self, path: str): + r""" + 传入文件路径,生成处理好的 ``scipy_sparse_matrix`` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param path: 数据集的路径。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 + """ data_bundle = MRLoader().load(path) return self.build_graph(data_bundle) class R8PmiGraphPipe(GraphBuilderBase): + """ + 构建 **R8** 数据集的 **Graph** 。 + :param graph_type: + :param widow_size: + :param threshold: + """ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) def build_graph(self, data_bundle: DataBundle): r""" - params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. - return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + :param data_bundle: 需要处理的 :class:`~fastNLP.io.DataBundle` 对象。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 """ self._get_doc_edge(data_bundle) self._get_word_edge() @@ -212,19 +230,31 @@ class R8PmiGraphPipe(GraphBuilderBase): self.tr_doc_index, self.dev_doc_index, self.te_doc_index) def build_graph_from_file(self, path: str): + r""" + 传入文件路径,生成处理好的 ``scipy_sparse_matrix`` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param path: 数据集的路径。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 + """ data_bundle = R8Loader().load(path) return self.build_graph(data_bundle) class R52PmiGraphPipe(GraphBuilderBase): + """ + 构建 **R52** 数据集的 **Graph** 。 + :param graph_type: + :param widow_size: + :param threshold: + """ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) def build_graph(self, data_bundle: DataBundle): r""" - params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. - return 返回csr类型的稀疏矩阵;训练集,验证集,测试集,在图中的index. + :param data_bundle: 需要处理的 :class:`~fastNLP.io.DataBundle` 对象。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 """ self._get_doc_edge(data_bundle) self._get_word_edge() @@ -234,19 +264,31 @@ class R52PmiGraphPipe(GraphBuilderBase): self.tr_doc_index, self.dev_doc_index, self.te_doc_index) def build_graph_from_file(self, path: str): + r""" + 传入文件路径,生成处理好的 ``scipy_sparse_matrix`` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param path: 数据集的路径。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 + """ data_bundle = R52Loader().load(path) return self.build_graph(data_bundle) class OhsumedPmiGraphPipe(GraphBuilderBase): + """ + 构建 **Ohsuned** 数据集的 **Graph** 。 + :param graph_type: + :param widow_size: + :param threshold: + """ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) def build_graph(self, data_bundle: DataBundle): r""" - params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. - return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + :param data_bundle: 需要处理的 :class:`~fastNLP.io.DataBundle` 对象。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 """ self._get_doc_edge(data_bundle) self._get_word_edge() @@ -256,19 +298,31 @@ class OhsumedPmiGraphPipe(GraphBuilderBase): self.tr_doc_index, self.dev_doc_index, self.te_doc_index) def build_graph_from_file(self, path: str): + r""" + 传入文件路径,生成处理好的 ``scipy_sparse_matrix`` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param path: 数据集的路径。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 + """ data_bundle = OhsumedLoader().load(path) return self.build_graph(data_bundle) class NG20PmiGraphPipe(GraphBuilderBase): + """ + 构建 **NG20** 数据集的 **Graph** 。 + :param graph_type: + :param widow_size: + :param threshold: + """ def __init__(self, graph_type='pmi', widow_size=10, threshold=0.): super().__init__(graph_type=graph_type, widow_size=widow_size, threshold=threshold) def build_graph(self, data_bundle: DataBundle): r""" - params: ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象. - return 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + :param data_bundle: 需要处理的 :class:`~fastNLP.io.DataBundle` 对象。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 """ self._get_doc_edge(data_bundle) self._get_word_edge() @@ -279,8 +333,10 @@ class NG20PmiGraphPipe(GraphBuilderBase): def build_graph_from_file(self, path: str): r""" - param: path->数据集的路径. - return: 返回csr类型的稀疏矩阵图;训练集,验证集,测试集,在图中的index. + 传入文件路径,生成处理好的 ``scipy_sparse_matrix`` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param path: 数据集的路径。 + :return: 返回 ``csr`` 类型的稀疏矩阵图;包含训练集,验证集,测试集,在图中的 index 。 """ data_bundle = NG20Loader().load(path) return self.build_graph(data_bundle) diff --git a/fastNLP/io/pipe/cws.py b/fastNLP/io/pipe/cws.py index 2937f147..27068348 100644 --- a/fastNLP/io/pipe/cws.py +++ b/fastNLP/io/pipe/cws.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "CWSPipe" ] @@ -135,7 +133,7 @@ def _find_and_replace_digit_spans(line): class CWSPipe(Pipe): r""" - 对CWS数据进行预处理, 处理之后的数据,具备以下的结构 + 对 **CWS** 数据进行处理,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_words", "chars", "target", "seq_len" @@ -144,30 +142,21 @@ class CWSPipe(Pipe): "2001年 新年 钟声...", "[8, 9, 9, 7, ...]", "[0, 1, 1, 1, 2...]", 20 "...", "[...]","[...]", . - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+-----------+-------+--------+---------+ - | field_names | raw_words | chars | target | seq_len | - +-------------+-----------+-------+--------+---------+ - | is_input | False | True | True | True | - | is_target | False | False | True | True | - | ignore_type | | False | False | False | - | pad_value | | 0 | 0 | 0 | - +-------------+-----------+-------+--------+---------+ - + :param dataset_name: data 的名称,支持 ``['pku', 'msra', 'cityu'(繁体), 'as'(繁体), None]`` + :param encoding_type: ``target`` 列使用什么类型的 encoding 方式,支持 ``['bmes', 'segapp']`` 两种。``"我 来自 复旦大学..."`` 这句话 ``bmes``的 + tag为 ``[S, B, E, B, M, M, E...]`` ; ``segapp`` 的 tag 为 ``[seg, app, seg, app, app, app, seg, ...]`` 。 + :param replace_num_alpha: 是否将数字和字母用特殊字符替换。 + :param bigrams: 是否增加一列 ``bigrams`` 。 ``bigrams`` 会对原文进行如下转化: ``['复', '旦', '大', '学', ...]->["复旦", "旦大", ...]`` 。如果 + 设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``bigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('bigrams')`` 获取。 + :param trigrams: 是否增加一列 ``trigrams`` 。 ``trigrams`` 会对原文进行如下转化 ``['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...]`` 。 + 如果设置为 ``True`` ,返回的 :class:`~fastNLP.core.DataSet` 将有一列名为 ``trigrams`` ,且已经转换为了 index 并设置为 input,对应的词表可以通过 + ``data_bundle.get_vocab('trigrams')`` 获取。 + :param num_proc: 处理数据时使用的进程数目。 """ - def __init__(self, dataset_name=None, encoding_type='bmes', replace_num_alpha=True, - bigrams=False, trigrams=False, num_proc: int = 0): - r""" - - :param str,None dataset_name: 支持'pku', 'msra', 'cityu', 'as', None - :param str encoding_type: 可以选择'bmes', 'segapp'两种。"我 来自 复旦大学...", bmes的tag为[S, B, E, B, M, M, E...]; segapp - 的tag为[seg, app, seg, app, app, app, seg, ...] - :param bool replace_num_alpha: 是否将数字和字母用特殊字符替换。 - :param bool bigrams: 是否增加一列bigram. bigram的构成是['复', '旦', '大', '学', ...]->["复旦", "旦大", ...] - :param bool trigrams: 是否增加一列trigram. trigram的构成是 ['复', '旦', '大', '学', ...]->["复旦大", "旦大学", ...] - """ + def __init__(self, dataset_name: str=None, encoding_type: str='bmes', replace_num_alpha: bool=True, + bigrams: bool=False, trigrams: bool=False, num_proc: int = 0): if encoding_type == 'bmes': self.word_lens_to_tags = _word_lens_to_bmes else: @@ -220,7 +209,7 @@ class CWSPipe(Pipe): def process(self, data_bundle: DataBundle) -> DataBundle: r""" - 可以处理的DataSet需要包含raw_words列 + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该包含 ``raw_words`` : .. csv-table:: :header: "raw_words" @@ -230,7 +219,7 @@ class CWSPipe(Pipe): "..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ data_bundle.copy_field('raw_words', 'chars') @@ -276,8 +265,9 @@ class CWSPipe(Pipe): def process_from_file(self, paths=None) -> DataBundle: r""" - - :param str paths: + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: :return: """ if self.dataset_name is None and paths is None: diff --git a/fastNLP/io/pipe/matching.py b/fastNLP/io/pipe/matching.py index 5b9981c2..a9abd943 100644 --- a/fastNLP/io/pipe/matching.py +++ b/fastNLP/io/pipe/matching.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "MatchingBertPipe", "RTEBertPipe", @@ -21,7 +19,7 @@ __all__ = [ "BQCorpusPipe", "RenamePipe", "GranularizePipe", - "MachingTruncatePipe", + "TruncateBertPipe", ] from functools import partial @@ -31,14 +29,13 @@ from .utils import get_tokenizer from ..data_bundle import DataBundle from ..loader.matching import SNLILoader, MNLILoader, QNLILoader, RTELoader, QuoraLoader, BQCorpusLoader, CNXNLILoader, \ LCQMCLoader -# from ...core._logger import log # from ...core.const import Const from ...core.vocabulary import Vocabulary class MatchingBertPipe(Pipe): r""" - Matching任务的Bert pipe,输出的DataSet将包含以下的field + **Matching** 任务的 Bert pipe ,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_words1", "raw_words2", "target", "words", "seq_len" @@ -47,29 +44,17 @@ class MatchingBertPipe(Pipe): "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", 5 "...", "...", ., "[...]", . - words列是将raw_words1(即premise), raw_words2(即hypothesis)使用"[SEP]"链接起来转换为index的。 - words列被设置为input,target列被设置为target和input(设置为input以方便在forward函数中计算loss, - 如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数的形参名进行传参). - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+------------+------------+--------+-------+---------+ - | field_names | raw_words1 | raw_words2 | target | words | seq_len | - +-------------+------------+------------+--------+-------+---------+ - | is_input | False | False | False | True | True | - | is_target | False | False | True | False | False | - | ignore_type | | | False | False | False | - | pad_value | | | 0 | 0 | 0 | - +-------------+------------+------------+--------+-------+---------+ + ``words`` 列是将 ``raw_words1`` (即 ``premise`` ), ``raw_words2`` (即 ``hypothesis`` )使用 ``[SEP]`` + 链接起来转换为 index 的。``words`` 列被设置为 input, ``target`` 列被设置为 target 和 input (设置为 input 以 + 方便在 :func:`forward` 函数中计算 loss,如果不在也不影响, **fastNLP** 将根据 :func:`forward` 函数的形参名进行 + 传参)。 + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): - r""" - - :param bool lower: 是否将word小写化。 - :param str tokenizer: 使用什么tokenizer来将句子切分为words. 支持spacy, raw两种。raw即使用空格拆分。 - """ super().__init__() self.lower = bool(lower) @@ -89,9 +74,9 @@ class MatchingBertPipe(Pipe): dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) return data_bundle - def process(self, data_bundle): + def process(self, data_bundle: DataBundle): r""" - 输入的data_bundle中的dataset需要具有以下结构: + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备以下结构: .. csv-table:: :header: "raw_words1", "raw_words2", "target" @@ -100,7 +85,7 @@ class MatchingBertPipe(Pipe): "...","..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ for dataset in data_bundle.datasets.values(): if dataset.has_field('target'): @@ -164,38 +149,103 @@ class MatchingBertPipe(Pipe): class RTEBertPipe(MatchingBertPipe): + """ + 处理 **RTE** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = RTELoader().load(paths) return self.process(data_bundle) class SNLIBertPipe(MatchingBertPipe): + """ + 处理 **SNLI** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = SNLILoader().load(paths) return self.process(data_bundle) class QuoraBertPipe(MatchingBertPipe): + """ + 处理 **Quora** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = QuoraLoader().load(paths) return self.process(data_bundle) class QNLIBertPipe(MatchingBertPipe): + """ + 处理 **QNNLI** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = QNLILoader().load(paths) return self.process(data_bundle) class MNLIBertPipe(MatchingBertPipe): + """ + 处理 **MNLI** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = MNLILoader().load(paths) return self.process(data_bundle) class MatchingPipe(Pipe): r""" - Matching任务的Pipe。输出的DataSet将包含以下的field + **Matching** 任务的 Pipe,处理之后 :class:`~fastNLP.core.DataSet` 中的内容如下: .. csv-table:: :header: "raw_words1", "raw_words2", "target", "words1", "words2", "seq_len1", "seq_len2" @@ -204,21 +254,14 @@ class MatchingPipe(Pipe): "This site includes a...", "The Government Executive...", 0, "[11, 12, 13,...]", "[2, 7, ...]", 6, 7 "...", "...", ., "[...]", "[...]", ., . - words1是premise,words2是hypothesis。其中words1,words2,seq_len1,seq_len2被设置为input;target被设置为target - 和input(设置为input以方便在forward函数中计算loss,如果不在forward函数中计算loss也不影响,fastNLP将根据forward函数 - 的形参名进行传参)。 - - dataset的print_field_meta()函数输出的各个field的被设置成input和target的情况为:: - - +-------------+------------+------------+--------+--------+--------+----------+----------+ - | field_names | raw_words1 | raw_words2 | target | words1 | words2 | seq_len1 | seq_len2 | - +-------------+------------+------------+--------+--------+--------+----------+----------+ - | is_input | False | False | False | True | True | True | True | - | is_target | False | False | True | False | False | False | False | - | ignore_type | | | False | False | False | False | False | - | pad_value | | | 0 | 0 | 0 | 0 | 0 | - +-------------+------------+------------+--------+--------+--------+----------+----------+ + ``words1`` 是 ``premise`` ,``words2`` 是 ``hypothesis`` 。其中 ``words1`` , ``words2`` , ``seq_len1``, ``seq_len2`` + 被设置为 input; ``target`` 列被设置为 target 和 input (设置为 input 以 + 方便在 :func:`forward` 函数中计算 loss,如果不在也不影响, **fastNLP** 将根据 :func:`forward` 函数的形参名进行 + 传参)。 + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 """ def __init__(self, lower=False, tokenizer: str = 'raw', num_proc: int = 0): @@ -246,9 +289,9 @@ class MatchingPipe(Pipe): dataset.apply_field(self.tokenizer, field_name=field_name, new_field_name=new_field_name, num_proc=self.num_proc) return data_bundle - def process(self, data_bundle): + def process(self, data_bundle: DataBundle): r""" - 接受的DataBundle中的DataSet应该具有以下的field, target列可以没有 + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备以下结构,可以没有 ``target`` 列: .. csv-table:: :header: "raw_words1", "raw_words2", "target" @@ -257,8 +300,8 @@ class MatchingPipe(Pipe): "This site includes a...", "The Government Executive...", "not_entailment" "...", "..." - :param ~fastNLP.DataBundle data_bundle: 通过loader读取得到的data_bundle,里面包含了数据集的原始数据内容 - :return: data_bundle + :param data_bundle: + :return: 处理后的 ``data_bundle`` """ data_bundle = self._tokenize(data_bundle, ['raw_words1', 'raw_words2'], ['words1', 'words2']) @@ -307,40 +350,117 @@ class MatchingPipe(Pipe): class RTEPipe(MatchingPipe): + """ + 处理 **RTE** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = RTELoader().load(paths) return self.process(data_bundle) class SNLIPipe(MatchingPipe): + """ + 处理 **SNLI** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = SNLILoader().load(paths) return self.process(data_bundle) class QuoraPipe(MatchingPipe): + """ + 处理 **Quora** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = QuoraLoader().load(paths) return self.process(data_bundle) class QNLIPipe(MatchingPipe): + """ + 处理 **QNLI** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = QNLILoader().load(paths) return self.process(data_bundle) class MNLIPipe(MatchingPipe): + """ + 处理 **MNLI** 数据。 + + :param lower: 是否对输入进行小写化。 + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['spacy', 'raw']`` 。``'raw'`` 表示使用空格作为切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = MNLILoader().load(paths) return self.process(data_bundle) class LCQMCPipe(MatchingPipe): - def __init__(self, tokenizer='cn=char', num_proc=0): + """ + 处理 **LCQMC** 数据。 + + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['cn-char']`` ,按字分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ + def __init__(self, tokenizer='cn-char', num_proc=0): super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = LCQMCLoader().load(paths) data_bundle = RenamePipe().process(data_bundle) data_bundle = self.process(data_bundle) @@ -349,10 +469,22 @@ class LCQMCPipe(MatchingPipe): class CNXNLIPipe(MatchingPipe): + """ + 处理 **XNLI Chinese** 数据。 + + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['cn-char']`` ,按字分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def __init__(self, tokenizer='cn-char', num_proc=0): super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = CNXNLILoader().load(paths) data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe().process(data_bundle) # 使中文数据的field @@ -362,10 +494,22 @@ class CNXNLIPipe(MatchingPipe): class BQCorpusPipe(MatchingPipe): + """ + 处理 **BQ Corpus** 数据。 + + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['cn-char']`` ,按字分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def __init__(self, tokenizer='cn-char', num_proc=0): super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = BQCorpusLoader().load(paths) data_bundle = RenamePipe().process(data_bundle) data_bundle = self.process(data_bundle) @@ -374,12 +518,23 @@ class BQCorpusPipe(MatchingPipe): class RenamePipe(Pipe): + """ + 重命名数据集的 Pipe ,经过处理后会将数据集中的 ``chars``, ``raw_chars1`` 等列重命名为 ``words``, + ``raw_words1``,反之亦然。 + + :param task: 任务类型,可选 ``['cn-nli', 'cn-nli-bert']`` 。 + :param num_proc: 处理数据时使用的进程数目。 + """ def __init__(self, task='cn-nli', num_proc=0): super().__init__() self.task = task self.num_proc = num_proc def process(self, data_bundle: DataBundle): # rename field name for Chinese Matching dataset + """ + :param data_bundle: + :return: 处理后的 ``data_bundle`` + """ if (self.task == 'cn-nli'): for name, dataset in data_bundle.datasets.items(): if (dataset.has_field('raw_chars1')): @@ -415,6 +570,16 @@ class RenamePipe(Pipe): class GranularizePipe(Pipe): + """ + 将数据集中 ``target`` 列中的 tag 按照一定的映射进行重命名,并丢弃不在映射中的 tag。 + + :param task: 任务类型,目前仅支持 ``['XNLI']``。 + + - ``'XNLI'`` -- 将 ``neutral``, ``entailment``, ``contradictory``, ``contradiction`` 分别 + 映射为 0, 1, 2, 3; + + :param num_proc: 处理数据时使用的进程数目。 + """ def __init__(self, task=None, num_proc=0): super().__init__() self.task = task @@ -437,6 +602,10 @@ class GranularizePipe(Pipe): return data_bundle def process(self, data_bundle: DataBundle): + """ + :param data_bundle: + :return: 处理后的 ``data_bundle`` + """ task_tag_dict = { 'XNLI': {'neutral': 0, 'entailment': 1, 'contradictory': 2, 'contradiction': 2} } @@ -446,22 +615,23 @@ class GranularizePipe(Pipe): raise RuntimeError(f"Only support {task_tag_dict.keys()} task_tag_map.") return data_bundle - -class MachingTruncatePipe(Pipe): # truncate sentence for bert, modify seq_len - def __init__(self): - super().__init__() - - def process(self, data_bundle: DataBundle): - for name, dataset in data_bundle.datasets.items(): - pass - return None - - class LCQMCBertPipe(MatchingBertPipe): - def __init__(self, tokenizer='cn=char', num_proc=0): + """ + 处理 **LCQMC** 数据 + + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['cn-char']`` ,按字分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ + def __init__(self, tokenizer='cn-char', num_proc=0): super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = LCQMCLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) @@ -471,10 +641,22 @@ class LCQMCBertPipe(MatchingBertPipe): class BQCorpusBertPipe(MatchingBertPipe): + """ + 处理 **BQ Corpus** 数据。 + + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['cn-char']`` ,按字分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def __init__(self, tokenizer='cn-char', num_proc=0): super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = BQCorpusLoader().load(paths) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) data_bundle = self.process(data_bundle) @@ -484,10 +666,22 @@ class BQCorpusBertPipe(MatchingBertPipe): class CNXNLIBertPipe(MatchingBertPipe): + """ + 处理 **XNLI Chinese** 数据。 + + :param tokenizer: 使用哪种 tokenize 方式将数据切成单词。支持 ``['cn-char']`` ,按字分词。 + :param num_proc: 处理数据时使用的进程数目。 + """ def __init__(self, tokenizer='cn-char', num_proc=0): super().__init__(tokenizer=tokenizer, num_proc=num_proc) def process_from_file(self, paths=None): + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = CNXNLILoader().load(paths) data_bundle = GranularizePipe(task='XNLI').process(data_bundle) data_bundle = RenamePipe(task='cn-nli-bert').process(data_bundle) @@ -498,6 +692,13 @@ class CNXNLIBertPipe(MatchingBertPipe): class TruncateBertPipe(Pipe): + """ + 对数据进行截断的 **Pipe** 。该 **Pipe** 将会寻找每条数据中的第一个分隔符 ``[SEP]`` ,对其前后的数据分别进行截断。 + 对于中文任务会将前后的文本分别截断至长度 **250** ,对于英文任务会分别截断至 **215** 。 + + :param task: 任务类型,可选 ``['cn', 'en']`` ,分别表示 **中文任务** 和 **英文任务** 。 + :param num_proc: 处理数据时使用的进程数目。 + """ def __init__(self, task='cn', num_proc=0): super().__init__() self.task = task @@ -522,6 +723,10 @@ class TruncateBertPipe(Pipe): return words_before_sep + words_after_sep def process(self, data_bundle: DataBundle) -> DataBundle: + """ + :param data_bundle: + :return: 处理后的 ``data_bundle`` + """ for name in data_bundle.datasets.keys(): dataset = data_bundle.get_dataset(name) sep_index_vocab = data_bundle.get_vocab('words').to_index('[SEP]') diff --git a/fastNLP/io/pipe/pipe.py b/fastNLP/io/pipe/pipe.py index 4916bf09..0ee43ae8 100644 --- a/fastNLP/io/pipe/pipe.py +++ b/fastNLP/io/pipe/pipe.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "Pipe", ] @@ -9,33 +7,38 @@ from fastNLP.io.data_bundle import DataBundle class Pipe: r""" - Pipe是fastNLP中用于处理DataBundle的类,但实际是处理DataBundle中的DataSet。所有Pipe都会在其process()函数的文档中指出该Pipe可处理的DataSet应该具备怎样的格式;在Pipe - 文档中说明该Pipe返回后DataSet的格式以及其field的信息;以及新增的Vocabulary的信息。 + :class:`Pipe` 是 **fastNLP** 中用于处理 :class:`~fastNLP.io.DataBundle` 的类,但实际是处理其中的 :class:`~fastNLP.core.DataSet` 。 + 所有 ``Pipe`` 都会在其 :meth:`process` 函数的文档中指出该 ``Pipe`` 可处理的 :class:`~fastNLP.core.DataSet` 应该具备怎样的格式;在 + ``Pipe`` 文档中说明该 ``Pipe`` 返回后 :class:`~fastNLP.core.DataSet` 的格式以及其 field 的信息;以及新增的 :class:`~fastNLP.core.Vocabulary` + 的信息。 - 一般情况下Pipe处理包含以下的几个过程,(1)将raw_words或raw_chars进行tokenize以切分成不同的词或字; - (2) 再建立词或字的 :class:`~fastNLP.Vocabulary` , 并将词或字转换为index; (3)将target列建立词表并将target列转为index; + 一般情况下 **Pipe** 处理包含以下的几个过程: + + 1. 将 ``raw_words`` 或 ``raw_chars`` 进行 tokenize 以切分成不同的词或字; + 2. 建立词或字的 :class:`~fastNLP.core.Vocabulary` ,并将词或字转换为 index; + 3. 将 ``target`` 列建立词表并将 ``target`` 列转为 index; - Pipe中提供了两个方法 + **Pipe** 中提供了两个方法: - -process()函数,输入为DataBundle - -process_from_file()函数,输入为对应Loader的load函数可接受的类型。 + - :meth:`process` 函数,输入为 :class:`~fastNLP.io.DataBundle` + - :meth:`process_from_file` 函数,输入为对应 :meth:`fastNLP.io.Loader.load` 函数可接受的类型。 """ def process(self, data_bundle: DataBundle) -> DataBundle: r""" - 对输入的DataBundle进行处理,然后返回该DataBundle。 + 对输入的 ``data_bundle`` 进行处理,然后返回该 ``data_bundle`` - :param ~fastNLP.DataBundle data_bundle: 需要处理的DataBundle对象 - :return: DataBundle + :param data_bundle: + :return: 处理后的 ``data_bundle`` """ raise NotImplementedError def process_from_file(self, paths: str) -> DataBundle: r""" - 传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()` + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` - :param str paths: - :return: DataBundle + :param paths: + :return: """ raise NotImplementedError diff --git a/fastNLP/io/pipe/qa.py b/fastNLP/io/pipe/qa.py index 23fe1367..0d646263 100644 --- a/fastNLP/io/pipe/qa.py +++ b/fastNLP/io/pipe/qa.py @@ -1,5 +1,5 @@ r""" -本文件中的Pipe主要用于处理问答任务的数据。 +本文件中的 **Pipe** 主要用于处理问答任务的数据。 """ @@ -78,32 +78,20 @@ def _concat_clip(data_bundle, max_len, concat_field_name='raw_chars'): class CMRC2018BertPipe(Pipe): r""" - 处理之后的DataSet将新增以下的field(传入的field仍然保留) + 处理 **CMRC2018** 的数据,处理之后 :class:`~fastNLP.core.DataSet` 中新增的内容如下(原有的 field 仍然保留): .. csv-table:: :header: "context_len", "raw_chars", "target_start", "target_end", "chars" - 492, ['范', '廷', '颂... ], 30, 34, "[21, 25, ...]" - 491, ['范', '廷', '颂... ], 41, 61, "[21, 25, ...]" - + 492, "['范', '廷', '颂... ]", 30, 34, "[21, 25, ...]" + 491, "['范', '廷', '颂... ]", 41, 61, "[21, 25, ...]" ".", "...", "...","...", "..." - raw_words列是context与question拼起来的结果(连接的地方加入了[SEP]),words是转为index的值, target_start为答案start的index,target_end为答案end的index - (闭区间);context_len指示的是words列中context的长度。 - - 其中各列的meta信息如下: - - .. code:: - - +-------------+-------------+-----------+--------------+------------+-------+---------+ - | field_names | context_len | raw_chars | target_start | target_end | chars | answers | - +-------------+-------------+-----------+--------------+------------+-------+---------| - | is_input | False | False | False | False | True | False | - | is_target | True | True | True | True | False | True | - | ignore_type | False | True | False | False | False | True | - | pad_value | 0 | 0 | 0 | 0 | 0 | 0 | - +-------------+-------------+-----------+--------------+------------+-------+---------+ - + ``raw_chars`` 列是 ``context`` 与 ``question`` 拼起来的结果(连接的地方加入了 ``[SEP]`` ), ``chars`` 是转为 + index 的值, ``target_start`` 为答案开始的位置, ``target_end`` 为答案结束的位置(闭区间); ``context_len`` + 指示的是 ``chars`` 列中 context 的长度。 + + :param max_len: """ def __init__(self, max_len=510): @@ -112,17 +100,17 @@ class CMRC2018BertPipe(Pipe): def process(self, data_bundle: DataBundle) -> DataBundle: r""" - 传入的DataSet应该具备以下的field + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该包含 ``raw_words`` : .. csv-table:: - :header:"title", "context", "question", "answers", "answer_starts", "id" + :header: "title", "context", "question", "answers", "answer_starts", "id" "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "范廷颂是什么时候被任为主教的?", ["1963年"], ["30"], "TRAIN_186_QUERY_0" "范廷颂", "范廷颂枢机(,),圣名保禄·若瑟()...", "1990年,范廷颂担任什么职务?", ["1990年被擢升为天..."], ["41"],"TRAIN_186_QUERY_1" "...", "...", "...","...", ".", "..." :param data_bundle: - :return: + :return: 处理后的 ``data_bundle`` """ data_bundle = _concat_clip(data_bundle, max_len=self.max_len, concat_field_name='raw_chars') @@ -138,5 +126,11 @@ class CMRC2018BertPipe(Pipe): return data_bundle def process_from_file(self, paths=None) -> DataBundle: + r""" + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: + """ data_bundle = CMRC2018Loader().load(paths) return self.process(data_bundle) diff --git a/fastNLP/io/pipe/summarization.py b/fastNLP/io/pipe/summarization.py index b413890b..dc27651c 100644 --- a/fastNLP/io/pipe/summarization.py +++ b/fastNLP/io/pipe/summarization.py @@ -1,4 +1,3 @@ -r"""undocumented""" import os import numpy as np from functools import partial @@ -20,21 +19,23 @@ TAG_UNK = "X" class ExtCNNDMPipe(Pipe): r""" - 对CNN/Daily Mail数据进行适用于extractive summarization task的预处理,预处理之后的数据,具备以下结构: + 对 **CNN/Daily Mail** 数据进行适用于 ``extractive summarization task`` 的预处理,预处理之后的数据具备以下结构: .. csv-table:: :header: "text", "summary", "label", "publication", "text_wd", "words", "seq_len", "target" - + + "['I got new tires from them and... ','...']", "['The new tires...','...']", "[0, 1]", "cnndm", "[['I','got',...'.'],...,['...']]", "[[54,89,...,5],...,[9,43,..,0]]", "[1,1,...,0]", "[0,1,...,0]" + "['Don't waste your time. We had two...','...']", "['Time is precious','...']", "[1]", "cnndm", "[['Don't','waste',...,'.'],...,['...']]", "[[5234,653,...,5],...,[87,234,..,0]]", "[1,1,...,0]", "[1,1,...,0]" + "['...']", "['...']", "[]", "cnndm", "[[''],...,['']]", "[[],...,[]]", "[]", "[]" + + :param vocab_size: 词表大小 + :param sent_max_len: 句子最大长度,不足的句子将 padding ,超出的将截断 + :param doc_max_timesteps: 文章最多句子个数,不足的将 padding,超出的将截断 + :param vocab_path: 外部词表路径 + :param domain: 是否需要建立 domain 词表 + :param num_proc: 处理数据时使用的进程数目。 """ - def __init__(self, vocab_size, sent_max_len, doc_max_timesteps, vocab_path=None, domain=False, num_proc=0): - r""" - - :param vocab_size: int, 词表大小 - :param sent_max_len: int, 句子最大长度,不足的句子将padding,超出的将截断 - :param doc_max_timesteps: int, 文章最多句子个数,不足的将padding,超出的将截断 - :param vocab_path: str, 外部词表路径 - :param domain: bool, 是否需要建立domain词表 - """ + def __init__(self, vocab_size: int, sent_max_len: int, doc_max_timesteps: int, vocab_path=None, domain=False, num_proc=0): self.vocab_size = vocab_size self.vocab_path = vocab_path self.sent_max_len = sent_max_len @@ -44,23 +45,24 @@ class ExtCNNDMPipe(Pipe): def process(self, data_bundle: DataBundle): r""" - 传入的DataSet应该具备如下的结构 + ``data_bunlde`` 中的 :class:`~fastNLP.core.DataSet` 应该具备以下结构: .. csv-table:: :header: "text", "summary", "label", "publication" - ["I got new tires from them and... ","..."], ["The new tires...","..."], [0, 1], "cnndm" - ["Don't waste your time. We had two...","..."], ["Time is precious","..."], [1], "cnndm" - ["..."], ["..."], [], "cnndm" + "['I got new tires from them and... ','...']", "['The new tires...','...']", "[0, 1]", "cnndm" + "['Don't waste your time. We had two...','...']", "['Time is precious','...']", "[1]", "cnndm" + "['...']", ['...']", "[]", "cnndm" :param data_bundle: - :return: 处理得到的数据包括 + :return: 处理后的 ``data_bundle``,新增以下列: + .. csv-table:: :header: "text_wd", "words", "seq_len", "target" - [["I","got",..."."],...,["..."]], [[54,89,...,5],...,[9,43,..,0]], [1,1,...,0], [0,1,...,0] - [["Don't","waste",...,"."],...,["..."]], [[5234,653,...,5],...,[87,234,..,0]], [1,1,...,0], [1,1,...,0] - [[""],...,[""]], [[],...,[]], [], [] + "[['I','got',...'.'],...,['...']]", "[[54,89,...,5],...,[9,43,..,0]]", "[1,1,...,0]", "[0,1,...,0]" + "[['Don't','waste',...,'.'],...,['...']]", "[[5234,653,...,5],...,[87,234,..,0]]", "[1,1,...,0]", "[1,1,...,0]" + "[[''],...,['']]", "[[],...,[]]", "[]", "[]" """ if self.vocab_path is None: @@ -117,8 +119,10 @@ class ExtCNNDMPipe(Pipe): def process_from_file(self, paths=None): r""" - :param paths: dict or string - :return: DataBundle + 传入文件路径,生成处理好的 :class:`~fastNLP.io.DataBundle` 对象。``paths`` 支持的路径形式可以参考 :meth:`fastNLP.io.Loader.load` + + :param paths: + :return: """ loader = ExtCNNDMLoader() if self.vocab_path is None: diff --git a/fastNLP/io/pipe/utils.py b/fastNLP/io/pipe/utils.py index c5c32d95..3f1b8563 100644 --- a/fastNLP/io/pipe/utils.py +++ b/fastNLP/io/pipe/utils.py @@ -1,5 +1,3 @@ -r"""undocumented""" - __all__ = [ "iob2", "iob2bioes", @@ -17,10 +15,10 @@ from pkg_resources import parse_version def iob2(tags: List[str]) -> List[str]: r""" - 检查数据是否是合法的IOB数据,如果是IOB1会被自动转换为IOB2。两种格式的区别见 + 检查数据是否是合法的 ``IOB`` 数据,如果是 ``IOB1`` 会被自动转换为 ``IOB2`` 。两种格式的区别见 https://datascience.stackexchange.com/questions/37824/difference-between-iob-and-iob2-format - :param tags: 需要转换的tags + :param tags: 需要转换的 tags """ for i, tag in enumerate(tags): if tag == "O": @@ -41,8 +39,9 @@ def iob2(tags: List[str]) -> List[str]: def iob2bioes(tags: List[str]) -> List[str]: r""" - 将iob的tag转换为bioes编码 - :param tags: + 将 ``iob`` 的 tag 转换为 ``bioes`` 编码 + + :param tags: 需要转换的 tags :return: """ new_tags = [] @@ -69,9 +68,10 @@ def iob2bioes(tags: List[str]) -> List[str]: def get_tokenizer(tokenize_method: str, lang='en'): r""" - :param str tokenize_method: 获取tokenzier方法 - :param str lang: 语言,当前仅支持en - :return: tokenize函数 + :param tokenize_method: 获取 tokenzier 方法,支持 ``['spacy', 'raw', 'cn-char']`` 。``'raw'`` 表示使用空格作为切分, ``'cn-char'`` 表示 + 按字符切分,``'spacy'`` 则使用 :mod:`spacy` 库进行分词。 + :param lang: :mod:`spacy` 使用的语言,当前仅支持 ``'en'`` 。 + :return: tokenize 函数 """ tokenizer_dict = { 'spacy': None, @@ -82,7 +82,7 @@ def get_tokenizer(tokenize_method: str, lang='en'): import spacy spacy.prefer_gpu() if lang != 'en': - raise RuntimeError("Spacy only supports en right right.") + raise RuntimeError("Spacy only supports en right now.") if parse_version(spacy.__version__) >= parse_version('3.0'): en = spacy.load('en_core_web_sm') else: diff --git a/fastNLP/io/utils.py b/fastNLP/io/utils.py index 5c0b63ce..a3af71f5 100644 --- a/fastNLP/io/utils.py +++ b/fastNLP/io/utils.py @@ -1,8 +1,3 @@ -r""" -.. todo:: - doc -""" - __all__ = [ "check_loader_paths" ] @@ -16,7 +11,7 @@ from typing import Union, Dict def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: r""" - 检查传入dataloader的文件的合法性。如果为合法路径,将返回至少包含'train'这个key的dict。类似于下面的结果:: + 检查传入 ``dataloader`` 的文件的合法性。如果为合法路径,将返回至少包含 ``'train'`` 这个 key 的字典。类似于下面的结果:: { 'train': '/some/path/to/', # 一定包含,建词表应该在这上面建立,剩下的其它文件应该只需要处理并index。 @@ -24,10 +19,13 @@ def check_loader_paths(paths: Union[str, Dict[str, str]]) -> Dict[str, str]: ... } - 如果paths为不合法的,将直接进行raise相应的错误. 如果paths内不包含train也会报错。 + 如果 ``paths`` 为不合法的,将直接进行 raise 相应的错误。如果 ``paths`` 内不包含 ``'train'`` 也会报错。 - :param str paths: 路径. 可以为一个文件路径(则认为该文件就是train的文件); 可以为一个文件目录,将在该目录下寻找包含train(文件名 - 中包含train这个字段), test, dev这三个字段的文件或文件夹; 可以为一个dict, 则key是用户自定义的某个文件的名称,value是这个文件的路径。 + :param str paths: 路径。可以为: + + - 一个文件路径,此时认为该文件就是 train 的文件; + - 一个文件目录,将在该目录下寻找包含 ``train`` (文件名中包含 train 这个字段), ``test`` ,``dev`` 这三个字段的文件或文件夹; + - 一个 dict, 则 key 是用户自定义的某个文件的名称,value 是这个文件的路径。 :return: """ if isinstance(paths, (str, Path)): diff --git a/fastNLP/transformers/__init__.py b/fastNLP/transformers/__init__.py index 6403f6b9..6b175b28 100644 --- a/fastNLP/transformers/__init__.py +++ b/fastNLP/transformers/__init__.py @@ -1 +1,3 @@ -"""基于 transformers-4.11.3 版本迁移""" \ No newline at end of file +""" +:mod:`transformers` 模块,包含了常用的预训练模型。 +""" diff --git a/fastNLP/transformers/torch/__init__.py b/fastNLP/transformers/torch/__init__.py index 9ce4fb10..3b564cd4 100644 --- a/fastNLP/transformers/torch/__init__.py +++ b/fastNLP/transformers/torch/__init__.py @@ -1,9 +1,15 @@ """ -为了防止因 https://github.com/huggingface/transformers 版本变化导致代码不兼容,当前 folder 以及子 folder -都复制自 https://github.com/huggingface/transformers 的4.11.3版本。 -In order to avoid the code change of https://github.com/huggingface/transformers to cause version -mismatch, we copy code from https://github.com/huggingface/transformers(version:4.11.3) in this +为了防止因 `transformers `_ 版本变化导致代码不兼容,当前文件夹以及子文件夹 +都复制自 `transformers `_ 的 4.11.3 版本。 + +In order to avoid the code change of `transformers `_ to cause version +mismatch, we copy code from `transformers `_ (version:4.11.3) in this folder and its subfolder. + +您可以如下面代码所示使用 transformers:: + + from fastNLP.transformers.torch import BertModel + ... """ __version__ = "4.11.3" from .models import * \ No newline at end of file