@@ -4,7 +4,8 @@ __all__ = [ | |||||
'EventsList', | 'EventsList', | ||||
'Filter', | 'Filter', | ||||
'CallbackManager', | 'CallbackManager', | ||||
'CheckpointCallback', | |||||
'ModelCheckpointCallback', | |||||
'TrainerCheckpointCallback', | |||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
@@ -16,7 +17,7 @@ __all__ = [ | |||||
from .callback import Callback | from .callback import Callback | ||||
from .callback_events import EventsList, Events, Filter | from .callback_events import EventsList, Events, Filter | ||||
from .callback_manager import CallbackManager | from .callback_manager import CallbackManager | ||||
from .checkpoint_callback import CheckpointCallback | |||||
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | ||||
from .lr_scheduler_callback import LRSchedCallback | from .lr_scheduler_callback import LRSchedCallback | ||||
from .load_best_model_callback import LoadBestModelCallback | from .load_best_model_callback import LoadBestModelCallback | ||||
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'CheckpointCallback' | |||||
'ModelCheckpointCallback', | |||||
'TrainerCheckpointCallback' | |||||
] | ] | ||||
import os | import os | ||||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | ||||
@@ -133,17 +133,18 @@ class Evaluator: | |||||
self.driver.barrier() | self.driver.barrier() | ||||
def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | |||||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||||
""" | """ | ||||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | ||||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name | |||||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||||
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name | |||||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||||
metric_indicator_name#metric_name#dataloader_name | |||||
其中 metric_indicator_name 可能不存在。 | |||||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||||
:return: | :return: | ||||
""" | """ | ||||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | ||||
@@ -157,7 +158,6 @@ class Evaluator: | |||||
assert self.driver.has_test_dataloaders() | assert self.driver.has_test_dataloaders() | ||||
metric_results = {} | metric_results = {} | ||||
self.reset() | self.reset() | ||||
evaluate_context = self.driver.get_evaluate_context() | evaluate_context = self.driver.get_evaluate_context() | ||||
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') | self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') | ||||
@@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger): | |||||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | ||||
if self.evaluator is not None and num_eval_sanity_batch > 0: | if self.evaluator is not None and num_eval_sanity_batch > 0: | ||||
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") | |||||
self.on_sanity_check_begin() | self.on_sanity_check_begin() | ||||
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | ||||
self.on_sanity_check_end(sanity_check_res) | self.on_sanity_check_end(sanity_check_res) | ||||
@@ -8,9 +8,8 @@ __all__ = [ | |||||
import _pickle as pickle | import _pickle as pickle | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Optional, List, Callable, Union, Dict, Any | |||||
from typing import Optional, List, Callable, Union, Dict, Any, Mapping | |||||
from functools import partial | from functools import partial | ||||
import warnings | |||||
import numpy as np | import numpy as np | ||||
from threading import Thread | from threading import Thread | ||||
@@ -197,6 +196,20 @@ class DataSet: | |||||
else: | else: | ||||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | ||||
def __setitem__(self, key, value): | |||||
assert isinstance(key, int) and key<len(self) | |||||
assert isinstance(value, Instance) or isinstance(value, Mapping) | |||||
ins_keys = set(value.keys()) | |||||
ds_keys = set(self.get_field_names()) | |||||
if len(ins_keys - ds_keys) != 0: | |||||
raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.") | |||||
if len(ds_keys - ins_keys) != 0: | |||||
raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.") | |||||
for field_name, field in self.field_arrays.items(): | |||||
field[key] = value[field_name] | |||||
def __getattribute__(self, item): | def __getattribute__(self, item): | ||||
return object.__getattribute__(self, item) | return object.__getattribute__(self, item) | ||||
@@ -813,6 +826,3 @@ class DataSet: | |||||
self.collate_fns.set_input(*field_names) | self.collate_fns.set_input(*field_names) | ||||
class IterableDataset: | |||||
pass | |||||
@@ -46,9 +46,6 @@ class FieldArray: | |||||
def __setitem__(self, idx: int, val: Any): | def __setitem__(self, idx: int, val: Any): | ||||
assert isinstance(idx, int) | assert isinstance(idx, int) | ||||
if idx == -1: | |||||
idx = len(self) - 1 | |||||
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}" | |||||
self.content[idx] = val | self.content[idx] = val | ||||
def get(self, indices: Union[int, List[int]]): | def get(self, indices: Union[int, List[int]]): | ||||
@@ -79,7 +76,7 @@ class FieldArray: | |||||
def split(self, sep: str = None, inplace: bool = True): | def split(self, sep: str = None, inplace: bool = True): | ||||
r""" | r""" | ||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。 | |||||
:param sep: 分割符,如果为None则直接调用str.split()。 | :param sep: 分割符,如果为None则直接调用str.split()。 | ||||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | :param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | ||||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod | |||||
from datetime import datetime | from datetime import datetime | ||||
from pathlib import Path | from pathlib import Path | ||||
from io import BytesIO | from io import BytesIO | ||||
import json | |||||
__all__ = [ | __all__ = [ | ||||
'Driver' | 'Driver' | ||||
@@ -447,13 +448,14 @@ class Driver(ABC): | |||||
exc_type, exc_value, exc_traceback_obj = sys.exc_info() | exc_type, exc_value, exc_traceback_obj = sys.exc_info() | ||||
_write_exc_info = { | _write_exc_info = { | ||||
'exc_type': exc_type, | |||||
'exc_value': exc_value, | |||||
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||||
'global_rank': getattr(self, "global_rank", None), | |||||
'rank': self.get_local_rank(), | |||||
'exc_type': str(exc_type.__name__), | |||||
'exc_value': str(exc_value), | |||||
'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||||
'exc_global_rank': getattr(self, "global_rank", None), | |||||
'exc_local_rank': self.get_local_rank(), | |||||
} | } | ||||
sys.stderr.write(str(_write_exc_info)+"\n") | |||||
sys.stderr.write("\nException info:\n") | |||||
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n") | |||||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | ||||
for pid in self._pids: | for pid in self._pids: | ||||
@@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||||
# world_size 和 rank | # world_size 和 rank | ||||
if FASTNLP_BACKEND_LAUNCH in os.environ: | if FASTNLP_BACKEND_LAUNCH in os.environ: | ||||
if device is not None: | if device is not None: | ||||
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
"up your script. And we will directly get the local device via " | "up your script. And we will directly get the local device via " | ||||
"`os.environ['LOCAL_RANK']`.") | "`os.environ['LOCAL_RANK']`.") | ||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
@@ -65,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||||
if not isinstance(device, List): | if not isinstance(device, List): | ||||
return TorchSingleDriver(model, device, **kwargs) | return TorchSingleDriver(model, device, **kwargs) | ||||
else: | else: | ||||
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||||
logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||||
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | "`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | ||||
"`driver` as `TorchDDPDriver`.") | "`driver` as `TorchDDPDriver`.") | ||||
return TorchDDPDriver(model, device, **kwargs) | return TorchDDPDriver(model, device, **kwargs) | ||||
@@ -105,6 +105,20 @@ class TestDataSetMethods(unittest.TestCase): | |||||
self.assertTrue(isinstance(field_array, FieldArray)) | self.assertTrue(isinstance(field_array, FieldArray)) | ||||
self.assertEqual(len(field_array), 40) | self.assertEqual(len(field_array), 40) | ||||
def test_setitem(self): | |||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||||
ds.add_field('i', list(range(len(ds)))) | |||||
assert ds.get_field('i').content == list(range(len(ds))) | |||||
import random | |||||
random.shuffle(ds) | |||||
import numpy as np | |||||
np.random.shuffle(ds) | |||||
assert ds.get_field('i').content != list(range(len(ds))) | |||||
ins1 = ds[1] | |||||
ds[2] = ds[1] | |||||
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] | |||||
def test_get_item_error(self): | def test_get_item_error(self): | ||||
with self.assertRaises(RuntimeError): | with self.assertRaises(RuntimeError): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | ||||