@@ -1,3 +1,5 @@ | |||||
from fastNLP.envs import * | from fastNLP.envs import * | ||||
from fastNLP.core import * | |||||
from fastNLP.core import * | |||||
__version__ = '0.8.0beta' |
@@ -11,6 +11,7 @@ import shutil | |||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import all_rank_call_context | from fastNLP.envs import all_rank_call_context | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class LoadBestModelCallback(HasMonitorCallback): | class LoadBestModelCallback(HasMonitorCallback): | ||||
@@ -61,7 +62,7 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
save_folder = os.path.join(save_folder, os.environ.get(FASTNLP_LAUNCH_TIME)) | save_folder = os.path.join(save_folder, os.environ.get(FASTNLP_LAUNCH_TIME)) | ||||
self.real_save_folder = os.path.join(save_folder, 'best_so_far') | self.real_save_folder = os.path.join(save_folder, 'best_so_far') | ||||
if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: | ||||
os.makedirs(self.real_save_folder) | |||||
os.makedirs(self.real_save_folder, exist_ok=True) | |||||
else: # 创建出一个 stringio | else: # 创建出一个 stringio | ||||
self.real_save_folder = None | self.real_save_folder = None | ||||
self.buffer = BytesIO() | self.buffer = BytesIO() | ||||
@@ -114,7 +115,8 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
trainer.driver.barrier() | trainer.driver.barrier() | ||||
def on_exception(self, trainer, exception): | def on_exception(self, trainer, exception): | ||||
self.encounter_exception = True | |||||
if not isinstance(exception, EarlyStopException): | |||||
self.encounter_exception = True | |||||
def _delete_folder(self): | def _delete_folder(self): | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
@@ -9,7 +9,6 @@ __all__ = [ | |||||
'TqdmCallback' | 'TqdmCallback' | ||||
] | ] | ||||
from ...envs.imports import _module_available, _compare_version | |||||
from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
from fastNLP.core.utils import f_rich_progress, f_tqdm_progress | from fastNLP.core.utils import f_rich_progress, f_tqdm_progress | ||||
@@ -296,9 +296,9 @@ class Trainer(TrainerEventTrigger): | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | ||||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 RichCallback, RawTextCallback等对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback 对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 等对象。 | |||||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~.fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 | |||||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | ||||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | ||||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | ||||
@@ -296,6 +296,7 @@ class PaddleFleetDriver(PaddleDriver): | |||||
使用 FleetLauncher 拉起子进程 | 使用 FleetLauncher 拉起子进程 | ||||
""" | """ | ||||
if self.local_rank == 0: | if self.local_rank == 0: | ||||
logger._set_distributed() | |||||
# 是 rank0 的话,则拉起其它子进程 | # 是 rank0 的话,则拉起其它子进程 | ||||
launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) | ||||
launcher.launch() | launcher.launch() | ||||
@@ -425,6 +425,7 @@ class TorchDDPDriver(TorchDriver): | |||||
os.environ[FASTNLP_DISTRIBUTED_CHECK] = f"{len(self.parallel_device)}" | os.environ[FASTNLP_DISTRIBUTED_CHECK] = f"{len(self.parallel_device)}" | ||||
os.environ[FASTNLP_GLOBAL_RANK] = "0" | os.environ[FASTNLP_GLOBAL_RANK] = "0" | ||||
logger._set_distributed() | |||||
interactive_ddp_procs = [] | interactive_ddp_procs = [] | ||||
@@ -208,6 +208,20 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
for handler in self.handlers: | for handler in self.handlers: | ||||
handler.setLevel(level) | handler.setLevel(level) | ||||
def _set_distributed(self): | |||||
""" | |||||
在 fastNLP 拉起进程的时候,调用一下这个方法,使得能够输出 rank 信息 | |||||
:return: | |||||
""" | |||||
for handler in self.handlers: | |||||
if isinstance(handler, logging.FileHandler): | |||||
formatter = logging.Formatter(fmt='Rank: %(rank)s - %(asctime)s - %(module)s - [%(levelname)s] - %(message)s', | |||||
datefmt='%Y/%m/%d %H:%M:%S') | |||||
else: | |||||
formatter = logging.Formatter('Rank: %(rank)s - %(message)s') | |||||
handler.setFormatter(formatter) | |||||
def _get_level(level): | def _get_level(level): | ||||
if not isinstance(level, int): | if not isinstance(level, int): | ||||
@@ -350,6 +364,8 @@ def _init_logger(path=None, stdout='rich', level='INFO'): | |||||
if path is not None: | if path is not None: | ||||
_add_file_handler(logger, path, level) | _add_file_handler(logger, path, level) | ||||
logger.setLevel(level) | |||||
return logger | return logger | ||||
@@ -31,9 +31,10 @@ class Accuracy(Metric): | |||||
r""" | r""" | ||||
get_metric 函数将根据 update 函数累计的评价指标统计量来计算最终的评价结果. | get_metric 函数将根据 update 函数累计的评价指标统计量来计算最终的评价结果. | ||||
:return dict evaluate_result: {"acc": float} | |||||
:return dict evaluate_result: {"acc": float, 'total': float, 'correct': float} | |||||
""" | """ | ||||
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6)} | |||||
evaluate_result = {'acc': round(self.correct.get_scalar() / (self.total.get_scalar() + 1e-12), 6), | |||||
'total': self.total.item(), 'correct': self.correct.item()} | |||||
return evaluate_result | return evaluate_result | ||||
def update(self, pred, target, seq_len=None): | def update(self, pred, target, seq_len=None): | ||||
@@ -83,7 +83,7 @@ class TqdmProgress(metaclass=Singleton): | |||||
:return: | :return: | ||||
""" | """ | ||||
assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ | assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ | ||||
f"To use {self.__class__.__name__}, tqdm>=4.57 is needed." | |||||
f"To use tqdm, tqdm>=4.57 is needed." | |||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." | assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." | ||||
@@ -28,10 +28,10 @@ if _NEED_IMPORT_TORCH: | |||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
num_labels: int = 10 | num_labels: int = 10 | ||||
feature_dimension: int = 10 | feature_dimension: int = 10 | ||||
data_num: int = 100 | |||||
data_num: int = 50 | |||||
seed: int = 0 | seed: int = 0 | ||||
batch_size: int = 4 | |||||
batch_size: int = 2 | |||||
shuffle: bool = True | shuffle: bool = True | ||||
@@ -204,99 +204,100 @@ def test_model_checkpoint_callback_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("only_state_dict", [True, False]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
def test_model_checkpoint_callback_2( | def test_model_checkpoint_callback_2( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device | |||||
device, | |||||
only_state_dict | |||||
): | ): | ||||
for only_state_dict in [True, False]: | |||||
try: | |||||
path = Path.cwd().joinpath("test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
try: | |||||
path = Path.cwd().joinpath("test_model_checkpoint") | |||||
path.mkdir(exist_ok=True, parents=True) | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def raise_exception(trainer): | |||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||||
raise NotImplementedError | |||||
from fastNLP.core.callbacks.callback_event import Event | |||||
callbacks = [ | |||||
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False, | |||||
on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict, | |||||
save_object='model'), | |||||
] | |||||
@Trainer.on(Event.on_train_epoch_end()) | |||||
def raise_exception(trainer): | |||||
if trainer.driver.get_local_rank() == 0 and trainer.cur_epoch_idx == 4: | |||||
raise NotImplementedError | |||||
with pytest.raises(NotImplementedError): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=10, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
callbacks = [ | |||||
CheckpointCallback(folder=path, every_n_epochs=None, every_n_batches=None, last=False, | |||||
on_exceptions=NotImplementedError, topk=None, monitor=None, only_state_dict=only_state_dict, | |||||
save_object='model'), | |||||
] | |||||
trainer.run() | |||||
with pytest.raises(NotImplementedError): | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=10, | |||||
callbacks=callbacks, | |||||
output_from_new_proc="all" | |||||
) | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
if FASTNLP_DISTRIBUTED_CHECK in os.environ: | |||||
os.environ.pop(FASTNLP_DISTRIBUTED_CHECK) | |||||
trainer.run() | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
if FASTNLP_DISTRIBUTED_CHECK in os.environ: | |||||
os.environ.pop(FASTNLP_DISTRIBUTED_CHECK) | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] | |||||
# 检查生成保存模型文件的数量是不是正确的; | |||||
all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} | |||||
assert len(all_saved_model_paths) == 1 | |||||
all_state_dicts = [exception_model_path] | |||||
if not isinstance(device, list): | |||||
assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] | |||||
# ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; | |||||
else: | |||||
assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths | |||||
exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] | |||||
for folder in all_state_dicts: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver="torch", | |||||
device=4, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
assert len(all_saved_model_paths) == 1 | |||||
all_state_dicts = [exception_model_path] | |||||
n_epochs=2, | |||||
output_from_new_proc="all" | |||||
) | |||||
for folder in all_state_dicts: | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver="torch", | |||||
device=4, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=2, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | |||||
trainer.driver.barrier() | |||||
trainer.load_model(folder, only_state_dict=only_state_dict) | |||||
trainer.run() | |||||
trainer.driver.barrier() | |||||
finally: | |||||
rank_zero_rm(path) | |||||
# pass | |||||
finally: | |||||
rank_zero_rm(path) | |||||
# pass | |||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@magic_argv_env_context(timeout=100) | @magic_argv_env_context(timeout=100) | ||||
def test_trainer_checkpoint_callback_1( | def test_trainer_checkpoint_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -20,13 +20,14 @@ from fastNLP.core.drivers.torch_driver import TorchSingleDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | ||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP import logger | |||||
@dataclass | @dataclass | ||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
num_labels: int = 10 | num_labels: int = 10 | ||||
feature_dimension: int = 10 | feature_dimension: int = 10 | ||||
data_num: int = 100 | |||||
data_num: int = 20 | |||||
seed: int = 0 | seed: int = 0 | ||||
batch_size: int = 4 | batch_size: int = 4 | ||||
@@ -71,18 +72,31 @@ def model_and_optimizers(request): | |||||
return trainer_params | return trainer_params | ||||
from fastNLP import Metric | |||||
class CountMetrc(Metric): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.register_element('count', 0, aggregate_method='sum') | |||||
def update(self, pred): | |||||
self.count += len(pred) | |||||
def get_metric(self) -> dict: | |||||
return {'cnt': self.count.item()} | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | @pytest.mark.parametrize("driver,device", [("torch", [0, 1]), ("torch", 1), ("torch", "cpu")]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_load_best_model_callback( | def test_load_best_model_callback( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device | |||||
device, | |||||
): | ): | ||||
for save_folder in ['save_models', None]: | for save_folder in ['save_models', None]: | ||||
for only_state_dict in [True, False]: | for only_state_dict in [True, False]: | ||||
callbacks = [LoadBestModelCallback(monitor='acc')] | |||||
callbacks = [LoadBestModelCallback(monitor='acc', only_state_dict=only_state_dict, | |||||
save_folder=save_folder)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -92,16 +106,16 @@ def test_load_best_model_callback( | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | ||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | ||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=3, | |||||
metrics={'acc': Accuracy()}, | |||||
n_epochs=2, | |||||
callbacks=callbacks, | callbacks=callbacks, | ||||
output_from_new_proc="all" | output_from_new_proc="all" | ||||
) | ) | ||||
trainer.run(num_eval_sanity_batch=0) | trainer.run(num_eval_sanity_batch=0) | ||||
driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | |||||
evaluator = Evaluator(model_and_optimizers.model, driver=driver, device=device, | |||||
_driver = TorchSingleDriver(model_and_optimizers.model, device=torch.device('cuda')) | |||||
evaluator = Evaluator(model_and_optimizers.model, driver=_driver, device=device, | |||||
dataloaders={'dl1': model_and_optimizers.evaluate_dataloaders}, | dataloaders={'dl1': model_and_optimizers.evaluate_dataloaders}, | ||||
metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | metrics={'acc': Accuracy(aggregate_when_get_metric=False)}, | ||||
output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, | ||||
@@ -113,5 +127,3 @@ def test_load_best_model_callback( | |||||
shutil.rmtree(save_folder, ignore_errors=True) | shutil.rmtree(save_folder, ignore_errors=True) | ||||
if dist.is_initialized(): | if dist.is_initialized(): | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -31,7 +31,7 @@ if _NEED_IMPORT_TORCH: | |||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
num_labels: int = 10 | num_labels: int = 10 | ||||
feature_dimension: int = 10 | feature_dimension: int = 10 | ||||
data_num: int = 100 | |||||
data_num: int = 20 | |||||
seed: int = 0 | seed: int = 0 | ||||
batch_size: int = 4 | batch_size: int = 4 | ||||
@@ -92,7 +92,7 @@ def model_and_optimizers(request): | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_model_more_evaluate_callback_1( | def test_model_more_evaluate_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -121,7 +121,7 @@ def test_model_more_evaluate_callback_1( | |||||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | ||||
save_object='model') | save_object='model') | ||||
] | ] | ||||
n_epochs = 5 | |||||
n_epochs = 3 | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -175,7 +175,7 @@ def test_model_more_evaluate_callback_1( | |||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1]), ("torch", 0)]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", [0, 1]), ("torch", 1) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_checkpoint_callback_1( | def test_trainer_checkpoint_callback_1( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -204,7 +204,7 @@ def test_trainer_checkpoint_callback_1( | |||||
folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | folder=path, topk=1, topk_monitor='acc', only_state_dict=only_state_dict, | ||||
save_object='trainer') | save_object='trainer') | ||||
] | ] | ||||
n_epochs = 5 | |||||
n_epochs = 2 | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -241,7 +241,7 @@ def test_trainer_checkpoint_callback_1( | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||
n_epochs=7, | |||||
n_epochs=5, | |||||
output_from_new_proc="all", | output_from_new_proc="all", | ||||
evaluate_fn='train_step' | evaluate_fn='train_step' | ||||
) | ) | ||||
@@ -21,7 +21,7 @@ from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
class ArgMaxDatasetConfig: | class ArgMaxDatasetConfig: | ||||
num_labels: int = 10 | num_labels: int = 10 | ||||
feature_dimension: int = 10 | feature_dimension: int = 10 | ||||
data_num: int = 100 | |||||
data_num: int = 20 | |||||
seed: int = 0 | seed: int = 0 | ||||
batch_size: int = 4 | batch_size: int = 4 | ||||
@@ -87,7 +87,7 @@ def test_run( model_and_optimizers: TrainerParameters, device): | |||||
if device != 'cpu' and not torch.cuda.is_available(): | if device != 'cpu' and not torch.cuda.is_available(): | ||||
pytest.skip(f"No cuda for device:{device}") | pytest.skip(f"No cuda for device:{device}") | ||||
n_epochs = 5 | |||||
n_epochs = 2 | |||||
for progress_bar in ['rich', 'auto', None, 'raw', 'tqdm']: | for progress_bar in ['rich', 'auto', None, 'raw', 'tqdm']: | ||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
@@ -70,12 +70,23 @@ def magic_argv_env_context(fn=None, timeout=300): | |||||
def _handle_timeout(signum, frame): | def _handle_timeout(signum, frame): | ||||
raise TimeoutError(f"\nYour test fn: {fn.__name__} has timed out.\n") | raise TimeoutError(f"\nYour test fn: {fn.__name__} has timed out.\n") | ||||
# 恢复 logger | |||||
handlers = [handler for handler in logger.handlers] | |||||
formatters = [handler.formatter for handler in handlers] | |||||
level = logger.level | |||||
signal.signal(signal.SIGALRM, _handle_timeout) | signal.signal(signal.SIGALRM, _handle_timeout) | ||||
signal.alarm(timeout) | signal.alarm(timeout) | ||||
res = fn(*args, **kwargs) | res = fn(*args, **kwargs) | ||||
signal.alarm(0) | signal.alarm(0) | ||||
sys.argv = deepcopy(command) | sys.argv = deepcopy(command) | ||||
os.environ = env | os.environ = env | ||||
for formatter, handler in zip(formatters, handlers): | |||||
handler.setFormatter(formatter) | |||||
logger.handlers = handlers | |||||
logger.setLevel(level) | |||||
return res | return res | ||||
return wrapper | return wrapper | ||||