diff --git a/docs/source/conf.py b/docs/source/conf.py index 01884ef7..812fb0ec 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -24,9 +24,9 @@ copyright = '2022, fastNLP' author = 'fastNLP' # The short X.Y version -version = '0.8' +version = '1.0' # The full version, including alpha/beta/rc tags -release = '0.8.0' +release = '1.0.0-alpha' # -- General configuration --------------------------------------------------- @@ -45,6 +45,7 @@ extensions = [ 'sphinx.ext.todo', 'sphinx_autodoc_typehints', 'sphinx_multiversion', + 'nbsphinx', ] autodoc_default_options = { @@ -169,7 +170,7 @@ man_pages = [ # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'fastNLP', 'fastNLP Documentation', - author, 'fastNLP', 'One line description of project.', + author, 'fastNLP', 'A fast NLP tool for programming.', 'Miscellaneous'), ] diff --git a/docs/source/fastNLP.core.callbacks.rst b/docs/source/fastNLP.core.callbacks.rst index 89d85f52..d0f3d210 100644 --- a/docs/source/fastNLP.core.callbacks.rst +++ b/docs/source/fastNLP.core.callbacks.rst @@ -31,5 +31,6 @@ Submodules fastNLP.core.callbacks.lr_scheduler_callback fastNLP.core.callbacks.more_evaluate_callback fastNLP.core.callbacks.progress_callback + fastNLP.core.callbacks.timer_callback fastNLP.core.callbacks.topk_saver fastNLP.core.callbacks.utils diff --git a/docs/source/fastNLP.core.callbacks.timer_callback.rst b/docs/source/fastNLP.core.callbacks.timer_callback.rst new file mode 100644 index 00000000..884fa604 --- /dev/null +++ b/docs/source/fastNLP.core.callbacks.timer_callback.rst @@ -0,0 +1,7 @@ +fastNLP.core.callbacks.timer\_callback module +============================================= + +.. automodule:: fastNLP.core.callbacks.timer_callback + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst b/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst new file mode 100644 index 00000000..ced75ccb --- /dev/null +++ b/docs/source/fastNLP.core.collators.padders.oneflow_padder.rst @@ -0,0 +1,7 @@ +fastNLP.core.collators.padders.oneflow\_padder module +===================================================== + +.. automodule:: fastNLP.core.collators.padders.oneflow_padder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.collators.padders.rst b/docs/source/fastNLP.core.collators.padders.rst index 6f40becb..0c50dd4c 100644 --- a/docs/source/fastNLP.core.collators.padders.rst +++ b/docs/source/fastNLP.core.collators.padders.rst @@ -16,6 +16,7 @@ Submodules fastNLP.core.collators.padders.get_padder fastNLP.core.collators.padders.jittor_padder fastNLP.core.collators.padders.numpy_padder + fastNLP.core.collators.padders.oneflow_padder fastNLP.core.collators.padders.padder fastNLP.core.collators.padders.paddle_padder fastNLP.core.collators.padders.raw_padder diff --git a/docs/source/fastNLP.core.dataloaders.mix_dataloader.rst b/docs/source/fastNLP.core.dataloaders.mix_dataloader.rst deleted file mode 100644 index d2ffa234..00000000 --- a/docs/source/fastNLP.core.dataloaders.mix_dataloader.rst +++ /dev/null @@ -1,7 +0,0 @@ -fastNLP.core.dataloaders.mix\_dataloader module -=============================================== - -.. automodule:: fastNLP.core.dataloaders.mix_dataloader - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst new file mode 100644 index 00000000..5a8939b0 --- /dev/null +++ b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.fdl.rst @@ -0,0 +1,7 @@ +fastNLP.core.dataloaders.oneflow\_dataloader.fdl module +======================================================= + +.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader.fdl + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst new file mode 100644 index 00000000..2b2081e5 --- /dev/null +++ b/docs/source/fastNLP.core.dataloaders.oneflow_dataloader.rst @@ -0,0 +1,15 @@ +fastNLP.core.dataloaders.oneflow\_dataloader package +==================================================== + +.. automodule:: fastNLP.core.dataloaders.oneflow_dataloader + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.core.dataloaders.oneflow_dataloader.fdl diff --git a/docs/source/fastNLP.core.dataloaders.rst b/docs/source/fastNLP.core.dataloaders.rst index e8c6b799..db53dbe0 100644 --- a/docs/source/fastNLP.core.dataloaders.rst +++ b/docs/source/fastNLP.core.dataloaders.rst @@ -13,6 +13,7 @@ Subpackages :maxdepth: 4 fastNLP.core.dataloaders.jittor_dataloader + fastNLP.core.dataloaders.oneflow_dataloader fastNLP.core.dataloaders.paddle_dataloader fastNLP.core.dataloaders.torch_dataloader @@ -22,6 +23,5 @@ Submodules .. toctree:: :maxdepth: 4 - fastNLP.core.dataloaders.mix_dataloader fastNLP.core.dataloaders.prepare_dataloader fastNLP.core.dataloaders.utils diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst new file mode 100644 index 00000000..cd8bd865 --- /dev/null +++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.mix_dataloader.rst @@ -0,0 +1,7 @@ +fastNLP.core.dataloaders.torch\_dataloader.mix\_dataloader module +================================================================= + +.. automodule:: fastNLP.core.dataloaders.torch_dataloader.mix_dataloader + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst index c9acca23..a3aeb1bf 100644 --- a/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst +++ b/docs/source/fastNLP.core.dataloaders.torch_dataloader.rst @@ -13,3 +13,4 @@ Submodules :maxdepth: 4 fastNLP.core.dataloaders.torch_dataloader.fdl + fastNLP.core.dataloaders.torch_dataloader.mix_dataloader diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst new file mode 100644 index 00000000..c7618619 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.ddp.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.ddp module +=============================================== + +.. automodule:: fastNLP.core.drivers.oneflow_driver.ddp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst new file mode 100644 index 00000000..9eae5d19 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.dist_utils.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.dist\_utils module +======================================================= + +.. automodule:: fastNLP.core.drivers.oneflow_driver.dist_utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst new file mode 100644 index 00000000..d7272c8e --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.initialize\_oneflow\_driver module +======================================================================= + +.. automodule:: fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst new file mode 100644 index 00000000..1f5d159e --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.oneflow_driver.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.oneflow\_driver module +=========================================================== + +.. automodule:: fastNLP.core.drivers.oneflow_driver.oneflow_driver + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.rst new file mode 100644 index 00000000..213dd24b --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.rst @@ -0,0 +1,20 @@ +fastNLP.core.drivers.oneflow\_driver package +============================================ + +.. automodule:: fastNLP.core.drivers.oneflow_driver + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.core.drivers.oneflow_driver.ddp + fastNLP.core.drivers.oneflow_driver.dist_utils + fastNLP.core.drivers.oneflow_driver.initialize_oneflow_driver + fastNLP.core.drivers.oneflow_driver.oneflow_driver + fastNLP.core.drivers.oneflow_driver.single_device + fastNLP.core.drivers.oneflow_driver.utils diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst new file mode 100644 index 00000000..a54e74ec --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.single_device.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.single\_device module +========================================================== + +.. automodule:: fastNLP.core.drivers.oneflow_driver.single_device + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst b/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst new file mode 100644 index 00000000..1eda7794 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.oneflow_driver.utils.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.oneflow\_driver.utils module +================================================= + +.. automodule:: fastNLP.core.drivers.oneflow_driver.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.rst b/docs/source/fastNLP.core.drivers.rst index bb168c76..30652fec 100644 --- a/docs/source/fastNLP.core.drivers.rst +++ b/docs/source/fastNLP.core.drivers.rst @@ -13,6 +13,7 @@ Subpackages :maxdepth: 4 fastNLP.core.drivers.jittor_driver + fastNLP.core.drivers.oneflow_driver fastNLP.core.drivers.paddle_driver fastNLP.core.drivers.torch_driver diff --git a/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst b/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst new file mode 100644 index 00000000..2944ffec --- /dev/null +++ b/docs/source/fastNLP.core.drivers.torch_driver.deepspeed.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.torch\_driver.deepspeed module +=================================================== + +.. automodule:: fastNLP.core.drivers.torch_driver.deepspeed + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst b/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst new file mode 100644 index 00000000..e68972a7 --- /dev/null +++ b/docs/source/fastNLP.core.drivers.torch_driver.fairscale.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.torch\_driver.fairscale module +=================================================== + +.. automodule:: fastNLP.core.drivers.torch_driver.fairscale + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.torch_driver.fairscale_sharded.rst b/docs/source/fastNLP.core.drivers.torch_driver.fairscale_sharded.rst deleted file mode 100644 index 765ac4ae..00000000 --- a/docs/source/fastNLP.core.drivers.torch_driver.fairscale_sharded.rst +++ /dev/null @@ -1,7 +0,0 @@ -fastNLP.core.drivers.torch\_driver.fairscale\_sharded module -============================================================ - -.. automodule:: fastNLP.core.drivers.torch_driver.fairscale_sharded - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/fastNLP.core.drivers.torch_driver.rst b/docs/source/fastNLP.core.drivers.torch_driver.rst index 9a0109a2..c9080a86 100644 --- a/docs/source/fastNLP.core.drivers.torch_driver.rst +++ b/docs/source/fastNLP.core.drivers.torch_driver.rst @@ -13,9 +13,11 @@ Submodules :maxdepth: 4 fastNLP.core.drivers.torch_driver.ddp + fastNLP.core.drivers.torch_driver.deepspeed fastNLP.core.drivers.torch_driver.dist_utils - fastNLP.core.drivers.torch_driver.fairscale_sharded + fastNLP.core.drivers.torch_driver.fairscale fastNLP.core.drivers.torch_driver.initialize_torch_driver fastNLP.core.drivers.torch_driver.single_device fastNLP.core.drivers.torch_driver.torch_driver + fastNLP.core.drivers.torch_driver.torch_fsdp fastNLP.core.drivers.torch_driver.utils diff --git a/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst b/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst new file mode 100644 index 00000000..a799b7fc --- /dev/null +++ b/docs/source/fastNLP.core.drivers.torch_driver.torch_fsdp.rst @@ -0,0 +1,7 @@ +fastNLP.core.drivers.torch\_driver.torch\_fsdp module +===================================================== + +.. automodule:: fastNLP.core.drivers.torch_driver.torch_fsdp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst new file mode 100644 index 00000000..2389250b --- /dev/null +++ b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.backend.rst @@ -0,0 +1,7 @@ +fastNLP.core.metrics.backend.oneflow\_backend.backend module +============================================================ + +.. automodule:: fastNLP.core.metrics.backend.oneflow_backend.backend + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst new file mode 100644 index 00000000..cb9e9653 --- /dev/null +++ b/docs/source/fastNLP.core.metrics.backend.oneflow_backend.rst @@ -0,0 +1,15 @@ +fastNLP.core.metrics.backend.oneflow\_backend package +===================================================== + +.. automodule:: fastNLP.core.metrics.backend.oneflow_backend + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.core.metrics.backend.oneflow_backend.backend diff --git a/docs/source/fastNLP.core.metrics.backend.rst b/docs/source/fastNLP.core.metrics.backend.rst index 5a8cf4ad..4466a54a 100644 --- a/docs/source/fastNLP.core.metrics.backend.rst +++ b/docs/source/fastNLP.core.metrics.backend.rst @@ -13,6 +13,7 @@ Subpackages :maxdepth: 4 fastNLP.core.metrics.backend.jittor_backend + fastNLP.core.metrics.backend.oneflow_backend fastNLP.core.metrics.backend.paddle_backend fastNLP.core.metrics.backend.torch_backend diff --git a/docs/source/fastNLP.core.utils.oneflow_utils.rst b/docs/source/fastNLP.core.utils.oneflow_utils.rst new file mode 100644 index 00000000..f9d11510 --- /dev/null +++ b/docs/source/fastNLP.core.utils.oneflow_utils.rst @@ -0,0 +1,7 @@ +fastNLP.core.utils.oneflow\_utils module +======================================== + +.. automodule:: fastNLP.core.utils.oneflow_utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.utils.rst b/docs/source/fastNLP.core.utils.rst index 2d682010..9bf76a23 100644 --- a/docs/source/fastNLP.core.utils.rst +++ b/docs/source/fastNLP.core.utils.rst @@ -16,7 +16,10 @@ Submodules fastNLP.core.utils.dummy_class fastNLP.core.utils.exceptions fastNLP.core.utils.jittor_utils + fastNLP.core.utils.oneflow_utils fastNLP.core.utils.paddle_utils fastNLP.core.utils.rich_progress + fastNLP.core.utils.seq_len_to_mask fastNLP.core.utils.torch_utils + fastNLP.core.utils.tqdm_progress fastNLP.core.utils.utils diff --git a/docs/source/fastNLP.core.utils.seq_len_to_mask.rst b/docs/source/fastNLP.core.utils.seq_len_to_mask.rst new file mode 100644 index 00000000..55188a65 --- /dev/null +++ b/docs/source/fastNLP.core.utils.seq_len_to_mask.rst @@ -0,0 +1,7 @@ +fastNLP.core.utils.seq\_len\_to\_mask module +============================================ + +.. automodule:: fastNLP.core.utils.seq_len_to_mask + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.core.utils.tqdm_progress.rst b/docs/source/fastNLP.core.utils.tqdm_progress.rst new file mode 100644 index 00000000..cfcdc655 --- /dev/null +++ b/docs/source/fastNLP.core.utils.tqdm_progress.rst @@ -0,0 +1,7 @@ +fastNLP.core.utils.tqdm\_progress module +======================================== + +.. automodule:: fastNLP.core.utils.tqdm_progress + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.rst b/docs/source/fastNLP.embeddings.rst new file mode 100644 index 00000000..1b220f59 --- /dev/null +++ b/docs/source/fastNLP.embeddings.rst @@ -0,0 +1,15 @@ +fastNLP.embeddings package +========================== + +.. automodule:: fastNLP.embeddings + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.embeddings.torch diff --git a/docs/source/fastNLP.embeddings.torch.char_embedding.rst b/docs/source/fastNLP.embeddings.torch.char_embedding.rst new file mode 100644 index 00000000..f0d1dad7 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.char_embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.char\_embedding module +=============================================== + +.. automodule:: fastNLP.embeddings.torch.char_embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.embedding.rst b/docs/source/fastNLP.embeddings.torch.embedding.rst new file mode 100644 index 00000000..1804a70e --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.embedding module +========================================= + +.. automodule:: fastNLP.embeddings.torch.embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.rst b/docs/source/fastNLP.embeddings.torch.rst new file mode 100644 index 00000000..6294e8a2 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.rst @@ -0,0 +1,19 @@ +fastNLP.embeddings.torch package +================================ + +.. automodule:: fastNLP.embeddings.torch + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.embeddings.torch.char_embedding + fastNLP.embeddings.torch.embedding + fastNLP.embeddings.torch.stack_embedding + fastNLP.embeddings.torch.static_embedding + fastNLP.embeddings.torch.utils diff --git a/docs/source/fastNLP.embeddings.torch.stack_embedding.rst b/docs/source/fastNLP.embeddings.torch.stack_embedding.rst new file mode 100644 index 00000000..dab50088 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.stack_embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.stack\_embedding module +================================================ + +.. automodule:: fastNLP.embeddings.torch.stack_embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.static_embedding.rst b/docs/source/fastNLP.embeddings.torch.static_embedding.rst new file mode 100644 index 00000000..fc1a2bb9 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.static_embedding.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.static\_embedding module +================================================= + +.. automodule:: fastNLP.embeddings.torch.static_embedding + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.embeddings.torch.utils.rst b/docs/source/fastNLP.embeddings.torch.utils.rst new file mode 100644 index 00000000..9d1fc5b5 --- /dev/null +++ b/docs/source/fastNLP.embeddings.torch.utils.rst @@ -0,0 +1,7 @@ +fastNLP.embeddings.torch.utils module +===================================== + +.. automodule:: fastNLP.embeddings.torch.utils + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.loader.rst b/docs/source/fastNLP.io.loader.rst index bd91b795..20be532a 100644 --- a/docs/source/fastNLP.io.loader.rst +++ b/docs/source/fastNLP.io.loader.rst @@ -14,7 +14,6 @@ Submodules fastNLP.io.loader.classification fastNLP.io.loader.conll - fastNLP.io.loader.coreference fastNLP.io.loader.csv fastNLP.io.loader.cws fastNLP.io.loader.json diff --git a/docs/source/fastNLP.io.model_io.rst b/docs/source/fastNLP.io.model_io.rst deleted file mode 100644 index fe13d1d7..00000000 --- a/docs/source/fastNLP.io.model_io.rst +++ /dev/null @@ -1,7 +0,0 @@ -fastNLP.io.model\_io module -=========================== - -.. automodule:: fastNLP.io.model_io - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/fastNLP.io.pipe.rst b/docs/source/fastNLP.io.pipe.rst index 9ad7e539..53a62918 100644 --- a/docs/source/fastNLP.io.pipe.rst +++ b/docs/source/fastNLP.io.pipe.rst @@ -15,7 +15,6 @@ Submodules fastNLP.io.pipe.classification fastNLP.io.pipe.conll fastNLP.io.pipe.construct_graph - fastNLP.io.pipe.coreference fastNLP.io.pipe.cws fastNLP.io.pipe.matching fastNLP.io.pipe.pipe diff --git a/docs/source/fastNLP.io.rst b/docs/source/fastNLP.io.rst index 5f025bba..7e1a5a67 100644 --- a/docs/source/fastNLP.io.rst +++ b/docs/source/fastNLP.io.rst @@ -25,5 +25,4 @@ Submodules fastNLP.io.embed_loader fastNLP.io.file_reader fastNLP.io.file_utils - fastNLP.io.model_io fastNLP.io.utils diff --git a/docs/source/fastNLP.models.rst b/docs/source/fastNLP.models.rst new file mode 100644 index 00000000..eaef5a5e --- /dev/null +++ b/docs/source/fastNLP.models.rst @@ -0,0 +1,15 @@ +fastNLP.models package +====================== + +.. automodule:: fastNLP.models + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.models.torch diff --git a/docs/source/fastNLP.models.torch.biaffine_parser.rst b/docs/source/fastNLP.models.torch.biaffine_parser.rst new file mode 100644 index 00000000..c75d7079 --- /dev/null +++ b/docs/source/fastNLP.models.torch.biaffine_parser.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.biaffine\_parser module +============================================ + +.. automodule:: fastNLP.models.torch.biaffine_parser + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.cnn_text_classification.rst b/docs/source/fastNLP.models.torch.cnn_text_classification.rst new file mode 100644 index 00000000..a0b4e1bd --- /dev/null +++ b/docs/source/fastNLP.models.torch.cnn_text_classification.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.cnn\_text\_classification module +===================================================== + +.. automodule:: fastNLP.models.torch.cnn_text_classification + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.rst b/docs/source/fastNLP.models.torch.rst new file mode 100644 index 00000000..7196f3f7 --- /dev/null +++ b/docs/source/fastNLP.models.torch.rst @@ -0,0 +1,19 @@ +fastNLP.models.torch package +============================ + +.. automodule:: fastNLP.models.torch + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.models.torch.biaffine_parser + fastNLP.models.torch.cnn_text_classification + fastNLP.models.torch.seq2seq_generator + fastNLP.models.torch.seq2seq_model + fastNLP.models.torch.sequence_labeling diff --git a/docs/source/fastNLP.models.torch.seq2seq_generator.rst b/docs/source/fastNLP.models.torch.seq2seq_generator.rst new file mode 100644 index 00000000..bc1e4ca0 --- /dev/null +++ b/docs/source/fastNLP.models.torch.seq2seq_generator.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.seq2seq\_generator module +============================================== + +.. automodule:: fastNLP.models.torch.seq2seq_generator + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.seq2seq_model.rst b/docs/source/fastNLP.models.torch.seq2seq_model.rst new file mode 100644 index 00000000..802b8793 --- /dev/null +++ b/docs/source/fastNLP.models.torch.seq2seq_model.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.seq2seq\_model module +========================================== + +.. automodule:: fastNLP.models.torch.seq2seq_model + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.models.torch.sequence_labeling.rst b/docs/source/fastNLP.models.torch.sequence_labeling.rst new file mode 100644 index 00000000..af834f53 --- /dev/null +++ b/docs/source/fastNLP.models.torch.sequence_labeling.rst @@ -0,0 +1,7 @@ +fastNLP.models.torch.sequence\_labeling module +============================================== + +.. automodule:: fastNLP.models.torch.sequence_labeling + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.rst b/docs/source/fastNLP.modules.rst index fa1d95de..b686105d 100644 --- a/docs/source/fastNLP.modules.rst +++ b/docs/source/fastNLP.modules.rst @@ -13,3 +13,4 @@ Subpackages :maxdepth: 4 fastNLP.modules.mix_modules + fastNLP.modules.torch diff --git a/docs/source/fastNLP.modules.torch.attention.rst b/docs/source/fastNLP.modules.torch.attention.rst new file mode 100644 index 00000000..52b7bf8c --- /dev/null +++ b/docs/source/fastNLP.modules.torch.attention.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.attention module +====================================== + +.. automodule:: fastNLP.modules.torch.attention + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.crf.rst b/docs/source/fastNLP.modules.torch.decoder.crf.rst new file mode 100644 index 00000000..2d9e3460 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.crf.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.crf module +======================================== + +.. automodule:: fastNLP.modules.torch.decoder.crf + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.mlp.rst b/docs/source/fastNLP.modules.torch.decoder.mlp.rst new file mode 100644 index 00000000..6bb9cc5c --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.mlp.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.mlp module +======================================== + +.. automodule:: fastNLP.modules.torch.decoder.mlp + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.rst b/docs/source/fastNLP.modules.torch.decoder.rst new file mode 100644 index 00000000..999ab01d --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.rst @@ -0,0 +1,18 @@ +fastNLP.modules.torch.decoder package +===================================== + +.. automodule:: fastNLP.modules.torch.decoder + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.decoder.crf + fastNLP.modules.torch.decoder.mlp + fastNLP.modules.torch.decoder.seq2seq_decoder + fastNLP.modules.torch.decoder.seq2seq_state diff --git a/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst b/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst new file mode 100644 index 00000000..43c77fea --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.seq2seq_decoder.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.seq2seq\_decoder module +===================================================== + +.. automodule:: fastNLP.modules.torch.decoder.seq2seq_decoder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst b/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst new file mode 100644 index 00000000..05f730e4 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.decoder.seq2seq_state.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.decoder.seq2seq\_state module +=================================================== + +.. automodule:: fastNLP.modules.torch.decoder.seq2seq_state + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.io.loader.coreference.rst b/docs/source/fastNLP.modules.torch.dropout.rst similarity index 52% rename from docs/source/fastNLP.io.loader.coreference.rst rename to docs/source/fastNLP.modules.torch.dropout.rst index 58dfb880..8e4b591b 100644 --- a/docs/source/fastNLP.io.loader.coreference.rst +++ b/docs/source/fastNLP.modules.torch.dropout.rst @@ -1,7 +1,7 @@ -fastNLP.io.loader.coreference module +fastNLP.modules.torch.dropout module ==================================== -.. automodule:: fastNLP.io.loader.coreference +.. automodule:: fastNLP.modules.torch.dropout :members: :undoc-members: :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst b/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst new file mode 100644 index 00000000..438ec076 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.conv_maxpool.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.conv\_maxpool module +================================================== + +.. automodule:: fastNLP.modules.torch.encoder.conv_maxpool + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.lstm.rst b/docs/source/fastNLP.modules.torch.encoder.lstm.rst new file mode 100644 index 00000000..918e13cb --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.lstm.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.lstm module +========================================= + +.. automodule:: fastNLP.modules.torch.encoder.lstm + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.rst b/docs/source/fastNLP.modules.torch.encoder.rst new file mode 100644 index 00000000..14120ed1 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.rst @@ -0,0 +1,20 @@ +fastNLP.modules.torch.encoder package +===================================== + +.. automodule:: fastNLP.modules.torch.encoder + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.encoder.conv_maxpool + fastNLP.modules.torch.encoder.lstm + fastNLP.modules.torch.encoder.seq2seq_encoder + fastNLP.modules.torch.encoder.star_transformer + fastNLP.modules.torch.encoder.transformer + fastNLP.modules.torch.encoder.variational_rnn diff --git a/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst b/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst new file mode 100644 index 00000000..152fc091 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.seq2seq_encoder.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.seq2seq\_encoder module +===================================================== + +.. automodule:: fastNLP.modules.torch.encoder.seq2seq_encoder + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst b/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst new file mode 100644 index 00000000..3257cf13 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.star_transformer.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.star\_transformer module +====================================================== + +.. automodule:: fastNLP.modules.torch.encoder.star_transformer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.transformer.rst b/docs/source/fastNLP.modules.torch.encoder.transformer.rst new file mode 100644 index 00000000..0a3c893f --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.transformer.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.transformer module +================================================ + +.. automodule:: fastNLP.modules.torch.encoder.transformer + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst b/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst new file mode 100644 index 00000000..71a70c3a --- /dev/null +++ b/docs/source/fastNLP.modules.torch.encoder.variational_rnn.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.encoder.variational\_rnn module +===================================================== + +.. automodule:: fastNLP.modules.torch.encoder.variational_rnn + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.generator.rst b/docs/source/fastNLP.modules.torch.generator.rst new file mode 100644 index 00000000..783db61d --- /dev/null +++ b/docs/source/fastNLP.modules.torch.generator.rst @@ -0,0 +1,15 @@ +fastNLP.modules.torch.generator package +======================================= + +.. automodule:: fastNLP.modules.torch.generator + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.generator.seq2seq_generator diff --git a/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst b/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst new file mode 100644 index 00000000..4abc102f --- /dev/null +++ b/docs/source/fastNLP.modules.torch.generator.seq2seq_generator.rst @@ -0,0 +1,7 @@ +fastNLP.modules.torch.generator.seq2seq\_generator module +========================================================= + +.. automodule:: fastNLP.modules.torch.generator.seq2seq_generator + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/source/fastNLP.modules.torch.rst b/docs/source/fastNLP.modules.torch.rst new file mode 100644 index 00000000..8e1fb0f5 --- /dev/null +++ b/docs/source/fastNLP.modules.torch.rst @@ -0,0 +1,26 @@ +fastNLP.modules.torch package +============================= + +.. automodule:: fastNLP.modules.torch + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.decoder + fastNLP.modules.torch.encoder + fastNLP.modules.torch.generator + +Submodules +---------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.modules.torch.attention + fastNLP.modules.torch.dropout diff --git a/docs/source/fastNLP.rst b/docs/source/fastNLP.rst index 89c8e058..f3e245fe 100644 --- a/docs/source/fastNLP.rst +++ b/docs/source/fastNLP.rst @@ -13,6 +13,9 @@ Subpackages :maxdepth: 4 fastNLP.core + fastNLP.embeddings fastNLP.envs fastNLP.io + fastNLP.models fastNLP.modules + fastNLP.transformers \ No newline at end of file diff --git a/docs/source/fastNLP.transformers.rst b/docs/source/fastNLP.transformers.rst new file mode 100644 index 00000000..023da63d --- /dev/null +++ b/docs/source/fastNLP.transformers.rst @@ -0,0 +1,14 @@ +fastNLP.transformers package +============================ +.. automodule:: fastNLP.transformers + :members: + :undoc-members: + :show-inheritance: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + fastNLP.transformers.torch diff --git a/docs/source/fastNLP.io.pipe.coreference.rst b/docs/source/fastNLP.transformers.torch.rst similarity index 53% rename from docs/source/fastNLP.io.pipe.coreference.rst rename to docs/source/fastNLP.transformers.torch.rst index bccdb0a7..9d5f0d65 100644 --- a/docs/source/fastNLP.io.pipe.coreference.rst +++ b/docs/source/fastNLP.transformers.torch.rst @@ -1,7 +1,7 @@ -fastNLP.io.pipe.coreference module +fastNLP.transformers.torch package ================================== -.. automodule:: fastNLP.io.pipe.coreference +.. automodule:: fastNLP.transformers.torch :members: :undoc-members: :show-inheritance: diff --git a/docs/source/index.rst b/docs/source/index.rst index 10b0d74c..abe73d8a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -2,18 +2,18 @@ fastNLP 中文文档 ===================== -用户手册 +快速上手 ---------------- .. toctree:: - :maxdepth: 1 + :maxdepth: 2 - 语法样例 + tutorials API 文档 ------------- -除了用户手册之外,你还可以通过查阅 API 文档来找到你所需要的工具。 +您可以通过查阅 API 文档来找到你所需要的工具。 .. toctree:: :titlesonly: diff --git a/docs/source/tutorials.rst b/docs/source/tutorials.rst new file mode 100644 index 00000000..bc32d10f --- /dev/null +++ b/docs/source/tutorials.rst @@ -0,0 +1,8 @@ +fastNLP 教程系列 +================ + +.. toctree:: + :maxdepth: 1 + :glob: + + tutorials/* diff --git a/docs/source/tutorials/fastnlp_torch_tutorial.ipynb b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb new file mode 100644 index 00000000..9633ac7f --- /dev/null +++ b/docs/source/tutorials/fastnlp_torch_tutorial.ipynb @@ -0,0 +1,869 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "6011adf8", + "metadata": {}, + "source": [ + "# 10 分钟快速上手 fastNLP torch\n", + "\n", + "在这个例子中,我们将使用BERT来解决conll2003数据集中的命名实体识别任务。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e166c051", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "--2022-07-07 10:12:29-- https://data.deepai.org/conll2003.zip\n", + "Resolving data.deepai.org (data.deepai.org)... 138.201.36.183\n", + "Connecting to data.deepai.org (data.deepai.org)|138.201.36.183|:443... connected.\n", + "WARNING: cannot verify data.deepai.org's certificate, issued by ‘CN=R3,O=Let's Encrypt,C=US’:\n", + " Issued certificate has expired.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 982975 (960K) [application/x-zip-compressed]\n", + "Saving to: ‘conll2003.zip’\n", + "\n", + "conll2003.zip 100%[===================>] 959.94K 653KB/s in 1.5s \n", + "\n", + "2022-07-07 10:12:32 (653 KB/s) - ‘conll2003.zip’ saved [982975/982975]\n", + "\n", + "Archive: conll2003.zip\n", + " inflating: conll2003/metadata \n", + " inflating: conll2003/test.txt \n", + " inflating: conll2003/train.txt \n", + " inflating: conll2003/valid.txt \n" + ] + } + ], + "source": [ + "# Linux/Mac 下载数据,并解压\n", + "import platform\n", + "if platform.system() != \"Windows\":\n", + " !wget https://data.deepai.org/conll2003.zip --no-check-certificate -O conll2003.zip\n", + " !unzip conll2003.zip -d conll2003\n", + "# Windows用户请通过复制该url到浏览器下载该数据并解压" + ] + }, + { + "cell_type": "markdown", + "id": "f7acbf1f", + "metadata": {}, + "source": [ + "## 目录\n", + "接下来我们将按照以下的内容介绍在如何通过fastNLP减少工程性代码的撰写 \n", + "- 1. 数据加载\n", + "- 2. 数据预处理、数据缓存\n", + "- 3. DataLoader\n", + "- 4. 模型准备\n", + "- 5. Trainer的使用\n", + "- 6. Evaluator的使用\n", + "- 7. 其它【待补充】\n", + " - 7.1 使用多卡进行训练、评测\n", + " - 7.2 使用ZeRO优化\n", + " - 7.3 通过overfit测试快速验证模型\n", + " - 7.4 复杂Monitor的使用\n", + " - 7.5 训练过程中,使用不同的测试函数\n", + " - 7.6 更有效率的Sampler\n", + " - 7.7 保存模型\n", + " - 7.8 断点重训\n", + " - 7.9 使用huggingface datasets\n", + " - 7.10 使用torchmetrics来作为metric\n", + " - 7.11 将预测结果写出到文件\n", + " - 7.12 混合 dataset 训练\n", + " - 7.13 logger的使用\n", + " - 7.14 自定义分布式 Metric 。\n", + " - 7.15 通过batch_step_fn实现R-Drop" + ] + }, + { + "cell_type": "markdown", + "id": "0657dfba", + "metadata": {}, + "source": [ + "#### 1. 数据加载\n", + "目前在``conll2003``目录下有``train.txt``, ``test.txt``与``valid.txt``三个文件,文件的格式为[conll格式](https://universaldependencies.org/format.html),其编码格式为 [BIO](https://blog.csdn.net/HappyRocking/article/details/79716212) 类型。可以通过继承 fastNLP.io.Loader 来简化加载过程,继承了 Loader 函数后,只需要在实现读取单个文件 _load() 函数即可。" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c557f0ba", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append('../..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "6f59e438", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "In total 3 datasets:\n", + "\ttrain has 14987 instances.\n", + "\ttest has 3684 instances.\n", + "\tdev has 3466 instances.\n", + "\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet, Instance\n", + "from fastNLP.io import Loader\n", + "\n", + "\n", + "# 继承Loader之后,我们只需要实现其中_load()方法,_load()方法传入一个文件路径,返回一个fastNLP DataSet对象,其目的是读取一个文件。\n", + "class ConllLoader(Loader):\n", + " def _load(self, path):\n", + " ds = DataSet()\n", + " with open(path, 'r') as f:\n", + " segments = []\n", + " for line in f:\n", + " line = line.strip()\n", + " if line == '': # 如果为空行,说明需要切换到下一句了。\n", + " if segments:\n", + " raw_words = [s[0] for s in segments]\n", + " raw_target = [s[1] for s in segments]\n", + " # 将一个 sample 插入到 DataSet中\n", + " ds.append(Instance(raw_words=raw_words, raw_target=raw_target)) \n", + " segments = []\n", + " else:\n", + " parts = line.split()\n", + " assert len(parts)==4\n", + " segments.append([parts[0], parts[-1]])\n", + " return ds\n", + " \n", + "\n", + "# 直接使用 load() 方法加载数据集, 返回的 data_bundle 是一个 fastNLP.io.DataBundle 对象,该对象相当于将多个 dataset 放置在一起,\n", + "# 可以方便之后的预处理,DataBundle 支持的接口可以在 !!! 查看。\n", + "data_bundle = ConllLoader().load({\n", + " 'train': 'conll2003/train.txt',\n", + " 'test': 'conll2003/test.txt',\n", + " 'dev': 'conll2003/valid.txt'\n", + "})\n", + "\"\"\"\n", + "也可以通过 ConllLoader().load('conll2003/') 来读取,其原理是load()函数将尝试从'conll2003/'文件夹下寻找文件名称中包含了\n", + "'train'、'test'和'dev'的文件,并分别读取将其命名为'train'、'test'和'dev'(如文件夹中同一个关键字出现在了多个文件名中将导致报错,\n", + "此时请通过dict的方式传入路径信息)。但在我们这里的数据里,没有文件包含dev,所以无法直接使用文件夹读取,转而通过dict的方式传入读取的路径,\n", + "该dict的key也将作为读取的数据集的名称,value即对应的文件路径。\n", + "\"\"\"\n", + "\n", + "print(data_bundle) # 打印 data_bundle 可以查看包含的 DataSet \n", + "# data_bundle.get_dataset('train') # 可以获取单个 dataset" + ] + }, + { + "cell_type": "markdown", + "id": "57ae314d", + "metadata": {}, + "source": [ + "#### 2. 数据预处理\n", + "接下来,我们将演示如何通过fastNLP提供的apply函数方便快捷地进行预处理。我们需要进行的预处理操作有: \n", + "(1)使用BertTokenizer将文本转换为index;同时记录每个word被bpe之后第一个bpe的index,用于得到word的hidden state; \n", + "(2)使用[Vocabulary](../fastNLP)来将raw_target转换为序号。 " + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "96389988", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c3bd41a323c94a41b409d29a5d4079b6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "IOPub message rate exceeded.\n", + "The notebook server will temporarily stop sending output\n", + "to the client in order to avoid crashing it.\n", + "To change this limit, set the config variable\n", + "`--NotebookApp.iopub_msg_rate_limit`.\n", + "\n", + "Current values:\n", + "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", + "NotebookApp.rate_limit_window=3.0 (secs)\n", + "\n" + ] + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10:48:13] INFO Save cache to /remote-home/hyan01/exps/fastNLP/fastN cache_results.py:332\n", + " LP/demo/torch_tutorial/caches/c7f74559_cache.pkl. \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10:48:13]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Save cache to \u001b[35m/remote-home/hyan01/exps/fastNLP/fastN\u001b[0m \u001b]8;id=831330;file://../../fastNLP/core/utils/cache_results.py\u001b\\\u001b[2mcache_results.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=609545;file://../../fastNLP/core/utils/cache_results.py#332\u001b\\\u001b[2m332\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[35mLP/demo/torch_tutorial/caches/\u001b[0m\u001b[95mc7f74559_cache.pkl.\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# fastNLP 中提供了BERT, RoBERTa, GPT, BART 模型,更多的预训练模型请直接使用transformers\n", + "from fastNLP.transformers.torch import BertTokenizer\n", + "from fastNLP import cache_results, Vocabulary\n", + "\n", + "# 使用cache_results来装饰函数,会将函数的返回结果缓存到'caches/{param_hash_id}_cache.pkl'路径中(其中{param_hash_id}是根据\n", + "# 传递给 process_data 函数参数决定的,因此当函数的参数变化时,会再生成新的缓存文件。如果需要重新生成新的缓存,(a) 可以在调用process_data\n", + "# 函数时,额外传入一个_refresh=True的参数; 或者(b)删除相应的缓存文件。此外,保存结果时,cache_results默认还会\n", + "# 记录 process_data 函数源码的hash值,当其源码发生了变动,直接读取缓存会发出警告,以防止在修改预处理代码之后,忘记刷新缓存。)\n", + "@cache_results('caches/cache.pkl')\n", + "def process_data(data_bundle, model_name):\n", + " tokenizer = BertTokenizer.from_pretrained(model_name)\n", + " def bpe(raw_words):\n", + " bpes = [tokenizer.cls_token_id]\n", + " first = [0]\n", + " first_index = 1 # 记录第一个bpe的位置\n", + " for word in raw_words:\n", + " bpe = tokenizer.encode(word, add_special_tokens=False)\n", + " bpes.extend(bpe)\n", + " first.append(first_index)\n", + " first_index += len(bpe)\n", + " bpes.append(tokenizer.sep_token_id)\n", + " first.append(first_index)\n", + " return {'input_ids': bpes, 'input_len': len(bpes), 'first': first, 'first_len': len(raw_words)}\n", + " # 对data_bundle中每个dataset的每一条数据中的raw_words使用bpe函数,并且将返回的结果加入到每条数据中。\n", + " data_bundle.apply_field_more(bpe, field_name='raw_words', num_proc=4)\n", + " # 对应我们还有 apply_field() 函数,该函数和 apply_field_more() 的区别在于传入到 apply_field() 中的函数应该返回一个 field 的\n", + " # 内容(即不需要用dict包裹了)。此外,我们还提供了 data_bundle.apply() ,传入 apply() 的函数需要支持传入一个Instance对象,\n", + " # 更多信息可以参考对应的文档。\n", + " \n", + " # tag的词表,由于这是词表,所以不需要有padding和unk\n", + " tag_vocab = Vocabulary(padding=None, unknown=None)\n", + " # 从 train 数据的 raw_target 中获取建立词表\n", + " tag_vocab.from_dataset(data_bundle.get_dataset('train'), field_name='raw_target')\n", + " # 使用词表将每个 dataset 中的raw_target转为数字,并且将写入到target这个field中\n", + " tag_vocab.index_dataset(data_bundle.datasets.values(), field_name='raw_target', new_field_name='target')\n", + " \n", + " # 可以将 vocabulary 绑定到 data_bundle 上,方便之后使用。\n", + " data_bundle.set_vocab(tag_vocab, field_name='target')\n", + " \n", + " return data_bundle, tokenizer\n", + "\n", + "data_bundle, tokenizer = process_data(data_bundle, 'bert-base-cased', _refresh=True) # 第一次调用耗时较长,第二次调用则会直接读取缓存的文件\n", + "# data_bundle = process_data(data_bundle, 'bert-base-uncased') # 由于参数变化,fastNLP 会再次生成新的缓存文件。 " + ] + }, + { + "cell_type": "markdown", + "id": "80036fcd", + "metadata": {}, + "source": [ + "### 3. DataLoader \n", + "由于现在的深度学习算法大都基于 mini-batch 进行优化,因此需要将多个 sample 组合成一个 batch 再输入到模型之中。在自然语言处理中,不同的 sample 往往长度不一致,需要进行 padding 操作。在fastNLP中,我们使用 fastNLP.TorchDataLoader 帮助用户快速进行 padding ,我们使用了 !!!fastNLP.Collator!!! 对象来进行 pad ,Collator 会在迭代过程中根据第一个 batch 的数据自动判定每个 field 是否可以进行 pad ,可以通过 Collator.set_pad() 函数修改某个 field 的 pad 行为。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "09494695", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_dataloader\n", + "\n", + "# 将 data_bundle 中每个 dataset 取出并构造出相应的 DataLoader 对象。返回的 dls 是一个 dict ,包含了 'train', 'test', 'dev' 三个\n", + "# fastNLP.TorchDataLoader 对象。\n", + "dls = prepare_dataloader(data_bundle, batch_size=24) \n", + "\n", + "\n", + "# fastNLP 将默认尝试对所有 field 都进行 pad ,如果当前 field 是不可 pad 的类型,则不进行pad;如果是可以 pad 的类型\n", + "# 默认使用 0 进行 pad 。\n", + "for dl in dls.values():\n", + " # 可以通过 set_pad 修改 padding 的行为。\n", + " dl.set_pad('input_ids', pad_val=tokenizer.pad_token_id)\n", + " # 如果希望忽略某个 field ,可以通过 set_ignore 方法。\n", + " dl.set_ignore('raw_target')\n", + " dl.set_pad('target', pad_val=-100)\n", + "# 另一种设置的方法是,可以在 dls = prepare_dataloader(data_bundle, batch_size=32) 之前直接调用 \n", + "# data_bundle.set_pad('input_ids', pad_val=tokenizer.pad_token_id); data_bundle.set_ignore('raw_target')来进行设置。\n", + "# DataSet 也支持这两个方法。\n", + "# 若此时调用 batch = next(dls['train']),则 batch 是一个 dict ,其中包含了\n", + "# 'input_ids': torch.LongTensor([batch_size, max_len])\n", + "# 'input_len': torch.LongTensor([batch_size])\n", + "# 'first': torch.LongTensor([batch_size, max_len'])\n", + "# 'first_len': torch.LongTensor([batch_size])\n", + "# 'target': torch.LongTensor([batch_size, max_len'-2])\n", + "# 'raw_words': List[List[str]] # 因为无法判断,所以 Collator 不会做任何处理" + ] + }, + { + "cell_type": "markdown", + "id": "3583df6d", + "metadata": {}, + "source": [ + "### 4. 模型准备\n", + "传入给fastNLP的模型,需要有两个特殊的方法``train_step``、``evaluate_step``,前者默认在 fastNLP.Trainer 中进行调用,后者默认在 fastNLP.Evaluator 中调用。如果模型中没有``train_step``方法,则Trainer会直接使用模型的``forward``函数;如果模型没有``evaluate_step``方法,则Evaluator会直接使用模型的``forward``函数。``train_step``方法(或当其不存在时,``forward``方法)的返回值必须为 dict 类型,并且必须包含``loss``这个 key 。\n", + "\n", + "此外fastNLP会使用形参名匹配的方式进行参数传递,例如以下模型\n", + "```python\n", + "class Model(nn.Module):\n", + " def train_step(self, x, y):\n", + " return {'loss': (x-y).abs().mean()}\n", + "```\n", + "fastNLP将尝试从 DataLoader 返回的 batch(假设包含的 key 为 input_ids, target) 中寻找 'x' 和 'y' 这两个 key ,如果没有找到则会报错。有以下的方法可以解决报错\n", + "- 修改 train_step 的参数为(input_ids, target),以保证和 DataLoader 返回的 batch 中的 key 匹配\n", + "- 修改 DataLoader 中返回 batch 的 key 的名字为 (x, y)\n", + "- 在 Trainer 中传入参数 train_input_mapping={'input_ids': 'x', 'target': 'y'} 将输入进行映射,train_input_mapping 也可以是一个函数,更多 train_input_mapping 的介绍可以参考文档。\n", + "\n", + "``evaluate_step``也是使用同样的匹配方式,前两条解决方法是一致的,第三种解决方案中,需要在 Evaluator 中传入 evaluate_input_mapping={'input_ids': 'x', 'target': 'y'}。" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f131c1a3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[10:48:21] WARNING Some weights of the model checkpoint at modeling_utils.py:1490\n", + " bert-base-uncased were not used when initializing \n", + " BertModel: ['cls.predictions.bias', \n", + " 'cls.predictions.transform.LayerNorm.weight', \n", + " 'cls.seq_relationship.weight', \n", + " 'cls.predictions.decoder.weight', \n", + " 'cls.predictions.transform.dense.weight', \n", + " 'cls.predictions.transform.LayerNorm.bias', \n", + " 'cls.predictions.transform.dense.bias', \n", + " 'cls.seq_relationship.bias'] \n", + " - This IS expected if you are initializing \n", + " BertModel from the checkpoint of a model trained \n", + " on another task or with another architecture (e.g. \n", + " initializing a BertForSequenceClassification model \n", + " from a BertForPreTraining model). \n", + " - This IS NOT expected if you are initializing \n", + " BertModel from the checkpoint of a model that you \n", + " expect to be exactly identical (initializing a \n", + " BertForSequenceClassification model from a \n", + " BertForSequenceClassification model). \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10:48:21]\u001b[0m\u001b[2;36m \u001b[0m\u001b[31mWARNING \u001b[0m Some weights of the model checkpoint at \u001b]8;id=387614;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=648168;file://../../fastNLP/transformers/torch/modeling_utils.py#1490\u001b\\\u001b[2m1490\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m bert-base-uncased were not used when initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel: \u001b[1m[\u001b[0m\u001b[32m'cls.predictions.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.decoder.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.weight'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.LayerNorm.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.predictions.transform.dense.bias'\u001b[0m, \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[32m'cls.seq_relationship.bias'\u001b[0m\u001b[1m]\u001b[0m \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m - This IS expected if you are initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model trained \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m on another task or with another architecture \u001b[1m(\u001b[0me.g. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m initializing a BertForSequenceClassification model \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m from a BertForPreTraining model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m - This IS NOT expected if you are initializing \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel from the checkpoint of a model that you \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m expect to be exactly identical \u001b[1m(\u001b[0minitializing a \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertForSequenceClassification model from a \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertForSequenceClassification model\u001b[1m)\u001b[0m. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO All the weights of BertModel were initialized from modeling_utils.py:1507\n", + " the model checkpoint at bert-base-uncased. \n", + " If your task is similar to the task the model of \n", + " the checkpoint was trained on, you can already use \n", + " BertModel for predictions without further \n", + " training. \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m All the weights of BertModel were initialized from \u001b]8;id=544687;file://../../fastNLP/transformers/torch/modeling_utils.py\u001b\\\u001b[2mmodeling_utils.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=934505;file://../../fastNLP/transformers/torch/modeling_utils.py#1507\u001b\\\u001b[2m1507\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m the model checkpoint at bert-base-uncased. \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m If your task is similar to the task the model of \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m the checkpoint was trained on, you can already use \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m BertModel for predictions without further \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m training. \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import torch\n", + "from torch import nn\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "from fastNLP.transformers.torch import BertModel\n", + "from fastNLP import seq_len_to_mask\n", + "import torch.nn.functional as F\n", + "\n", + "\n", + "class BertNER(nn.Module):\n", + " def __init__(self, model_name, num_class, tag_vocab=None):\n", + " super().__init__()\n", + " self.bert = BertModel.from_pretrained(model_name)\n", + " self.mlp = nn.Sequential(nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size),\n", + " nn.Dropout(0.3),\n", + " nn.Linear(self.bert.config.hidden_size, num_class))\n", + " self.tag_vocab = tag_vocab # 这里传入 tag_vocab 的目的是为了演示 constrined_decode \n", + " if tag_vocab is not None:\n", + " self._init_constrained_transition()\n", + " \n", + " def forward(self, input_ids, input_len, first):\n", + " attention_mask = seq_len_to_mask(input_len)\n", + " outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)\n", + " last_hidden_state = outputs.last_hidden_state\n", + " first = first.unsqueeze(-1).repeat(1, 1, last_hidden_state.size(-1))\n", + " first_bpe_state = last_hidden_state.gather(dim=1, index=first)\n", + " first_bpe_state = first_bpe_state[:, 1:-1] # 删除 cls 和 sep\n", + " \n", + " pred = self.mlp(first_bpe_state)\n", + " return {'pred': pred}\n", + " \n", + " def train_step(self, input_ids, input_len, first, target):\n", + " pred = self(input_ids, input_len, first)['pred']\n", + " loss = F.cross_entropy(pred.transpose(1, 2), target)\n", + " return {'loss': loss}\n", + " \n", + " def evaluate_step(self, input_ids, input_len, first):\n", + " pred = self(input_ids, input_len, first)['pred'].argmax(dim=-1)\n", + " return {'pred': pred}\n", + " \n", + " def constrained_decode(self, input_ids, input_len, first, first_len):\n", + " # 这个函数在推理时,将保证解码出来的 tag 一定不与前一个 tag 矛盾【例如一定不会出现 B-person 后面接着 I-Location 的情况】\n", + " # 本身这个需求可以在 Metric 中实现,这里在模型中实现的目的是为了方便演示:如何在fastNLP中使用不同的评测函数\n", + " pred = self(input_ids, input_len, first)['pred']\n", + " cons_pred = []\n", + " for _pred, _len in zip(pred, first_len):\n", + " _pred = _pred[:_len]\n", + " tags = [_pred[0].argmax(dim=-1).item()] # 这里就不考虑第一个位置非法的情况了\n", + " for i in range(1, _len):\n", + " tags.append((_pred[i] + self.transition[tags[-1]]).argmax().item())\n", + " cons_pred.append(torch.LongTensor(tags))\n", + " cons_pred = pad_sequence(cons_pred, batch_first=True)\n", + " return {'pred': cons_pred}\n", + " \n", + " def _init_constrained_transition(self):\n", + " from fastNLP.modules.torch import allowed_transitions\n", + " allowed_trans = allowed_transitions(self.tag_vocab)\n", + " transition = torch.ones((len(self.tag_vocab), len(self.tag_vocab)))*-100000.0\n", + " for s, e in allowed_trans:\n", + " transition[s, e] = 0\n", + " self.register_buffer('transition', transition)\n", + "\n", + "model = BertNER('bert-base-uncased', len(data_bundle.get_vocab('target')), data_bundle.get_vocab('target'))" + ] + }, + { + "cell_type": "markdown", + "id": "5aeee1e9", + "metadata": {}, + "source": [ + "### Trainer 的使用\n", + "fastNLP 的 Trainer 是用于对模型进行训练的部件。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f4250f0b", + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[10:49:22] INFO Running evaluator sanity check for 2 batches. trainer.py:661\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10:49:22]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=246773;file://../../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=639347;file://../../fastNLP/core/controllers/trainer.py#661\u001b\\\u001b[2m661\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
+++++++++++++++++++++++++++++ Eval. results on Epoch:1, Batch:0 +++++++++++++++++++++++++++++\n", + "\n" + ], + "text/plain": [ + "\u001b[38;5;41m+++++++++++++++++++++++++++++ \u001b[0m\u001b[1mEval. results on Epoch:\u001b[0m\u001b[1;36m1\u001b[0m\u001b[1m, Batch:\u001b[0m\u001b[1;36m0\u001b[0m\u001b[38;5;41m +++++++++++++++++++++++++++++\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#f\": 0.402447,\n", + " \"pre#f\": 0.447906,\n", + " \"rec#f\": 0.365365\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#f\"\u001b[0m: \u001b[1;36m0.402447\u001b[0m,\n", + " \u001b[1;34m\"pre#f\"\u001b[0m: \u001b[1;36m0.447906\u001b[0m,\n", + " \u001b[1;34m\"rec#f\"\u001b[0m: \u001b[1;36m0.365365\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[10:51:15] INFO The best performance for monitor f#f:0.402447 was progress_callback.py:37\n", + " achieved in Epoch:1, Global Batch:625. The \n", + " evaluation result: \n", + " {'f#f': 0.402447, 'pre#f': 0.447906, 'rec#f': \n", + " 0.365365} \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[10:51:15]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m The best performance for monitor f#\u001b[1;92mf:0\u001b[0m.\u001b[1;36m402447\u001b[0m was \u001b]8;id=192029;file://../../fastNLP/core/callbacks/progress_callback.py\u001b\\\u001b[2mprogress_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=994998;file://../../fastNLP/core/callbacks/progress_callback.py#37\u001b\\\u001b[2m37\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m achieved in Epoch:\u001b[1;36m1\u001b[0m, Global Batch:\u001b[1;36m625\u001b[0m. The \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m evaluation result: \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.402447\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.447906\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m0.365365\u001b[0m\u001b[1m}\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO Loading best model from buffer with f#f: load_best_model_callback.py:115\n", + " 0.402447... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from buffer with f#f: \u001b]8;id=654516;file://../../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96586;file://../../fastNLP/core/callbacks/load_best_model_callback.py#115\u001b\\\u001b[2m115\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m0.402447\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from torch import optim\n", + "from fastNLP import Trainer, LoadBestModelCallback, TorchWarmupCallback\n", + "from fastNLP import SpanFPreRecMetric\n", + "\n", + "optimizer = optim.AdamW(model.parameters(), lr=2e-5)\n", + "callbacks = [\n", + " LoadBestModelCallback(), # 用于在训练结束之后加载性能最好的model的权重\n", + " TorchWarmupCallback()\n", + "] \n", + "\n", + "trainer = Trainer(model=model, train_dataloader=dls['train'], optimizers=optimizer, \n", + " evaluate_dataloaders=dls['dev'], \n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n", + " n_epochs=1, callbacks=callbacks, \n", + " # 在评测时将 dataloader 中的 first_len 映射 seq_len, 因为 Accuracy.update 接口需要输入一个名为 seq_len 的参数\n", + " evaluate_input_mapping={'first_len': 'seq_len'}, overfit_batches=0,\n", + " device=0, monitor='f#f', fp16=False) # fp16 为 True 的话,将使用 float16 进行训练。\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c600a450", + "metadata": {}, + "source": [ + "### Evaluator的使用\n", + "fastNLP中用于评测数据的对象。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b19f0ba", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'f#f'\u001b[0m: \u001b[1;36m0.390326\u001b[0m, \u001b[32m'pre#f'\u001b[0m: \u001b[1;36m0.414741\u001b[0m, \u001b[32m'rec#f'\u001b[0m: \u001b[1;36m0.368626\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'f#f': 0.390326, 'pre#f': 0.414741, 'rec#f': 0.368626}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP import SpanFPreRecMetric\n", + "\n", + "evaluator = Evaluator(model=model, dataloaders=dls['test'], \n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))}, \n", + " evaluate_input_mapping={'first_len': 'seq_len'}, \n", + " device=0)\n", + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "52f87770", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f723fe399df34917875ad74c2542508c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# 如果想评测一下使用 constrained decoding的性能,则可以通过传入 evaluate_fn 指定使用的函数\n", + "def input_mapping(x):\n", + " x['seq_len'] = x['first_len']\n", + " return x\n", + "evaluator = Evaluator(model=model, dataloaders=dls['test'], device=0,\n", + " metrics={'f': SpanFPreRecMetric(tag_vocab=data_bundle.get_vocab('target'))},\n", + " evaluate_fn='constrained_decode',\n", + " # 如果将 first_len 重新命名为了 seq_len, 将导致 constrained_decode 的输入缺少 first_len 参数,因此\n", + " # 额外重复一下 'first_len': 'first_len',使得这个参数不会消失。\n", + " evaluate_input_mapping=input_mapping)\n", + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "419e718b", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_0.ipynb b/docs/source/tutorials/fastnlp_tutorial_0.ipynb new file mode 100644 index 00000000..09667794 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_0.ipynb @@ -0,0 +1,1352 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aec0fde7", + "metadata": {}, + "source": [ + "# T0. trainer 和 evaluator 的基本使用\n", + "\n", + " 1 trainer 和 evaluator 的基本关系\n", + " \n", + " 1.1 trainer 和 evaluater 的初始化\n", + "\n", + " 1.2 driver 的含义与使用要求\n", + "\n", + " 1.3 trainer 内部初始化 evaluater\n", + "\n", + " 2 使用 fastNLP 搭建 argmax 模型\n", + "\n", + " 2.1 trainer_step 和 evaluator_step\n", + "\n", + " 2.2 trainer 和 evaluator 的参数匹配\n", + "\n", + " 2.3 示例:argmax 模型的搭建\n", + "\n", + " 3 使用 fastNLP 训练 argmax 模型\n", + " \n", + " 3.1 trainer 外部初始化的 evaluator\n", + "\n", + " 3.2 trainer 内部初始化的 evaluator " + ] + }, + { + "cell_type": "markdown", + "id": "09ea669a", + "metadata": {}, + "source": [ + "## 1. trainer 和 evaluator 的基本关系\n", + "\n", + "### 1.1 trainer 和 evaluator 的初始化\n", + "\n", + "在`fastNLP 1.0`中,`Trainer`模块和`Evaluator`模块分别表示 **“训练器”和“评测器”**\n", + "\n", + " 对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n", + "\n", + "在`fastNLP 1.0`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n", + "\n", + " 非常关键的问题在于**如何正确设置二者的 driver**。这就引入了另一个问题:什么是 `driver`?\n", + "\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", + " optimizers=optimizer, # 优化模块基于 torch.optim.*\n", + " ...\n", + " driver=\"torch\", # 使用 pytorch 模块进行训练 \n", + " device='cuda', # 使用 GPU:0 显卡执行训练\n", + " ...\n", + " )\n", + "...\n", + "evaluator = Evaluator(\n", + " model=model, # 模型基于 torch.nn.Module\n", + " dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", + " metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", + " ...\n", + " driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", + " device=None,\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "3c11fe1a", + "metadata": {}, + "source": [ + "### 1.2 driver 的含义与使用要求\n", + "\n", + "在`fastNLP 1.0`中,**driver**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n", + "\n", + " 例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n", + "\n", + "在`fastNLP 1.0`中,**Trainer 和 Evaluator 都依赖于具体的 driver 来完成整体的工作流程**\n", + "\n", + " 具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n", + "\n", + "注:这里给出一条建议:**在同一脚本中**,**所有的** Trainer **和** Evaluator **使用的** driver **应当保持一致**\n", + "\n", + " 尽量不出现,之前使用单卡的`driver`,后面又使用多卡的`driver`,这是因为,当脚本执行至\n", + "\n", + " 多卡`driver`处时,会重启一个进程执行之前所有内容,如此一来可能会造成一些意想不到的麻烦" + ] + }, + { + "cell_type": "markdown", + "id": "2cac4a1a", + "metadata": {}, + "source": [ + "### 1.3 Trainer 内部初始化 Evaluator\n", + "\n", + "在`fastNLP 1.0`中,如果在**初始化 Trainer 时**,**传入参数 evaluator_dataloaders 和 metrics **\n", + "\n", + " 则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", + " metrics={'acc': Accuracy()}, # 传入参数 metrics\n", + " ...\n", + " )\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0c9c7dda", + "metadata": {}, + "source": [ + "## 2. argmax 模型的搭建实例" + ] + }, + { + "cell_type": "markdown", + "id": "524ac200", + "metadata": {}, + "source": [ + "### 2.1 trainer_step 和 evaluator_step\n", + "\n", + "在`fastNLP 1.0`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", + "\n", + " 添加`pytorch`要求的`forward`方法外,还需要添加 `train_step` 和 `evaluate_step` 这两个方法\n", + "\n", + "```python\n", + "class Model(torch.nn.Module):\n", + " def __init__(self):\n", + " super(Model, self).__init__()\n", + " self.loss_fn = torch.nn.CrossEntropyLoss()\n", + " pass\n", + "\n", + " def forward(self, x):\n", + " pass\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}\n", + "```\n", + "***\n", + "在`fastNLP 1.0`中,**函数 train_step 是 Trainer 中参数 train_fn 的默认值**\n", + "\n", + " 由于,在`Trainer`训练时,**Trainer 通过参数 train_fn 对应的模型方法获得当前数据批次的损失值**\n", + "\n", + " 因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n", + "\n", + " 如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n", + "\n", + "注:在`fastNLP 1.0`中,**Trainer 要求模型通过 train_step 来返回一个字典**,**满足如 {\"loss\": loss} 的形式**\n", + "\n", + " 此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现输出的转换,详见(trainer的详细讲解,待补充)\n", + "\n", + "同样,在`fastNLP 1.0`中,**函数 evaluate_step 是 Evaluator 中参数 evaluate_fn 的默认值**\n", + "\n", + " 在`Evaluator`测试时,**Evaluator 通过参数 evaluate_fn 对应的模型方法获得当前数据批次的评测结果**\n", + "\n", + " 从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n", + "\n", + " 从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "fb3272eb", + "metadata": {}, + "source": [ + "### 2.2 trainer 和 evaluator 的参数匹配\n", + "\n", + "在`fastNLP 1.0`中,参数匹配涉及到两个方面,分别是在\n", + "\n", + " 一方面,**在模型的前向传播中**,**dataloader 向 train_step 或 evaluate_step 函数传递 batch**\n", + "\n", + " 另方面,**在模型的评测过程中**,**evaluate_dataloader 向 metric 的 update 函数传递 batch**\n", + "\n", + "对于前者,在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`False`时\n", + "\n", + " **fastNLP 1.0 要求 dataloader 生成的每个 batch **,**满足如 {\"x\": x, \"y\": y} 的形式**\n", + "\n", + " 同时,`fastNLP 1.0`会查看模型的`train_step`和`evaluate_step`方法的参数签名,并为对应参数传入对应数值\n", + "\n", + " **字典形式的定义**,**对应在 Dataset 定义的 \\_\\_getitem\\_\\_ 方法中**,例如下方的`ArgMaxDatset`\n", + "\n", + " 而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", + "\n", + " `fastNLP 1.0`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", + "\n", + "```python\n", + "class Dataset(torch.utils.data.Dataset):\n", + " def __init__(self, x, y):\n", + " self.x = x\n", + " self.y = y\n", + "\n", + " def __len__(self):\n", + " return len(self.x)\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "f5f1a6aa", + "metadata": {}, + "source": [ + "对于后者,首先要明确,在`Trainer`和`Evaluator`中,`metrics`的计算分为`update`和`get_metric`两步\n", + "\n", + " **update 函数**,**针对一个 batch 的预测结果**,计算其累计的评价指标\n", + "\n", + " **get_metric 函数**,**统计 update 函数累计的评价指标**,来计算最终的评价结果\n", + "\n", + " 例如对于`Accuracy`来说,`update`函数会更新一个`batch`的正例数量`right_num`和负例数量`total_num`\n", + "\n", + " 而`get_metric`函数则会返回所有`batch`的评测值`right_num / total_num`\n", + "\n", + " 在此基础上,**fastNLP 1.0 要求 evaluate_dataloader 生成的每个 batch 传递给对应的 metric**\n", + "\n", + " **以 {\"pred\": y_pred, \"target\": y_true} 的形式**,对应其`update`函数的函数签名\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "f62b7bb1", + "metadata": {}, + "source": [ + "### 2.3 示例:argmax 模型的搭建\n", + "\n", + "下文将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", + "\n", + " 首先,使用`pytorch.nn.Module`定义`argmax`模型,目标是输入一组固定维度的向量,输出其中数值最大的数的索引" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "5314482b", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class ArgMaxModel(nn.Module):\n", + " def __init__(self, num_labels, feature_dimension):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + "\n", + " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", + " self.ac1 = nn.ReLU()\n", + " self.linear2 = nn.Linear(in_features=10, out_features=10)\n", + " self.ac2 = nn.ReLU()\n", + " self.output = nn.Linear(in_features=10, out_features=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " pred = self.ac1(self.linear1(x))\n", + " pred = self.ac2(self.linear2(pred))\n", + " pred = self.output(pred)\n", + " return pred\n", + "\n", + " def train_step(self, x, y):\n", + " pred = self(x)\n", + " return {\"loss\": self.loss_fn(pred, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " pred = self(x)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": y}" + ] + }, + { + "cell_type": "markdown", + "id": "71f3fa6b", + "metadata": {}, + "source": [ + " 接着,使用`torch.utils.data.Dataset`定义`ArgMaxDataset`数据集\n", + "\n", + " 数据集包含三个参数:维度`feature_dimension`、数据量`data_num`和随机种子`seed`\n", + "\n", + " 数据及初始化是,自动生成指定维度的向量,并为每个向量标注出其中最大值的索引作为预测标签" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fe612e61", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "class ArgMaxDataset(Dataset):\n", + " def __init__(self, feature_dimension, data_num=1000, seed=0):\n", + " self.num_labels = feature_dimension\n", + " self.feature_dimension = feature_dimension\n", + " self.data_num = data_num\n", + " self.seed = seed\n", + "\n", + " g = torch.Generator()\n", + " g.manual_seed(1000)\n", + " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n", + " self.y = torch.max(self.x, dim=-1)[1]\n", + "\n", + " def __len__(self):\n", + " return self.data_num\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}" + ] + }, + { + "cell_type": "markdown", + "id": "2cb96332", + "metadata": {}, + "source": [ + " 然后,根据`ArgMaxModel`类初始化模型实例,保持输入维度`feature_dimension`和输出标签数量`num_labels`一致\n", + "\n", + " 再根据`ArgMaxDataset`类初始化两个数据集实例,分别用来模型测试和模型评测,数据量各1000笔" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76172ef8", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "model = ArgMaxModel(num_labels=10, feature_dimension=10)\n", + "\n", + "train_dataset = ArgMaxDataset(feature_dimension=10, data_num=1000)\n", + "evaluate_dataset = ArgMaxDataset(feature_dimension=10, data_num=100)" + ] + }, + { + "cell_type": "markdown", + "id": "4e7d25ee", + "metadata": {}, + "source": [ + " 此外,使用`torch.utils.data.DataLoader`初始化两个数据加载模块,批量大小同为8,分别用于训练和测评" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "363b5b09", + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n", + "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)" + ] + }, + { + "cell_type": "markdown", + "id": "c8d4443f", + "metadata": {}, + "source": [ + " 最后,使用`torch.optim.SGD`初始化一个优化模块,基于随机梯度下降法" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "dc28a2d9", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.optim import SGD\n", + "\n", + "optimizer = SGD(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "eb8ca6cf", + "metadata": {}, + "source": [ + "## 3. 使用 fastNLP 1.0 训练 argmax 模型\n", + "\n", + "### 3.1 trainer 外部初始化的 evaluator" + ] + }, + { + "cell_type": "markdown", + "id": "55145553", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Trainer`类,初始化`trainer`实例,对模型进行训练\n", + "\n", + " 需要导入预先定义好的模型`model`、对应的数据加载模块`train_dataloader`、优化模块`optimizer`\n", + "\n", + " 通过`progress_bar`设定进度条格式,默认为`\"auto\"`,此外还有`\"rich\"`、`\"raw\"`和`None`\n", + "\n", + " 但对于`\"auto\"`和`\"rich\"`格式,在`jupyter`中,进度条会在训练结束后会被丢弃\n", + "\n", + " 通过`n_epochs`设定优化迭代轮数,默认为20;全部`Trainer`的全部变量与函数可以通过`dir(trainer)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b51b7a2d", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"torch\",\n", + " device='cuda',\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 设定迭代轮数 \n", + " progress_bar=\"auto\" # 设定进度条格式\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6e202d6e", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + " 其中,可以通过参数`num_train_batch_per_epoch`决定每个`epoch`运行多少个`batch`后停止,默认全部\n", + "\n", + " `run`函数完成后在`jupyter`中没有输出保留,此外,通过`help(trainer.run)`可以查询`run`函数的详细内容" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "ba047ead", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "c16c5fa4", + "metadata": {}, + "source": [ + "通过从`fastNLP`库中导入`Evaluator`类,初始化`evaluator`实例,对模型进行评测\n", + "\n", + " 需要导入预先定义好的模型`model`、对应的数据加载模块`evaluate_dataloader`\n", + "\n", + " 需要注意的是评测方法`metrics`,设定为形如`{'acc': fastNLP.core.metrics.Accuracy()}`的字典\n", + "\n", + " 类似地,也可以通过`progress_bar`限定进度条格式,默认为`\"auto\"`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1c6b6b36", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP import Accuracy\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 需要使用 trainer 已经启动的 driver\n", + " device=None,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 需要严格使用此种形式的字典\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "8157bb9b", + "metadata": {}, + "source": [ + "通过使用`Evaluator`类的`run`函数,进行训练\n", + "\n", + " 其中,可以通过参数`num_eval_batch_per_dl`决定每个`evaluate_dataloader`运行多少个`batch`停止,默认全部\n", + "\n", + " 最终,输出形如`{'acc#acc': acc}`的字典,在`jupyter`中,进度条会在评测结束后会被丢弃" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f7cb0165", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.31\u001b[0m, \u001b[32m'total#acc'\u001b[0m: \u001b[1;36m100.0\u001b[0m, \u001b[32m'correct#acc'\u001b[0m: \u001b[1;36m31.0\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.31, 'total#acc': 100.0, 'correct#acc': 31.0}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "dd9f68fa", + "metadata": {}, + "source": [ + "### 3.2 trainer 内部初始化的 evaluator \n", + "\n", + "通过在初始化`trainer`实例时加入`evaluate_dataloaders`和`metrics`,可以实现在训练过程中进行评测\n", + "\n", + " 通过`progress_bar`同时设定训练和评估进度条格式,在`jupyter`中,在进度条训练结束后会被丢弃\n", + "\n", + " 但是中间的评估结果仍会保留;**通过 evaluate_every 设定评估频率**,可以为负数、正数或者函数:\n", + "\n", + " **为负数时**,**表示每隔几个 epoch 评估一次**;**为正数时**,**则表示每隔几个 batch 评估一次**" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "183c7d19", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver=trainer.driver, # 因为是在同个脚本中,这里的 driver 同样需要重用\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " optimizers=optimizer,\n", + " n_epochs=10, \n", + " evaluate_every=-1, # 表示每个 epoch 的结束进行评估\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "714cc404", + "metadata": {}, + "source": [ + "通过使用`Trainer`类的`run`函数,进行训练\n", + "\n", + " 还可以通过**参数 num_eval_sanity_batch 决定每次训练前运行多少个 evaluate_batch 进行评测**,**默认为 2 **\n", + "\n", + " 之所以“先评测后训练”,是为了保证训练很长时间的数据,不会在评测阶段出问题,故作此**试探性评测**" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "2e4daa2c", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[18:28:25] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[18:28:25]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=549287;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=645362;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.31,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 31.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.31\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m31.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.33,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 33.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.33\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m33.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.34,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 34.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.34\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m34.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.36,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 36.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.36\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m36.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.37,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 37.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.37\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m37.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.4,\n", + " \"total#acc\": 100.0,\n", + " \"correct#acc\": 40.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.4\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m40.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "c4e9c619", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.4, 'total#acc': 100.0, 'correct#acc': 40.0}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bc7cb4a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_1.ipynb b/docs/source/tutorials/fastnlp_tutorial_1.ipynb new file mode 100644 index 00000000..cff81a21 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_1.ipynb @@ -0,0 +1,1333 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cdc25fcd", + "metadata": {}, + "source": [ + "# T1. dataset 和 vocabulary 的基本使用\n", + "\n", + " 1 dataset 的使用与结构\n", + " \n", + " 1.1 dataset 的结构与创建\n", + "\n", + " 1.2 dataset 的数据预处理\n", + "\n", + " 1.3 延伸:instance 和 field\n", + "\n", + " 2 vocabulary 的结构与使用\n", + "\n", + " 2.1 vocabulary 的创建与修改\n", + "\n", + " 2.2 vocabulary 与 OOV 问题\n", + "\n", + " 3 dataset 和 vocabulary 的组合使用\n", + " \n", + " 3.1 从 dataframe 中加载 dataset\n", + "\n", + " 3.2 从 dataset 中获取 vocabulary" + ] + }, + { + "cell_type": "markdown", + "id": "0eb18a22", + "metadata": {}, + "source": [ + "## 1. dataset 的基本使用\n", + "\n", + "### 1.1 dataset 的结构与创建\n", + "\n", + "在`fastNLP 1.0`中,使用`DataSet`模块表示数据集,**dataset 类似于关系型数据库中的数据表**(下文统一为小写 `dataset`)\n", + "\n", + " **主要包含 field 字段和 instance 实例两个元素**,对应 table 中的 field 字段和`record`记录\n", + "\n", + "在`fastNLP 1.0`中,`DataSet`模块被定义在`fastNLP.core.dataset`路径下,导入该模块后,最简单的\n", + "\n", + " 初始化方法,即将字典形式的表格 **{'field1': column1, 'field2': column2, ...}** 传入构造函数" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a1d69ad2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"],\n", + " 'words': [['This', 'is', 'an', 'apple', '.'], \n", + " ['I', 'like', 'apples', '.'], \n", + " ['Apples', 'are', 'good', 'for', 'our', 'health', '.']],\n", + " 'num': [5, 4, 7]}\n", + "\n", + "dataset = DataSet(data)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9260fdc6", + "metadata": {}, + "source": [ + " 在`dataset`的实例中,字段`field`的名称和实例`instance`中的字符串也可以中文" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3d72ef00", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------+--------------------+------------------------+------+\n", + "| 序号 | 句子 | 字符 | 长度 |\n", + "+------+--------------------+------------------------+------+\n", + "| 0 | 生活就像海洋, | ['生', '活', '就', ... | 7 |\n", + "| 1 | 只有意志坚强的人, | ['只', '有', '意', ... | 9 |\n", + "| 2 | 才能到达彼岸。 | ['才', '能', '到', ... | 7 |\n", + "+------+--------------------+------------------------+------+\n" + ] + } + ], + "source": [ + "temp = {'序号': [0, 1, 2], \n", + " '句子':[\"生活就像海洋,\", \"只有意志坚强的人,\", \"才能到达彼岸。\"],\n", + " '字符': [['生', '活', '就', '像', '海', '洋', ','], \n", + " ['只', '有', '意', '志', '坚', '强', '的', '人', ','], \n", + " ['才', '能', '到', '达', '彼', '岸', '。']],\n", + " '长度': [7, 9, 7]}\n", + "\n", + "chinese = DataSet(temp)\n", + "print(chinese)" + ] + }, + { + "cell_type": "markdown", + "id": "202e5490", + "metadata": {}, + "source": [ + "在`dataset`中,使用`drop`方法可以删除满足条件的实例,这里使用了python中的`lambda`表达式\n", + "\n", + " 注一:在`drop`方法中,通过设置`inplace`参数将删除对应实例后的`dataset`作为一个新的实例生成" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "09b478f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2492313174344 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped = dropped.drop(lambda ins:ins['num'] < 5, inplace=False)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "aa277674", + "metadata": {}, + "source": [ + " 注二:**对对象使用等号一般表示传引用**,所以对`dataset`使用等号,是传引用而不是赋值\n", + "\n", + " 如下所示,**dropped 和 dataset 具有相同 id**,**对 dropped 执行删除操作 dataset 同时会被修改**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "77c8583a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2491986424200 2491986424200\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped.drop(lambda ins:ins['num'] < 5)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "a76199dc", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_instance`方法可以删除对应序号的`instance`实例,序号从0开始" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d8824b40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+--------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "+-----+--------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.delete_instance(2)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "f4fa9f33", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_field`方法可以删除对应名称的`field`字段" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f68ddb40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+--------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "+-----+--------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset.delete_field('num')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "b1e9d42c", + "metadata": {}, + "source": [ + "### 1.2 dataset 的数据预处理\n", + "\n", + "在`dataset`模块中,`apply`、`apply_field`、`apply_more`和`apply_field_more`函数可以进行简单的数据预处理\n", + "\n", + " **apply 和 apply_more 输入整条实例**,**apply_field 和 apply_field_more 仅输入实例的部分字段**\n", + "\n", + " **apply 和 apply_field 仅输出单个字段**,**apply_more 和 apply_field_more 则是输出多个字段**\n", + "\n", + " **apply 和 apply_field 返回的是个列表**,**apply_more 和 apply_field_more 返回的是个字典**\n", + "\n", + " 预处理过程中,通过`progress_bar`参数设置显示进度条类型,通过`num_proc`设置多进程\n", + "***\n", + "\n", + "`apply`的参数包括一个函数`func`和一个新字段名`new_field_name`,函数`func`的处理对象是`dataset`模块中\n", + "\n", + " 的每个`instance`实例,函数`func`的处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "72a0b5f9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"], }\n", + "dataset = DataSet(data)\n", + "dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words', progress_bar=\"tqdm\") #\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "c10275ee", + "metadata": {}, + "source": [ + " **apply 使用的函数可以是一个基于 lambda 表达式的匿名函数**,**也可以是一个自定义的函数**" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b1a8631f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "\n", + "def get_words(instance):\n", + " sentence = instance['sentence']\n", + " words = sentence.split()\n", + " return words\n", + "\n", + "dataset.apply(get_words, new_field_name='words', progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "64abf745", + "metadata": {}, + "source": [ + "`apply_field`的参数,除了函数`func`外还有`field_name`和`new_field_name`,该函数`func`的处理对象仅\n", + "\n", + " 是`dataset`模块中的每个`field_name`对应的字段内容,处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "057c1d2c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words', \n", + " progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "5a9cc8b2", + "metadata": {}, + "source": [ + "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n", + "\n", + " 要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "51e2f02c", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].split(), 'num': len(ins['sentence'].split())}, \n", + " progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "02d2b7ef", + "metadata": {}, + "source": [ + "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n", + "\n", + " 要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "db4295d5", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_field_more(lambda sent:{'words': sent.split(), 'num': len(sent.split())}, \n", + " field_name='sentence', progress_bar=\"tqdm\")\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9c09e592", + "metadata": {}, + "source": [ + "### 1.3 延伸:instance 和 field\n", + "\n", + "在`fastNLP 1.0`中,使用`Instance`模块表示数据集`dataset`中的每条数据,被称为实例\n", + "\n", + " 构造方式类似于构造一个字典,通过键值相同的`Instance`列表,也可以初始化一个`dataset`,代码如下" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "012f537c", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import DataSet\n", + "from fastNLP import Instance\n", + "\n", + "dataset = DataSet([\n", + " Instance(sentence=\"This is an apple .\",\n", + " words=['This', 'is', 'an', 'apple', '.'],\n", + " num=5),\n", + " Instance(sentence=\"I like apples .\",\n", + " words=['I', 'like', 'apples', '.'],\n", + " num=4),\n", + " Instance(sentence=\"Apples are good for our health .\",\n", + " words=['Apples', 'are', 'good', 'for', 'our', 'health', '.'],\n", + " num=7),\n", + " ])" + ] + }, + { + "cell_type": "markdown", + "id": "2fafb1ef", + "metadata": {}, + "source": [ + " 通过`items`、`keys`和`values`方法,可以分别获得`dataset`的`item`列表、`key`列表、`value`列表" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a4c1c10d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_items([('sentence', 'This is an apple .'), ('words', ['This', 'is', 'an', 'apple', '.']), ('num', 5)])\n", + "dict_keys(['sentence', 'words', 'num'])\n", + "dict_values(['This is an apple .', ['This', 'is', 'an', 'apple', '.'], 5])\n" + ] + } + ], + "source": [ + "ins = Instance(sentence=\"This is an apple .\", words=['This', 'is', 'an', 'apple', '.'], num=5)\n", + "\n", + "print(ins.items())\n", + "print(ins.keys())\n", + "print(ins.values())" + ] + }, + { + "cell_type": "markdown", + "id": "b5459a2d", + "metadata": {}, + "source": [ + " 通过`add_field`方法,可以在`Instance`实例中,通过参数`field_name`添加字段,通过参数`field`赋值" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "55376402", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------------------+------------------------+-----+-----+\n", + "| sentence | words | num | idx |\n", + "+--------------------+------------------------+-----+-----+\n", + "| This is an apple . | ['This', 'is', 'an'... | 5 | 0 |\n", + "+--------------------+------------------------+-----+-----+\n" + ] + } + ], + "source": [ + "ins.add_field(field_name='idx', field=0)\n", + "print(ins)" + ] + }, + { + "cell_type": "markdown", + "id": "49caaa9c", + "metadata": {}, + "source": [ + "在`fastNLP 1.0`中,使用`FieldArray`模块表示数据集`dataset`中的每条字段名(注:没有`field`类)\n", + "\n", + " 通过`get_all_fields`方法可以获取`dataset`的字段列表\n", + "\n", + " 通过`get_field_names`方法可以获取`dataset`的字段名称列表,代码如下" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fe15f4c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentence':
\n", + " | SentenceId | \n", + "Sentence | \n", + "Sentiment | \n", + "
---|---|---|---|
0 | \n", + "1 | \n", + "A series of escapades demonstrating the adage ... | \n", + "negative | \n", + "
1 | \n", + "2 | \n", + "This quiet , introspective and entertaining in... | \n", + "positive | \n", + "
2 | \n", + "3 | \n", + "Even fans of Ismail Merchant 's work , I suspe... | \n", + "negative | \n", + "
3 | \n", + "4 | \n", + "A positively thrilling combination of ethnogra... | \n", + "neutral | \n", + "
4 | \n", + "5 | \n", + "A comedy-drama of nearly epic proportions root... | \n", + "positive | \n", + "
5 | \n", + "6 | \n", + "The Importance of Being Earnest , so thick wit... | \n", + "neutral | \n", + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------------------------------------+----------+\n", + "| text | label |\n", + "+------------------------------------------+----------+\n", + "| ['a', 'series', 'of', 'escapades', 'd... | negative |\n", + "| ['this', 'quiet', ',', 'introspective... | positive |\n", + "| ['even', 'fans', 'of', 'ismail', 'mer... | negative |\n", + "| ['the', 'importance', 'of', 'being', ... | neutral |\n", + "+------------------------------------------+----------+\n", + "+------------------------------------------+----------+\n", + "| text | label |\n", + "+------------------------------------------+----------+\n", + "| ['a', 'comedy-drama', 'of', 'nearly',... | positive |\n", + "| ['a', 'positively', 'thrilling', 'com... | neutral |\n", + "+------------------------------------------+----------+\n", + "{'
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/2 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/2 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| SentenceId | Sentence | Sentiment | input_ids | token_type_ids | attention_mask | target |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n", + "| 1 | A series of... | negative | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 4 | A positivel... | neutral | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 2 |\n", + "| 3 | Even fans o... | negative | [101, 2130,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 1 |\n", + "| 5 | A comedy-dr... | positive | [101, 1037,... | [0, 0, 0, 0, 0,... | [1, 1, 1, 1, 1,... | 0 |\n", + "+------------+----------------+-----------+----------------+--------------------+--------------------+--------+\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import pandas as pd\n", + "from functools import partial\n", + "from fastNLP.transformers.torch import BertTokenizer\n", + "\n", + "from fastNLP import DataSet\n", + "from fastNLP import Vocabulary\n", + "from fastNLP.io import DataBundle\n", + "\n", + "\n", + "class PipeDemo:\n", + " def __init__(self, tokenizer='bert-base-uncased'):\n", + " self.tokenizer = BertTokenizer.from_pretrained(tokenizer)\n", + "\n", + " def process_from_file(self, path='./data/test4dataset.tsv'):\n", + " datasets = DataSet.from_pandas(pd.read_csv(path, sep='\\t'))\n", + " train_ds, test_ds = datasets.split(ratio=0.7)\n", + " train_ds, dev_ds = datasets.split(ratio=0.8)\n", + " data_bundle = DataBundle(datasets={'train': train_ds, 'dev': dev_ds, 'test': test_ds})\n", + "\n", + " encode = partial(self.tokenizer.encode_plus, max_length=100, truncation=True,\n", + " return_attention_mask=True)\n", + " data_bundle.apply_field_more(encode, field_name='Sentence', progress_bar='tqdm')\n", + " \n", + " target_vocab = Vocabulary(padding=None, unknown=None)\n", + "\n", + " target_vocab.from_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment')\n", + " target_vocab.index_dataset(*[ds for _, ds in data_bundle.iter_datasets()], field_name='Sentiment',\n", + " new_field_name='target')\n", + "\n", + " data_bundle.set_pad('input_ids', pad_val=self.tokenizer.pad_token_id)\n", + " data_bundle.set_ignore('SentenceId', 'Sentence', 'Sentiment') \n", + " return data_bundle\n", + "\n", + " \n", + "pipe = PipeDemo(tokenizer='bert-base-uncased')\n", + "\n", + "data_bundle = pipe.process_from_file('./data/test4dataset.tsv')\n", + "\n", + "print(data_bundle.get_dataset('train'))" + ] + }, + { + "cell_type": "markdown", + "id": "76e6b8ab", + "metadata": {}, + "source": [ + "### 1.2 dataloader 的函数创建\n", + "\n", + "在`fastNLP 1.0`中,**更方便、可能更常用的 dataloader 创建方法是通过 prepare_xx_dataloader 函数**\n", + "\n", + " 例如下方的`prepare_torch_dataloader`函数,指定必要参数,读取数据集,生成对应`dataloader`\n", + "\n", + " 类型为`TorchDataLoader`,只能适用于`pytorch`框架,因此对应`trainer`初始化时`driver='torch'`\n", + "\n", + "同时我们看还可以发现,在`fastNLP 1.0`中,**batch 表示为字典 dict 类型**,**key 值就是原先数据集中各个字段**\n", + "\n", + " **除去经过 DataBundle.set_ignore 函数隐去的部分**,而`value`值为`pytorch`框架对应的`torch.Tensor`类型" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5fd60e42", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split(), 'target': ins['label']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('label')\n", + "dataset.delete_field('idx')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "markdown", + "id": "96380c67", + "metadata": {}, + "source": [ + " 然后,使用`tutorial-3`中的知识,**通过 prepare_torch_dataloader 处理数据集得到 dataloader**" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b9dd1273", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "id": "eb75aaba", + "metadata": {}, + "source": [ + "模型使用方面,这里使用`Embedding`、`LSTM`、`MLP`等模块搭建模型,方法类似`pytorch`,结构如下所示\n", + "\n", + "```\n", + "ClsByModules(\n", + " (embedding): Embedding(\n", + " (embed): Embedding(8458, 100)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (lstm): LSTM(\n", + " (lstm): LSTM(100, 64, num_layers=2, batch_first=True, bidirectional=True)\n", + " )\n", + " (mlp): MLP(\n", + " (hiddens): ModuleList()\n", + " (output): Linear(in_features=128, out_features=2, bias=True)\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", + " )\n", + " (loss_fn): CrossEntropyLoss()\n", + ")\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "0b25b25c", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "from fastNLP.modules.torch import LSTM, MLP\n", + "from fastNLP.embeddings.torch import Embedding\n", + "\n", + "\n", + "class ClsByModules(nn.Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " nn.Module.__init__(self)\n", + "\n", + " self.embedding = Embedding((vocab_size, embedding_dim))\n", + " self.lstm = LSTM(embedding_dim, hidden_dim, num_layers=num_layers, bidirectional=True)\n", + " self.mlp = MLP([hidden_dim * 2, output_dim], dropout=dropout)\n", + " \n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " output = self.mlp(torch.cat((hidden[-1], hidden[-2]), dim=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " return {\"loss\": self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {\"pred\": pred, \"target\": target}" + ] + }, + { + "cell_type": "markdown", + "id": "4890de5a", + "metadata": {}, + "source": [ + " 接着,初始化模型`model`实例,同时,使用`torch.optim.AdamW`初始化`optimizer`实例" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9dbbf50d", + "metadata": {}, + "outputs": [], + "source": [ + "model = ClsByModules(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "markdown", + "id": "054538f5", + "metadata": {}, + "source": [ + " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "7a93432f", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "31102e0f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[16:20:10] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:20:10]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=908530;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=864197;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.525,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 84.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m84.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.54375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 87.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.54375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m87.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.55,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 88.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.55\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m88.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 100.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m100.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.65,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 104.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.65\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m104.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.69375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 111.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.69375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m111.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 108.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.66875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 107.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.66875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 108.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m108.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.68125,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 109.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.68125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m109.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8bc4bfb2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.712222, 'total#acc': 900.0, 'correct#acc': 641.0}" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "07538876", + "metadata": {}, + "source": [ + " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "1b52eafd", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "383" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "\n", + "del model\n", + "del trainer\n", + "del dataset\n", + "del sst2data\n", + "\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "d9443213", + "metadata": {}, + "source": [ + "## 2. fastNLP 中 models 模块的介绍\n", + "\n", + "### 2.1 示例一:models 实现 CNN 分类\n", + "\n", + " 本示例使用`fastNLP 1.0`中预定义模型`models`中的`CNNText`模型,实现`SST-2`文本二分类任务\n", + "\n", + "数据使用方面,此处沿用在上个示例中展示的`SST-2`数据集,数据加载过程相同且已经执行过了,因此简略\n", + "\n", + "模型使用方面,如上所述,这里使用**基于卷积神经网络 CNN 的预定义文本分类模型 CNNText**,结构如下所示\n", + "\n", + " 首先是内置的`100`维嵌入层、`dropout`层、紧接着是三个一维卷积,将`100`维嵌入特征,分别通过\n", + "\n", + " **感受野为 1 、 3 、 5 的卷积算子变换至 30 维、 40 维、 50 维的卷积特征**,再将三者拼接\n", + "\n", + " 最终再次通过`dropout`层、线性变换层,映射至二元的输出值,对应两个分类结果上的几率`logits`\n", + "\n", + "```\n", + "CNNText(\n", + " (embed): Embedding(\n", + " (embed): Embedding(5194, 100)\n", + " (dropout): Dropout(p=0.0, inplace=False)\n", + " )\n", + " (conv_pool): ConvMaxpool(\n", + " (convs): ModuleList(\n", + " (0): Conv1d(100, 30, kernel_size=(1,), stride=(1,), bias=False)\n", + " (1): Conv1d(100, 40, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)\n", + " (2): Conv1d(100, 50, kernel_size=(5,), stride=(1,), padding=(2,), bias=False)\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (fc): Linear(in_features=120, out_features=2, bias=True)\n", + ")\n", + "```\n", + "\n", + "对应到代码上,**从 fastNLP.models.torch 路径下导入 CNNText**,初始化`CNNText`和`optimizer`实例\n", + "\n", + " 注意:初始化`CNNText`时,**二元组参数 embed 、分类数量 num_classes 是必须传入的**,其中\n", + "\n", + " **embed 表示嵌入层的嵌入抽取矩阵大小**,因此第二个元素对应的是默认隐藏层维度 `100` 维" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "f6e76e2e", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.models.torch import CNNText\n", + "\n", + "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "0cc5ca10", + "metadata": {}, + "source": [ + " 最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "50a13ee5", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "28903a7d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[16:21:57] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:21:57]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=813103;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=271516;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.654444,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 589.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.654444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m589.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.767778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 691.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.767778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m691.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.797778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 718.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.797778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m718.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.803333,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 723.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.803333\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m723.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.807778,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 727.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.807778\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m727.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.812222,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 731.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.812222\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m731.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.804444,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 724.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.804444\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m724.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.811111,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 730.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.811111,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 730.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.811111\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m730.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.806667,\n", + " \"total#acc\": 900.0,\n", + " \"correct#acc\": 726.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.806667\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m900.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m726.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f47a6a35", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.806667, 'total#acc': 900.0, 'correct#acc': 726.0}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "5b5c0446", + "metadata": {}, + "source": [ + " 注:此处使用`gc`模块删除相关变量,释放内存,为接下来新的模型训练预留存储空间,下同" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "e9e70f88", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "344" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import gc\n", + "\n", + "del model\n", + "del trainer\n", + "\n", + "gc.collect()" + ] + }, + { + "cell_type": "markdown", + "id": "6aec2a19", + "metadata": {}, + "source": [ + "### 2.2 示例二:models 实现 BiLSTM 标注\n", + "\n", + " 通过两个示例一的对比可以发现,得益于`models`对模型结构的封装,使用`models`明显更加便捷\n", + "\n", + " 针对更加复杂的模型时,编码更加轻松;本示例将使用`models`中的`BiLSTMCRF`模型\n", + "\n", + " 避免`CRF`和`Viterbi`算法代码书写的困难,轻松实现`CoNLL-2003`中的命名实体识别`NER`任务\n", + "\n", + "模型使用方面,如上所述,这里使用**基于双向 LSTM +条件随机场 CRF 的标注模型 BiLSTMCRF**,结构如下所示\n", + "\n", + " 其中,隐藏层维度默认`100`维,因此对应双向`LSTM`输出`200`维,`dropout`层退学概率、`LSTM`层数可调\n", + "\n", + "```\n", + "BiLSTMCRF(\n", + " (embed): Embedding(7590, 100)\n", + " (lstm): LSTM(\n", + " (lstm): LSTM(100, 100, batch_first=True, bidirectional=True)\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (fc): Linear(in_features=200, out_features=9, bias=True)\n", + " (crf): ConditionalRandomField()\n", + ")\n", + "```\n", + "\n", + "数据使用方面,此处仍然**使用 datasets 模块中的 load_dataset 函数**,以如下形式,加载`CoNLL-2003`数据集\n", + "\n", + " 首次下载后会保存至`~.cache/huggingface/datasets/conll2003/conll2003/1.0.0/`目录下" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "03e66686", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset conll2003 (/remote-home/xrliu/.cache/huggingface/datasets/conll2003/conll2003/1.0.0/63f4ebd1bcb7148b1644497336fd74643d4ce70123334431a3c053b7ee4e96ee)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "593bc03ed5914953ab94268ff2f01710", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "ner2data = load_dataset('conll2003', 'conll2003')" + ] + }, + { + "cell_type": "markdown", + "id": "fc505631", + "metadata": {}, + "source": [ + "紧接着,使用`tutorial-1`和`tutorial-2`中的知识,将数据集转化为`fastNLP`中的`DataSet`格式\n", + "\n", + " 完成数据集格式调整、文本序列化等操作;此处**需要 'words' 、 'seq_len' 、 'target' 三个字段**\n", + "\n", + "此外,**需要定义 NER 标签到标签序号的映射**(**词汇表 label_vocab**),数据集中标签已经完成了序号映射\n", + "\n", + " 所以需要人工定义**9 个标签对应之前的 9 个分类目标**;数据集说明中规定,`'O'`表示其他标签\n", + "\n", + " **后缀 '-PER' 、 '-ORG' 、 '-LOC' 、 '-MISC' 对应人名、组织名、地名、时间等其他命名**\n", + "\n", + " **前缀 'B-' 表示起始标签、 'I-' 表示终止标签**;例如,`'B-PER'`表示人名实体的起始标签" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1f88cad4", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/4000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(ner2data['train'].to_pandas())[:4000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['tokens'], 'seq_len': len(ins['tokens']), 'target': ins['ner_tags']}, \n", + " progress_bar=\"tqdm\")\n", + "dataset.delete_field('tokens')\n", + "dataset.delete_field('ner_tags')\n", + "dataset.delete_field('pos_tags')\n", + "dataset.delete_field('chunk_tags')\n", + "dataset.delete_field('id')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "token_vocab = Vocabulary()\n", + "token_vocab.from_dataset(dataset, field_name='words')\n", + "token_vocab.index_dataset(dataset, field_name='words')\n", + "label_vocab = Vocabulary(padding=None, unknown=None)\n", + "label_vocab.add_word_lst(['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC'])\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)" + ] + }, + { + "cell_type": "markdown", + "id": "d9889427", + "metadata": {}, + "source": [ + "然后,同样使用`tutorial-3`中的知识,通过`prepare_torch_dataloader`处理数据集得到`dataloader`" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "7802a072", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "id": "2bc7831b", + "metadata": {}, + "source": [ + "接着,**从 fastNLP.models.torch 路径下导入 BiLSTMCRF**,初始化`BiLSTMCRF`实例和优化器\n", + "\n", + " 注意:初始化`BiLSTMCRF`时,和`CNNText`相同,**参数 embed 、 num_classes 是必须传入的**\n", + "\n", + " 隐藏层维度`hidden_size`默认`100`维,调整`150`维;退学概率默认`0.1`,调整`0.2`" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "4e12c09f", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.models.torch import BiLSTMCRF\n", + "\n", + "model = BiLSTMCRF(embed=(len(token_vocab), 150), num_classes=len(label_vocab), \n", + " num_layers=1, hidden_size=150, dropout=0.2)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-3)" + ] + }, + { + "cell_type": "markdown", + "id": "bf30608f", + "metadata": {}, + "source": [ + "最后,使用`trainer`模块,集成`model`、`optimizer`、`dataloader`、`metric`训练\n", + "\n", + " **使用 SpanFPreRecMetric 作为 NER 的评价标准**,详细请参考接下来的`tutorial-5`\n", + "\n", + " 同时,**初始化时需要添加 vocabulary 形式的标签与序号之间的映射 tag_vocab**" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "cbd6c205", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, SpanFPreRecMetric\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'F1': SpanFPreRecMetric(tag_vocab=label_vocab)}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "0f8eff34", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[16:23:41] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[16:23:41]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=565652;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=224849;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.169014,\n", + " \"pre#F1\": 0.170732,\n", + " \"rec#F1\": 0.167331\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.169014\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.170732\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.167331\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.361809,\n", + " \"pre#F1\": 0.312139,\n", + " \"rec#F1\": 0.430279\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.361809\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.312139\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.430279\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.525,\n", + " \"pre#F1\": 0.475728,\n", + " \"rec#F1\": 0.585657\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.525\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.475728\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.585657\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.627306,\n", + " \"pre#F1\": 0.584192,\n", + " \"rec#F1\": 0.677291\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.627306\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.584192\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.677291\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.710937,\n", + " \"pre#F1\": 0.697318,\n", + " \"rec#F1\": 0.7251\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.710937\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.697318\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.7251\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.739563,\n", + " \"pre#F1\": 0.738095,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.739563\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.738095\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.748491,\n", + " \"pre#F1\": 0.756098,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.748491\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.756098\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.716763,\n", + " \"pre#F1\": 0.69403,\n", + " \"rec#F1\": 0.741036\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.716763\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.69403\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.741036\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.768293,\n", + " \"pre#F1\": 0.784232,\n", + " \"rec#F1\": 0.752988\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.768293\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.784232\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.752988\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"f#F1\": 0.757692,\n", + " \"pre#F1\": 0.732342,\n", + " \"rec#F1\": 0.784861\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"f#F1\"\u001b[0m: \u001b[1;36m0.757692\u001b[0m,\n", + " \u001b[1;34m\"pre#F1\"\u001b[0m: \u001b[1;36m0.732342\u001b[0m,\n", + " \u001b[1;34m\"rec#F1\"\u001b[0m: \u001b[1;36m0.784861\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "37871d6b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'f#F1': 0.766798, 'pre#F1': 0.741874, 'rec#F1': 0.793456}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "96bae094", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_5.ipynb b/docs/source/tutorials/fastnlp_tutorial_5.ipynb new file mode 100644 index 00000000..ab759feb --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_5.ipynb @@ -0,0 +1,1242 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7ff16", + "metadata": {}, + "source": [ + "# T5. trainer 和 evaluator 的深入介绍\n", + "\n", + " 1 fastNLP 中 driver 的补充介绍\n", + " \n", + " 1.1 trainer 和 driver 的构想 \n", + "\n", + " 1.2 device 与 多卡训练\n", + "\n", + " 2 fastNLP 中的更多 metric 类型\n", + "\n", + " 2.1 预定义的 metric 类型\n", + "\n", + " 2.2 自定义的 metric 类型\n", + "\n", + " 3 fastNLP 中 trainer 的补充介绍\n", + "\n", + " 3.1 trainer 的内部结构" + ] + }, + { + "cell_type": "markdown", + "id": "08752c5a", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## 1. fastNLP 中 driver 的补充介绍\n", + "\n", + "### 1.1 trainer 和 driver 的构想\n", + "\n", + "在`fastNLP 1.0`中,模型训练最关键的模块便是**训练模块 trainer 、评测模块 evaluator 、驱动模块 driver**,\n", + "\n", + " 在`tutorial 0`中,已经简单介绍过上述三个模块:**driver 用来控制训练评测中的 model 的最终运行**\n", + "\n", + " **evaluator 封装评测的 metric**,**trainer 封装训练的 optimizer**,**也可以包括 evaluator**\n", + "\n", + "之所以做出上述的划分,其根本目的在于要**达成对于多个 python 学习框架**,**例如 pytorch 、 paddle 、 jittor 的兼容**\n", + "\n", + " 对于训练环节,其伪代码如下方左边紫色一栏所示,由于**不同框架对模型、损失、张量的定义各有不同**,所以将训练环节\n", + "\n", + " 划分为**框架无关的循环控制、批量分发部分**,**由 trainer 模块负责**实现,对应的伪代码如下方中间一栏所示\n", + "\n", + " 以及**随框架不同的模型调用、数值优化部分**,**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n", + "\n", + "|训练过程|框架无关 对应`Trainer`|框架相关 对应`Driver`\n", + "|----|----|----|\n", + "| try: | try: | |\n", + "| for epoch in 1:n_eoochs: | for epoch in 1:n_eoochs: | |\n", + "| for step in 1:total_steps: | for step in 1:total_steps: | |\n", + "| batch = fetch_batch() | batch = fetch_batch() | |\n", + "| loss = model.forward(batch) | | loss = model.forward(batch) |\n", + "| loss.backward() | | loss.backward() |\n", + "| model.clear_grad() | | model.clear_grad() |\n", + "| model.update() | | model.update() |\n", + "| if need_save: | if need_save: | |\n", + "| model.save() | | model.save() |\n", + "| except: | except: | |\n", + "| process_exception() | process_exception() | |" + ] + }, + { + "cell_type": "markdown", + "id": "3e55f07b", + "metadata": {}, + "source": [ + " 对于评测环节,其伪代码如下方左边紫色一栏所示,同样由于不同框架对模型、损失、张量的定义各有不同,所以将评测环节\n", + "\n", + " 划分为**框架无关的循环控制、分发汇总部分**,**由 evaluator 模块负责**实现,对应的伪代码如下方中间一栏所示\n", + "\n", + " 以及**随框架不同的模型调用、评测计算部分**,同样**由 driver 模块负责**实现,对应的伪代码如下方右边一栏所示\n", + "\n", + "|评测过程|框架无关 对应`Evaluator`|框架相关 对应`Driver`\n", + "|----|----|----|\n", + "| try: | try: | |\n", + "| model.set_eval() | model.set_eval() | |\n", + "| for step in 1:total_steps: | for step in 1:total_steps: | |\n", + "| batch = fetch_batch() | batch = fetch_batch() | |\n", + "| outputs = model.evaluate(batch) | | outputs = model.evaluate(batch) |\n", + "| metric.compute(batch, outputs) | | metric.compute(batch, outputs) |\n", + "| results = metric.get_metric() | results = metric.get_metric() | |\n", + "| except: | except: | |\n", + "| process_exception() | process_exception() | |" + ] + }, + { + "cell_type": "markdown", + "id": "94ba11c6", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "由此,从程序员的角度,`fastNLP v1.0` **通过一个 driver 让基于 pytorch 、 paddle 、 jittor 、 oneflow 框架的模型**\n", + "\n", + " **都能在相同的 trainer 和 evaluator 上运行**,这也**是 fastNLP v1.0 相比于之前版本的一大亮点**\n", + "\n", + " 而从`driver`的角度,`fastNLP v1.0`通过定义一个`driver`基类,**将所有张量转化为 numpy.tensor**\n", + "\n", + " 并由此泛化出`torch_driver`、`paddle_driver`、`jittor_driver`三个子类,从而实现了\n", + "\n", + " 对`pytorch`、`paddle`、`jittor`的兼容,有关后两者的实践请参考接下来的`tutorial-6`" + ] + }, + { + "cell_type": "markdown", + "id": "ab1cea7d", + "metadata": {}, + "source": [ + "### 1.2 device 与 多卡训练\n", + "\n", + "**fastNLP v1.0 支持多卡训练**,实现方法则是**通过将 trainer 中的 device 设置为对应显卡的序号列表**\n", + "\n", + " 由单卡切换成多卡,无论是数据、模型还是评测都会面临一定的调整,`fastNLP v1.0`保证:\n", + "\n", + " 数据拆分时,不同卡之间相互协调,所有数据都可以被训练,且不会使用到相同的数据\n", + "\n", + " 模型训练时,模型之间需要交换梯度;评测计算时,每张卡先各自计算,再汇总结果\n", + "\n", + " 例如,在评测计算运行`get_metric`函数时,`fastNLP v1.0`将自动按照`self.right`和`self.total`\n", + "\n", + " 指定的 **aggregate_method 方法**,默认为`sum`,将每张卡上结果汇总起来,因此最终\n", + "\n", + " 在调用`get_metric`方法时,`Accuracy`类能够返回全部的统计结果,代码如下\n", + " \n", + "```python\n", + "trainer = Trainer(\n", + " model=model, # model 基于 pytorch 实现 \n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " ...\n", + " driver='torch', # driver 使用 torch_driver \n", + " device=[0, 1], # gpu 选择 cuda:0 + cuda:1\n", + " ...\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " ...\n", + " )\n", + "\n", + "class Accuracy(Metric):\n", + " def __init__(self):\n", + " super().__init__()\n", + " self.register_element(name='total', value=0, aggregate_method='sum')\n", + " self.register_element(name='right', value=0, aggregate_method='sum')\n", + "```\n" + ] + }, + { + "cell_type": "markdown", + "id": "e2e0a210", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "注:`fastNLP v1.0`中要求`jupyter`不能多卡,仅能单卡,故在所有`tutorial`中均不作相关演示" + ] + }, + { + "cell_type": "markdown", + "id": "8d19220c", + "metadata": {}, + "source": [ + "## 2. fastNLP 中的更多 metric 类型\n", + "\n", + "### 2.1 预定义的 metric 类型\n", + "\n", + "在`fastNLP 1.0`中,除了前几篇`tutorial`中经常见到的**正确率 Accuracy**,还有其他**预定义的评测标准 metric**\n", + "\n", + " 包括**所有 metric 的基类 Metric**、适配`Transformers`中相关模型的正确率`TransformersAccuracy`\n", + "\n", + " **适用于分类语境下的 F1 值 ClassifyFPreRecMetric**(其中也包括召回率`Pre`、精确率`Rec`\n", + "\n", + " **适用于抽取语境下的 F1 值 SpanFPreRecMetric**;相关基本信息内容见下表,之后是详细分析\n", + "\n", + "代码名称|简要介绍|代码路径\n", + "----|----|----|\n", + " `Metric` | 定义`metrics`时继承的基类 | `/core/metrics/metric.py` |\n", + " `Accuracy` | 正确率,最为常用 | `/core/metrics/accuracy.py` |\n", + " `TransformersAccuracy` | 正确率,为了兼容`Transformers`中相关模型 | `/core/metrics/accuracy.py` |\n", + " `ClassifyFPreRecMetric` | 召回率、精确率、F1值,适用于**分类问题** | `/core/metrics/classify_f1_pre_rec_metric.py` |\n", + " `SpanFPreRecMetric` | 召回率、精确率、F1值,适用于**抽取问题** | `/core/metrics/span_f1_pre_rec_metric.py` |" + ] + }, + { + "cell_type": "markdown", + "id": "fdc083a3", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + " 如`tutorial-0`中所述,所有的`metric`都包含`get_metric`和`update`函数,其中\n", + "\n", + " **update 函数更新单个 batch 的统计量**,**get_metric 函数返回最终结果**,并打印显示\n", + "\n", + "\n", + "### 2.1.1 Accuracy 与 TransformersAccuracy\n", + "\n", + "`Accuracy`,正确率,预测正确的数据`right_num`在总数据`total_num`,中的占比(公式就不用列了\n", + "\n", + " `get_metric`函数打印格式为 **{\"acc#xx\": float, 'total#xx': float, 'correct#xx': float}**\n", + "\n", + " 一般在初始化时不需要传参,`fastNLP`会根据`update`函数的传入参数确定对应后台框架`backend`\n", + "\n", + " **update 函数的参数包括 pred 、 target 、 seq_len**,**后者用来标记批次中每笔数据的长度**\n", + "\n", + "`TransformersAccuracy`,继承自`Accuracy`,只是为了兼容`Transformers`框架中相关模型\n", + "\n", + " 在`update`函数中,将`Transformers`框架输出的`attention_mask`参数转化为`seq_len`参数\n", + "\n", + "\n", + "### 2.1.2 ClassifyFPreRecMetric 与 SpanFPreRecMetric\n", + "\n", + "`ClassifyFPreRecMetric`,分类评价,`SpanFPreRecMetric`,抽取评价,后者在`tutorial-4`中已出现\n", + "\n", + " 两者的相同之处在于:**第一**,**都包括召回率/查全率 ec**、**精确率/查准率 Pre**、**F1 值**这三个指标\n", + "\n", + " `get_metric`函数打印格式为 **{\"f#xx\": float, 'pre#xx': float, 'rec#xx': float}**\n", + "\n", + " 三者的计算公式如下,其中`beta`默认为`1`,即`F1`值是召回率`Rec`和精确率`Pre`的调和平均数\n", + "\n", + "$$\\text{召回率}\\ Rec=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有本来是正例的数量}}\\qquad \\text{精确率}\\ Pre=\\dfrac{\\text{正确预测为正例的数量}}{\\text{所有预测为正例的数量}}$$\n", + "\n", + "$$F_{beta} = \\frac{(1 + {beta}^{2})*(Pre*Rec)}{({beta}^{2}*Pre + Rec)}$$\n", + "\n", + " **第二**,可以通过参数`only_gross`为`False`,要求返回所有类别的`Rec-Pre-F1`,同时`F1`值又根据参数`f_type`又分为\n", + "\n", + " **micro F1**(**直接统计所有类别的 Rec-Pre-F1**)、**macro F1**(**统计各类别的 Rec-Pre-F1 再算术平均**)\n", + "\n", + " **第三**,两者在初始化时还可以**传入基于 fastNLP.Vocabulary 的 tag_vocab 参数记录数据集中的标签序号**\n", + "\n", + " **与标签名称之间的映射**,通过字符串列表`ignore_labels`参数,指定若干标签不用于`Rec-Pre-F1`的计算\n", + "\n", + "两者的不同之处在于:`ClassifyFPreRecMetric`针对简单的分类问题,每个分类标签之间彼此独立,不构成标签对\n", + "\n", + " **SpanFPreRecMetric 针对更复杂的抽取问题**,**规定标签 B-xx 和 I-xx 或 B-xx 和 E-xx 构成标签对**\n", + "\n", + " 在计算`Rec-Pre-F1`时,`ClassifyFPreRecMetric`只需要考虑标签本身是否正确这就足够了,但是\n", + "\n", + " 对于`SpanFPreRecMetric`,需要保证**标签符合规则且覆盖的区间与正确结果重合才算正确**\n", + "\n", + " 因此回到`tutorial-4`中`CoNLL-2003`的`NER`任务,如果评测方法选择`ClassifyFPreRecMetric`\n", + "\n", + " 或者`Accuracy`,会发现虽然评测结果显示很高,这是因为选择的评测方法要求太低\n", + "\n", + " 最后通过`CoNLL-2003`的词性标注`POS`任务简单演示下`ClassifyFPreRecMetric`相关的使用\n", + "\n", + "```python\n", + "from fastNLP import Vocabulary\n", + "from fastNLP import ClassifyFPreRecMetric\n", + "\n", + "tag_vocab = Vocabulary(padding=None, unknown=None) # 记录序号与标签之间的映射\n", + "tag_vocab.add_word_lst(['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', \n", + " 'CC', 'CD', 'DT', 'EX', 'FW', 'IN', 'JJ', 'JJR', 'JJS', 'LS', \n", + " 'MD', 'NN', 'NNP', 'NNPS', 'NNS', 'NN|SYM', 'PDT', 'POS', 'PRP', 'PRP$', \n", + " 'RB', 'RBR', 'RBS', 'RP', 'SYM', 'TO', 'UH', 'VB', 'VBD', 'VBG', \n", + " 'VBN', 'VBP', 'VBZ', 'WDT', 'WP', 'WP+', 'WRB', ]) # CoNLL-2003 中的 pos_tags\n", + "ignore_labels = ['\"', \"''\", '#', '$', '(', ')', ',', '.', ':', '``', ]\n", + "\n", + "FPreRec = ClassifyFPreRecMetric(tag_vocab=tag_vocab, \n", + " ignore_labels=ignore_labels, # 表示评测/优化中不考虑上述标签的正误/损失\n", + " only_gross=True, # 默认为 True 表示输出所有类别的综合统计结果\n", + " f_type='micro') # 默认为 'micro' 表示统计所有类别的 Rec-Pre-F1\n", + "metrics = {'F1': FPreRec}\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "8a22f522", + "metadata": {}, + "source": [ + "### 2.2 自定义的 metric 类型\n", + "\n", + "如上文所述,`Metric`作为所有`metric`的基类,`Accuracy`等都是其子类,同样地,对于**自定义的 metric 类型**\n", + "\n", + " 也**需要继承自 Metric 类**,同时**内部自定义好 __init__ 、 update 和 get_metric 函数**\n", + "\n", + " 在`__init__`函数中,根据需求定义评测时需要用到的变量,此处沿用`Accuracy`中的`total_num`和`right_num`\n", + "\n", + " 在`update`函数中,根据需求定义评测变量的更新方式,需要注意的是如`tutorial-0`中所述,**update`的参数名**\n", + "\n", + " **需要待评估模型在 evaluate_step 中的输出名称一致**,由此**和数据集中对应字段名称一致**,即**参数匹配**\n", + "\n", + " 在`fastNLP v1.0`中,`update`函数的默认输入参数:`pred`,对应预测值;`target`,对应真实值\n", + "\n", + " 此处仍然沿用,因为接下来会需要使用`fastNLP`函数的与定义模型,其输入参数格式即使如此\n", + "\n", + " 在`get_metric`函数中,根据需求定义评测指标最终的计算,此处直接计算准确率,该函数必须返回一个字典\n", + "\n", + " 其中,字串`'prefix'`表示该`metric`的名称,会对应显示到`trainer`的`progress bar`中\n", + "\n", + "根据上述要求,这里简单定义了一个名为`MyMetric`的评测模块,用于分类问题的评测,以此展开一个实例展示" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08a872e9", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import sys\n", + "sys.path.append('..')\n", + "\n", + "from fastNLP import Metric\n", + "\n", + "class MyMetric(Metric):\n", + "\n", + " def __init__(self):\n", + " Metric.__init__(self)\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + "\n", + " def update(self, pred, target):\n", + " self.total_num += target.size(0)\n", + " self.right_num += target.eq(pred).sum().item()\n", + "\n", + " def get_metric(self, reset=True):\n", + " acc = self.right_num / self.total_num\n", + " if reset:\n", + " self.total_num = 0\n", + " self.right_num = 0\n", + " return {'prefix': acc}" + ] + }, + { + "cell_type": "markdown", + "id": "0155f447", + "metadata": {}, + "source": [ + " 数据使用方面,此处仍然使用`datasets`模块中的`load_dataset`函数,加载`SST-2`二分类数据集" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5ad81ac7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ef923b90b19847f4916cccda5d33fc36", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "sst2data = load_dataset('glue', 'sst2')" + ] + }, + { + "cell_type": "markdown", + "id": "e9d81760", + "metadata": {}, + "source": [ + " 在数据预处理中,需要注意的是,这里原本应该根据`metric`和`model`的输入参数格式,调整\n", + "\n", + " 数据集中表示预测目标的字段,调整为`target`,在后文中会揭晓为什么,以及如何补救" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "cfb28b1b", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import DataSet\n", + "\n", + "dataset = DataSet.from_pandas(sst2data['train'].to_pandas())[:6000]\n", + "\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].lower().split()}, progress_bar=\"tqdm\")\n", + "dataset.delete_field('sentence')\n", + "dataset.delete_field('idx')\n", + "\n", + "from fastNLP import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "vocab.from_dataset(dataset, field_name='words')\n", + "vocab.index_dataset(dataset, field_name='words')\n", + "\n", + "train_dataset, evaluate_dataset = dataset.split(ratio=0.85)\n", + "\n", + "from fastNLP import prepare_torch_dataloader\n", + "\n", + "train_dataloader = prepare_torch_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_torch_dataloader(evaluate_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "id": "af3f8c63", + "metadata": {}, + "source": [ + " 模型使用方面,此处仍然使用`tutorial-4`中介绍过的预定义`CNNText`模型,实现`SST-2`二分类" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2fd210c5", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.models.torch import CNNText\n", + "\n", + "model = CNNText(embed=(len(vocab), 100), num_classes=2, dropout=0.1)\n", + "\n", + "from torch.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-4)" + ] + }, + { + "cell_type": "markdown", + "id": "6e723b87", + "metadata": {}, + "source": [ + "## 3. fastNLP 中 trainer 的补充介绍\n", + "\n", + "### 3.1 trainer 的内部结构\n", + "\n", + "在`tutorial-0`中,我们已经介绍了`trainer`的基本使用,从`tutorial-1`到`tutorial-4`,我们也已经展示了\n", + "\n", + " 很多`trainer`的使用案例,这里通过表格,相对完整地介绍`trainer`模块的属性和初始化参数(标粗为必选参数\n", + "\n", + "\n", + "名称|参数|属性|功能|内容\n", + "----|----|----|----|----|\n", + "| **model** | √ | √ | 指定`trainer`控制的模型 | 视框架而定,如`torch.nn.Module` |\n", + "| `device` | √ | | 指定`trainer`运行的卡位 | 例如`'cpu'`、`'cuda'`、`0`、`[0, 1]`等 |\n", + "| | | √ | 记录`trainer`运行的卡位 | `Device`类型,在初始化阶段生成 |\n", + "| **driver** | √ | | 指定`trainer`驱动的框架 | 包括`'torch'`、`'paddle'`、`'jittor'` |\n", + "| | | √ | 记录`trainer`驱动的框架 | `Driver`类型,在初始化阶段生成 |\n", + "| `n_epochs` | √ | - | 指定`trainer`迭代的轮数 | 默认`20`,记录在`driver.n_epochs`中 |\n", + "| **optimizers** | √ | √ | 指定`trainer`优化的方法 | 视框架而定,如`torch.optim.Adam` |\n", + "| `metrics` | √ | √ | 指定`trainer`评测的方法 | 字典类型,如`{'acc': Metric()}` |\n", + "| `evaluator` | | √ | 内置的`trainer`评测模块 | `Evaluator`类型,在初始化阶段生成 |\n", + "| `input_mapping` | √ | √ | 调整`dataloader`的参数不匹配 | 函数类型,输出字典匹配`forward`输入参数 |\n", + "| `output_mapping` | √ | √ | 调整`forward`输出的参数不匹配 | 函数类型,输出字典匹配`xx_step`输入参数 |\n", + "| **train_dataloader** | √ | √ | 指定`trainer`训练的数据 | `DataLoader`类型,生成视框架而定 |\n", + "| `evaluate_dataloaders` | √ | √ | 指定`trainer`评测的数据 | `DataLoader`类型,生成视框架而定 |\n", + "| `train_fn` | √ | √ | 指定`trainer`获取某个批次的损失值 | 函数类型,默认为`model.train_step` |\n", + "| `evaluate_fn` | √ | √ | 指定`trainer`获取某个批次的评估量 | 函数类型,默认为`model.evaluate_step` |\n", + "| `batch_step_fn` | √ | √ | 指定`trainer`训练时前向传输一个批次的方式 | 函数类型,默认为`TrainBatchLoop.batch_step_fn` |\n", + "| `evaluate_batch_step_fn` | √ | √ | 指定`trainer`评测时前向传输一个批次的方式 | 函数类型,默认为`EvaluateBatchLoop.batch_step_fn` |\n", + "| `accumulation_steps` | √ | √ | 指定`trainer`训练时反向传播的频率 | 默认为`1`,即每个批次都反向传播 |\n", + "| `evaluate_every` | √ | √ | 指定`evaluator`评测时计算的频率 | 默认`-1`表示每个循环一次,相反`1`表示每个批次一次 |\n", + "| `progress_bar` | √ | √ | 指定`trainer`训练和评测时的进度条样式 | 包括`'auto'`、`'tqdm'`、`'raw'`、`'rich'` |\n", + "| `callbacks` | √ | | 指定`trainer`训练时需要触发的函数 | `Callback`列表类型,详见`tutorial-7` |\n", + "| `callback_manager` | | √ | 记录与管理`callbacks`相关内容 | `CallbackManager`类型,详见`tutorial-7` |\n", + "| `monitor` | √ | √ | 辅助部分的`callbacks`相关内容 | 字符串/函数类型,详见`tutorial-7` |\n", + "| `marker` | √ | √ | 标记`trainer`实例,辅助`callbacks`相关内容 | 字符串型,详见`tutorial-7` |\n", + "| `trainer_state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `TrainerState`类型,详见`tutorial-7` |\n", + "| `state` | | √ | 记录`trainer`状态,辅助`callbacks`相关内容 | `State`类型,详见`tutorial-7` |\n", + "| `fp16` | √ | √ | 指定`trainer`是否进行混合精度训练 | 布尔类型,默认`False` |\n", + "\n", + "其中,**input_mapping 和 output_mapping** 定义形式如下:输入字典形式的数据,根据参数匹配要求调整数据格式,这里就回应了前文未在数据集预处理时调整格式的问题,**总之参数匹配一定要求**" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "de96c1d1", + "metadata": {}, + "outputs": [], + "source": [ + "def input_mapping(data):\n", + " data['target'] = data['label']\n", + " return data" + ] + }, + { + "cell_type": "markdown", + "id": "2fc8b9f3", + "metadata": {}, + "source": [ + " 而`trainer`模块的基础方法列表如下,相关进阶操作,如`on`系列函数、`callback`控制,请参考后续的`tutorial-7`\n", + "\n", + "|名称|功能|主要参数|\n", + "|----|----|----|\n", + "| `run` | 控制`trainer`中模型的训练和评测 | 详见后文 |\n", + "| `train_step` | 实现`trainer`训练中一个批数据的前向传播过程 | 输入`batch` |\n", + "| `backward` | 实现`trainer`训练中一次损失的反向传播过程 | 输入`output` |\n", + "| `zero_grad` | 实现`trainer`训练中`optimizers`的梯度置零 | 无输入 |\n", + "| `step` | 实现`trainer`训练中`optimizers`的参数更新 | 无输入 |\n", + "| `epoch_evaluate` | 实现`trainer`训练中每个循环的评测,实际是否执行取决于评测频率 | 无输入 |\n", + "| `step_evaluate` | 实现`trainer`训练中每个批次的评测,实际是否执行取决于评测频率 | 无输入 |\n", + "| `save_model` | 保存`trainer`中的模型参数/状态字典至`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`False` |\n", + "| `load_model` | 加载`trainer`中的模型参数/状态字典自`fastnlp_model.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只加载状态字典,默认`True` |\n", + "| `save_checkpoint` | 保存`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态至`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` |\n", + "| `load_checkpoint` | 加载`trainer`中模型参数/状态字典 以及 `callback`、`sampler` 和`optimizer`的状态自`fastnlp_model/checkpoint.pkl.tar` | `folder`指明路径,`only_state_dict`指明是否只保存状态字典,默认`True` `resume_training`指明是否只精确到上次训练的批量,默认`True` |\n", + "| `add_callback_fn` | 在`trainer`初始化后添加`callback`函数 | 输入`event`指明回调时机,`fn`指明回调函数 |\n", + "| `on` | 函数修饰器,将一个函数转变为`callback`函数 | 详见`tutorial-7` |\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "1e21df35", + "metadata": {}, + "source": [ + "紧接着,初始化`trainer`实例,继续完成`SST-2`分类,其中`metrics`输入的键值对,字串`'suffix'`和之前定义的\n", + "\n", + " 字串`'prefix'`将拼接在一起显示到`progress bar`中,故完整的输出形式为`{'prefix#suffix': float}`" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "926a9c50", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " input_mapping=input_mapping,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'suffix': MyMetric()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b1b2e8b7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "source": [ + "最后就是`run`函数的使用,关于其参数,这里也以表格形式列出,由此就解答了`num_eval_batch_per_dl=10`的含义\n", + "\n", + "|名称|功能|默认值|\n", + "|----|----|----|\n", + "| `num_train_batch_per_epoch` | 指定`trainer`训练时,每个循环计算批量数目 | 整数类型,默认`-1`,表示训练时,每个循环计算所有批量 |\n", + "| `num_eval_batch_per_dl` | 指定`trainer`评测时,每个循环计算批量数目 | 整数类型,默认`-1`,表示评测时,每个循环计算所有批量 |\n", + "| `num_eval_sanity_batch` | 指定`trainer`训练开始前,试探性评测批量数目 | 整数类型,默认`2`,表示训练开始前评估两个批量 |\n", + "| `resume_from` | 指定`trainer`恢复状态的路径,需要是文件夹 | 字符串型,默认`None`,使用可参考`CheckpointCallback` |\n", + "| `resume_training` | 指定`trainer`恢复状态的程度 | 布尔类型,默认`True`恢复所有状态,`False`仅恢复`model`和`optimizers`状态 |" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "43be274f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
[09:30:35] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[09:30:35]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=954293;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=366534;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.6875\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.6875\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8125\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.825\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.825\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8125\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8125\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.8\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.8\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"prefix#suffix\": 0.80625\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"prefix#suffix\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1abfa0a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_6.ipynb b/docs/source/tutorials/fastnlp_tutorial_6.ipynb new file mode 100644 index 00000000..63f7481e --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_6.ipynb @@ -0,0 +1,1646 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "fdd7ff16", + "metadata": {}, + "source": [ + "# T6. fastNLP 与 paddle 或 jittor 的结合\n", + "\n", + " 1 fastNLP 结合 paddle 训练模型\n", + " \n", + " 1.1 关于 paddle 的简单介绍\n", + "\n", + " 1.2 使用 paddle 搭建并训练模型\n", + "\n", + " 2 fastNLP 结合 jittor 训练模型\n", + "\n", + " 2.1 关于 jittor 的简单介绍\n", + "\n", + " 2.2 使用 jittor 搭建并训练模型\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "08752c5a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6b13d42c39ba455eb370bf2caaa3a264", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "sst2data = load_dataset('glue', 'sst2')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "7e8cc210", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[38;5;2m[i 0604 21:01:38.510813 72 log.cc:351] Load log_sync: 1\u001b[m\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Processing: 0%| | 0/6000 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "
[21:03:08] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:03:08]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=894986;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=567751;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:111: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " if ip and hasattr(ip, 'kernel') and hasattr(ip.kernel, '_parent_header'):\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/ipywidgets/widgets/widget_\n", + "output.py:112: DeprecationWarning: Kernel._parent_header is deprecated in ipykernel 6. Use \n", + ".get_parent()\n", + " self.msg_id = ip.kernel._parent_header['header']['msg_id']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n", + "\n" + ], + "text/plain": [ + "/remote-home/xrliu/anaconda3/envs/demo/lib/python3.7/site-packages/paddle/tensor/creation.py:\n", + "125: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To \n", + "silence this warning, use `object` by itself. Doing this will not modify any behavior and is \n", + "safe. \n", + "Deprecated in NumPy 1.20; for more details and guidance: \n", + "https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations\n", + " if data.dtype == np.object:\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.78125,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 125.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.78125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 126.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.79375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 127.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.81875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 131.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.81875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m131.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.80625,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 129.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.80625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m129.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.79375,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 127.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.79375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m127.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7875,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 126.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m126.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8,\n", + " \"total#acc\": 160.0,\n", + " \"correct#acc\": 128.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m128.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10) " + ] + }, + { + "cell_type": "markdown", + "id": "cb9a0b3c", + "metadata": {}, + "source": [ + "## 2. fastNLP 结合 jittor 训练模型" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c600191d", + "metadata": {}, + "outputs": [], + "source": [ + "import jittor\n", + "import jittor.nn as nn\n", + "\n", + "from jittor import Module\n", + "\n", + "\n", + "class ClsByJittor(Module):\n", + " def __init__(self, vocab_size, embedding_dim, output_dim, hidden_dim=64, num_layers=2, dropout=0.5):\n", + " Module.__init__(self)\n", + " self.hidden_dim = hidden_dim\n", + "\n", + " self.embedding = nn.Embedding(num=vocab_size, dim=embedding_dim)\n", + " self.lstm = nn.LSTM(input_size=embedding_dim, hidden_size=hidden_dim, batch_first=True, # 默认 batch_first=False\n", + " num_layers=num_layers, bidirectional=True, dropout=dropout)\n", + " self.mlp = nn.Sequential([nn.Dropout(p=dropout),\n", + " nn.Linear(hidden_dim * 2, hidden_dim * 2),\n", + " nn.ReLU(),\n", + " nn.Linear(hidden_dim * 2, output_dim),\n", + " nn.Sigmoid(),])\n", + "\n", + " self.loss_fn = nn.MSELoss()\n", + "\n", + " def execute(self, words):\n", + " output = self.embedding(words)\n", + " output, (hidden, cell) = self.lstm(output)\n", + " output = self.mlp(jittor.concat((hidden[-1], hidden[-2]), dim=1))\n", + " return output\n", + " \n", + " def train_step(self, words, target):\n", + " pred = self(words)\n", + " target = jittor.stack((1 - target, target), dim=1)\n", + " return {'loss': self.loss_fn(pred, target)}\n", + "\n", + " def evaluate_step(self, words, target):\n", + " pred = self(words)\n", + " pred = jittor.argmax(pred, dim=-1)[0]\n", + " return {'pred': pred, 'target': target}" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "a94ed8c4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ClsByJittor(\n", + " embedding: Embedding(8458, 100)\n", + " lstm: LSTM(100, 64, 2, bias=True, batch_first=True, dropout=0.5, bidirectional=True, proj_size=0)\n", + " mlp: Sequential(\n", + " 0: Dropout(0.5, is_train=False)\n", + " 1: Linear(128, 128, float32[128,], None)\n", + " 2: relu()\n", + " 3: Linear(128, 2, float32[2,], None)\n", + " 4: Sigmoid()\n", + " )\n", + " loss_fn: MSELoss(mean)\n", + ")" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model = ClsByJittor(vocab_size=len(vocab), embedding_dim=100, output_dim=2)\n", + "\n", + "model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6d15ebc1", + "metadata": {}, + "outputs": [], + "source": [ + "from jittor.optim import AdamW\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "95d8d09e", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import prepare_jittor_dataloader\n", + "\n", + "train_dataloader = prepare_jittor_dataloader(train_dataset, batch_size=16, shuffle=True)\n", + "evaluate_dataloader = prepare_jittor_dataloader(evaluate_dataset, batch_size=16)\n", + "\n", + "# dl_bundle = prepare_jittor_dataloader(data_bundle, batch_size=16, shuffle=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "917eab81", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP import Trainer, Accuracy\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver='jittor',\n", + " device='gpu', # 'cpu', 'gpu', 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=train_dataloader, # dl_bundle['train'],\n", + " evaluate_dataloaders=evaluate_dataloader, # dl_bundle['dev'],\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "f7c4ac5a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[21:05:51] INFO Running evaluator sanity check for 2 batches. trainer.py:596\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:05:51]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=69759;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=202322;file://../fastNLP/core/controllers/trainer.py#596\u001b\\\u001b[2m596\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n", + "Compiling Operators(5/6) used: 8.31s eta: 1.66s 6/6) used: 9.33s eta: 0s \n", + "\n", + "Compiling Operators(31/31) used: 7.31s eta: 0s \n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.61875,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 99\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.61875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m99\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 112\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m112\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.725,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 116\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.725\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m116\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.74375,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 119\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.75625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 121\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.75625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 121\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.75625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m121\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.73125,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 117\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.73125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m117\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 122\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.74375,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 119\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.74375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m119\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.7625,\n", + " \"total#acc\": 160,\n", + " \"correct#acc\": 122\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.7625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m160\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m122\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3df5f425", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb new file mode 100644 index 00000000..af8e60a0 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_e1.ipynb @@ -0,0 +1,1280 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 从这篇开始,我们将开启 **fastNLP v1.0 tutorial 的 example 系列**,在接下来的\n", + "\n", + " 每篇`tutorial`里,我们将会介绍`fastNLP v1.0`在自然语言处理任务上的应用实例" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP import Accuracy\n", + "\n", + "print(transformers.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:GLUE 通用语言理解评估、SST-2 文本情感二分类数据集\n", + "\n", + " 本示例使用`GLUE`评估基准中的`SST-2`数据集,通过`fine-tuning`方式\n", + "\n", + " 调整`distilbert-bert`分类模型,以下首先简单介绍下`GLUE`和`SST-2`\n", + "\n", + "**GLUE**,**全称 General Language Understanding Evaluation**,**通用语言理解评估**,\n", + "\n", + " 包含9个数据集,各语料的语言均为英语,涉及多个自然语言理解`NLU`任务,包括\n", + "\n", + " **CoLA**,文本分类任务,预测单句语法正误分类;**SST-2**,文本分类任务,预测单句情感二分类\n", + "\n", + " **MRPC**,句对分类任务,预测句对语义一致性;**STS-B**,相似度打分任务,预测句对语义相似度回归\n", + "\n", + " **QQP**,句对分类任务,预测问题对语义一致性;**MNLI**,文本推理任务,预测句对蕴含/矛盾/中立预测\n", + "\n", + " **QNLI / RTE / WNLI**,文本推理,预测是否蕴含二分类(其中,`QNLI`从`SQuAD`转化而来\n", + "\n", + " 诸如`BERT`、`T5`等经典模型都会在此基准上验证效果,更多参考[GLUE论文](https://arxiv.org/pdf/1804.07461v3.pdf)\n", + "\n", + " 此处,我们使用`SST-2`来训练`bert`,实现文本分类,其他任务描述见下图" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "GLUE_TASKS = ['cola', 'mnli', 'mrpc', 'qnli', 'qqp', 'rte', 'sst2', 'stsb', 'wnli']\n", + "\n", + "task = 'sst2'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "\n", + "**SST**,**全称`Stanford Sentiment Treebank**,**斯坦福情感树库**,**单句情感分类**数据集\n", + "\n", + " 包含电影评论语句和对应的情感极性,1 对应`positive` 正面情感,0 对应`negative` 负面情感\n", + "\n", + " 数据集包括三部分:训练集 67350 条,验证集 873 条,测试集 1821 条,更多参考[下载链接](https://gluebenchmark.com/tasks)\n", + "\n", + "对应到代码上,此处使用`datasets`模块中的`load_dataset`函数,指定`SST-2`数据集,自动加载\n", + "\n", + " 首次下载后会保存至`~/.cache/huggingface/modules/datasets_modules/datasets/glue/`目录下" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c5915debacf9443986b5b3b34870b303", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset\n", + "\n", + "dataset = load_dataset('glue', task)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 加载之后,根据`GLUE`中`SST-2`数据集的格式,尝试打印部分数据,检查加载结果" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sentence: hide new secretions from the parental units \n" + ] + } + ], + "source": [ + "task_to_keys = {\n", + " 'cola': ('sentence', None),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mnli': ('premise', 'hypothesis'),\n", + " 'mrpc': ('sentence1', 'sentence2'),\n", + " 'qnli': ('question', 'sentence'),\n", + " 'qqp': ('question1', 'question2'),\n", + " 'rte': ('sentence1', 'sentence2'),\n", + " 'sst2': ('sentence', None),\n", + " 'stsb': ('sentence1', 'sentence2'),\n", + " 'wnli': ('sentence1', 'sentence2'),\n", + "}\n", + "\n", + "sentence1_key, sentence2_key = task_to_keys[task]\n", + "\n", + "if sentence2_key is None:\n", + " print(f\"Sentence: {dataset['train'][0][sentence1_key]}\")\n", + "else:\n", + " print(f\"Sentence 1: {dataset['train'][0][sentence1_key]}\")\n", + " print(f\"Sentence 2: {dataset['train'][0][sentence2_key]}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:加载 tokenizer、预处理 dataset、dataloader 使用\n", + "\n", + " 接下来进入模型训练的准备工作,分别需要使用`tokenizer`模块对数据集进行分词与标注\n", + "\n", + " 定义`SeqClsDataset`对应`dataloader`模块用来实现数据集在训练/测试时的加载\n", + "\n", + "此处的`tokenizer`和`SequenceClassificationModel`都是基于**distilbert-base-uncased 模型**\n", + "\n", + " 即使用较小的、不区分大小写的数据集,**对 bert-base 进行知识蒸馏后的版本**,结构上\n", + "\n", + " 包含**1个编码层**、**6个自注意力层**,**参数量`66M**,详解见本篇末尾,更多请参考[DistilBert论文](https://arxiv.org/pdf/1910.01108.pdf)\n", + "\n", + "首先,通过从`transformers`库中导入 **AutoTokenizer 模块**,**使用 from_pretrained 函数初始化**\n", + "\n", + " 此处的`use_fast`表示是否使用`tokenizer`的快速版本;尝试序列化示例数据,检查加载结果\n", + "\n", + " 需要注意的是,处理后返回的两个键值,**'input_ids'**表示原始文本对应的词素编号序列\n", + "\n", + " **'attention_mask'**表示自注意力运算时的掩码(标上`0`的部分对应`padding`的内容" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': [101, 7592, 1010, 2023, 2028, 6251, 999, 102, 1998, 2023, 6251, 3632, 2007, 2009, 1012, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n" + ] + } + ], + "source": [ + "model_checkpoint = 'distilbert-base-uncased'\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)\n", + "\n", + "print(tokenizer(\"Hello, this one sentence!\", \"And this sentence goes with it.\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,定义预处理函数,**通过 dataset.map 方法**,**将数据集中的文本**,**替换为词素编号序列**" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ca1fbe5e8eb059f3.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-03661263fbf302f5.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-fbe8e7a4e4f18f45.arrow\n" + ] + } + ], + "source": [ + "def preprocess_function(examples):\n", + " if sentence2_key is None:\n", + " return tokenizer(examples[sentence1_key], truncation=True)\n", + " return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)\n", + "\n", + "encoded_dataset = dataset.map(preprocess_function, batched=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,通过继承`torch`中的`Dataset`类,定义`SeqClsDataset`类,需要注意的是\n", + "\n", + " 其中,**\\_\\_getitem\\_\\_ 函数各返回值引用的键值**,**必须和原始数据集中的属性对应**\n", + "\n", + " 例如,`'label'`是`SST-2`数据集中原有的内容(包括`'sentence'`和`'label'`\n", + "\n", + " `'input_ids'`和`'attention_mask'`则是`tokenizer`处理后添加的字段" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsDataset(Dataset):\n", + " def __init__(self, dataset):\n", + " Dataset.__init__(self)\n", + " self.dataset = dataset\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, item):\n", + " item = self.dataset[item]\n", + " return item['input_ids'], item['attention_mask'], [item['label']] " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "再然后,**定义校对函数 collate_fn 对齐同个 batch 内的每笔数据**,需要注意的是该函数的\n", + "\n", + " **返回值必须是字典**,**键值必须同待训练模型的 train_step 和 evaluate_step 函数的参数**\n", + "\n", + " **相对应**;这也就是在`tutorial-0`中便被强调的,`fastNLP v1.0`的第一条**参数匹配**机制" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "def collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item[0])\n", + " max_length[0] = max(max_length[0], len(each_item[0]))\n", + " atten_mask.append(each_item[1])\n", + " max_length[1] = max(max_length[1], len(each_item[1]))\n", + " labels.append(each_item[2])\n", + " max_length[2] = max(max_length[2], len(each_item[2]))\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", + "dataloader_train = DataLoader(dataset=dataset_train, \n", + " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", + "dataloader_valid = DataLoader(dataset=dataset_valid, \n", + " batch_size=32, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 distilbert-base、fastNLP 参数匹配、fine-tuning\n", + "\n", + " 最后就是模型训练的,分别需要使用`distilbert-base-uncased`搭建分类模型\n", + "\n", + " 初始化优化器`optimizer`、训练模块`trainer`,通过`run`函数完成训练\n", + "\n", + "此处使用的`nn.Module`模块搭建模型,与`tokenizer`类似,通过从`transformers`库中\n", + "\n", + " 导入`AutoModelForSequenceClassification`模块,基于`distilbert-base-uncased`模型初始\n", + "\n", + "需要注意的是**AutoModelForSequenceClassification 模块的输入参数和输出结构**\n", + "\n", + " 一方面,可以**通过输入标签值 labels**,**使用模块内的损失函数计算损失 loss**\n", + "\n", + " 并且可以选择输入是词素编号序列`input_ids`,还是词素嵌入序列`inputs_embeds`\n", + "\n", + " 另方面,该模块不会直接输出预测结果,而是会**输出各预测分类上的几率 logits**\n", + "\n", + " 基于上述描述,此处完成了中`train_step`和`evaluate_step`函数的定义\n", + "\n", + " 同样需要注意,函数的返回值体现了`fastNLP v1.0`的第二条**参数匹配**机制" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, num_labels, model_checkpoint):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " output = self.back_bone(input_ids=input_ids, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return output\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_transform.bias', 'vocab_projector.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_layer_norm.bias']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.weight', 'classifier.bias', 'pre_classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "num_labels = 3 if task == 'mnli' else 1 if task == 'stsb' else 2\n", + "\n", + "model = SeqClsModel(num_labels=num_labels, model_checkpoint=model_checkpoint)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=5e-5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=0, # 'cuda'\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dataloader_train,\n", + " evaluate_dataloaders=dataloader_valid,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,使用`trainer.run`方法,训练模型,`n_epochs`参数中已经指定需要迭代`10`轮\n", + "\n", + " `num_eval_batch_per_dl`参数则指定每次只对验证集中的`10`个`batch`进行评估" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[09:12:45] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[09:12:45]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=408427;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=303634;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.884375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 283.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.878125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 281.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.884375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 283.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.9,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 288.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.9\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m288.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.8875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 284.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.8875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m284.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.88125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 282.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.88125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m282.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 280.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m280.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.865625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 277.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.865625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m277.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.884375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 283.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.884375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m283.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.878125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 281.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.878125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m281.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.884174, 'total#acc': 872.0, 'correct#acc': 771.0}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 附:`DistilBertForSequenceClassification`模块结构\n", + "\n", + "```\n", + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "4.18.0\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "from torch.optim import AdamW\n", + "from torch.utils.data import DataLoader, Dataset\n", + "\n", + "import transformers\n", + "from transformers import AutoTokenizer\n", + "from transformers import AutoModelForSequenceClassification\n", + "\n", + "import sys\n", + "sys.path.append('..')\n", + "\n", + "import fastNLP\n", + "from fastNLP import Trainer\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "print(transformers.__version__)\n", + "\n", + "task = 'sst2'\n", + "model_checkpoint = 'distilbert-base-uncased' # 'bert-base-uncased'" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:P-Tuning v2 原理概述、P-Tuning v2 模型搭建\n", + "\n", + " 本示例使用`P-Tuning v2`作为`prompt-based tuning`与`fastNLP v1.0`结合的案例\n", + "\n", + " 以下首先简述`P-Tuning v2`的论文原理,并由此引出`fastNLP v1.0`的代码实践\n", + "\n", + "**P-Tuning v2**出自论文[Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks](https://arxiv.org/pdf/2110.07602.pdf)\n", + "\n", + " 其主要贡献在于,**在 PrefixTuning 等深度提示学习基础上**,**提升了其在分类标注等 NLU 任务的表现**\n", + "\n", + " 并使之在中等规模模型,主要是**参数量在 100M-1B 区间的模型上**,**获得与全参数微调相同的效果**\n", + "\n", + " 其结构如图所示,通过**在输入序列的分类符 [CLS] 之前**,**加入前缀序列**(**序号对应嵌入是待训练的连续值向量**\n", + "\n", + " **刺激模型在新任务下**,从`[CLS]`对应位置,**输出符合微调任务的输出**,从而达到适应微调任务的目的\n", + "\n", + "\n", + "\n", + "本示例使用`bert-base-uncased`模型,作为`P-Tuning v2`的基础`backbone`,设置`requires_grad=False`\n", + "\n", + " 固定其参数不参与训练,**设置 pre_seq_len 长的 prefix_tokens 作为输入的提示前缀序列**\n", + "\n", + " **使用基于 nn.Embedding 的 prefix_encoder 为提示前缀嵌入**,通过`get_prompt`函数获取,再将之\n", + "\n", + " 拼接至批量内每笔数据前得到`inputs_embeds`,同时更新自注意力掩码`attention_mask`\n", + "\n", + " 将`inputs_embeds`、`attention_mask`和`labels`输入`backbone`,**得到输出包括 loss 和 logits**" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsModel(nn.Module):\n", + " def __init__(self, model_checkpoint, num_labels, pre_seq_len):\n", + " nn.Module.__init__(self)\n", + " self.num_labels = num_labels\n", + " self.back_bone = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, \n", + " num_labels=num_labels)\n", + " self.embeddings = self.back_bone.get_input_embeddings()\n", + "\n", + " for param in self.back_bone.parameters():\n", + " param.requires_grad = False\n", + " \n", + " self.pre_seq_len = pre_seq_len\n", + " self.prefix_tokens = torch.arange(self.pre_seq_len).long()\n", + " self.prefix_encoder = nn.Embedding(self.pre_seq_len, self.embeddings.embedding_dim)\n", + " \n", + " def get_prompt(self, batch_size):\n", + " prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.back_bone.device)\n", + " prompts = self.prefix_encoder(prefix_tokens)\n", + " return prompts\n", + "\n", + " def forward(self, input_ids, attention_mask, labels=None):\n", + " \n", + " batch_size = input_ids.shape[0]\n", + " raw_embedding = self.embeddings(input_ids)\n", + " \n", + " prompts = self.get_prompt(batch_size=batch_size)\n", + " inputs_embeds = torch.cat((prompts, raw_embedding), dim=1)\n", + " prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.back_bone.device)\n", + " attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)\n", + "\n", + " outputs = self.back_bone(inputs_embeds=inputs_embeds, \n", + " attention_mask=attention_mask, labels=labels)\n", + " return outputs\n", + "\n", + " def train_step(self, input_ids, attention_mask, labels):\n", + " loss = self(input_ids, attention_mask, labels).loss\n", + " return {'loss': loss}\n", + "\n", + " def evaluate_step(self, input_ids, attention_mask, labels):\n", + " pred = self(input_ids, attention_mask, labels).logits\n", + " pred = torch.max(pred, dim=-1)[1]\n", + " return {'pred': pred, 'target': labels}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "接着,通过确定分类数量初始化模型实例,同时调用`torch.optim.AdamW`模块初始化优化器\n", + "\n", + " 根据`P-Tuning v2`论文:`Generally, simple classification tasks prefer shorter prompts (less than 20)`\n", + "\n", + " 此处`pre_seq_len`参数设定为`20`,学习率相应做出调整,其他内容和`tutorial-E1`中的内容一致" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.bias', 'vocab_transform.weight', 'vocab_projector.bias']\n", + "- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['pre_classifier.weight', 'classifier.weight', 'pre_classifier.bias', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = SeqClsModel(model_checkpoint=model_checkpoint, num_labels=2, pre_seq_len=20)\n", + "\n", + "optimizers = AdamW(params=model.parameters(), lr=1e-2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:加载 tokenizer、预处理 dataset、模型训练与分析\n", + "\n", + " 本示例沿用`tutorial-E1`中的数据集,即使用`GLUE`评估基准中的`SST-2`数据集\n", + "\n", + " 以`bert-base-uncased`模型作为基准,基于`P-Tuning v2`方式微调\n", + "\n", + " 数据集加载相关代码流程见下,内容和`tutorial-E1`中的内容基本一致\n", + "\n", + "首先,使用`datasets.load_dataset`加载数据集,使用`transformers.AutoTokenizer`\n", + "\n", + " 构建`tokenizer`实例,通过`dataset.map`使用`tokenizer`将文本替换为词素序号序列" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset glue (/remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "21cbd92c3397497d84dc10f017ec96f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/3 [00:00, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from datasets import load_dataset, load_metric\n", + "\n", + "dataset = load_dataset('glue', task)\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-294e481a713c5754.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-ed9d9258aaf0fb54.arrow\n", + "Loading cached processed dataset at /remote-home/xrliu/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f44c5576b89f9e6b.arrow\n" + ] + } + ], + "source": [ + "def preprocess_function(examples):\n", + " return tokenizer(examples['sentence'], truncation=True)\n", + "\n", + "encoded_dataset = dataset.map(preprocess_function, batched=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "然后,定义`SeqClsDataset`类、定义校对函数`collate_fn`,这里沿用`tutorial-E1`中的内容\n", + "\n", + " 同样需要注意/强调的是,**\\_\\_getitem\\_\\_ 函数的返回值必须和原始数据集中的属性对应**\n", + "\n", + " **collate_fn 函数的返回值必须和 train_step 和 evaluate_step 函数的参数匹配**" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "class SeqClsDataset(Dataset):\n", + " def __init__(self, dataset):\n", + " Dataset.__init__(self)\n", + " self.dataset = dataset\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, item):\n", + " item = self.dataset[item]\n", + " return item['input_ids'], item['attention_mask'], [item['label']] \n", + "\n", + "def collate_fn(batch):\n", + " input_ids, atten_mask, labels = [], [], []\n", + " max_length = [0] * 3\n", + " for each_item in batch:\n", + " input_ids.append(each_item[0])\n", + " max_length[0] = max(max_length[0], len(each_item[0]))\n", + " atten_mask.append(each_item[1])\n", + " max_length[1] = max(max_length[1], len(each_item[1]))\n", + " labels.append(each_item[2])\n", + " max_length[2] = max(max_length[2], len(each_item[2]))\n", + "\n", + " for i in range(3):\n", + " each = (input_ids, atten_mask, labels)[i]\n", + " for item in each:\n", + " item.extend([0] * (max_length[i] - len(item)))\n", + " return {'input_ids': torch.cat([torch.tensor([item]) for item in input_ids], dim=0),\n", + " 'attention_mask': torch.cat([torch.tensor([item]) for item in atten_mask], dim=0),\n", + " 'labels': torch.cat([torch.tensor(item) for item in labels], dim=0)}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "再然后,分别对`tokenizer`处理过的训练集数据、验证集数据,进行预处理和批量划分" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_train = SeqClsDataset(encoded_dataset['train'])\n", + "dataloader_train = DataLoader(dataset=dataset_train, \n", + " batch_size=32, shuffle=True, collate_fn=collate_fn)\n", + "dataset_valid = SeqClsDataset(encoded_dataset['validation'])\n", + "dataloader_valid = DataLoader(dataset=dataset_valid, \n", + " batch_size=32, shuffle=False, collate_fn=collate_fn)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "最后,使用之前完成的`dataloader_train`和`dataloader_valid`,定义训练模块`trainer`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = Trainer(\n", + " model=model,\n", + " driver='torch',\n", + " device=1, # [0, 1],\n", + " n_epochs=10,\n", + " optimizers=optimizers,\n", + " train_dataloader=dataloader_train,\n", + " evaluate_dataloaders=dataloader_valid,\n", + " metrics={'acc': Accuracy()}\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + " 使用`trainer.run`方法训练模型,同样每次只对验证集中的`10`个`batch`进行评估" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[22:53:00] INFO Running evaluator sanity check for 2 batches. trainer.py:592\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[22:53:00]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=406635;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=951504;file://../fastNLP/core/controllers/trainer.py#592\u001b\\\u001b[2m592\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:1, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.540625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 173.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.540625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m173.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:2, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m2\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.5,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 160.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.5\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m160.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:3, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m3\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.509375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 163.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.509375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m163.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:4, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m4\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.634375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 203.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.634375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m203.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:5, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m5\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.6125,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 196.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.6125\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m196.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:6, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m6\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.675,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 216.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.675\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m216.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:7, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m7\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.64375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 206.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.64375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m206.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:8, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m8\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.665625,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 213.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.665625\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m213.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
----------------------------- Eval. results on Epoch:9, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "----------------------------- Eval. results on Epoch:\u001b[1;36m9\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.659375,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 211.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.659375\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m211.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:10, Batch:0 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m10\u001b[0m, Batch:\u001b[1;36m0\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#acc\": 0.696875,\n", + " \"total#acc\": 320.0,\n", + " \"correct#acc\": 223.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#acc\"\u001b[0m: \u001b[1;36m0.696875\u001b[0m,\n", + " \u001b[1;34m\"total#acc\"\u001b[0m: \u001b[1;36m320.0\u001b[0m,\n", + " \u001b[1;34m\"correct#acc\"\u001b[0m: \u001b[1;36m223.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run(num_eval_batch_per_dl=10)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "可以发现,其效果远远逊色于`fine-tuning`,这是因为`P-Tuning v2`虽然能够适应参数量\n", + "\n", + " 在`100M-1B`区间的模型,但是,**distilbert-base 的参数量仅为 66M**,无法触及其下限\n", + "\n", + "另一方面,**fastNLP v1.0 不支持 jupyter 多卡**,所以无法在笔者的电脑/服务器上,完成\n", + "\n", + " 合适规模模型的学习,例如`110M`的`bert-base`模型,以及`340M`的`bert-large`模型" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.737385, 'total#acc': 872.0, 'correct#acc': 643.0}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 1 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb new file mode 100644 index 00000000..a5883416 --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e1.ipynb @@ -0,0 +1,1086 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E3. 使用 paddlenlp 和 fastNLP 实现中文文本情感分析\n", + "\n", + "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何使用 `paddlenlp` 自然语言处理库和 `fastNLP` 来完成比较简单的情感分析任务。\n", + "\n", + "1. 基础介绍:飞桨自然语言处理库 ``paddlenlp`` 和语义理解框架 ``ERNIE``\n", + "\n", + "2. 准备工作:使用 ``tokenizer`` 处理数据并构造 ``dataloader``\n", + "\n", + "3. 模型训练:加载 ``ERNIE`` 预训练模型,使用 ``fastNLP`` 进行训练" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:飞桨自然语言处理库 paddlenlp 和语义理解框架 ERNIE\n", + "\n", + "#### 1.1 飞桨自然语言处理库 paddlenlp\n", + "\n", + "``paddlenlp`` 是由百度以飞桨 ``PaddlePaddle`` 为核心开发的自然语言处理库,集成了多个数据集和 NLP 模型,包括百度自研的语义理解框架 ``ERNIE`` 。在本篇教程中,我们会以 ``paddlenlp`` 为基础,使用模型 ``ERNIE`` 完成中文情感分析任务。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.3.3\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append(\"../\")\n", + "\n", + "import paddle\n", + "import paddlenlp\n", + "from paddlenlp.transformers import AutoTokenizer\n", + "from paddlenlp.transformers import AutoModelForSequenceClassification\n", + "\n", + "print(paddlenlp.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 1.2 语义理解框架 ERNIE\n", + "\n", + "``ERNIE(Enhanced Representation from kNowledge IntEgration)`` 是百度提出的基于知识增强的持续学习语义理解框架,至今已有 ``ERNIE 2.0``、``ERNIE 3.0``、``ERNIE-M``、``ERNIE-tiny`` 等多种预训练模型。``ERNIE 1.0`` 采用``Transformer Encoder`` 作为其语义表示的骨架,并改进了两种 ``mask`` 策略,分别为基于**短语**和**实体**(人名、组织等)的策略。在 ``ERNIE`` 中,由多个字组成的短语或者实体将作为一个统一单元,在训练的时候被统一地 ``mask`` 掉,这样可以潜在地学习到知识的依赖以及更长的语义依赖来让模型更具泛化性。\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "``ERNIE 2.0`` 则提出了连续学习(``Continual Learning``)的概念,即首先用一个简单的任务来初始化模型,在更新时用前一个任务训练好的参数作为下一个任务模型初始化的参数。这样在训练新的任务时,模型便可以记住之前学习到的知识,使得模型在新任务上获得更好的表现。``ERNIE 2.0`` 分别构建了词法、语法、语义不同级别的预训练任务,并使用不同的 task id 来标示不同的任务,在共计16个中英文任务上都取得了SOTA效果。\n", + "\n", + "\n", + "\n", + "``ERNIE 3.0`` 将自回归和自编码网络融合在一起进行预训练,其中自编码网络采用 ``ERNIE 2.0`` 的多任务学习增量式构建预训练任务,持续进行语义理解学习。其中自编码网络增加了知识增强的预训练任务。自回归网络则基于 ``Tranformer-XL`` 结构,支持长文本语言模型建模,并在多个自然语言处理任务中取得了SOTA的效果。\n", + "\n", + "\n", + "\n", + "接下来,我们将展示如何在 ``fastNLP`` 中使用基于 ``paddle`` 的 ``ERNIE 1.0`` 框架进行中文情感分析。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 使用 tokenizer 处理数据并构造 dataloader\n", + "\n", + "#### 2.1 加载中文数据集 ChnSentiCorp\n", + "\n", + "``ChnSentiCorp`` 数据集是由中国科学院发布的中文句子级情感分析数据集,包含了从网络上获取的酒店、电影、书籍等多个领域的评论,每条评论都被划分为两个标签:消极(``0``)和积极(``1``),可以用于二分类的中文情感分析任务。通过 ``paddlenlp.datasets.load_dataset`` 函数,我们可以加载并查看 ``ChnSentiCorp`` 数据集的内容。" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "训练集大小: 9600\n", + "{'text': '选择珠江花园的原因就是方便,有电动扶梯直接到达海边,周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般,但还算整洁。 泳池在大堂的屋顶,因此很小,不过女儿倒是喜欢。 包的早餐是西式的,还算丰富。 服务吗,一般', 'label': 1, 'qid': ''}\n", + "{'text': '15.4寸笔记本的键盘确实爽,基本跟台式机差不多了,蛮喜欢数字小键盘,输数字特方便,样子也很美观,做工也相当不错', 'label': 1, 'qid': ''}\n", + "{'text': '房间太小。其他的都一般。。。。。。。。。', 'label': 0, 'qid': ''}\n" + ] + } + ], + "source": [ + "from paddlenlp.datasets import load_dataset\n", + "\n", + "train_dataset, val_dataset, test_dataset = load_dataset(\"chnsenticorp\", splits=[\"train\", \"dev\", \"test\"])\n", + "print(\"训练集大小:\", len(train_dataset))\n", + "for i in range(3):\n", + " print(train_dataset[i])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 处理数据\n", + "\n", + "可以看到,原本的数据集仅包含中文的文本和标签,这样的数据是无法被模型识别的。同英文文本分类任务一样,我们需要使用 ``tokenizer`` 对文本进行分词并转换为数字形式的结果。我们可以加载已经预训练好的中文分词模型 ``ernie-1.0-base-zh``,将分词的过程写在函数 ``_process`` 中,然后调用数据集的 ``map`` 函数对每一条数据进行分词。其中:\n", + "- 参数 ``max_length`` 代表句子的最大长度;\n", + "- ``padding=\"max_length\"`` 表示将长度不足的结果 padding 至和最大长度相同;\n", + "- ``truncation=True`` 表示将长度过长的句子进行截断。\n", + "\n", + "至此,我们得到了每条数据长度均相同的数据集。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-22 21:31:04,168] [ INFO]\u001b[0m - We are using
[21:31:16] INFO Running evaluator sanity check for 2 batches. trainer.py:631\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:31:16]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=4641;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=822054;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:60 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.895833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1075.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.895833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1075.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:120 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.8975,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1077.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.8975\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1077.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:180 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.911667,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1094.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.911667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1094.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:240 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.9225,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1107.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9225\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1107.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.9275,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1113.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9275\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1113.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:60 -----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m60\u001b[0m -----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.930833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1117.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.930833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1117.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:120 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m120\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.935833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1123.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:180 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m180\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.935833,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1123.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.935833\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1123.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:240 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m240\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.9375,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1125.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.9375\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1125.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:1, Batch:300 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m1\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"acc#accuracy\": 0.941667,\n", + " \"total#accuracy\": 1200.0,\n", + " \"correct#accuracy\": 1130.0\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"acc#accuracy\"\u001b[0m: \u001b[1;36m0.941667\u001b[0m,\n", + " \u001b[1;34m\"total#accuracy\"\u001b[0m: \u001b[1;36m1200.0\u001b[0m,\n", + " \u001b[1;34m\"correct#accuracy\"\u001b[0m: \u001b[1;36m1130.0\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:28] INFO Loading best model from fnlp-ernie/2022-0 load_best_model_callback.py:111\n", + " 6-22-21_29_12_898095/best_so_far with \n", + " acc#accuracy: 0.941667... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m0\u001b[0m \u001b]8;id=340364;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=763898;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m6\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_898095/best_so_far with \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m acc#accuracy: \u001b[1;36m0.941667\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[21:34:34] INFO Deleting fnlp-ernie/2022-06-22-21_29_12_8 load_best_model_callback.py:131\n", + " 98095/best_so_far... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[21:34:34]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m22\u001b[0m-21_29_12_8 \u001b]8;id=430330;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=508566;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 98095/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import LRSchedCallback, LoadBestModelCallback\n", + "from fastNLP import Trainer, Accuracy\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 2\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(5e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks = [\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"acc#accuracy\", larger_better=True, save_folder=\"fnlp-ernie\"),\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"paddle\",\n", + " optimizers=optimizer,\n", + " device=0,\n", + " n_epochs=n_epochs,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " evaluate_every=60,\n", + " metrics={\"accuracy\": Accuracy()},\n", + " callbacks=callbacks,\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 测试和评估\n", + "\n", + "现在我们已经得到了一个表现良好的 ``ERNIE`` 模型,接下来可以在测试集上测试模型的效果了。``fastNLP.Evaluator`` 提供了定制函数的功能。我们以 ``test_dataloader`` 初始化一个 ``Evaluator``,然后将写好的测试函数 ``test_batch_step_fn`` 传给参数 ``evaluate_batch_step_fn``,``Evaluate`` 在对每个 batch 进行评估时就会调用我们自定义的 ``test_batch_step_fn`` 函数而不是 ``evaluate_step`` 函数。在这里,我们仅测试 5 条数据并输出文本和对应的标签。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n", + "\n" + ], + "text/plain": [ + "text: ['这个宾馆比较陈旧了,特价的房间也很一般。总体来说一般']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n", + "\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n", + "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n", + "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n", + "集算什么??简直是画蛇添足!!']\n", + "\n" + ], + "text/plain": [ + "text: ['怀着十分激动的心情放映,可是看着看着发现,在放映完毕后,出现一集米老鼠的动画片!开始\n", + "还怀疑是不是赠送的个别现象,可是后来发现每张DVD后面都有!真不知道生产商怎么想的,我想看的是猫\n", + "和老鼠,不是米老鼠!如果厂家是想赠送的话,那就全套米老鼠和唐老鸭都赠送,只在每张DVD后面添加一\n", + "集算什么??简直是画蛇添足!!']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n", + "\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n", + "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n", + "]\n", + "\n" + ], + "text/plain": [ + "text: ['还稍微重了点,可能是硬盘大的原故,还要再轻半斤就好了。其他要进一步验证。贴的几种膜气\n", + "泡较多,用不了多久就要更换了,屏幕膜稍好点,但比没有要强多了。建议配赠几张膜让用用户自己贴。'\n", + "]\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 0\n", + "\n" + ], + "text/plain": [ + "labels: 0\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['交通方便;环境很好;服务态度很好 房间较小']\n", + "\n" + ], + "text/plain": [ + "text: ['交通方便;环境很好;服务态度很好 房间较小']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n", + "\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n", + "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n", + "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n", + "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n", + "。']\n", + "\n" + ], + "text/plain": [ + "text: ['不错,作者的观点很颠覆目前中国父母的教育方式,其实古人们对于教育已经有了很系统的体系\n", + "了,可是现在的父母以及祖父母们更多的娇惯纵容孩子,放眼看去自私的孩子是大多数,父母觉得自己的\n", + "孩子在外面只要不吃亏就是好事,完全把古人几千年总结的教育古训抛在的九霄云外。所以推荐准妈妈们\n", + "可以在等待宝宝降临的时候,好好学习一下,怎么把孩子教育成一个有爱心、有责任心、宽容、大度的人\n", + "。']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
labels: 1\n", + "\n" + ], + "text/plain": [ + "labels: 1\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{}" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "def test_batch_step_fn(evaluator, batch):\n", + " input_ids = batch[\"input_ids\"]\n", + " attention_mask = batch[\"attention_mask\"]\n", + " token_type_ids = batch[\"token_type_ids\"]\n", + " logits = model(input_ids, attention_mask, token_type_ids)\n", + " predict = logits.argmax().item()\n", + " print(\"text:\", batch['text'])\n", + " print(\"labels:\", predict)\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=test_dataloader,\n", + " driver=\"paddle\",\n", + " device=0,\n", + " evaluate_batch_step_fn=test_batch_step_fn,\n", + ")\n", + "evaluator.run(5) " + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('fnlp-paddle')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb new file mode 100644 index 00000000..439d7f9f --- /dev/null +++ b/docs/source/tutorials/fastnlp_tutorial_paddle_e2.ipynb @@ -0,0 +1,1510 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# E4. 使用 paddlenlp 和 fastNLP 训练中文阅读理解任务\n", + "\n", + "本篇教程属于 **fastNLP v1.0 tutorial 的 paddle examples 系列**。在本篇教程中,我们将为您展示如何在 `fastNLP` 中通过自定义 `Metric` 和 损失函数来完成进阶的问答任务。\n", + "\n", + "1. 基础介绍:自然语言处理中的阅读理解任务\n", + "\n", + "2. 准备工作:加载 `DuReader-robust` 数据集,并使用 `tokenizer` 处理数据\n", + "\n", + "3. 模型训练:自己定义评测用的 `Metric` 实现更加自由的任务评测" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. 基础介绍:自然语言处理中的阅读理解任务\n", + "\n", + "阅读理解任务,顾名思义,就是给出一段文字,然后让模型理解这段文字所含的语义。大部分机器阅读理解任务都采用问答式测评,即设计与文章内容相关的自然语言式问题,让模型理解问题并根据文章作答。与文本分类任务不同的是,在阅读理解任务中我们有时需要需要输入“一对”句子,分别代表问题和上下文;答案的格式也分为多种:\n", + "\n", + "- 多项选择:让模型从多个答案选项中选出正确答案\n", + "- 区间答案:答案为上下文的一段子句,需要模型给出答案的起始位置\n", + "- 自由回答:不做限制,让模型自行生成答案\n", + "- 完形填空:在原文中挖空部分关键词,让模型补全;这类答案往往不需要问题\n", + "\n", + "如果您对 `transformers` 有所了解的话,其中的 `ModelForQuestionAnswering` 系列模型就可以用于这项任务。阅读理解模型的泛用性是衡量该技术能否在实际应用中大规模落地的重要指标之一,随着当前技术的进步,许多模型虽然能够在一些测试集上取得较好的性能,但在实际应用中,这些模型仍然难以让人满意。在本篇教程中,我们将会为您展示如何训练一个问答模型。\n", + "\n", + "在这一领域,`SQuAD` 数据集是一个影响深远的数据集。它的全称是斯坦福问答数据集(Stanford Question Answering Dataset),每条数据包含 `(问题,上下文,答案)` 三部分,规模大(约十万条,2.0又新增了五万条),在提出之后很快成为训练问答任务的经典数据集之一。`SQuAD` 数据集有两个指标来衡量模型的表现:`EM`(Exact Match,精确匹配)和 `F1`(模糊匹配)。前者反应了模型给出的答案中有多少和正确答案完全一致,后者则反应了模型给出的答案中与正确答案重叠的部分,均为越高越好。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. 准备工作:加载 DuReader-robust 数据集,并使用 tokenizer 处理数据" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/remote-home/shxing/anaconda3/envs/fnlp-paddle/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2.3.3\n" + ] + } + ], + "source": [ + "import sys\n", + "sys.path.append(\"../\")\n", + "import paddle\n", + "import paddlenlp\n", + "\n", + "print(paddlenlp.__version__)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在数据集方面,我们选用 `DuReader-robust` 中文数据集作为训练数据。它是一种抽取式问答数据集,采用 `SQuAD` 数据格式,能够评估真实应用场景下模型的泛用性。" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n", + "Reusing dataset dureader_robust (/remote-home/shxing/.cache/huggingface/datasets/dureader_robust/plain_text/1.0.0/d462ecadc8c010cee20f57632f1413f272867cd802a91a602df48c7d34eb0c27)\n", + "\u001b[32m[2022-06-27 19:22:46,998] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/vocab.txt\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'id': '0a25cb4bc1ab6f474c699884e04601e4', 'title': '', 'context': '第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。', 'question': '仙剑奇侠传3第几集上天界', 'answers': {'text': ['第35集'], 'answer_start': [0]}}\n", + "{'id': '7de192d6adf7d60ba73ba25cf590cc1e', 'title': '', 'context': '选择燃气热水器时,一定要关注这几个问题:1、出水稳定性要好,不能出现忽热忽冷的现象2、快速到达设定的需求水温3、操作要智能、方便4、安全性要好,要装有安全报警装置 市场上燃气热水器品牌众多,购买时还需多加对比和仔细鉴别。方太今年主打的磁化恒温热水器在使用体验方面做了全面升级:9秒速热,可快速进入洗浴模式;水温持久稳定,不会出现忽热忽冷的现象,并通过水量伺服技术将出水温度精确控制在±0.5℃,可满足家里宝贝敏感肌肤洗护需求;配备CO和CH4双气体报警装置更安全(市场上一般多为CO单气体报警)。另外,这款热水器还有智能WIFI互联功能,只需下载个手机APP即可用手机远程操作热水器,实现精准调节水温,满足家人多样化的洗浴需求。当然方太的磁化恒温系列主要的是增加磁化功能,可以有效吸附水中的铁锈、铁屑等微小杂质,防止细菌滋生,使沐浴水质更洁净,长期使用磁化水沐浴更利于身体健康。', 'question': '燃气热水器哪个牌子好', 'answers': {'text': ['方太'], 'answer_start': [110]}}\n", + "{'id': 'b9e74d4b9228399b03701d1fe6d52940', 'title': '', 'context': '迈克尔.乔丹在NBA打了15个赛季。他在84年进入nba,期间在1993年10月6日第一次退役改打棒球,95年3月18日重新回归,在99年1月13日第二次退役,后于2001年10月31日复出,在03年最终退役。迈克尔·乔丹(Michael Jordan),1963年2月17日生于纽约布鲁克林,美国著名篮球运动员,司职得分后卫,历史上最伟大的篮球运动员。1984年的NBA选秀大会,乔丹在首轮第3顺位被芝加哥公牛队选中。 1986-87赛季,乔丹场均得到37.1分,首次获得分王称号。1990-91赛季,乔丹连夺常规赛MVP和总决赛MVP称号,率领芝加哥公牛首次夺得NBA总冠军。 1997-98赛季,乔丹获得个人职业生涯第10个得分王,并率领公牛队第六次夺得总冠军。2009年9月11日,乔丹正式入选NBA名人堂。', 'question': '乔丹打了多少个赛季', 'answers': {'text': ['15个'], 'answer_start': [12]}}\n", + "训练集大小: 14520\n", + "验证集大小: 1417\n" + ] + } + ], + "source": [ + "from paddlenlp.datasets import load_dataset\n", + "train_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"train\")\n", + "val_dataset = load_dataset(\"PaddlePaddle/dureader_robust\", splits=\"validation\")\n", + "for i in range(3):\n", + " print(train_dataset[i])\n", + "print(\"训练集大小:\", len(train_dataset))\n", + "print(\"验证集大小:\", len(val_dataset))\n", + "\n", + "MODEL_NAME = \"ernie-1.0-base-zh\"\n", + "from paddlenlp.transformers import ErnieTokenizer\n", + "tokenizer =ErnieTokenizer.from_pretrained(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.1 处理训练集\n", + "\n", + "对于阅读理解任务,数据处理的方式较为麻烦。接下来我们会为您详细讲解处理函数 `_process_train` 的功能,同时也将通过实践展示关于 `tokenizer` 的更多功能,让您更加深入地了解自然语言处理任务。首先让我们向 `tokenizer` 输入一条数据(以列表的形式):" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2\n", + "dict_keys(['offset_mapping', 'input_ids', 'token_type_ids', 'overflow_to_sample'])\n" + ] + } + ], + "source": [ + "result = tokenizer(\n", + " [train_dataset[0][\"question\"]],\n", + " [train_dataset[0][\"context\"]],\n", + " stride=128,\n", + " max_length=256,\n", + " padding=\"max_length\",\n", + " return_dict=False\n", + ")\n", + "\n", + "print(len(result))\n", + "print(result[0].keys())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "首先不难理解的是,模型必须要同时接受问题(`question`)和上下文(`context`)才能够进行阅读理解,因此我们需要将二者同时进行分词(`tokenize`)。所幸,`Tokenizer` 提供了这一功能,当我们调用 `tokenizer` 的时候,其第一个参数名为 `text`,第二个参数名为 `text_pair`,这使得我们可以同时对一对文本进行分词。同时,`tokenizer` 还需要标记出一条数据中哪些属于问题,哪些属于上下文,这一功能则由 `token_type_ids` 完成。`token_type_ids` 会将输入的第一个文本(问题)标记为 `0`,第二个文本(上下文)标记为 `1`,这样模型在训练时便可以将问题和上下文区分开来:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2]\n", + "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓', '缓', '张', '开', '眼', '睛', ',', '景', '天', '又', '惊', '又', '喜', '之', '际', ',', '长', '卿', '和', '紫', '萱', '的', '仙', '船', '驶', '至', ',', '见', '众', '人', '无', '恙', ',', '也', '十', '分', '高', '兴', '。', '众', '人', '登', '船', ',', '用', '尽', '合', '力', '把', '自', '身', '的', '真', '气', '和', '水', '分', '输', '给', '她', '。', '雪', '见', '终', '于', '醒', '过', '来', '了', ',', '但', '却', '一', '脸', '木', '然', ',', '全', '无', '反', '应', '。', '众', '人', '向', '常', '胤', '求', '助', ',', '却', '发', '现', '人', '世', '界', '竟', '没', '有', '雪', '见', '的', '身', '世', '纪', '录', '。', '长', '卿', '询', '问', '清', '微', '的', '身', '世', ',', '清', '微', '语', '带', '双', '关', '说', '一', '切', '上', '了', '天', '界', '便', '有', '答', '案', '。', '长', '卿', '驾', '驶', '仙', '船', ',', '众', '人', '决', '定', '立', '马', '动', '身', ',', '往', '天', '界', '而', '去', '。', '众', '人', '来', '到', '一', '荒', '山', ',', '长', '卿', '指', '出', ',', '魔', '界', '和', '天', '界', '相', '连', '。', '由', '魔', '界', '进', '入', '通', '过', '神', '魔', '之', '井', ',', '便', '可', '登', '天', '。', '众', '人', '至', '魔', '界', '入', '口', ',', '仿', '若', '一', '黑', '色', '的', '蝙', '蝠', '洞', ',', '但', '始', '终', '无', '法', '进', '入', '。', '后', '来', '花', '楹', '发', '现', '只', '要', '有', '翅', '膀', '便', '能', '飞', '入', '[SEP]']\n", + "[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n" + ] + } + ], + "source": [ + "print(result[0][\"input_ids\"])\n", + "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"]))\n", + "print(result[0][\"token_type_ids\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "根据上面的输出我们可以看出,`tokenizer` 会将数据开头用 `[CLS]` 标记,用 `[SEP]` 来分割句子。同时,根据 `token_type_ids` 得到的 0、1 串,我们也很容易将问题和上下文区分开。顺带一提,如果一条数据进行了 `padding`,那么这部分会被标记为 `0` 。\n", + "\n", + "在输出的 `keys` 中还有一项名为 `offset_mapping` 的键。该项数据能够表示分词后的每个 `token` 在原文中对应文字或词语的位置。比如我们可以像下面这样将数据打印出来:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7)]\n", + "[1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427]\n", + "['[CLS]', '仙', '剑', '奇', '侠', '传', '3', '第', '几', '集', '上', '天', '界', '[SEP]', '第', '35', '集', '雪', '见', '缓']\n" + ] + } + ], + "source": [ + "print(result[0][\"offset_mapping\"][:20])\n", + "print(result[0][\"input_ids\"][:20])\n", + "print(tokenizer.convert_ids_to_tokens(result[0][\"input_ids\"])[:20])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`[CLS]` 由于是 `tokenizer` 自己添加进去用于标记数据的 `token`,因此它在原文中找不到任何对应的词语,所以给出的位置范围就是 `(0, 0)`;第二个 `token` 对应第一个 `“仙”` 字,因此映射的位置就是 `(0, 1)`;同理,后面的 `[SEP]` 也不对应任何文字,映射的位置为 `(0, 0)`;而接下来的 `token` 对应 **上下文** 中的第一个字 `“第”`,映射出的位置为 `(0, 1)`;再后面的 `token` 对应原文中的两个字符 `35`,因此其位置映射为 `(1, 3)` 。通过这种手段,我们可以更方便地获取 `token` 与原文的对应关系。\n", + "\n", + "最后,您也许会注意到我们获取的 `result` 长度为 2 。这是文本在分词后长度超过了 `max_length` 256 ,`tokenizer` 将数据分成了两部分所致。在阅读理解任务中,我们不可能像文本分类那样轻易地将一条数据截断,因为答案很可能就出现在后面被丢弃的那部分数据中,因此,我们需要保留所有的数据(当然,您也可以直接丢弃这些超长的数据)。`overflow_to_sample` 则可以标识当前数据在原数据的索引:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[CLS]仙剑奇侠传3第几集上天界[SEP]第35集雪见缓缓张开眼睛,景天又惊又喜之际,长卿和紫萱的仙船驶至,见众人无恙,也十分高兴。众人登船,用尽合力把自身的真气和水分输给她。雪见终于醒过来了,但却一脸木然,全无反应。众人向常胤求助,却发现人世界竟没有雪见的身世纪录。长卿询问清微的身世,清微语带双关说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入[SEP]\n", + "overflow_to_sample: 0\n", + "[CLS]仙剑奇侠传3第几集上天界[SEP]说一切上了天界便有答案。长卿驾驶仙船,众人决定立马动身,往天界而去。众人来到一荒山,长卿指出,魔界和天界相连。由魔界进入通过神魔之井,便可登天。众人至魔界入口,仿若一黑色的蝙蝠洞,但始终无法进入。后来花楹发现只要有翅膀便能飞入。于是景天等人打下许多乌鸦,模仿重楼的翅膀,制作数对翅膀状巨物。刚佩戴在身,便被吸入洞口。众人摔落在地,抬头发现魔界守卫。景天和众魔套交情,自称和魔尊重楼相熟,众魔不理,打了起来。[SEP][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD][PAD]\n", + "overflow_to_sample: 0\n" + ] + } + ], + "source": [ + "for res in result:\n", + " tokens = tokenizer.convert_ids_to_tokens(res[\"input_ids\"])\n", + " print(\"\".join(tokens))\n", + " print(\"overflow_to_sample: \", res[\"overflow_to_sample\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "将两条数据均输出之后可以看到,它们都出自我们传入的数据,并且存在一部分重合。`tokenizer` 的 `stride` 参数可以设置重合部分的长度,这也可以帮助模型识别被分割开的两条数据;`overflow_to_sample` 的 `0` 则代表它们来自于第 `0` 条数据。\n", + "\n", + "基于以上信息,我们处理训练集的思路如下:\n", + "\n", + "1. 通过 `overflow_to_sample` 来获取原来的数据\n", + "2. 通过原数据的 `answers` 找到答案的起始位置\n", + "3. 通过 `offset_mapping` 给出的映射关系在分词处理后的数据中找到答案的起始位置,分别记录在 `start_pos` 和 `end_pos` 中;如果没有找到答案(比如答案被截断了),那么答案的起始位置就被标记为 `[CLS]` 的位置。\n", + "\n", + "这样 `_process_train` 函数就呼之欲出了,我们调用 `train_dataset.map` 函数,并将 `batched` 参数设置为 `True` ,将所有数据批量地进行更新。有一点需要注意的是,**在处理过后数据量会增加**。" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'offset_mapping': [(0, 0), (0, 1), (1, 2), (2, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (0, 0), (0, 1), (1, 3), (3, 4), (4, 5), (5, 6), (6, 7), (7, 8), (8, 9), (9, 10), (10, 11), (11, 12), (12, 13), (13, 14), (14, 15), (15, 16), (16, 17), (17, 18), (18, 19), (19, 20), (20, 21), (21, 22), (22, 23), (23, 24), (24, 25), (25, 26), (26, 27), (27, 28), (28, 29), (29, 30), (30, 31), (31, 32), (32, 33), (33, 34), (34, 35), (35, 36), (36, 37), (37, 38), (38, 39), (39, 40), (40, 41), (41, 42), (42, 43), (43, 44), (44, 45), (45, 46), (46, 47), (47, 48), (48, 49), (49, 50), (50, 51), (51, 52), (52, 53), (53, 54), (54, 55), (55, 56), (56, 57), (57, 58), (58, 59), (59, 60), (60, 61), (61, 62), (62, 63), (63, 64), (64, 65), (65, 66), (66, 67), (67, 68), (68, 69), (69, 70), (70, 71), (71, 72), (72, 73), (73, 74), (74, 75), (75, 76), (76, 77), (77, 78), (78, 79), (79, 80), (80, 81), (81, 82), (82, 83), (83, 84), (84, 85), (85, 86), (86, 87), (87, 88), (88, 89), (89, 90), (90, 91), (91, 92), (92, 93), (93, 94), (94, 95), (95, 96), (96, 97), (97, 98), (98, 99), (99, 100), (100, 101), (101, 102), (102, 103), (103, 104), (104, 105), (105, 106), (106, 107), (107, 108), (108, 109), (109, 110), (110, 111), (111, 112), (112, 113), (113, 114), (114, 115), (115, 116), (116, 117), (117, 118), (118, 119), (119, 120), (120, 121), (121, 122), (122, 123), (123, 124), (124, 125), (125, 126), (126, 127), (127, 128), (128, 129), (129, 130), (130, 131), (131, 132), (132, 133), (133, 134), (134, 135), (135, 136), (136, 137), (137, 138), (138, 139), (139, 140), (140, 141), (141, 142), (142, 143), (143, 144), (144, 145), (145, 146), (146, 147), (147, 148), (148, 149), (149, 150), (150, 151), (151, 152), (152, 153), (153, 154), (154, 155), (155, 156), (156, 157), (157, 158), (158, 159), (159, 160), (160, 161), (161, 162), (162, 163), (163, 164), (164, 165), (165, 166), (166, 167), (167, 168), (168, 169), (169, 170), (170, 171), (171, 172), (172, 173), (173, 174), (174, 175), (175, 176), (176, 177), (177, 178), (178, 179), (179, 180), (180, 181), (181, 182), (182, 183), (183, 184), (184, 185), (185, 186), (186, 187), (187, 188), (188, 189), (189, 190), (190, 191), (191, 192), (192, 193), (193, 194), (194, 195), (195, 196), (196, 197), (197, 198), (198, 199), (199, 200), (200, 201), (201, 202), (202, 203), (203, 204), (204, 205), (205, 206), (206, 207), (207, 208), (208, 209), (209, 210), (210, 211), (211, 212), (212, 213), (213, 214), (214, 215), (215, 216), (216, 217), (217, 218), (218, 219), (219, 220), (220, 221), (221, 222), (222, 223), (223, 224), (224, 225), (225, 226), (226, 227), (227, 228), (228, 229), (229, 230), (230, 231), (231, 232), (232, 233), (233, 234), (234, 235), (235, 236), (236, 237), (237, 238), (238, 239), (239, 240), (240, 241), (241, 242), (0, 0)], 'input_ids': [1, 1034, 1189, 734, 2003, 241, 284, 131, 553, 271, 28, 125, 280, 2, 131, 1773, 271, 1097, 373, 1427, 1427, 501, 88, 662, 1906, 4, 561, 125, 311, 1168, 311, 692, 46, 430, 4, 84, 2073, 14, 1264, 3967, 5, 1034, 1020, 1829, 268, 4, 373, 539, 8, 154, 5210, 4, 105, 167, 59, 69, 685, 12043, 539, 8, 883, 1020, 4, 29, 720, 95, 90, 427, 67, 262, 5, 384, 266, 14, 101, 59, 789, 416, 237, 12043, 1097, 373, 616, 37, 1519, 93, 61, 15, 4, 255, 535, 7, 1529, 619, 187, 4, 62, 154, 451, 149, 12043, 539, 8, 253, 223, 3679, 323, 523, 4, 535, 34, 87, 8, 203, 280, 1186, 340, 9, 1097, 373, 5, 262, 203, 623, 704, 12043, 84, 2073, 1137, 358, 334, 702, 5, 262, 203, 4, 334, 702, 405, 360, 653, 129, 178, 7, 568, 28, 15, 125, 280, 518, 9, 1179, 487, 12043, 84, 2073, 1621, 1829, 1034, 1020, 4, 539, 8, 448, 91, 202, 466, 70, 262, 4, 638, 125, 280, 83, 299, 12043, 539, 8, 61, 45, 7, 1537, 176, 4, 84, 2073, 288, 39, 4, 889, 280, 14, 125, 280, 156, 538, 12043, 190, 889, 280, 71, 109, 124, 93, 292, 889, 46, 1248, 4, 518, 48, 883, 125, 12043, 539, 8, 268, 889, 280, 109, 270, 4, 1586, 845, 7, 669, 199, 5, 3964, 3740, 1084, 4, 255, 440, 616, 154, 72, 71, 109, 12043, 49, 61, 283, 3591, 34, 87, 297, 41, 9, 1993, 2602, 518, 52, 706, 109, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'overflow_to_sample': 0, 'start_pos': 14, 'end_pos': 16}\n", + "处理后的训练集大小: 26198\n" + ] + } + ], + "source": [ + "max_length = 256\n", + "doc_stride = 128\n", + "def _process_train(data):\n", + "\n", + " contexts = [data[i][\"context\"] for i in range(len(data))]\n", + " questions = [data[i][\"question\"] for i in range(len(data))]\n", + "\n", + " tokenized_data_list = tokenizer(\n", + " questions,\n", + " contexts,\n", + " stride=doc_stride,\n", + " max_length=max_length,\n", + " padding=\"max_length\",\n", + " return_dict=False\n", + " )\n", + "\n", + " for i, tokenized_data in enumerate(tokenized_data_list):\n", + " # 获取 [CLS] 对应的位置\n", + " input_ids = tokenized_data[\"input_ids\"]\n", + " cls_index = input_ids.index(tokenizer.cls_token_id)\n", + "\n", + " # 在 tokenize 的过程中,汉字和 token 在位置上并非一一对应的\n", + " # 而 offset mapping 记录了每个 token 在原文中对应的起始位置\n", + " offsets = tokenized_data[\"offset_mapping\"]\n", + " # token_type_ids 记录了一条数据中哪些是问题,哪些是上下文\n", + " token_type_ids = tokenized_data[\"token_type_ids\"]\n", + "\n", + " # 一条数据可能因为长度过长而在 tokenized_data 中存在多个结果\n", + " # overflow_to_sample 表示了当前 tokenize_example 属于 data 中的哪一条数据\n", + " sample_index = tokenized_data[\"overflow_to_sample\"]\n", + " answers = data[sample_index][\"answers\"]\n", + "\n", + " # answers 和 answer_starts 均为长度为 1 的 list\n", + " # 我们可以计算出答案的结束位置\n", + " start_char = answers[\"answer_start\"][0]\n", + " end_char = start_char + len(answers[\"text\"][0])\n", + "\n", + " token_start_index = 0\n", + " while token_type_ids[token_start_index] != 1:\n", + " token_start_index += 1\n", + "\n", + " token_end_index = len(input_ids) - 1\n", + " while token_type_ids[token_end_index] != 1:\n", + " token_end_index -= 1\n", + " # 分词后一条数据的结尾一定是 [SEP],因此还需要减一\n", + " token_end_index -= 1\n", + "\n", + " if not (offsets[token_start_index][0] <= start_char and\n", + " offsets[token_end_index][1] >= end_char):\n", + " # 如果答案不在这条数据中,则将答案位置标记为 [CLS] 的位置\n", + " tokenized_data_list[i][\"start_pos\"] = cls_index\n", + " tokenized_data_list[i][\"end_pos\"] = cls_index\n", + " else:\n", + " # 否则,我们可以找到答案对应的 token 的起始位置,记录在 start_pos 和 end_pos 中\n", + " while token_start_index < len(offsets) and offsets[\n", + " token_start_index][0] <= start_char:\n", + " token_start_index += 1\n", + " tokenized_data_list[i][\"start_pos\"] = token_start_index - 1\n", + " while offsets[token_end_index][1] >= end_char:\n", + " token_end_index -= 1\n", + " tokenized_data_list[i][\"end_pos\"] = token_end_index + 1\n", + "\n", + " return tokenized_data_list\n", + "\n", + "train_dataset.map(_process_train, batched=True, num_workers=5)\n", + "print(train_dataset[0])\n", + "print(\"处理后的训练集大小:\", len(train_dataset))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 2.2 处理验证集\n", + "\n", + "对于验证集的处理则简单得多,我们只需要保存原数据的 `id` 并将 `offset_mapping` 中不属于上下文的部分设置为 `None` 即可。" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP.core import PaddleDataLoader\n", + "\n", + "train_dataloader = PaddleDataLoader(train_dataset, batch_size=32, shuffle=True)\n", + "val_dataloader = PaddleDataLoader(val_dataset, batch_size=16)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. 模型训练:自己定义评测用的 Metric 实现更加自由的任务评测\n", + "\n", + "#### 3.1 损失函数\n", + "\n", + "对于阅读理解任务,我们使用的是 `ErnieForQuestionAnswering` 模型。该模型在接受输入后会返回两个值:`start_logits` 和 `end_logits` ,大小均为 `(batch_size, sequence_length)`,反映了每条数据每个词语为答案起始位置的可能性,因此我们需要自定义一个损失函数来计算 `loss`。 `CrossEntropyLossForSquad` 会分别对答案起始位置的预测值和真实值计算交叉熵,最后返回其平均值作为最终的损失。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "class CrossEntropyLossForSquad(paddle.nn.Layer):\n", + " def __init__(self):\n", + " super(CrossEntropyLossForSquad, self).__init__()\n", + "\n", + " def forward(self, start_logits, end_logits, start_pos, end_pos):\n", + " start_pos = paddle.unsqueeze(start_pos, axis=-1)\n", + " end_pos = paddle.unsqueeze(end_pos, axis=-1)\n", + " start_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=start_logits, label=start_pos)\n", + " start_loss = paddle.mean(start_loss)\n", + " end_loss = paddle.nn.functional.softmax_with_cross_entropy(\n", + " logits=end_logits, label=end_pos)\n", + " end_loss = paddle.mean(end_loss)\n", + "\n", + " loss = (start_loss + end_loss) / 2\n", + " return loss" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.2 定义模型\n", + "\n", + "模型的核心则是 `ErnieForQuestionAnswering` 的 `ernie-1.0-base-zh` 预训练模型,同时按照 `fastNLP` 的规定定义 `train_step` 和 `evaluate_step` 函数。这里 `evaluate_step` 函数并没有像文本分类那样直接返回该批次数据的评测结果,这一点我们将在下面为您讲解。" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m[2022-06-27 19:00:15,825] [ INFO]\u001b[0m - Already cached /remote-home/shxing/.paddlenlp/models/ernie-1.0-base-zh/ernie_v1_chn_base.pdparams\u001b[0m\n", + "W0627 19:00:15.831080 21543 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.5, Driver API Version: 11.2, Runtime API Version: 11.2\n", + "W0627 19:00:15.843276 21543 gpu_context.cc:306] device: 0, cuDNN Version: 8.1.\n" + ] + } + ], + "source": [ + "from paddlenlp.transformers import ErnieForQuestionAnswering\n", + "\n", + "class QAModel(paddle.nn.Layer):\n", + " def __init__(self, model_checkpoint):\n", + " super(QAModel, self).__init__()\n", + " self.model = ErnieForQuestionAnswering.from_pretrained(model_checkpoint)\n", + " self.loss_func = CrossEntropyLossForSquad()\n", + "\n", + " def forward(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self.model(input_ids, token_type_ids)\n", + " return start_logits, end_logits\n", + "\n", + " def train_step(self, input_ids, token_type_ids, start_pos, end_pos):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " loss = self.loss_func(start_logits, end_logits, start_pos, end_pos)\n", + " return {\"loss\": loss}\n", + "\n", + " def evaluate_step(self, input_ids, token_type_ids):\n", + " start_logits, end_logits = self(input_ids, token_type_ids)\n", + " return {\"start_logits\": start_logits, \"end_logits\": end_logits}\n", + "\n", + "model = QAModel(MODEL_NAME)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.3 自定义 Metric 进行数据的评估\n", + "\n", + "`paddlenlp` 为我们提供了评测 `SQuAD` 格式数据集的函数 `compute_prediction` 和 `squad_evaluate`:\n", + "- `compute_prediction` 函数要求传入原数据 `examples` 、处理后的数据 `features` 和 `features` 对应的结果 `predictions`(一个包含所有数据 `start_logits` 和 `end_logits` 的元组)\n", + "- `squad_evaluate` 要求传入原数据 `examples` 和预测结果 `all_predictions`(通常来自于 `compute_prediction`)\n", + "\n", + "在使用这两个函数的时候,我们需要向其中传入数据集,但显然根据 `fastNLP` 的设计,我们无法在 `evaluate_step` 里实现这一过程,并且 `fastNLP` 也并没有提供计算 `F1` 和 `EM` 的 `Metric`,故我们需要自己定义用于评测的 `Metric`。\n", + "\n", + "在初始化之外,一个 `Metric` 还需要实现三个函数:\n", + "\n", + "1. `reset` - 该函数会在验证数据集的迭代之前被调用,用于清空数据;在我们自定义的 `Metric` 中,我们需要将 `all_start_logits` 和 `all_end_logits` 清空,重新收集每个 `batch` 的结果。\n", + "2. `update` - 该函数会在在每个 `batch` 得到结果后被调用,用于更新 `Metric` 的状态;它的参数即为 `evaluate_step` 返回的内容。我们在这里将得到的 `start_logits` 和 `end_logits` 收集起来。\n", + "3. `get_metric` - 该函数会在数据集被迭代完毕后调用,用于计算评测的结果。现在我们有了整个验证集的 `all_start_logits` 和 `all_end_logits` ,将他们传入 `compute_predictions` 函数得到预测的结果,并继续使用 `squad_evaluate` 函数得到评测的结果。\n", + " - 注:`suqad_evaluate` 函数会自己输出评测结果,为了不让其干扰 `fastNLP` 输出,这里我们使用 `contextlib.redirect_stdout(None)` 将函数的标准输出屏蔽掉。\n", + "\n", + "综上,`SquadEvaluateMetric` 实现的评估过程是:将验证集中所有数据的 `logits` 收集起来,然后统一传入 `compute_prediction` 和 `squad_evaluate` 中进行评估。值得一提的是,`paddlenlp.datasets.load_dataset` 返回的结果是一个 `MapDataset` 类型,其 `data` 成员为加载时的数据,`new_data` 为经过 `map` 函数处理后更新的数据,因此可以分别作为 `examples` 和 `features` 传入。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core import Metric\n", + "from paddlenlp.metrics.squad import squad_evaluate, compute_prediction\n", + "import contextlib\n", + "\n", + "class SquadEvaluateMetric(Metric):\n", + " def __init__(self, examples, features, testing=False):\n", + " super(SquadEvaluateMetric, self).__init__(\"paddle\", False)\n", + " self.examples = examples\n", + " self.features = features\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + " self.testing = testing\n", + "\n", + " def reset(self):\n", + " self.all_start_logits = []\n", + " self.all_end_logits = []\n", + "\n", + " def update(self, start_logits, end_logits):\n", + " for start, end in zip(start_logits, end_logits):\n", + " self.all_start_logits.append(start.numpy())\n", + " self.all_end_logits.append(end.numpy())\n", + "\n", + " def get_metric(self):\n", + " all_predictions, _, _ = compute_prediction(\n", + " self.examples, self.features[:len(self.all_start_logits)],\n", + " (self.all_start_logits, self.all_end_logits),\n", + " False, 20, 30\n", + " )\n", + " with contextlib.redirect_stdout(None):\n", + " result = squad_evaluate(\n", + " examples=self.examples,\n", + " preds=all_predictions,\n", + " is_whitespace_splited=False\n", + " )\n", + "\n", + " if self.testing:\n", + " self.print_predictions(all_predictions)\n", + " return result\n", + "\n", + " def print_predictions(self, preds):\n", + " for i, data in enumerate(self.examples):\n", + " if i >= 5:\n", + " break\n", + " print()\n", + " print(\"原文:\", data[\"context\"])\n", + " print(\"问题:\", data[\"question\"], \\\n", + " \"答案:\", preds[data[\"id\"]], \\\n", + " \"正确答案:\", data[\"answers\"][\"text\"])\n", + "\n", + "metric = SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.4 训练\n", + "\n", + "至此所有的准备工作已经完成,可以使用 `Trainer` 进行训练了。学习率我们依旧采用线性预热策略 `LinearDecayWithWarmup`,优化器为 `AdamW`;回调模块我们选择 `LRSchedCallback` 更新学习率和 `LoadBestModelCallback` 监视评测结果的 `f1` 分数。初始化好 `Trainer` 之后,就将训练的过程交给 `fastNLP` 吧。" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
[19:04:54] INFO Running evaluator sanity check for 2 batches. trainer.py:631\n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[19:04:54]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Running evaluator sanity check for \u001b[1;36m2\u001b[0m batches. \u001b]8;id=367046;file://../fastNLP/core/controllers/trainer.py\u001b\\\u001b[2mtrainer.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=96810;file://../fastNLP/core/controllers/trainer.py#631\u001b\\\u001b[2m631\u001b[0m\u001b]8;;\u001b\\\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:100 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m100\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 49.25899788285109,\n", + " \"f1#squad\": 66.55559127349602,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 49.25899788285109,\n", + " \"HasAns_f1#squad\": 66.55559127349602,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m49.25899788285109\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m66.55559127349602\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:200 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m200\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 57.37473535638673,\n", + " \"f1#squad\": 70.93036525200617,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 57.37473535638673,\n", + " \"HasAns_f1#squad\": 70.93036525200617,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m57.37473535638673\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m70.93036525200617\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:300 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m300\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 63.86732533521524,\n", + " \"f1#squad\": 78.62546663568186,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 63.86732533521524,\n", + " \"HasAns_f1#squad\": 78.62546663568186,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m63.86732533521524\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m78.62546663568186\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:400 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m400\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 64.92589978828511,\n", + " \"f1#squad\": 79.36746074079691,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 64.92589978828511,\n", + " \"HasAns_f1#squad\": 79.36746074079691,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m64.92589978828511\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.36746074079691\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:500 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m500\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 65.70218772053634,\n", + " \"f1#squad\": 80.33295482054824,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 65.70218772053634,\n", + " \"HasAns_f1#squad\": 80.33295482054824,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:600 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m600\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 65.41990119971771,\n", + " \"f1#squad\": 79.7483487059053,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 65.41990119971771,\n", + " \"HasAns_f1#squad\": 79.7483487059053,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.41990119971771\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.7483487059053\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:700 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m700\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 66.61961891319689,\n", + " \"f1#squad\": 80.32432238994133,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 66.61961891319689,\n", + " \"HasAns_f1#squad\": 80.32432238994133,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m66.61961891319689\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m80.32432238994133\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
---------------------------- Eval. results on Epoch:0, Batch:800 ----------------------------\n", + "\n" + ], + "text/plain": [ + "---------------------------- Eval. results on Epoch:\u001b[1;36m0\u001b[0m, Batch:\u001b[1;36m800\u001b[0m ----------------------------\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " \"exact#squad\": 65.84333098094567,\n", + " \"f1#squad\": 79.23169801265415,\n", + " \"total#squad\": 1417,\n", + " \"HasAns_exact#squad\": 65.84333098094567,\n", + " \"HasAns_f1#squad\": 79.23169801265415,\n", + " \"HasAns_total#squad\": 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[1;34m\"exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[1;34m\"HasAns_exact#squad\"\u001b[0m: \u001b[1;36m65.84333098094567\u001b[0m,\n", + " \u001b[1;34m\"HasAns_f1#squad\"\u001b[0m: \u001b[1;36m79.23169801265415\u001b[0m,\n", + " \u001b[1;34m\"HasAns_total#squad\"\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
[19:20:28] INFO Loading best model from fnlp-ernie-squad/ load_best_model_callback.py:111\n", + " 2022-06-27-19_00_15_388554/best_so_far \n", + " with f1#squad: 80.33295482054824... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m[19:20:28]\u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Loading best model from fnlp-ernie-squad/ \u001b]8;id=163935;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=31503;file://../fastNLP/core/callbacks/load_best_model_callback.py#111\u001b\\\u001b[2m111\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m \u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_00_15_388554/best_so_far \u001b[2m \u001b[0m\n", + "\u001b[2;36m \u001b[0m with f1#squad: \u001b[1;36m80.33295482054824\u001b[0m\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
INFO Deleting fnlp-ernie-squad/2022-06-27-19_0 load_best_model_callback.py:131\n", + " 0_15_388554/best_so_far... \n", + "\n" + ], + "text/plain": [ + "\u001b[2;36m \u001b[0m\u001b[2;36m \u001b[0m\u001b[34mINFO \u001b[0m Deleting fnlp-ernie-squad/\u001b[1;36m2022\u001b[0m-\u001b[1;36m06\u001b[0m-\u001b[1;36m27\u001b[0m-19_0 \u001b]8;id=560859;file://../fastNLP/core/callbacks/load_best_model_callback.py\u001b\\\u001b[2mload_best_model_callback.py\u001b[0m\u001b]8;;\u001b\\\u001b[2m:\u001b[0m\u001b]8;id=573263;file://../fastNLP/core/callbacks/load_best_model_callback.py#131\u001b\\\u001b[2m131\u001b[0m\u001b]8;;\u001b\\\n", + "\u001b[2;36m \u001b[0m 0_15_388554/best_so_far\u001b[33m...\u001b[0m \u001b[2m \u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Trainer, LRSchedCallback, LoadBestModelCallback\n", + "from paddlenlp.transformers import LinearDecayWithWarmup\n", + "\n", + "n_epochs = 1\n", + "num_training_steps = len(train_dataloader) * n_epochs\n", + "lr_scheduler = LinearDecayWithWarmup(3e-5, num_training_steps, 0.1)\n", + "optimizer = paddle.optimizer.AdamW(\n", + " learning_rate=lr_scheduler,\n", + " parameters=model.parameters(),\n", + ")\n", + "callbacks=[\n", + " LRSchedCallback(lr_scheduler, step_on=\"batch\"),\n", + " LoadBestModelCallback(\"f1#squad\", larger_better=True, save_folder=\"fnlp-ernie-squad\")\n", + "]\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=val_dataloader,\n", + " device=1,\n", + " optimizers=optimizer,\n", + " n_epochs=n_epochs,\n", + " callbacks=callbacks,\n", + " evaluate_every=100,\n", + " metrics={\"squad\": metric},\n", + ")\n", + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### 3.5 测试\n", + "\n", + "最后,我们可以使用 `Evaluator` 查看我们训练的结果。我们在之前为 `SquadEvaluateMetric` 设置了 `testing` 参数来在测试阶段进行输出,可以看到,训练的结果还是比较不错的。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n", + "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n", + "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n", + "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n", + "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n", + "行垫,油墨外露容易脱落。 \n", + "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n", + "\n" + ], + "text/plain": [ + "原文: 爬行垫根据中间材料的不同可以分为:XPE爬行垫、EPE爬行垫、EVA爬行垫、PVC爬行垫;其中XPE爬\n", + "行垫、EPE爬行垫都属于PE材料加保鲜膜复合而成,都是无异味的环保材料,但是XPE爬行垫是品质较好的爬\n", + "行垫,韩国进口爬行垫都是这种爬行垫,而EPE爬行垫是国内厂家为了减低成本,使用EPE(珍珠棉)作为原料生\n", + "产的一款爬行垫,该材料弹性差,易碎,开孔发泡防水性弱。EVA爬行垫、PVC爬行垫是用EVA或PVC作为原材料\n", + "与保鲜膜复合的而成的爬行垫,或者把图案转印在原材料上,这两款爬行垫通常有异味,如果是图案转印的爬\n", + "行垫,油墨外露容易脱落。 \n", + "当时我儿子爬的时候,我们也买了垫子,但是始终有味。最后就没用了,铺的就的薄毯子让他爬。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n", + "\n" + ], + "text/plain": [ + "问题: 爬行垫什么材质的好 答案: EPE(珍珠棉 正确答案: ['XPE']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n", + "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n", + "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n", + "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n", + "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n", + "10厘米。\n", + "\n" + ], + "text/plain": [ + "原文: 真实情况是160-162。她平时谎报的168是因为不离脚穿高水台恨天高(15厘米) 图1她穿着高水台恨\n", + "天高和刘亦菲一样高,(刘亦菲对外报身高172)范冰冰礼服下厚厚的高水台暴露了她的心机,对比一下两者的\n", + "鞋子吧 图2 穿着高水台恨天高才和刘德华谢霆锋持平,如果她真的有168,那么加上鞋高,刘和谢都要有180?\n", + "明显是不可能的。所以刘德华对外报的身高174减去10-15厘米才是范冰冰的真实身高 图3,范冰冰有一次脱\n", + "鞋上场,这个最说明问题了,看看她的身体比例吧。还有目测一下她手上鞋子的鞋跟有多高多厚吧,至少超过\n", + "10厘米。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n", + "\n" + ], + "text/plain": [ + "问题: 范冰冰多高真实身高 答案: 160-162 正确答案: ['160-162']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n", + "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n", + "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n", + "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n", + "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n", + "\n" + ], + "text/plain": [ + "原文: 防水作为目前高端手机的标配,特别是苹果也支持防水之后,国产大多数高端旗舰手机都已经支持防\n", + "水。虽然我们真的不会故意把手机放入水中,但是有了防水之后,用户心里会多一重安全感。那么近日最为\n", + "火热的小米6防水吗?小米6的防水级别又是多少呢? 小编查询了很多资料发现,小米6确实是防水的,但是为\n", + "了保持低调,同时为了不被别人说防水等级不够,很多资料都没有标注小米是否防水。根据评测资料显示,小\n", + "米6是支持IP68级的防水,是绝对能够满足日常生活中的防水需求的。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n", + "\n" + ], + "text/plain": [ + "问题: 小米6防水等级 答案: IP68级 正确答案: ['IP68级']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n", + "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n", + "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n", + "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n", + "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n", + "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n", + "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n", + "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n", + "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n", + "\n" + ], + "text/plain": [ + "原文: 这位朋友你好,女性出现妊娠反应一般是从6-12周左右,也就是女性怀孕1个多月就会开始出现反应,\n", + "第3个月的时候,妊辰反应基本结束。 而大部分女性怀孕初期都会出现恶心、呕吐的感觉,这些症状都是因\n", + "人而异的,除非恶心、呕吐的非常厉害,才需要就医,否则这些都是刚怀孕的的正常症状。1-3个月的时候可\n", + "以观察一下自己的皮肤,一般女性怀孕初期可能会产生皮肤色素沉淀或是腹壁产生妊娠纹,特别是在怀孕的\n", + "后期更加明显。 还有很多女性怀孕初期会出现疲倦、嗜睡的情况。怀孕三个月的时候,膀胱会受到日益胀\n", + "大的子宫的压迫,容量会变小,所以怀孕期间也会有尿频的现象出现。月经停止也是刚怀孕最容易出现的症\n", + "状,只要是平时月经正常的女性,在性行为后超过正常经期两周,就有可能是怀孕了。 如果你想判断自己是\n", + "否怀孕,可以看看自己有没有这些反应。当然这也只是多数人的怀孕表现,也有部分女性怀孕表现并不完全\n", + "是这样,如果你无法确定自己是否怀孕,最好去医院检查一下。\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n", + "\n" + ], + "text/plain": [ + "问题: 怀孕多久会有反应 答案: 6-12周左右 正确答案: ['6-12周左右', '6-12周', '1个多月']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n", + "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n", + "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n", + "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n", + "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n", + ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n", + "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n", + "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n", + "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n", + "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n", + "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n", + "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n", + "\n" + ], + "text/plain": [ + "原文: 【东奥会计在线——中级会计职称频道推荐】根据《关于提高科技型中小企业研究开发费用税前加计\n", + "扣除比例的通知》的规定,研发费加计扣除比例提高到75%。|财政部、国家税务总局、科技部发布《关于提\n", + "高科技型中小企业研究开发费用税前加计扣除比例的通知》。|通知称,为进一步激励中小企业加大研发投\n", + "入,支持科技创新,就提高科技型中小企业研究开发费用(以下简称研发费用)税前加计扣除比例有关问题发\n", + "布通知。|通知明确,科技型中小企业开展研发活动中实际发生的研发费用,未形成无形资产计入当期损益的\n", + ",在按规定据实扣除的基础上,在2017年1月1日至2019年12月31日期间,再按照实际发生额的75%在税前加计\n", + "扣除;形成无形资产的,在上述期间按照无形资产成本的175%在税前摊销。|科技型中小企业享受研发费用税\n", + "前加计扣除政策的其他政策口径按照《财政部国家税务总局科技部关于完善研究开发费用税前加计扣除政\n", + "策的通知》(财税〔2015〕119号)规定执行。|科技型中小企业条件和管理办法由科技部、财政部和国家税\n", + "务总局另行发布。科技、财政和税务部门应建立信息共享机制,及时共享科技型中小企业的相关信息,加强\n", + "协调配合,保障优惠政策落实到位。|上一篇文章:关于2016年度企业研究开发费用税前加计扣除政策企业所\n", + "得税纳税申报问题的公告 下一篇文章:关于提高科技型中小企业研究开发费用税前加计扣除比例的通知\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n", + "\n" + ], + "text/plain": [ + "问题: 研发费用加计扣除比例 答案: 75% 正确答案: ['75%']\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n" + ], + "text/plain": [] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{\n", + " 'exact#squad': 65.70218772053634,\n", + " 'f1#squad': 80.33295482054824,\n", + " 'total#squad': 1417,\n", + " 'HasAns_exact#squad': 65.70218772053634,\n", + " 'HasAns_f1#squad': 80.33295482054824,\n", + " 'HasAns_total#squad': 1417\n", + "}\n", + "\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\n", + " \u001b[32m'exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m,\n", + " \u001b[32m'HasAns_exact#squad'\u001b[0m: \u001b[1;36m65.70218772053634\u001b[0m,\n", + " \u001b[32m'HasAns_f1#squad'\u001b[0m: \u001b[1;36m80.33295482054824\u001b[0m,\n", + " \u001b[32m'HasAns_total#squad'\u001b[0m: \u001b[1;36m1417\u001b[0m\n", + "\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from fastNLP import Evaluator\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=val_dataloader,\n", + " device=1,\n", + " metrics={\n", + " \"squad\": SquadEvaluateMetric(\n", + " val_dataloader.dataset.data,\n", + " val_dataloader.dataset.new_data,\n", + " testing=True,\n", + " ),\n", + " },\n", + ")\n", + "result = evaluator.run()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.7.13 ('fnlp-paddle')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "31f2d9d3efc23c441973d7c4273acfea8b132b6a578f002629b6b44b8f65e720" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/tutorials/figures/E1-fig-glue-benchmark.png b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png new file mode 100644 index 00000000..515db700 Binary files /dev/null and b/docs/source/tutorials/figures/E1-fig-glue-benchmark.png differ diff --git a/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png new file mode 100644 index 00000000..b5a9c1b8 Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-p-tuning-v2-model.png differ diff --git a/docs/source/tutorials/figures/E2-fig-pet-model.png b/docs/source/tutorials/figures/E2-fig-pet-model.png new file mode 100644 index 00000000..c3c377c0 Binary files /dev/null and b/docs/source/tutorials/figures/E2-fig-pet-model.png differ diff --git a/docs/source/tutorials/figures/T0-fig-parameter-matching.png b/docs/source/tutorials/figures/T0-fig-parameter-matching.png new file mode 100644 index 00000000..24013cc1 Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-parameter-matching.png differ diff --git a/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png new file mode 100644 index 00000000..38222ee8 Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-trainer-and-evaluator.png differ diff --git a/docs/source/tutorials/figures/T0-fig-training-structure.png b/docs/source/tutorials/figures/T0-fig-training-structure.png new file mode 100644 index 00000000..edc2e2ff Binary files /dev/null and b/docs/source/tutorials/figures/T0-fig-training-structure.png differ diff --git a/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png new file mode 100644 index 00000000..803cf34a Binary files /dev/null and b/docs/source/tutorials/figures/T1-fig-dataset-and-vocabulary.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png new file mode 100644 index 00000000..ff2519c4 Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking-levels.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png new file mode 100644 index 00000000..ed003a2f Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-1.0-masking.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png new file mode 100644 index 00000000..d45f65d8 Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-2.0-continual-pretrain.png differ diff --git a/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png new file mode 100644 index 00000000..f50ddb1c Binary files /dev/null and b/docs/source/tutorials/figures/paddle-ernie-3.0-framework.png differ diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index 9885a175..31249c80 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -2,4 +2,4 @@ from fastNLP.envs import * from fastNLP.core import * -__version__ = '0.8.0beta' +__version__ = '1.0.0alpha' diff --git a/fastNLP/core/callbacks/callback_event.py b/fastNLP/core/callbacks/callback_event.py index 8a51b6de..f632cf3c 100644 --- a/fastNLP/core/callbacks/callback_event.py +++ b/fastNLP/core/callbacks/callback_event.py @@ -35,14 +35,14 @@ class Event: :param value: Trainer 的 callback 时机; :param every: 每触发多少次才真正运行一次; - :param once: 在第一次运行后时候再次执行; + :param once: 是否仅运行一次; :param filter_fn: 输入参数的应该为 ``(filter, trainer)``,其中 ``filter`` 对象中包含了 `filter.num_called` 和 `filter.num_executed` 两个变量分别获取当前被调用了多少次,真正执行了多少次;``trainer`` 对象即为当前正在运行的 Trainer; """ every: Optional[int] - once: Optional[int] + once: Optional[bool] - def __init__(self, value: str, every: Optional[int] = None, once: Optional[int] = None, + def __init__(self, value: str, every: Optional[int] = None, once: Optional[bool] = None, filter_fn: Optional[Callable] = None): self.every = every self.once = once @@ -68,7 +68,6 @@ class Event: return Event(value='on_after_trainer_initialized', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_sanity_check_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_sanity_check_begin` 时触发; @@ -85,7 +84,6 @@ class Event: return Event(value='on_sanity_check_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_sanity_check_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_sanity_check_end` 时触发; @@ -101,7 +99,6 @@ class Event: return Event(value='on_sanity_check_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_begin` 时触发; @@ -117,7 +114,6 @@ class Event: return Event(value='on_train_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_end` 时触发; @@ -133,7 +129,6 @@ class Event: return Event(value='on_train_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_epoch_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_epoch_begin` 时触发; @@ -149,7 +144,6 @@ class Event: return Event(value='on_train_epoch_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_epoch_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_epoch_end` 时触发; @@ -165,7 +159,6 @@ class Event: return Event(value='on_train_epoch_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_fetch_data_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_fetch_data_begin` 时触发; @@ -181,7 +174,6 @@ class Event: return Event(value='on_fetch_data_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_fetch_data_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_fetch_data_end` 时触发; @@ -197,7 +189,6 @@ class Event: return Event(value='on_fetch_data_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_batch_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_batch_begin` 时触发; @@ -213,7 +204,6 @@ class Event: return Event(value='on_train_batch_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_train_batch_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_train_batch_end` 时触发; @@ -229,7 +219,6 @@ class Event: return Event(value='on_train_batch_end', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_exception(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_exception` 时触发; @@ -245,7 +234,6 @@ class Event: return Event(value='on_exception', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_save_model(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_save_model` 时触发; @@ -261,7 +249,6 @@ class Event: return Event(value='on_save_model', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_load_model(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_load_model` 时触发; @@ -277,7 +264,6 @@ class Event: return Event(value='on_load_model', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_save_checkpoint(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_save_checkpoint` 时触发; @@ -293,7 +279,6 @@ class Event: return Event(value='on_save_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_load_checkpoint(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; @@ -309,7 +294,6 @@ class Event: return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_load_checkpoint(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_load_checkpoint` 时触发; @@ -325,7 +309,6 @@ class Event: return Event(value='on_load_checkpoint', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_before_backward(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_before_backward` 时触发; @@ -341,7 +324,6 @@ class Event: return Event(value='on_before_backward', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_after_backward(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_after_backward` 时触发; @@ -357,7 +339,6 @@ class Event: return Event(value='on_after_backward', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_before_optimizers_step(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_before_optimizers_step` 时触发; @@ -373,7 +354,6 @@ class Event: return Event(value='on_before_optimizers_step', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_after_optimizers_step(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_after_optimizers_step` 时触发; @@ -389,7 +369,6 @@ class Event: return Event(value='on_after_optimizers_step', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_before_zero_grad(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_before_zero_grad` 时触发; @@ -405,7 +384,6 @@ class Event: return Event(value='on_before_zero_grad', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_after_zero_grad(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_after_zero_grad` 时触发; @@ -421,7 +399,6 @@ class Event: return Event(value='on_after_zero_grad', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_evaluate_begin(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_evaluate_begin` 时触发; @@ -437,7 +414,6 @@ class Event: return Event(value='on_evaluate_begin', every=every, once=once, filter_fn=filter_fn) @staticmethod - def on_evaluate_end(every=None, once=None, filter_fn=None): """ 当 Trainer 运行到 :func:`on_evaluate_end` 时触发; diff --git a/fastNLP/core/collators/padders/oneflow_padder.py b/fastNLP/core/collators/padders/oneflow_padder.py index 5e235a0f..30d73e26 100644 --- a/fastNLP/core/collators/padders/oneflow_padder.py +++ b/fastNLP/core/collators/padders/oneflow_padder.py @@ -7,6 +7,7 @@ from inspect import isclass import numpy as np from fastNLP.envs.imports import _NEED_IMPORT_ONEFLOW +from fastNLP.envs.utils import _module_available if _NEED_IMPORT_ONEFLOW: import oneflow diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 14dad89b..454db3d0 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -84,13 +84,13 @@ class Trainer(TrainerEventTrigger): .. warning:: 当使用分布式训练时, **fastNLP** 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 epoch 中,不同卡 - 用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 - 你自身保证每张卡上所使用的数据是不同的。 + 用以训练的数据是不重叠的。如果您对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 + 您自身保证每张卡上所使用的数据是不同的。 :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 ``torch.distributed.launch/run`` 启动时可以为 ``None``, - 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 - 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 + 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是您可以通过参数 ``input_mapping`` 和 ``output_mapping`` 来实现设备之间 + 数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时您也可以通过在 kwargs 添加参数 ``data_device`` 来让我们帮助您将数据 迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前自己构造 DDP 的场景); device 的可选输入如下所示: @@ -196,7 +196,7 @@ class Trainer(TrainerEventTrigger): 3. 如果此时 batch 此时是其它类型,那么我们将会直接报错; 2. 如果 ``input_mapping`` 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; - 注意该参数会被传进 ``Evaluator`` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时); + 注意该参数会被传进 ``Evaluator`` 中;因此您可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 ``device`` 为 ``None`` 时); 如果 ``Trainer`` 和 ``Evaluator`` 需要使用不同的 ``input_mapping``, 请使用 ``train_input_mapping`` 与 ``evaluate_input_mapping`` 分别进行设置。 :param output_mapping: 应当为一个字典或者函数。作用和 ``input_mapping`` 类似,区别在于其用于转换输出: @@ -367,7 +367,7 @@ class Trainer(TrainerEventTrigger): .. note:: ``Trainer`` 是通过在内部直接初始化一个 ``Evaluator`` 来进行验证; - ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,你需要保证这几个参数得到正确的传入: + ``Trainer`` 内部的 ``Evaluator`` 默认是 None,如果您需要在训练过程中进行验证,您需要保证这几个参数得到正确的传入: 必须的参数:``metrics`` 与 ``evaluate_dataloaders``; @@ -898,7 +898,7 @@ class Trainer(TrainerEventTrigger): 这段代码意味着 ``fn1`` 和 ``fn2`` 会被加入到 ``trainer1``,``fn3`` 会被加入到 ``trainer2``; - 注意如果你使用该函数修饰器来为你的训练添加 callback,请务必保证你加入 callback 函数的代码在实例化 `Trainer` 之前; + 注意如果您使用该函数修饰器来为您的训练添加 callback,请务必保证您加入 callback 函数的代码在实例化 `Trainer` 之前; 补充性的解释见 :meth:`~fastNLP.core.controllers.Trainer.add_callback_fn`; diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 862c7ee7..8ef521eb 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -584,7 +584,7 @@ class DataSet: 将 :class:`DataSet` 每个 ``instance`` 中为 ``field_name`` 的 field 传给函数 ``func``,并写入到 ``new_field_name`` 中。 - :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; :param field_name: 传入 ``func`` 的 field 名称; :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; @@ -624,10 +624,9 @@ class DataSet: ``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`~fastNLP.core.dataset.DataSet.apply_more` 中关于 ``apply_more`` 与 ``apply`` 区别的介绍。 - :param func: 对指定 fiel` 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; - :param field_name: 传入 ``func`` 的 fiel` 名称; - :param new_field_name: 函数执行结果写入的 ``field`` 名称。该函数会将 ``func`` 返回的内容放入到 ``new_field_name`` 对 - 应的 ``field`` 中,注意如果名称与已有的 field 相同则会进行覆盖。如果为 ``None`` 则不会覆盖和创建 field ; + :param func: 对指定 field 进行处理的函数,注意其输入应为 ``instance`` 中名为 ``field_name`` 的 field 的内容; + :param field_name: 传入 ``func`` 的 field 名称; + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 .. note:: @@ -751,8 +750,8 @@ class DataSet: 3. ``apply_more`` 默认修改 ``DataSet`` 中的 field ,``apply`` 默认不修改。 - :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 + :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 ``True`` :param num_proc: 使用进程的数量。 .. note:: diff --git a/fastNLP/core/drivers/torch_driver/torch_fsdp.py b/fastNLP/core/drivers/torch_driver/torch_fsdp.py index 9359615a..0b1948e8 100644 --- a/fastNLP/core/drivers/torch_driver/torch_fsdp.py +++ b/fastNLP/core/drivers/torch_driver/torch_fsdp.py @@ -1,15 +1,17 @@ -from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12 +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_12, _NEED_IMPORT_TORCH if _TORCH_GREATER_EQUAL_1_12: from torch.distributed.fsdp import FullyShardedDataParallel, StateDictType, FullStateDictConfig, OptimStateKeyType +if _NEED_IMPORT_TORCH: + import torch + import torch.distributed as dist + from torch.nn.parallel import DistributedDataParallel + import os -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel from typing import Optional, Union, List, Dict, Mapping from pathlib import Path diff --git a/fastNLP/core/utils/cache_results.py b/fastNLP/core/utils/cache_results.py index d462c06c..dc114b48 100644 --- a/fastNLP/core/utils/cache_results.py +++ b/fastNLP/core/utils/cache_results.py @@ -1,3 +1,14 @@ +""" +:func:`cache_results` 函数是 **fastNLP** 中用于缓存数据的装饰器,通过该函数您可以省去调试代码过程中一些耗时过长程序 +带来的时间开销。比如在加载并处理较大的数据时,每次修改训练参数都需要从头开始执行处理数据的过程,那么 :func:`cache_results` +便可以跳过这部分漫长的时间。详细的使用方法和原理请参见下面的说明。 + +.. warning:: + + 如果您发现对代码进行修改之后程序执行的结果没有变化,很有可能是这个函数的原因;届时删除掉缓存数据即可。 + +""" + from datetime import datetime import hashlib import _pickle diff --git a/fastNLP/embeddings/torch/static_embedding.py b/fastNLP/embeddings/torch/static_embedding.py index 12e7294c..6980c851 100644 --- a/fastNLP/embeddings/torch/static_embedding.py +++ b/fastNLP/embeddings/torch/static_embedding.py @@ -86,7 +86,7 @@ class StaticEmbedding(TokenEmbedding): :param requires_grad: 是否需要梯度。 :param init_method: 如何初始化没有找到的值。可以使用 :mod:`torch.nn.init` 中的各种方法,传入的方法应该接受一个 tensor,并 inplace 地修改其值。 - :param lower: 是否将 ``vocab`` 中的词语小写后再和预训练的词表进行匹配。如果你的词表中包含大写的词语,或者就是需要单独 + :param lower: 是否将 ``vocab`` 中的词语小写后再和预训练的词表进行匹配。如果您的词表中包含大写的词语,或者就是需要单独 为大写的词语开辟一个 vector 表示,则将 ``lower`` 设置为 ``False``。 :param dropout: 以多大的概率对 embedding 的表示进行 Dropout。0.1 即随机将 10% 的值置为 0。 :param word_dropout: 按照一定概率随机将 word 设置为 ``unk_index`` ,这样可以使得 ``