@@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback): | |||
real_monitor=self._real_monitor, | |||
res=results) | |||
if (monitor_value < self.monitor_value and self.larger_better is False) or \ | |||
(monitor_value > self.monitor_value and self.larger_better): | |||
(monitor_value > self.monitor_value and self.larger_better): | |||
self.monitor_value = monitor_value | |||
if self.real_save_folder: | |||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
@@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||
class TorchDriver(Driver): | |||
@@ -178,8 +179,28 @@ class TorchDriver(Driver): | |||
model.load_state_dict(res.state_dict()) | |||
@rank_zero_call | |||
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
# 1. 保存模型的状态; | |||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||
sampler = dataloader_args.batch_sampler | |||
elif dataloader_args.sampler: | |||
sampler = dataloader_args.sampler | |||
else: | |||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||
states['sampler_states'] = sampler.state_dict() | |||
else: | |||
raise RuntimeError( | |||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||
# 2. 保存模型的状态; | |||
if should_save_model: | |||
model = self.unwrap_model() | |||
if only_state_dict: | |||
@@ -191,7 +212,7 @@ class TorchDriver(Driver): | |||
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) | |||
logger.debug("Save model") | |||
# 2. 保存 optimizers 的状态; | |||
# 3. 保存 optimizers 的状态; | |||
optimizers_state_dict = {} | |||
for i in range(len(self.optimizers)): | |||
optimizer: torch.optim.Optimizer = self.optimizers[i] | |||
@@ -203,7 +224,7 @@ class TorchDriver(Driver): | |||
states["optimizers_state_dict"] = optimizers_state_dict | |||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | |||
# 1. 加载 optimizers 的状态; | |||
@@ -224,6 +245,39 @@ class TorchDriver(Driver): | |||
model.load_state_dict(res.state_dict()) | |||
logger.debug("Load model.") | |||
# 3. 恢复 sampler 的状态; | |||
dataloader_args = self.get_dataloader_args(dataloader) | |||
sampler = dataloader_args.sampler | |||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||
if self.is_distributed(): | |||
raise RuntimeError( | |||
"It is not allowed to use single device checkpoint retraining before but ddp now.") | |||
sampler = ReproducibleBatchSampler( | |||
batch_sampler=sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
) | |||
sampler.load_state_dict(states['sampler_states']) | |||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||
if not isinstance(sampler, ReproducibleBatchSampler): | |||
if dataloader_args.drop_last: | |||
batch_idx_in_epoch = len( | |||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||
else: | |||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||
# sampler 是 batch_sampler; | |||
else: | |||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||
return states | |||
def get_evaluate_context(self): | |||
@@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2( | |||
dist.destroy_process_group() | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@pytest.mark.parametrize("only_state_dict", [True, False]) | |||
@magic_argv_env_context | |||
@@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1( | |||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | |||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||
@pytest.mark.parametrize("version", [0, 1]) | |||
@magic_argv_env_context | |||
def test_trainer_checkpoint_callback_2( | |||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||
import argparse | |||
import os | |||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||
import sys | |||
path = os.path.abspath(__file__) | |||
@@ -101,7 +101,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||
) | |||
trainer.run() | |||
dist.barrier() | |||
# dist.barrier() | |||
if __name__ == "__main__": | |||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||
import argparse | |||
import os | |||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||
import sys | |||
path = os.path.abspath(__file__) | |||
@@ -77,15 +77,14 @@ def model_and_optimizers(request): | |||
# 测试一下 cpu; | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | |||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||
@magic_argv_env_context | |||
def test_trainer_torch_without_evaluator( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
callbacks, | |||
n_epochs=10, | |||
): | |||
callbacks = [RecordLossCallback(loss_threshold=0.1)] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
@@ -108,8 +107,7 @@ def test_trainer_torch_without_evaluator( | |||
dist.destroy_process_group() | |||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", 4), | |||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), | |||
@pytest.mark.parametrize("fp16", [False, True]) | |||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||
@magic_argv_env_context | |||
@@ -117,11 +115,11 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||
model_and_optimizers: TrainerParameters, | |||
driver, | |||
device, | |||
callbacks, | |||
fp16, | |||
accumulation_steps, | |||
n_epochs=10, | |||
): | |||
callbacks = [RecordLossCallback(loss_threshold=0.1)] | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver=driver, | |||
@@ -148,7 +146,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||
# 测试 accumulation_steps; | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), ("torch", [4, 5])]) | |||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) | |||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | |||
@magic_argv_env_context | |||
def test_trainer_torch_without_evaluator_accumulation_steps( | |||
@@ -181,7 +179,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||
dist.destroy_process_group() | |||
@pytest.mark.parametrize("driver,device", [("torch", [6, 7])]) | |||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||
@pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | |||
@magic_argv_env_context | |||
def test_trainer_output_from_new_proc( | |||
@@ -244,7 +242,7 @@ def test_trainer_output_from_new_proc( | |||
synchronize_safe_rm(path) | |||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) | |||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||
@pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | |||
@magic_argv_env_context | |||
def test_trainer_on_exception( | |||
@@ -0,0 +1,300 @@ | |||
import os | |||
import tempfile | |||
import datetime | |||
from pathlib import Path | |||
import logging | |||
import re | |||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import synchronize_safe_rm | |||
# 测试 TorchDDPDriver; | |||
@magic_argv_env_context | |||
def test_add_file_ddp_1(): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||
""" | |||
import torch | |||
import torch.distributed as dist | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||
driver = TorchDDPDriver( | |||
model=model, | |||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||
output_from_new_proc="all" | |||
) | |||
driver.setup() | |||
msg = 'some test log msg' | |||
path = Path.cwd() | |||
filepath = path.joinpath('log.txt') | |||
handler = logger.add_file(filepath, mode="w") | |||
logger.info(msg) | |||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||
for h in logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
h.flush() | |||
dist.barrier() | |||
with open(filepath, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
assert msg in line | |||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
synchronize_safe_rm(filepath) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
logger.removeHandler(handler) | |||
@magic_argv_env_context | |||
def test_add_file_ddp_2(): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
""" | |||
import torch | |||
import torch.distributed as dist | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||
driver = TorchDDPDriver( | |||
model=model, | |||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||
output_from_new_proc="all" | |||
) | |||
driver.setup() | |||
msg = 'some test log msg' | |||
origin_path = Path.cwd() | |||
try: | |||
path = origin_path.joinpath("not_existed") | |||
filepath = path.joinpath('log.txt') | |||
handler = logger.add_file(filepath) | |||
logger.info(msg) | |||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||
for h in logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
h.flush() | |||
dist.barrier() | |||
with open(filepath, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
assert msg in line | |||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@magic_argv_env_context | |||
def test_add_file_ddp_3(): | |||
""" | |||
path = None; | |||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||
""" | |||
import torch | |||
import torch.distributed as dist | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||
driver = TorchDDPDriver( | |||
model=model, | |||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||
output_from_new_proc="all" | |||
) | |||
driver.setup() | |||
msg = 'some test log msg' | |||
handler = logger.add_file() | |||
logger.info(msg) | |||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||
for h in logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
h.flush() | |||
dist.barrier() | |||
file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log") | |||
with open(file, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
# print(f"\nrank: {driver.get_local_rank()} line, {line}\n") | |||
assert msg in line | |||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
synchronize_safe_rm(file) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
logger.removeHandler(handler) | |||
@magic_argv_env_context | |||
def test_add_file_ddp_4(): | |||
""" | |||
测试 path 是文件夹; | |||
""" | |||
import torch | |||
import torch.distributed as dist | |||
from fastNLP.core.log.logger import logger | |||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||
driver = TorchDDPDriver( | |||
model=model, | |||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||
output_from_new_proc="all" | |||
) | |||
driver.setup() | |||
msg = 'some test log msg' | |||
path = Path.cwd().joinpath("not_existed") | |||
try: | |||
handler = logger.add_file(path) | |||
logger.info(msg) | |||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||
for h in logger.handlers: | |||
if isinstance(h, logging.FileHandler): | |||
h.flush() | |||
dist.barrier() | |||
file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log") | |||
with open(file, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
assert msg in line | |||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||
pattern = re.compile(msg) | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
class TestLogger: | |||
msg = 'some test log msg' | |||
def test_add_file_1(self): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
path = Path(tempfile.mkdtemp()) | |||
try: | |||
filepath = path.joinpath('log.txt') | |||
handler = logger.add_file(filepath) | |||
logger.info(self.msg) | |||
with open(filepath, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
def test_add_file_2(self): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
origin_path = Path(tempfile.mkdtemp()) | |||
try: | |||
path = origin_path.joinpath("not_existed") | |||
path = path.joinpath('log.txt') | |||
handler = logger.add_file(path) | |||
logger.info(self.msg) | |||
with open(path, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(origin_path) | |||
logger.removeHandler(handler) | |||
def test_add_file_3(self): | |||
""" | |||
测试 path 是 None; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
handler = logger.add_file() | |||
logger.info(self.msg) | |||
path = Path.cwd() | |||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||
for file in path.iterdir(): | |||
if file.name.startswith(cur_datetime): | |||
with open(file, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
file.unlink() | |||
logger.removeHandler(handler) | |||
def test_add_file_4(self): | |||
""" | |||
测试 path 是文件夹; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
path = Path(tempfile.mkdtemp()) | |||
try: | |||
handler = logger.add_file(path) | |||
logger.info(self.msg) | |||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||
for file in path.iterdir(): | |||
if file.name.startswith(cur_datetime): | |||
with open(file, 'r') as f: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
def test_stdout(self, capsys): | |||
from fastNLP.core.log.logger import logger | |||
handler = logger.set_stdout(stdout="raw") | |||
logger.info(self.msg) | |||
logger.debug('aabbc') | |||
captured = capsys.readouterr() | |||
assert "some test log msg\n" == captured.out | |||
logger.removeHandler(handler) | |||
@@ -10,13 +10,6 @@ from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset | |||
class SamplerTest(unittest.TestCase): | |||
def test_sequentialsampler(self): | |||