diff --git a/fastNLP/__init__.py b/fastNLP/__init__.py index b856faff..915e382d 100644 --- a/fastNLP/__init__.py +++ b/fastNLP/__init__.py @@ -1,3 +1,3 @@ from fastNLP.envs import * -from fastNLP.core import Trainer, Evaluator \ No newline at end of file +from fastNLP.core import * \ No newline at end of file diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index f1421c38..439f5886 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -57,9 +57,37 @@ __all__ = [ "TorchPaddleDriver", # log - "logger" + "logger", + "print", - # + # metrics + "Metric", + "Accuracy", + 'SpanFPreRecMetric', + 'ClassifyFPreRecMetric', + + # samplers + 'ReproducibleSampler', + 'RandomSampler', + "SequentialSampler", + "SortedSampler", + 'UnrepeatedSampler', + 'UnrepeatedRandomSampler', + "UnrepeatedSortedSampler", + "UnrepeatedSequentialSampler", + "ReproduceBatchSampler", + "BucketedBatchSampler", + "ReproducibleBatchSampler", + "RandomBatchSampler", + + # utils + "cache_results", + "f_rich_progress", + "auto_param_call", + "seq_len_to_mask", + + # vocabulary.py + 'Vocabulary' ] from .callbacks import * from .collators import * @@ -68,4 +96,7 @@ from .dataloaders import * from .dataset import * from .drivers import * from .log import * -from .utils import * \ No newline at end of file +from .metrics import * +from .samplers import * +from .utils import * +from .vocabulary import Vocabulary \ No newline at end of file diff --git a/fastNLP/core/callbacks/topk_saver.py b/fastNLP/core/callbacks/topk_saver.py index 8c3f3811..25e66cb9 100644 --- a/fastNLP/core/callbacks/topk_saver.py +++ b/fastNLP/core/callbacks/topk_saver.py @@ -7,7 +7,7 @@ from copy import deepcopy from pathlib import Path from typing import Optional, Dict, Tuple, Callable, Union -from fastNLP.core.utils import rank_zero_rm +from ...envs.distributed import rank_zero_rm from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.envs import rank_zero_call diff --git a/fastNLP/core/collators/__init__.py b/fastNLP/core/collators/__init__.py index 1e508689..3033c37e 100644 --- a/fastNLP/core/collators/__init__.py +++ b/fastNLP/core/collators/__init__.py @@ -8,6 +8,7 @@ __all__ = [ "NullPadder", "RawNumberPadder", "RawSequencePadder", + "RawTensorPadder", 'TorchNumberPadder', 'TorchSequencePadder', 'TorchTensorPadder', diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 5c5abda4..9ea08d95 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -67,7 +67,7 @@ def _get_backend() -> str: # 方式 (2) for backend in CHECK_BACKEND: if backend in sys.modules: - logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") + logger.debug(f"sys.modules contains backend:{backend}.") return backend for key, module in sys.modules.items(): catch_backend = _check_module(module) diff --git a/fastNLP/core/collators/padders/__init__.py b/fastNLP/core/collators/padders/__init__.py index 09a5ca8d..11ffc07b 100644 --- a/fastNLP/core/collators/padders/__init__.py +++ b/fastNLP/core/collators/padders/__init__.py @@ -9,6 +9,7 @@ __all__ = [ "RawNumberPadder", "RawSequencePadder", + "RawTensorPadder", 'TorchNumberPadder', 'TorchSequencePadder', diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py index 4d507f2e..1113c91a 100644 --- a/fastNLP/core/collators/padders/numpy_padder.py +++ b/fastNLP/core/collators/padders/numpy_padder.py @@ -79,7 +79,7 @@ class NumpyTensorPadder(Padder): def pad(batch_field, pad_val, dtype): try: if not isinstance(batch_field[0], np.ndarray): - batch_field = [np.array(field.tolist()) for field in batch_field] + batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] except AttributeError: raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " f"it must have tolist() method.") diff --git a/fastNLP/core/collators/padders/paddle_padder.py b/fastNLP/core/collators/padders/paddle_padder.py index 10d5a385..f7db6534 100644 --- a/fastNLP/core/collators/padders/paddle_padder.py +++ b/fastNLP/core/collators/padders/paddle_padder.py @@ -131,7 +131,7 @@ class PaddleTensorPadder(Padder): def pad(batch_field, pad_val, dtype): try: if not isinstance(batch_field[0], paddle.Tensor): - batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field] + batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] except AttributeError: raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") @@ -143,8 +143,6 @@ class PaddleTensorPadder(Padder): tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) - if isinstance(field, np.ndarray): - field = paddle.to_tensor(field) tensor[slices] = field return tensor diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py index 18f414e8..f1940380 100644 --- a/fastNLP/core/collators/padders/torch_padder.py +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -114,7 +114,7 @@ class TorchTensorPadder(Padder): def pad(batch_field, pad_val, dtype): try: if not isinstance(batch_field[0], torch.Tensor): - batch_field = [torch.tensor(field.tolist()) for field in batch_field] + batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] except AttributeError: raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " f"it must have tolist() method.") @@ -124,8 +124,6 @@ class TorchTensorPadder(Padder): tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) for i, field in enumerate(batch_field): slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) - if isinstance(field, np.ndarray): - field = torch.from_numpy(field) tensor[slices] = field return tensor diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 73342748..f3a739f0 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -18,9 +18,9 @@ from fastNLP.core.utils import ( auto_param_call, check_user_specific_params, paddle_move_data_to_device, - is_in_paddle_dist, - rank_zero_rm + is_in_paddle_dist ) +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.core.samplers import ( ReproduceBatchSampler, ReproducibleSampler, diff --git a/fastNLP/core/log/__init__.py b/fastNLP/core/log/__init__.py index 3cb6d4dc..d1d95f20 100644 --- a/fastNLP/core/log/__init__.py +++ b/fastNLP/core/log/__init__.py @@ -1,6 +1,8 @@ __all__ = [ - 'logger' + 'logger', + "print" ] from .logger import logger +from .print import print diff --git a/fastNLP/core/metrics/__init__.py b/fastNLP/core/metrics/__init__.py index 82bca331..f7d60606 100644 --- a/fastNLP/core/metrics/__init__.py +++ b/fastNLP/core/metrics/__init__.py @@ -1,16 +1,11 @@ __all__ = [ "Metric", "Accuracy", - 'Backend', - 'AutoBackend', - 'PaddleBackend', - 'TorchBackend', 'SpanFPreRecMetric', 'ClassifyFPreRecMetric', ] from .metric import Metric from .accuracy import Accuracy -from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend from .span_f1_pre_rec_metric import SpanFPreRecMetric from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 9fb538a9..4af6a24a 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -23,8 +23,6 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', - 'rank_zero_rm', - 'rank_zero_mkdir' ] from .cache_results import cache_results @@ -36,7 +34,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_utils import torch_move_data_to_device from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ - deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir -from ..dataloaders.utils import indice_collate_wrapper + deprecated, seq_len_to_mask diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 91b3c8f6..93f38e2a 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -22,8 +22,6 @@ import numpy as np from pathlib import Path from fastNLP.core.log import logger -from fastNLP.envs import FASTNLP_GLOBAL_RANK - __all__ = [ 'get_fn_arg_names', @@ -37,8 +35,6 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', - 'rank_zero_rm', - 'rank_zero_mkdir' ] @@ -609,54 +605,6 @@ def wait_filepath(path, exist=True): logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") - -def rank_zero_rm(path: Optional[Union[str, Path]]): - """ - 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 - 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; - 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 - - :param path: - :return: - """ - if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - if path is None: - return - if isinstance(path, str): - path = Path(path) - if not path.exists(): - return - _recursive_rm(path) - - -def _recursive_rm(path: Path): - if path.is_file() or path.is_symlink(): - if path.exists(): - try: - path.unlink() - except Exception: - pass - return - for sub_path in list(path.iterdir()): - _recursive_rm(sub_path) - path.rmdir() - - -def rank_zero_mkdir(path: Optional[Union[str, Path]]): - """ - 注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; - 该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 - - """ - if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - if path is None: - return - if isinstance(path, str): - path = Path(path) - - path.mkdir(parents=True, exist_ok=True) - - def get_class_that_defined_method(method): """ 给定一个method,返回这个 method 的 class 的对象 diff --git a/fastNLP/envs/__init__.py b/fastNLP/envs/__init__.py index bc09c33b..6c5e857e 100644 --- a/fastNLP/envs/__init__.py +++ b/fastNLP/envs/__init__.py @@ -3,12 +3,17 @@ r""" """ __all__ = [ 'dump_fastnlp_backend', - 'is_cur_env_distributed', - 'get_global_rank', + + # utils + 'get_gpu_count', + + # distributed + "rank_zero_rm", 'rank_zero_call', + 'get_global_rank', + 'fastnlp_no_sync_context', 'all_rank_call_context', - 'get_gpu_count', - 'fastnlp_no_sync_context' + 'is_cur_env_distributed', ] diff --git a/fastNLP/envs/distributed.py b/fastNLP/envs/distributed.py index 34515c2c..3d87c8b2 100644 --- a/fastNLP/envs/distributed.py +++ b/fastNLP/envs/distributed.py @@ -1,6 +1,7 @@ import os from functools import wraps -from typing import Callable, Any, Optional +from pathlib import Path +from typing import Callable, Any, Optional, Union from contextlib import contextmanager __all__ = [ @@ -8,7 +9,8 @@ __all__ = [ 'get_global_rank', 'rank_zero_call', 'all_rank_call_context', - 'fastnlp_no_sync_context' + 'fastnlp_no_sync_context', + "rank_zero_rm" ] from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC @@ -96,3 +98,35 @@ def all_rank_call_context(): os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank else: os.environ.pop(FASTNLP_GLOBAL_RANK) + + +def rank_zero_rm(path: Optional[Union[str, Path]]): + """ + 这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 + 在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; + 该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 + + :param path: + :return: + """ + if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: + if path is None: + return + if isinstance(path, str): + path = Path(path) + if not path.exists(): + return + _recursive_rm(path) + + +def _recursive_rm(path: Path): + if path.is_file() or path.is_symlink(): + if path.exists(): + try: + path.unlink() + except Exception: + pass + return + for sub_path in list(path.iterdir()): + _recursive_rm(sub_path) + path.rmdir() \ No newline at end of file diff --git a/fastNLP/envs/env.py b/fastNLP/envs/env.py index 74d833e0..9cc05a02 100644 --- a/fastNLP/envs/env.py +++ b/fastNLP/envs/env.py @@ -22,7 +22,7 @@ FASTNLP_GLOBAL_RANK = "FASTNLP_GLOBAL_RANK" FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL" -# todo 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; +# 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; # FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。 FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" @@ -42,7 +42,7 @@ USER_CUDA_VISIBLE_DEVICES = 'USER_CUDA_VISIBLE_DEVICES' # 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1] FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' -# todo 注释 +# 检测到当前脚本是通过类似 python -m torch.launch 启动的话设置这个变量为1 FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 2de21825..60dcc862 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -11,7 +11,7 @@ from fastNLP.core.controllers.trainer import Trainer from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 08c6f8e2..9c32c20b 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -20,7 +20,7 @@ from fastNLP.core.controllers.trainer import Trainer from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy diff --git a/tests/core/collators/test_collator.py b/tests/core/collators/test_collator.py index ba1e7e08..65101321 100644 --- a/tests/core/collators/test_collator.py +++ b/tests/core/collators/test_collator.py @@ -83,7 +83,7 @@ class TestCollator: assert raw_pad_batch == collator(dict_batch) collator = Collator(backend='raw') raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] findListDiff(raw_pad_lst, collator(list_batch)) @@ -194,7 +194,7 @@ class TestCollator: collator.set_ignore('_0', '_3', '_1') collator.set_pad('_4', pad_val=None) raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] findListDiff(raw_pad_lst, collator(list_batch)) @@ -210,7 +210,7 @@ class TestCollator: collator.set_pad('_2', backend='numpy') collator.set_pad('_4', backend='numpy', pad_val=100) raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), - [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], [{'1'}, {'2'}]] findListDiff(raw_pad_lst, collator(list_batch)) diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 102ab310..e3d90e9b 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -13,7 +13,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch from tests.helpers.utils import magic_argv_env_context, Capturing -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch.distributed as dist diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index a184bb11..3b3f15ec 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE if _NEED_IMPORT_PADDLE: import paddle diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index b8ccd802..ba243106 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -7,7 +7,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH if _NEED_IMPORT_PADDLE: import paddle diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index d6f0ee77..0e3f99ad 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index ef60e2b6..9115ed19 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -7,7 +7,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch diff --git a/tests/core/log/test_logger_torch.py b/tests/core/log/test_logger_torch.py index 13a758e9..7d45782c 100644 --- a/tests/core/log/test_logger_torch.py +++ b/tests/core/log/test_logger_torch.py @@ -7,7 +7,7 @@ import re import pytest from fastNLP.envs.env import FASTNLP_LAUNCH_TIME -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm from fastNLP.core.log.logger import logger from tests.helpers.utils import magic_argv_env_context, recover_logger diff --git a/tests/core/utils/test_cache_results.py b/tests/core/utils/test_cache_results.py index 77c618bb..efef9f10 100644 --- a/tests/core/utils/test_cache_results.py +++ b/tests/core/utils/test_cache_results.py @@ -6,7 +6,7 @@ import sys sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) from fastNLP.core.utils.cache_results import cache_results -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm def get_subprocess_results(cmd): diff --git a/tests/envs/test_set_backend.py b/tests/envs/test_set_backend.py index 170110ce..c45acd7b 100644 --- a/tests/envs/test_set_backend.py +++ b/tests/envs/test_set_backend.py @@ -3,7 +3,7 @@ import pytest from fastNLP.envs.set_backend import dump_fastnlp_backend from tests.helpers.utils import Capturing -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm def test_dump_fastnlp_envs(): diff --git a/tests/modules/mix_modules/_test_mix_module.py b/tests/modules/mix_modules/_test_mix_module.py index 700e0cfe..87206fd6 100644 --- a/tests/modules/mix_modules/_test_mix_module.py +++ b/tests/modules/mix_modules/_test_mix_module.py @@ -9,7 +9,7 @@ import numpy as np from fastNLP.modules.mix_modules.mix_module import MixModule from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle -from fastNLP.core import rank_zero_rm +from fastNLP.envs.distributed import rank_zero_rm ############################################################################ diff --git a/tutorials/fastnlp_tutorial_0.ipynb b/tutorials/fastnlp_tutorial_0.ipynb index 28fcfddf..26675ecf 100644 --- a/tutorials/fastnlp_tutorial_0.ipynb +++ b/tutorials/fastnlp_tutorial_0.ipynb @@ -136,7 +136,7 @@ "在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", "\n", "  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n", - "***\n", + "\n", "```python\n", "class Model(torch.nn.Module):\n", " def __init__(self):\n", @@ -177,9 +177,7 @@ "\n", "  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", "\n", - "***\n", - "\n", - "![fastNLP 0.8 中,Trainer 和 Evaluator 的关系图](./figures/T0-fig-trainer-and-evaluator.png)" + "" ] }, { @@ -206,7 +204,7 @@ "  而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", "\n", "    `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", - "***\n", + "\n", "```python\n", "class Dataset(torch.utils.data.Dataset):\n", " def __init__(self, x, y):\n", @@ -253,7 +251,7 @@ "id": "5314482b", "metadata": { "pycharm": { - "is_executing": false + "is_executing": true } }, "outputs": [], @@ -641,11 +639,11 @@ { "data": { "text/html": [ - "
{'acc#acc': 0.43}\n",
+       "
{'acc#acc': 0.29}\n",
        "
\n" ], "text/plain": [ - "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.43\u001b[0m\u001b[1m}\u001b[0m\n" + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.29\u001b[0m\u001b[1m}\u001b[0m\n" ] }, "metadata": {}, @@ -654,7 +652,7 @@ { "data": { "text/plain": [ - "{'acc#acc': 0.43}" + "{'acc#acc': 0.29}" ] }, "execution_count": 9, diff --git a/tutorials/fastnlp_tutorial_1.ipynb b/tutorials/fastnlp_tutorial_1.ipynb new file mode 100644 index 00000000..11bd2219 --- /dev/null +++ b/tutorials/fastnlp_tutorial_1.ipynb @@ -0,0 +1,1156 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "cdc25fcd", + "metadata": {}, + "source": [ + "# T1. dataset 和 vocabulary 的基本使用\n", + "\n", + "  1   dataset 的使用与结构\n", + " \n", + "    1.1   dataset 的结构与创建\n", + "\n", + "    1.2   dataset 的数据预处理\n", + "\n", + "    1.3   延伸:instance 和 field\n", + "\n", + "  2   vocabulary 的结构与使用\n", + "\n", + "    2.1   vocabulary 的创建与修改\n", + "\n", + "    2.2   vocabulary 与 OOV 问题\n", + "\n", + "  3   dataset 和 vocabulary 的组合使用\n", + " \n", + "    3.1   从 dataframe 中加载 dataset\n", + "\n", + "    3.2   从 dataset 中获取 vocabulary" + ] + }, + { + "cell_type": "markdown", + "id": "0eb18a22", + "metadata": {}, + "source": [ + "## 1. dataset 的基本使用\n", + "\n", + "### 1.1 dataset 的结构与创建\n", + "\n", + "在`fastNLP 0.8`中,使用`DataSet`模块表示数据集,**`dataset`类似于关系型数据库中的数据表**(下文统一为小写`dataset`)\n", + "\n", + "  **主要包含`field`字段和`instance`实例两个元素**,对应`table`中的`field`字段和`record`记录\n", + "\n", + "在`fastNLP 0.8`中,`DataSet`模块被定义在`fastNLP.core.dataset`路径下,导入该模块后,最简单的\n", + "\n", + "  初始化方法,即将字典形式的表格 **`{'field1': column1, 'field2': column2, ...}`** 传入构造函数" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a1d69ad2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "from fastNLP.core.dataset import DataSet\n", + "\n", + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"],\n", + " 'words': [['This', 'is', 'an', 'apple', '.'], \n", + " ['I', 'like', 'apples', '.'], \n", + " ['Apples', 'are', 'good', 'for', 'our', 'health', '.']],\n", + " 'num': [5, 4, 7]}\n", + "\n", + "dataset = DataSet(data)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9260fdc6", + "metadata": {}, + "source": [ + "  在`dataset`的实例中,字段`field`的名称和实例`instance`中的字符串也可以中文" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "3d72ef00", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+------+--------------------+------------------------+------+\n", + "| 序号 | 句子 | 字符 | 长度 |\n", + "+------+--------------------+------------------------+------+\n", + "| 0 | 生活就像海洋, | ['生', '活', '就', ... | 7 |\n", + "| 1 | 只有意志坚强的人, | ['只', '有', '意', ... | 9 |\n", + "| 2 | 才能到达彼岸。 | ['才', '能', '到', ... | 7 |\n", + "+------+--------------------+------------------------+------+\n" + ] + } + ], + "source": [ + "temp = {'序号': [0, 1, 2], \n", + " '句子':[\"生活就像海洋,\", \"只有意志坚强的人,\", \"才能到达彼岸。\"],\n", + " '字符': [['生', '活', '就', '像', '海', '洋', ','], \n", + " ['只', '有', '意', '志', '坚', '强', '的', '人', ','], \n", + " ['才', '能', '到', '达', '彼', '岸', '。']],\n", + " '长度': [7, 9, 7]}\n", + "\n", + "chinese = DataSet(temp)\n", + "print(chinese)" + ] + }, + { + "cell_type": "markdown", + "id": "202e5490", + "metadata": {}, + "source": [ + "在`dataset`中,使用`drop`方法可以删除满足条件的实例,这里使用了python中的`lambda`表达式\n", + "\n", + "  注一:在`drop`方法中,通过设置`inplace`参数将删除对应实例后的`dataset`作为一个新的实例生成" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "09b478f8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1969418794120 1971237588872\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped = dropped.drop(lambda ins:ins['num'] < 5, inplace=False)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "aa277674", + "metadata": {}, + "source": [ + "  注二:在`fastNLP 0.8`中,**对`dataset`使用等号**,**其效果是传引用**,**而不是赋值**(???)\n", + "\n", + "    如下所示,**`dropped`和`dataset`具有相同`id`**,**对`dropped`执行删除操作`dataset`同时会被修改**" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "77c8583a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1971237588872 1971237588872\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n", + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dropped = dataset\n", + "dropped.drop(lambda ins:ins['num'] < 5)\n", + "print(id(dropped), id(dataset))\n", + "print(dropped)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "a76199dc", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_instance`方法可以删除对应序号的`instance`实例,序号从0开始" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d8824b40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+--------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "+-----+--------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.delete_instance(2)\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "f4fa9f33", + "metadata": {}, + "source": [ + "在`dataset`中,使用`delet_field`方法可以删除对应名称的`field`字段" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f68ddb40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+--------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+--------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "+-----+--------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset.delete_field('num')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "b1e9d42c", + "metadata": {}, + "source": [ + "### 1.2 dataset 的数据预处理\n", + "\n", + "在`dataset`模块中,`apply`、`apply_field`、`apply_more`和`apply_field_more`函数可以进行简单的数据预处理\n", + "\n", + "  **`apply`和`apply_more`针对整条实例**,**`apply_field`和`apply_field_more`仅针对实例的部分字段**\n", + "\n", + "  **`apply`和`apply_field`仅针对单个字段**,**`apply_more`和`apply_field_more`则可以针对多个字段**\n", + "\n", + "  **`apply`和`apply_field`返回的是个列表**,**`apply_more`和`apply_field_more`返回的是个字典**\n", + "\n", + "***\n", + "\n", + "`apply`的参数包括一个函数`func`和一个新字段名`new_field_name`,函数`func`的处理对象是`dataset`模块中\n", + "\n", + "  的每个`instance`实例,函数`func`的处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "72a0b5f9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "data = {'idx': [0, 1, 2], \n", + " 'sentence':[\"This is an apple .\", \"I like apples .\", \"Apples are good for our health .\"], }\n", + "dataset = DataSet(data)\n", + "dataset.apply(lambda ins: ins['sentence'].split(), new_field_name='words')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "c10275ee", + "metadata": {}, + "source": [ + "  **`apply`使用的函数可以是一个基于`lambda`表达式的匿名函数**,**也可以是一个自定义的函数**" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b1a8631f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "\n", + "def get_words(instance):\n", + " sentence = instance['sentence']\n", + " words = sentence.split()\n", + " return words\n", + "\n", + "dataset.apply(get_words, new_field_name='words')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "64abf745", + "metadata": {}, + "source": [ + "`apply_field`的参数,除了函数`func`外还有`field_name`和`new_field_name`,该函数`func`的处理对象仅\n", + "\n", + "  是`dataset`模块中的每个`field_name`对应的字段内容,处理结果存放在`new_field_name`对应的新建字段内" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "057c1d2c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------------+------------------------------+\n", + "| idx | sentence | words |\n", + "+-----+------------------------------+------------------------------+\n", + "| 0 | This is an apple . | ['This', 'is', 'an', 'app... |\n", + "| 1 | I like apples . | ['I', 'like', 'apples', '... |\n", + "| 2 | Apples are good for our h... | ['Apples', 'are', 'good',... |\n", + "+-----+------------------------------+------------------------------+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_field(lambda sent:sent.split(), field_name='sentence', new_field_name='words')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "5a9cc8b2", + "metadata": {}, + "source": [ + "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n", + "\n", + "  要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "51e2f02c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_more(lambda ins:{'words': ins['sentence'].split(), 'num': len(ins['sentence'].split())})\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "02d2b7ef", + "metadata": {}, + "source": [ + "`apply_more`的参数只有函数`func`,函数`func`的处理对象是`dataset`模块中的每个`instance`实例\n", + "\n", + "  要求函数`func`返回一个字典,根据字典的`key-value`确定存储在`dataset`中的字段名称与内容" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "db4295d5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+-----+------------------------+------------------------+-----+\n", + "| idx | sentence | words | num |\n", + "+-----+------------------------+------------------------+-----+\n", + "| 0 | This is an apple . | ['This', 'is', 'an'... | 5 |\n", + "| 1 | I like apples . | ['I', 'like', 'appl... | 4 |\n", + "| 2 | Apples are good for... | ['Apples', 'are', '... | 7 |\n", + "+-----+------------------------+------------------------+-----+\n" + ] + } + ], + "source": [ + "dataset = DataSet(data)\n", + "dataset.apply_field_more(lambda sent:{'words': sent.split(), 'num': len(sent.split())}, \n", + " field_name='sentence')\n", + "print(dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "9c09e592", + "metadata": {}, + "source": [ + "### 1.3 延伸:instance 和 field\n", + "\n", + "在`fastNLP 0.8`中,使用`Instance`模块表示数据集`dataset`中的每条数据,被称为实例\n", + "\n", + "  构造方式类似于构造一个字典,通过键值相同的`Instance`列表,也可以初始化一个`dataset`,代码如下" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "012f537c", + "metadata": {}, + "outputs": [], + "source": [ + "from fastNLP.core.dataset import DataSet\n", + "from fastNLP.core.dataset import Instance\n", + "\n", + "dataset = DataSet([\n", + " Instance(sentence=\"This is an apple .\",\n", + " words=['This', 'is', 'an', 'apple', '.'],\n", + " num=5),\n", + " Instance(sentence=\"I like apples .\",\n", + " words=['I', 'like', 'apples', '.'],\n", + " num=4),\n", + " Instance(sentence=\"Apples are good for our health .\",\n", + " words=['Apples', 'are', 'good', 'for', 'our', 'health', '.'],\n", + " num=7),\n", + " ])" + ] + }, + { + "cell_type": "markdown", + "id": "2fafb1ef", + "metadata": {}, + "source": [ + "  通过`items`、`keys`和`values`方法,可以分别获得`dataset`的`item`列表、`key`列表、`value`列表" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a4c1c10d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dict_items([('sentence', 'This is an apple .'), ('words', ['This', 'is', 'an', 'apple', '.']), ('num', 5)])\n", + "dict_keys(['sentence', 'words', 'num'])\n", + "dict_values(['This is an apple .', ['This', 'is', 'an', 'apple', '.'], 5])\n" + ] + } + ], + "source": [ + "ins = Instance(sentence=\"This is an apple .\", words=['This', 'is', 'an', 'apple', '.'], num=5)\n", + "\n", + "print(ins.items())\n", + "print(ins.keys())\n", + "print(ins.values())" + ] + }, + { + "cell_type": "markdown", + "id": "b5459a2d", + "metadata": {}, + "source": [ + "  通过`add_field`方法,可以在`Instance`实例中,通过参数`field_name`添加字段,通过参数`field`赋值" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "55376402", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "+--------------------+------------------------+-----+-----+\n", + "| sentence | words | num | idx |\n", + "+--------------------+------------------------+-----+-----+\n", + "| This is an apple . | ['This', 'is', 'an'... | 5 | 0 |\n", + "+--------------------+------------------------+-----+-----+\n" + ] + } + ], + "source": [ + "ins.add_field(field_name='idx', field=0)\n", + "print(ins)" + ] + }, + { + "cell_type": "markdown", + "id": "49caaa9c", + "metadata": {}, + "source": [ + "在`fastNLP 0.8`中,使用`FieldArray`模块表示数据集`dataset`中的每条字段名(注:没有`field`类)\n", + "\n", + "  通过`get_all_fields`方法可以获取`dataset`的字段列表\n", + "\n", + "  通过`get_field_names`方法可以获取`dataset`的字段名称列表,代码如下" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "fe15f4c1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'sentence': ,\n", + " 'words': ,\n", + " 'num': }" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.get_all_fields()" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "5433815c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['num', 'sentence', 'words']" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset.get_field_names()" + ] + }, + { + "cell_type": "markdown", + "id": "4964eeed", + "metadata": {}, + "source": [ + "其他`dataset`的基本使用:通过`in`或者`has_field`方法可以判断`dataset`的是否包含某种字段\n", + "\n", + "  通过`rename_field`方法可以更改`dataset`中的字段名称;通过`concat`方法可以实现两个`dataset`中的拼接\n", + "\n", + "  通过`len`可以统计`dataset`中的实例数目;`dataset`的全部变量与函数可以通过`dir(dataset)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "25ce5488", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3 False\n", + "6 True\n", + "+------------------------------+------------------------------+--------+\n", + "| sentence | words | length |\n", + "+------------------------------+------------------------------+--------+\n", + "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", + "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", + "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", + "| This is an apple . | ['This', 'is', 'an', 'app... | 5 |\n", + "| I like apples . | ['I', 'like', 'apples', '... | 4 |\n", + "| Apples are good for our h... | ['Apples', 'are', 'good',... | 7 |\n", + "+------------------------------+------------------------------+--------+\n" + ] + } + ], + "source": [ + "print(len(dataset), dataset.has_field('length')) \n", + "if 'num' in dataset:\n", + " dataset.rename_field('num', 'length')\n", + "elif 'length' in dataset:\n", + " dataset.rename_field('length', 'num')\n", + "dataset.concat(dataset)\n", + "print(len(dataset), dataset.has_field('length')) \n", + "print(dataset) " + ] + }, + { + "cell_type": "markdown", + "id": "e30a6cd7", + "metadata": {}, + "source": [ + "## 2. vocabulary 的结构与使用\n", + "\n", + "### 2.1 vocabulary 的创建与修改\n", + "\n", + "在`fastNLP 0.8`中,使用`Vocabulary`模块表示词汇表,**`vocabulary`的核心是从单词到序号的映射**\n", + "\n", + "  可以直接通过构造函数实例化,通过查找`word2idx`属性,可以找到`vocabulary`映射对应的字典实现\n", + "\n", + "  **默认补零`padding`用``表示**,**对应序号为0**;**未知单词`unknown`用``表示**,**对应序号1**\n", + "\n", + "  通过打印`vocabulary`可以看到词汇表中的单词列表,其中,`padding`和`unknown`不会显示" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "3515e096", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vocabulary([]...)\n", + "{'': 0, '': 1}\n", + " 0\n", + " 1\n" + ] + } + ], + "source": [ + "from fastNLP.core.vocabulary import Vocabulary\n", + "\n", + "vocab = Vocabulary()\n", + "print(vocab)\n", + "print(vocab.word2idx)\n", + "print(vocab.padding, vocab.padding_idx)\n", + "print(vocab.unknown, vocab.unknown_idx)" + ] + }, + { + "cell_type": "markdown", + "id": "640be126", + "metadata": {}, + "source": [ + "在`vocabulary`中,通过`add_word`方法或`add_word_lst`方法,可以单独或批量添加单词\n", + "\n", + "  通过`len`或`word_count`属性,可以显示`vocabulary`的单词量和每个单词添加的次数" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "88c7472a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "5 Counter({'生活': 1, '就像': 1, '海洋': 1})\n", + "6 Counter({'生活': 1, '就像': 1, '海洋': 1, '只有': 1})\n" + ] + } + ], + "source": [ + "vocab.add_word_lst(['生活', '就像', '海洋'])\n", + "print(len(vocab), vocab.word_count)\n", + "vocab.add_word('只有')\n", + "print(len(vocab), vocab.word_count)" + ] + }, + { + "cell_type": "markdown", + "id": "f9ec8b28", + "metadata": {}, + "source": [ + "  **通过`to_word`方法可以找到单词对应的序号**,**通过`to_index`方法可以找到序号对应的单词**\n", + "\n", + "    由于序号0和序号1已经被占用,所以**新加入的词的序号从2开始计数**,如`'生活'`对应2\n", + "\n", + "    通过`has_word`方法可以判断单词是否在词汇表中,没有的单词被判做``" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "3447acde", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " 0\n", + " 1\n", + "生活 2\n", + "只有 5\n", + "彼岸 1 False\n" + ] + } + ], + "source": [ + "print(vocab.to_word(0), vocab.to_index(''))\n", + "print(vocab.to_word(1), vocab.to_index(''))\n", + "print(vocab.to_word(2), vocab.to_index('生活'))\n", + "print(vocab.to_word(5), vocab.to_index('只有'))\n", + "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))" + ] + }, + { + "cell_type": "markdown", + "id": "b4e36850", + "metadata": {}, + "source": [ + "**`vocabulary`允许反复添加相同单词**,**可以通过`word_count`方法看到相应单词被添加的次数**\n", + "\n", + "  但其中没有``和``,`vocabulary`的全部变量与函数可以通过`dir(vocabulary)`查询" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "490b101c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13 Counter({'生活': 2, '就像': 2, '海洋': 2, '只有': 2, '意志': 1, '坚强的': 1, '人': 1, '才': 1, '能': 1, '到达': 1, '彼岸': 1})\n", + "彼岸 12 True\n" + ] + } + ], + "source": [ + "vocab.add_word_lst(['生活', '就像', '海洋', '只有', '意志', '坚强的', '人', '才', '能', '到达', '彼岸'])\n", + "print(len(vocab), vocab.word_count)\n", + "print('彼岸', vocab.to_index('彼岸'), vocab.has_word('彼岸'))" + ] + }, + { + "cell_type": "markdown", + "id": "23e32a63", + "metadata": {}, + "source": [ + "### 2.2 vocabulary 与 OOV 问题\n", + "\n", + "在`vocabulary`模块初始化的时候,可以通过指定`unknown`和`padding`为`None`,限制其存在\n", + "\n", + "  此时添加单词直接从0开始标号,如果遇到未知单词会直接报错,即 out of vocabulary" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "a99ff909", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'positive': 0, 'negative': 1}\n", + "ValueError: word `neutral` not in vocabulary\n" + ] + } + ], + "source": [ + "vocab = Vocabulary(unknown=None, padding=None)\n", + "\n", + "vocab.add_word_lst(['positive', 'negative'])\n", + "print(vocab.word2idx)\n", + "\n", + "try:\n", + " print(vocab.to_index('neutral'))\n", + "except ValueError:\n", + " print(\"ValueError: word `neutral` not in vocabulary\")" + ] + }, + { + "cell_type": "markdown", + "id": "618da6bd", + "metadata": {}, + "source": [ + "  相应的,如果只指定其中的`unknown`,则编号会后移一个,同时遇到未知单词全部当做``" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "432f74c1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'': 0, 'positive': 1, 'negative': 2}\n", + "0 \n" + ] + } + ], + "source": [ + "vocab = Vocabulary(unknown='', padding=None)\n", + "\n", + "vocab.add_word_lst(['positive', 'negative'])\n", + "print(vocab.word2idx)\n", + "\n", + "print(vocab.to_index('neutral'), vocab.to_word(vocab.to_index('neutral')))" + ] + }, + { + "cell_type": "markdown", + "id": "b6263f73", + "metadata": {}, + "source": [ + "## 3 dataset 和 vocabulary 的组合使用\n", + " \n", + "### 3.1 从 dataframe 中加载 dataset\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "id": "89059713", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3dbd985d", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4f634586", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "5ba13989", + "metadata": {}, + "source": [ + "### 3.2 从 dataset 中获取 vocabulary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a2de615b", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f5eed18", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/figures/T0-fig-trainer-and-evaluator.png b/tutorials/figures/T0-fig-trainer-and-evaluator.png index a98ab83b..6e95650d 100644 Binary files a/tutorials/figures/T0-fig-trainer-and-evaluator.png and b/tutorials/figures/T0-fig-trainer-and-evaluator.png differ