@@ -1,3 +1,3 @@ | |||||
from fastNLP.envs import * | from fastNLP.envs import * | ||||
from fastNLP.core import Trainer, Evaluator | |||||
from fastNLP.core import * |
@@ -57,9 +57,37 @@ __all__ = [ | |||||
"TorchPaddleDriver", | "TorchPaddleDriver", | ||||
# log | # log | ||||
"logger" | |||||
"logger", | |||||
"print", | |||||
# | |||||
# metrics | |||||
"Metric", | |||||
"Accuracy", | |||||
'SpanFPreRecMetric', | |||||
'ClassifyFPreRecMetric', | |||||
# samplers | |||||
'ReproducibleSampler', | |||||
'RandomSampler', | |||||
"SequentialSampler", | |||||
"SortedSampler", | |||||
'UnrepeatedSampler', | |||||
'UnrepeatedRandomSampler', | |||||
"UnrepeatedSortedSampler", | |||||
"UnrepeatedSequentialSampler", | |||||
"ReproduceBatchSampler", | |||||
"BucketedBatchSampler", | |||||
"ReproducibleBatchSampler", | |||||
"RandomBatchSampler", | |||||
# utils | |||||
"cache_results", | |||||
"f_rich_progress", | |||||
"auto_param_call", | |||||
"seq_len_to_mask", | |||||
# vocabulary.py | |||||
'Vocabulary' | |||||
] | ] | ||||
from .callbacks import * | from .callbacks import * | ||||
from .collators import * | from .collators import * | ||||
@@ -68,4 +96,7 @@ from .dataloaders import * | |||||
from .dataset import * | from .dataset import * | ||||
from .drivers import * | from .drivers import * | ||||
from .log import * | from .log import * | ||||
from .utils import * | |||||
from .metrics import * | |||||
from .samplers import * | |||||
from .utils import * | |||||
from .vocabulary import Vocabulary |
@@ -7,7 +7,7 @@ from copy import deepcopy | |||||
from pathlib import Path | from pathlib import Path | ||||
from typing import Optional, Dict, Tuple, Callable, Union | from typing import Optional, Dict, Tuple, Callable, Union | ||||
from fastNLP.core.utils import rank_zero_rm | |||||
from ...envs.distributed import rank_zero_rm | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
@@ -8,6 +8,7 @@ __all__ = [ | |||||
"NullPadder", | "NullPadder", | ||||
"RawNumberPadder", | "RawNumberPadder", | ||||
"RawSequencePadder", | "RawSequencePadder", | ||||
"RawTensorPadder", | |||||
'TorchNumberPadder', | 'TorchNumberPadder', | ||||
'TorchSequencePadder', | 'TorchSequencePadder', | ||||
'TorchTensorPadder', | 'TorchTensorPadder', | ||||
@@ -67,7 +67,7 @@ def _get_backend() -> str: | |||||
# 方式 (2) | # 方式 (2) | ||||
for backend in CHECK_BACKEND: | for backend in CHECK_BACKEND: | ||||
if backend in sys.modules: | if backend in sys.modules: | ||||
logger.debug(f"sys.modules contains backend:{catch_backend[0]}.") | |||||
logger.debug(f"sys.modules contains backend:{backend}.") | |||||
return backend | return backend | ||||
for key, module in sys.modules.items(): | for key, module in sys.modules.items(): | ||||
catch_backend = _check_module(module) | catch_backend = _check_module(module) | ||||
@@ -9,6 +9,7 @@ __all__ = [ | |||||
"RawNumberPadder", | "RawNumberPadder", | ||||
"RawSequencePadder", | "RawSequencePadder", | ||||
"RawTensorPadder", | |||||
'TorchNumberPadder', | 'TorchNumberPadder', | ||||
'TorchSequencePadder', | 'TorchSequencePadder', | ||||
@@ -79,7 +79,7 @@ class NumpyTensorPadder(Padder): | |||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
try: | try: | ||||
if not isinstance(batch_field[0], np.ndarray): | if not isinstance(batch_field[0], np.ndarray): | ||||
batch_field = [np.array(field.tolist()) for field in batch_field] | |||||
batch_field = [np.array(field.tolist(), dtype=dtype) for field in batch_field] | |||||
except AttributeError: | except AttributeError: | ||||
raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a np.ndarray (it is {type(batch_field[0])}), " | ||||
f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
@@ -131,7 +131,7 @@ class PaddleTensorPadder(Padder): | |||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
try: | try: | ||||
if not isinstance(batch_field[0], paddle.Tensor): | if not isinstance(batch_field[0], paddle.Tensor): | ||||
batch_field = [paddle.to_tensor(field.tolist()) for field in batch_field] | |||||
batch_field = [paddle.to_tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||||
except AttributeError: | except AttributeError: | ||||
raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a paddle.Tensor (it is {type(batch_field[0])}), " | ||||
f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
@@ -143,8 +143,6 @@ class PaddleTensorPadder(Padder): | |||||
tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = paddle.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
if isinstance(field, np.ndarray): | |||||
field = paddle.to_tensor(field) | |||||
tensor[slices] = field | tensor[slices] = field | ||||
return tensor | return tensor | ||||
@@ -114,7 +114,7 @@ class TorchTensorPadder(Padder): | |||||
def pad(batch_field, pad_val, dtype): | def pad(batch_field, pad_val, dtype): | ||||
try: | try: | ||||
if not isinstance(batch_field[0], torch.Tensor): | if not isinstance(batch_field[0], torch.Tensor): | ||||
batch_field = [torch.tensor(field.tolist()) for field in batch_field] | |||||
batch_field = [torch.tensor(field.tolist(), dtype=dtype) for field in batch_field] | |||||
except AttributeError: | except AttributeError: | ||||
raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | raise RuntimeError(f"If the field is not a torch.Tensor (it is {type(batch_field[0])}), " | ||||
f"it must have tolist() method.") | f"it must have tolist() method.") | ||||
@@ -124,8 +124,6 @@ class TorchTensorPadder(Padder): | |||||
tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) | ||||
for i, field in enumerate(batch_field): | for i, field in enumerate(batch_field): | ||||
slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) | ||||
if isinstance(field, np.ndarray): | |||||
field = torch.from_numpy(field) | |||||
tensor[slices] = field | tensor[slices] = field | ||||
return tensor | return tensor | ||||
@@ -18,9 +18,9 @@ from fastNLP.core.utils import ( | |||||
auto_param_call, | auto_param_call, | ||||
check_user_specific_params, | check_user_specific_params, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
is_in_paddle_dist, | |||||
rank_zero_rm | |||||
is_in_paddle_dist | |||||
) | ) | ||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.core.samplers import ( | from fastNLP.core.samplers import ( | ||||
ReproduceBatchSampler, | ReproduceBatchSampler, | ||||
ReproducibleSampler, | ReproducibleSampler, | ||||
@@ -1,6 +1,8 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'logger' | |||||
'logger', | |||||
"print" | |||||
] | ] | ||||
from .logger import logger | from .logger import logger | ||||
from .print import print | |||||
@@ -1,16 +1,11 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"Metric", | "Metric", | ||||
"Accuracy", | "Accuracy", | ||||
'Backend', | |||||
'AutoBackend', | |||||
'PaddleBackend', | |||||
'TorchBackend', | |||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
] | ] | ||||
from .metric import Metric | from .metric import Metric | ||||
from .accuracy import Accuracy | from .accuracy import Accuracy | ||||
from .backend import Backend, AutoBackend, PaddleBackend, TorchBackend | |||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | from .span_f1_pre_rec_metric import SpanFPreRecMetric | ||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric |
@@ -23,8 +23,6 @@ __all__ = [ | |||||
'Option', | 'Option', | ||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'rank_zero_rm', | |||||
'rank_zero_mkdir' | |||||
] | ] | ||||
from .cache_results import cache_results | from .cache_results import cache_results | ||||
@@ -36,7 +34,6 @@ from .torch_paddle_utils import torch_paddle_move_data_to_device | |||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | ||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | ||||
deprecated, seq_len_to_mask, rank_zero_rm, rank_zero_mkdir | |||||
from ..dataloaders.utils import indice_collate_wrapper | |||||
deprecated, seq_len_to_mask | |||||
@@ -22,8 +22,6 @@ import numpy as np | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_GLOBAL_RANK | |||||
__all__ = [ | __all__ = [ | ||||
'get_fn_arg_names', | 'get_fn_arg_names', | ||||
@@ -37,8 +35,6 @@ __all__ = [ | |||||
'Option', | 'Option', | ||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
'rank_zero_rm', | |||||
'rank_zero_mkdir' | |||||
] | ] | ||||
@@ -609,54 +605,6 @@ def wait_filepath(path, exist=True): | |||||
logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") | logger.warning(f"Waiting path:{path} to {msg} for {count*0.01} seconds...") | ||||
def rank_zero_rm(path: Optional[Union[str, Path]]): | |||||
""" | |||||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||||
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
:param path: | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if not path.exists(): | |||||
return | |||||
_recursive_rm(path) | |||||
def _recursive_rm(path: Path): | |||||
if path.is_file() or path.is_symlink(): | |||||
if path.exists(): | |||||
try: | |||||
path.unlink() | |||||
except Exception: | |||||
pass | |||||
return | |||||
for sub_path in list(path.iterdir()): | |||||
_recursive_rm(sub_path) | |||||
path.rmdir() | |||||
def rank_zero_mkdir(path: Optional[Union[str, Path]]): | |||||
""" | |||||
注意该函数是用来创建文件夹,如果需要创建一个文件,不要使用该函数; | |||||
该函数会保证所有进程都检测到 path 创建之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
""" | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
path.mkdir(parents=True, exist_ok=True) | |||||
def get_class_that_defined_method(method): | def get_class_that_defined_method(method): | ||||
""" | """ | ||||
给定一个method,返回这个 method 的 class 的对象 | 给定一个method,返回这个 method 的 class 的对象 | ||||
@@ -3,12 +3,17 @@ r""" | |||||
""" | """ | ||||
__all__ = [ | __all__ = [ | ||||
'dump_fastnlp_backend', | 'dump_fastnlp_backend', | ||||
'is_cur_env_distributed', | |||||
'get_global_rank', | |||||
# utils | |||||
'get_gpu_count', | |||||
# distributed | |||||
"rank_zero_rm", | |||||
'rank_zero_call', | 'rank_zero_call', | ||||
'get_global_rank', | |||||
'fastnlp_no_sync_context', | |||||
'all_rank_call_context', | 'all_rank_call_context', | ||||
'get_gpu_count', | |||||
'fastnlp_no_sync_context' | |||||
'is_cur_env_distributed', | |||||
] | ] | ||||
@@ -1,6 +1,7 @@ | |||||
import os | import os | ||||
from functools import wraps | from functools import wraps | ||||
from typing import Callable, Any, Optional | |||||
from pathlib import Path | |||||
from typing import Callable, Any, Optional, Union | |||||
from contextlib import contextmanager | from contextlib import contextmanager | ||||
__all__ = [ | __all__ = [ | ||||
@@ -8,7 +9,8 @@ __all__ = [ | |||||
'get_global_rank', | 'get_global_rank', | ||||
'rank_zero_call', | 'rank_zero_call', | ||||
'all_rank_call_context', | 'all_rank_call_context', | ||||
'fastnlp_no_sync_context' | |||||
'fastnlp_no_sync_context', | |||||
"rank_zero_rm" | |||||
] | ] | ||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC | from fastNLP.envs.env import FASTNLP_GLOBAL_RANK, FASTNLP_NO_SYNC | ||||
@@ -96,3 +98,35 @@ def all_rank_call_context(): | |||||
os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank | os.environ[FASTNLP_GLOBAL_RANK] = old_fastnlp_global_rank | ||||
else: | else: | ||||
os.environ.pop(FASTNLP_GLOBAL_RANK) | os.environ.pop(FASTNLP_GLOBAL_RANK) | ||||
def rank_zero_rm(path: Optional[Union[str, Path]]): | |||||
""" | |||||
这个是因为在分布式文件系统中可能会发生错误,rank0下发删除成功后就运行走了,但实际的删除需要rank0的机器发送到远程文件系统再去执行,这个时候 | |||||
在rank0那里,确实已经删除成功了,但是在远程文件系统那里这个操作还没完成,rank1读取的时候还是读取到存在这个文件; | |||||
该函数会保证所有进程都检测到 path 删除之后才退出,请保证不同进程上 path 是完全一样的,否则会陷入死锁状态。 | |||||
:param path: | |||||
:return: | |||||
""" | |||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | |||||
if path is None: | |||||
return | |||||
if isinstance(path, str): | |||||
path = Path(path) | |||||
if not path.exists(): | |||||
return | |||||
_recursive_rm(path) | |||||
def _recursive_rm(path: Path): | |||||
if path.is_file() or path.is_symlink(): | |||||
if path.exists(): | |||||
try: | |||||
path.unlink() | |||||
except Exception: | |||||
pass | |||||
return | |||||
for sub_path in list(path.iterdir()): | |||||
_recursive_rm(sub_path) | |||||
path.rmdir() |
@@ -22,7 +22,7 @@ FASTNLP_GLOBAL_RANK = "FASTNLP_GLOBAL_RANK" | |||||
FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL" | FASTNLP_LOG_LEVEL = "FASTNLP_LOG_LEVEL" | ||||
# todo 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; | |||||
# 每一个分布式的 driver 都应当正确地设立该值;具体可见 ddp; | |||||
# FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。 | # FASTNLP_LAUNCH_TIME 记录了当前 fastNLP 脚本启动的时间。 | ||||
FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" | FASTNLP_LAUNCH_TIME = "FASTNLP_LAUNCH_TIME" | ||||
@@ -42,7 +42,7 @@ USER_CUDA_VISIBLE_DEVICES = 'USER_CUDA_VISIBLE_DEVICES' | |||||
# 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1] | # 用于在 torch.distributed.launch 时移除传入的 rank ,在 pytorch 中有使用。值的可选为 [0, 1] | ||||
FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | FASTNLP_REMOVE_LOCAL_RANK = 'FASTNLP_REMOVE_LOCAL_RANK' | ||||
# todo 注释 | |||||
# 检测到当前脚本是通过类似 python -m torch.launch 启动的话设置这个变量为1 | |||||
FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | FASTNLP_BACKEND_LAUNCH = "FASTNLP_BACKEND_LAUNCH" | ||||
# fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | # fastNLP中用于关闭 fastNLP 1.barrier 与 2.gather/broadcast 。默认为 '0' 表示不关闭;为 '1' 表示 fastNLP 的 barrier 不执行; | ||||
@@ -11,7 +11,7 @@ from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
@@ -20,7 +20,7 @@ from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
@@ -83,7 +83,7 @@ class TestCollator: | |||||
assert raw_pad_batch == collator(dict_batch) | assert raw_pad_batch == collator(dict_batch) | ||||
collator = Collator(backend='raw') | collator = Collator(backend='raw') | ||||
raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], | ||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
findListDiff(raw_pad_lst, collator(list_batch)) | findListDiff(raw_pad_lst, collator(list_batch)) | ||||
@@ -194,7 +194,7 @@ class TestCollator: | |||||
collator.set_ignore('_0', '_3', '_1') | collator.set_ignore('_0', '_3', '_1') | ||||
collator.set_pad('_4', pad_val=None) | collator.set_pad('_4', pad_val=None) | ||||
raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], | ||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
findListDiff(raw_pad_lst, collator(list_batch)) | findListDiff(raw_pad_lst, collator(list_batch)) | ||||
@@ -210,7 +210,7 @@ class TestCollator: | |||||
collator.set_pad('_2', backend='numpy') | collator.set_pad('_2', backend='numpy') | ||||
collator.set_pad('_4', backend='numpy', pad_val=100) | collator.set_pad('_4', backend='numpy', pad_val=100) | ||||
raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), | ||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], | |||||
[1.1, 2.1], [[1.1], [2.1]], [True, False], [[1, 0], [1, 1]], [{'1': '1'}, {'2': '2'}], | |||||
[{'1'}, {'2'}]] | [{'1'}, {'2'}]] | ||||
findListDiff(raw_pad_lst, collator(list_batch)) | findListDiff(raw_pad_lst, collator(list_batch)) | ||||
@@ -13,7 +13,7 @@ from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
from tests.helpers.utils import magic_argv_env_context, Capturing | from tests.helpers.utils import magic_argv_env_context, Capturing | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
@@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( | |||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -7,7 +7,7 @@ from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | |||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -12,7 +12,7 @@ from fastNLP.core.samplers import ( | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -7,7 +7,7 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset | ||||
from tests.helpers.datasets.paddle_data import PaddleNormalDataset | from tests.helpers.datasets.paddle_data import PaddleNormalDataset | ||||
from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE, _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -7,7 +7,7 @@ import re | |||||
import pytest | import pytest | ||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
from fastNLP.core.log.logger import logger | from fastNLP.core.log.logger import logger | ||||
from tests.helpers.utils import magic_argv_env_context, recover_logger | from tests.helpers.utils import magic_argv_env_context, recover_logger | ||||
@@ -6,7 +6,7 @@ import sys | |||||
sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | sys.path.append(os.path.join(os.path.dirname(__file__), '../../..')) | ||||
from fastNLP.core.utils.cache_results import cache_results | from fastNLP.core.utils.cache_results import cache_results | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
def get_subprocess_results(cmd): | def get_subprocess_results(cmd): | ||||
@@ -3,7 +3,7 @@ import pytest | |||||
from fastNLP.envs.set_backend import dump_fastnlp_backend | from fastNLP.envs.set_backend import dump_fastnlp_backend | ||||
from tests.helpers.utils import Capturing | from tests.helpers.utils import Capturing | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
def test_dump_fastnlp_envs(): | def test_dump_fastnlp_envs(): | ||||
@@ -9,7 +9,7 @@ import numpy as np | |||||
from fastNLP.modules.mix_modules.mix_module import MixModule | from fastNLP.modules.mix_modules.mix_module import MixModule | ||||
from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | from fastNLP.modules.mix_modules.utils import paddle2torch, torch2paddle | ||||
from fastNLP.core import rank_zero_rm | |||||
from fastNLP.envs.distributed import rank_zero_rm | |||||
############################################################################ | ############################################################################ | ||||
@@ -136,7 +136,7 @@ | |||||
"在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", | "在`fastNLP 0.8`中,使用`pytorch.nn.Module`搭建需要训练的模型,在搭建模型过程中,除了\n", | ||||
"\n", | "\n", | ||||
"  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n", | "  添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法\n", | ||||
"***\n", | |||||
"\n", | |||||
"```python\n", | "```python\n", | ||||
"class Model(torch.nn.Module):\n", | "class Model(torch.nn.Module):\n", | ||||
" def __init__(self):\n", | " def __init__(self):\n", | ||||
@@ -177,9 +177,7 @@ | |||||
"\n", | "\n", | ||||
"  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", | "  从模块角度,该字典的键值和`metric`中的`update`函数的签名一致,这样的机制在传参时被称为“**参数匹配**”\n", | ||||
"\n", | "\n", | ||||
"***\n", | |||||
"\n", | |||||
"" | |||||
"<img src=\"./figures/T0-fig-trainer-and-evaluator.png\" width=\"80%\" height=\"80%\" align=\"center\"></img>" | |||||
] | ] | ||||
}, | }, | ||||
{ | { | ||||
@@ -206,7 +204,7 @@ | |||||
"  而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", | "  而在`Trainer`和`Evaluator`中的参数`model_wo_auto_param_call`被设置为`True`时\n", | ||||
"\n", | "\n", | ||||
"    `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", | "    `fastNLP 0.8`会将`batch`直接传给模型的`train_step`、`evaluate_step`或`forward`函数\n", | ||||
"***\n", | |||||
"\n", | |||||
"```python\n", | "```python\n", | ||||
"class Dataset(torch.utils.data.Dataset):\n", | "class Dataset(torch.utils.data.Dataset):\n", | ||||
" def __init__(self, x, y):\n", | " def __init__(self, x, y):\n", | ||||
@@ -253,7 +251,7 @@ | |||||
"id": "5314482b", | "id": "5314482b", | ||||
"metadata": { | "metadata": { | ||||
"pycharm": { | "pycharm": { | ||||
"is_executing": false | |||||
"is_executing": true | |||||
} | } | ||||
}, | }, | ||||
"outputs": [], | "outputs": [], | ||||
@@ -641,11 +639,11 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/html": [ | "text/html": [ | ||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.43</span><span style=\"font-weight: bold\">}</span>\n", | |||||
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">{</span><span style=\"color: #008000; text-decoration-color: #008000\">'acc#acc'</span>: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.29</span><span style=\"font-weight: bold\">}</span>\n", | |||||
"</pre>\n" | "</pre>\n" | ||||
], | ], | ||||
"text/plain": [ | "text/plain": [ | ||||
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.43\u001b[0m\u001b[1m}\u001b[0m\n" | |||||
"\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.29\u001b[0m\u001b[1m}\u001b[0m\n" | |||||
] | ] | ||||
}, | }, | ||||
"metadata": {}, | "metadata": {}, | ||||
@@ -654,7 +652,7 @@ | |||||
{ | { | ||||
"data": { | "data": { | ||||
"text/plain": [ | "text/plain": [ | ||||
"{'acc#acc': 0.43}" | |||||
"{'acc#acc': 0.29}" | |||||
] | ] | ||||
}, | }, | ||||
"execution_count": 9, | "execution_count": 9, | ||||