@@ -81,7 +81,7 @@ class LoadBestModelCallback(Callback): | |||||
real_monitor=self._real_monitor, | real_monitor=self._real_monitor, | ||||
res=results) | res=results) | ||||
if (monitor_value < self.monitor_value and self.larger_better is False) or \ | if (monitor_value < self.monitor_value and self.larger_better is False) or \ | ||||
(monitor_value > self.monitor_value and self.larger_better): | |||||
(monitor_value > self.monitor_value and self.larger_better): | |||||
self.monitor_value = monitor_value | self.monitor_value = monitor_value | ||||
if self.real_save_folder: | if self.real_save_folder: | ||||
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | ||||
@@ -30,6 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -178,8 +179,28 @@ class TorchDriver(Driver): | |||||
model.load_state_dict(res.state_dict()) | model.load_state_dict(res.state_dict()) | ||||
@rank_zero_call | @rank_zero_call | ||||
def save(self, folder: Path, states: Dict, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 1. 保存模型的状态; | |||||
def save(self, folder: Path, states: Dict, dataloader, only_state_dict: bool = True, should_save_model: bool = True, **kwargs): | |||||
# 传入的 dataloader 参数是 trainer 的 dataloader 属性,因为 driver 的所有 dataloader 我们是不会去改变它的,而是通过改变 | |||||
# trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; | |||||
# 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; | |||||
# 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 | |||||
# sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): | |||||
sampler = dataloader_args.batch_sampler | |||||
elif dataloader_args.sampler: | |||||
sampler = dataloader_args.sampler | |||||
else: | |||||
raise RuntimeError("This condition is not supposed to appear. Please report a bug to us.") | |||||
if hasattr(sampler, 'state_dict') and callable(sampler.state_dict): | |||||
states['sampler_states'] = sampler.state_dict() | |||||
else: | |||||
raise RuntimeError( | |||||
'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') | |||||
# 2. 保存模型的状态; | |||||
if should_save_model: | if should_save_model: | ||||
model = self.unwrap_model() | model = self.unwrap_model() | ||||
if only_state_dict: | if only_state_dict: | ||||
@@ -191,7 +212,7 @@ class TorchDriver(Driver): | |||||
torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) | torch.save(model, folder.joinpath(FASTNLP_MODEL_FILENAME)) | ||||
logger.debug("Save model") | logger.debug("Save model") | ||||
# 2. 保存 optimizers 的状态; | |||||
# 3. 保存 optimizers 的状态; | |||||
optimizers_state_dict = {} | optimizers_state_dict = {} | ||||
for i in range(len(self.optimizers)): | for i in range(len(self.optimizers)): | ||||
optimizer: torch.optim.Optimizer = self.optimizers[i] | optimizer: torch.optim.Optimizer = self.optimizers[i] | ||||
@@ -203,7 +224,7 @@ class TorchDriver(Driver): | |||||
states["optimizers_state_dict"] = optimizers_state_dict | states["optimizers_state_dict"] = optimizers_state_dict | ||||
torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | torch.save(states, Path(folder).joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
def load(self, folder: Path, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
def load(self, folder: Path, dataloader, only_state_dict: bool = True, should_load_model: bool = True, **kwargs) -> Dict: | |||||
states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | states = torch.load(folder.joinpath(FASTNLP_CHECKPOINT_FILENAME)) | ||||
# 1. 加载 optimizers 的状态; | # 1. 加载 optimizers 的状态; | ||||
@@ -224,6 +245,39 @@ class TorchDriver(Driver): | |||||
model.load_state_dict(res.state_dict()) | model.load_state_dict(res.state_dict()) | ||||
logger.debug("Load model.") | logger.debug("Load model.") | ||||
# 3. 恢复 sampler 的状态; | |||||
dataloader_args = self.get_dataloader_args(dataloader) | |||||
sampler = dataloader_args.sampler | |||||
if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): | |||||
# 说明这里需要使用 ReproduceSampler 来弄一下了 | |||||
if self.is_distributed(): | |||||
raise RuntimeError( | |||||
"It is not allowed to use single device checkpoint retraining before but ddp now.") | |||||
sampler = ReproducibleBatchSampler( | |||||
batch_sampler=sampler, | |||||
batch_size=dataloader_args.batch_size, | |||||
drop_last=dataloader_args.drop_last | |||||
) | |||||
sampler.load_state_dict(states['sampler_states']) | |||||
states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) | |||||
# 4. 修改 trainer_state.batch_idx_in_epoch | |||||
# sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; | |||||
if not isinstance(sampler, ReproducibleBatchSampler): | |||||
if dataloader_args.drop_last: | |||||
batch_idx_in_epoch = len( | |||||
sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size | |||||
else: | |||||
batch_idx_in_epoch = (len(sampler) + dataloader_args.batch_size - 1) // dataloader_args.batch_size - \ | |||||
(sampler.num_left_samples + dataloader_args.batch_size - 1) // dataloader_args.batch_size | |||||
# sampler 是 batch_sampler; | |||||
else: | |||||
batch_idx_in_epoch = sampler.batch_idx_in_epoch | |||||
states["batch_idx_in_epoch"] = batch_idx_in_epoch | |||||
return states | return states | ||||
def get_evaluate_context(self): | def get_evaluate_context(self): | ||||
@@ -316,7 +316,7 @@ def test_model_checkpoint_callback_2( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@pytest.mark.parametrize("only_state_dict", [True, False]) | @pytest.mark.parametrize("only_state_dict", [True, False]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -466,7 +466,7 @@ def test_trainer_checkpoint_callback_1( | |||||
# 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | # 通过自己编写 model_save_fn 和 model_load_fn 来测试 huggingface 的 transformers 的模型的保存和加载; | ||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [0, 1]), ("torch", 1)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch_ddp", [6, 7]), ("torch", 7)]) # ("torch", "cpu"), ("torch_ddp", [0, 1]), ("torch", 1) | |||||
@pytest.mark.parametrize("version", [0, 1]) | @pytest.mark.parametrize("version", [0, 1]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_checkpoint_callback_2( | def test_trainer_checkpoint_callback_2( | ||||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||||
import argparse | import argparse | ||||
import os | import os | ||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||||
import sys | import sys | ||||
path = os.path.abspath(__file__) | path = os.path.abspath(__file__) | ||||
@@ -101,7 +101,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
) | ) | ||||
trainer.run() | trainer.run() | ||||
dist.barrier() | |||||
# dist.barrier() | |||||
if __name__ == "__main__": | if __name__ == "__main__": | ||||
@@ -6,7 +6,7 @@ python -m torch.distributed.launch --nproc_per_node 2 tests/core/controllers/_te | |||||
import argparse | import argparse | ||||
import os | import os | ||||
os.environ["CUDA_VISIBLE_DEVICES"] = "4,5" | |||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2" | |||||
import sys | import sys | ||||
path = os.path.abspath(__file__) | path = os.path.abspath(__file__) | ||||
@@ -77,15 +77,14 @@ def model_and_optimizers(request): | |||||
# 测试一下 cpu; | # 测试一下 cpu; | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | @pytest.mark.parametrize("driver,device", [("torch", "cpu")]) | ||||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator( | def test_trainer_torch_without_evaluator( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | |||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
callbacks = [RecordLossCallback(loss_threshold=0.1)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -108,8 +107,7 @@ def test_trainer_torch_without_evaluator( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", 4), ("torch", [4, 5])]) # ("torch", 4), | |||||
@pytest.mark.parametrize("callbacks", [[RecordLossCallback(loss_threshold=0.1)]]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [1, 2])]) # ("torch", 4), | |||||
@pytest.mark.parametrize("fp16", [False, True]) | @pytest.mark.parametrize("fp16", [False, True]) | ||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
@@ -117,11 +115,11 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
driver, | driver, | ||||
device, | device, | ||||
callbacks, | |||||
fp16, | fp16, | ||||
accumulation_steps, | accumulation_steps, | ||||
n_epochs=10, | n_epochs=10, | ||||
): | ): | ||||
callbacks = [RecordLossCallback(loss_threshold=0.1)] | |||||
trainer = Trainer( | trainer = Trainer( | ||||
model=model_and_optimizers.model, | model=model_and_optimizers.model, | ||||
driver=driver, | driver=driver, | ||||
@@ -148,7 +146,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( | |||||
# 测试 accumulation_steps; | # 测试 accumulation_steps; | ||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 4), ("torch", [4, 5])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [1, 2])]) | |||||
@pytest.mark.parametrize("accumulation_steps", [1, 3]) | @pytest.mark.parametrize("accumulation_steps", [1, 3]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_torch_without_evaluator_accumulation_steps( | def test_trainer_torch_without_evaluator_accumulation_steps( | ||||
@@ -181,7 +179,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.parametrize("driver,device", [("torch", [6, 7])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||||
@pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | @pytest.mark.parametrize("output_from_new_proc", ["all", "ignore", "only_error", "test_log"]) | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_output_from_new_proc( | def test_trainer_output_from_new_proc( | ||||
@@ -244,7 +242,7 @@ def test_trainer_output_from_new_proc( | |||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
@pytest.mark.parametrize("driver,device", [("torch", [4, 5])]) | |||||
@pytest.mark.parametrize("driver,device", [("torch", [1, 2])]) | |||||
@pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | @pytest.mark.parametrize("cur_rank", [0]) # 依次测试如果是当前进程出现错误,是否能够正确地 kill 掉其他进程; , 1, 2, 3 | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_on_exception( | def test_trainer_on_exception( | ||||
@@ -0,0 +1,300 @@ | |||||
import os | |||||
import tempfile | |||||
import datetime | |||||
from pathlib import Path | |||||
import logging | |||||
import re | |||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | |||||
# 测试 TorchDDPDriver; | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_1(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd() | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath, mode="w") | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(filepath) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_2(): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
origin_path = Path.cwd() | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_3(): | |||||
""" | |||||
path = None; | |||||
多卡时根据时间创造文件名字有一个很大的 bug,就是不同的进程启动之间是有时差的,因此会导致他们各自输出到单独的 log 文件中; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
handler = logger.add_file() | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = Path.cwd().joinpath(os.environ.get(FASTNLP_LAUNCH_TIME)+".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
# print(f"\nrank: {driver.get_local_rank()} line, {line}\n") | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
synchronize_safe_rm(file) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | |||||
def test_add_file_ddp_4(): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
import torch | |||||
import torch.distributed as dist | |||||
from fastNLP.core.log.logger import logger | |||||
from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10) | |||||
driver = TorchDDPDriver( | |||||
model=model, | |||||
parallel_device=[torch.device("cuda:0"), torch.device("cuda:1")], | |||||
output_from_new_proc="all" | |||||
) | |||||
driver.setup() | |||||
msg = 'some test log msg' | |||||
path = Path.cwd().joinpath("not_existed") | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(msg) | |||||
logger.warning(f"\nrank {driver.get_local_rank()} should have this message!\n") | |||||
for h in logger.handlers: | |||||
if isinstance(h, logging.FileHandler): | |||||
h.flush() | |||||
dist.barrier() | |||||
file = path.joinpath(os.environ.get(FASTNLP_LAUNCH_TIME) + ".log") | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert msg in line | |||||
assert f"\nrank {driver.get_local_rank()} should have this message!\n" in line | |||||
pattern = re.compile(msg) | |||||
assert len(pattern.findall(line)) == 1 | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
dist.barrier() | |||||
dist.destroy_process_group() | |||||
class TestLogger: | |||||
msg = 'some test log msg' | |||||
def test_add_file_1(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
filepath = path.joinpath('log.txt') | |||||
handler = logger.add_file(filepath) | |||||
logger.info(self.msg) | |||||
with open(filepath, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_2(self): | |||||
""" | |||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
origin_path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
path = origin_path.joinpath("not_existed") | |||||
path = path.joinpath('log.txt') | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
with open(path, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(origin_path) | |||||
logger.removeHandler(handler) | |||||
def test_add_file_3(self): | |||||
""" | |||||
测试 path 是 None; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.add_file() | |||||
logger.info(self.msg) | |||||
path = Path.cwd() | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
file.unlink() | |||||
logger.removeHandler(handler) | |||||
def test_add_file_4(self): | |||||
""" | |||||
测试 path 是文件夹; | |||||
""" | |||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | |||||
try: | |||||
handler = logger.add_file(path) | |||||
logger.info(self.msg) | |||||
cur_datetime = str(datetime.datetime.now().strftime('%Y-%m-%d')) | |||||
for file in path.iterdir(): | |||||
if file.name.startswith(cur_datetime): | |||||
with open(file, 'r') as f: | |||||
line = ''.join([l for l in f]) | |||||
assert self.msg in line | |||||
finally: | |||||
synchronize_safe_rm(path) | |||||
logger.removeHandler(handler) | |||||
def test_stdout(self, capsys): | |||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.set_stdout(stdout="raw") | |||||
logger.info(self.msg) | |||||
logger.debug('aabbc') | |||||
captured = capsys.readouterr() | |||||
assert "some test log msg\n" == captured.out | |||||
logger.removeHandler(handler) | |||||
@@ -10,13 +10,6 @@ from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset | from tests.helpers.datasets.torch_data import TorchNormalDataset | ||||
class SamplerTest(unittest.TestCase): | class SamplerTest(unittest.TestCase): | ||||
def test_sequentialsampler(self): | def test_sequentialsampler(self): | ||||