@@ -4,7 +4,8 @@ __all__ = [ | |||
'EventsList', | |||
'Filter', | |||
'CallbackManager', | |||
'CheckpointCallback', | |||
'ModelCheckpointCallback', | |||
'TrainerCheckpointCallback', | |||
'choose_progress_callback', | |||
'ProgressCallback', | |||
'RichCallback', | |||
@@ -16,7 +17,7 @@ __all__ = [ | |||
from .callback import Callback | |||
from .callback_events import EventsList, Events, Filter | |||
from .callback_manager import CallbackManager | |||
from .checkpoint_callback import CheckpointCallback | |||
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | |||
from .lr_scheduler_callback import LRSchedCallback | |||
from .load_best_model_callback import LoadBestModelCallback | |||
@@ -1,5 +1,6 @@ | |||
__all__ = [ | |||
'CheckpointCallback' | |||
'ModelCheckpointCallback', | |||
'TrainerCheckpointCallback' | |||
] | |||
import os | |||
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping | |||
@@ -133,17 +133,18 @@ class Evaluator: | |||
self.driver.barrier() | |||
def run(self, num_eval_batch_per_dl: int = -1) -> Dict: | |||
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: | |||
""" | |||
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 | |||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||
metric_indicator_name#metric_name | |||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||
metric_indicator_name#metric_name#dataloader_name | |||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||
如果存在多个metric,一个dataloader的情况,key的命名规则是 | |||
metric_indicator_name#metric_name | |||
如果存在多个数据集,一个metric的情况,key的命名规则是 | |||
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 | |||
如果存在多个metric,多个dataloader的情况,key的命名规则是 | |||
metric_indicator_name#metric_name#dataloader_name | |||
其中 metric_indicator_name 可能不存在。 | |||
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 | |||
:return: | |||
""" | |||
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." | |||
@@ -157,7 +158,6 @@ class Evaluator: | |||
assert self.driver.has_test_dataloaders() | |||
metric_results = {} | |||
self.reset() | |||
evaluate_context = self.driver.get_evaluate_context() | |||
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') | |||
@@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger): | |||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | |||
if self.evaluator is not None and num_eval_sanity_batch > 0: | |||
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") | |||
self.on_sanity_check_begin() | |||
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) | |||
self.on_sanity_check_end(sanity_check_res) | |||
@@ -8,9 +8,8 @@ __all__ = [ | |||
import _pickle as pickle | |||
from copy import deepcopy | |||
from typing import Optional, List, Callable, Union, Dict, Any | |||
from typing import Optional, List, Callable, Union, Dict, Any, Mapping | |||
from functools import partial | |||
import warnings | |||
import numpy as np | |||
from threading import Thread | |||
@@ -197,6 +196,20 @@ class DataSet: | |||
else: | |||
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) | |||
def __setitem__(self, key, value): | |||
assert isinstance(key, int) and key<len(self) | |||
assert isinstance(value, Instance) or isinstance(value, Mapping) | |||
ins_keys = set(value.keys()) | |||
ds_keys = set(self.get_field_names()) | |||
if len(ins_keys - ds_keys) != 0: | |||
raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.") | |||
if len(ds_keys - ins_keys) != 0: | |||
raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.") | |||
for field_name, field in self.field_arrays.items(): | |||
field[key] = value[field_name] | |||
def __getattribute__(self, item): | |||
return object.__getattribute__(self, item) | |||
@@ -813,6 +826,3 @@ class DataSet: | |||
self.collate_fns.set_input(*field_names) | |||
class IterableDataset: | |||
pass | |||
@@ -46,9 +46,6 @@ class FieldArray: | |||
def __setitem__(self, idx: int, val: Any): | |||
assert isinstance(idx, int) | |||
if idx == -1: | |||
idx = len(self) - 1 | |||
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}" | |||
self.content[idx] = val | |||
def get(self, indices: Union[int, List[int]]): | |||
@@ -79,7 +76,7 @@ class FieldArray: | |||
def split(self, sep: str = None, inplace: bool = True): | |||
r""" | |||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值 | |||
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。 | |||
:param sep: 分割符,如果为None则直接调用str.split()。 | |||
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。 | |||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod | |||
from datetime import datetime | |||
from pathlib import Path | |||
from io import BytesIO | |||
import json | |||
__all__ = [ | |||
'Driver' | |||
@@ -447,13 +448,14 @@ class Driver(ABC): | |||
exc_type, exc_value, exc_traceback_obj = sys.exc_info() | |||
_write_exc_info = { | |||
'exc_type': exc_type, | |||
'exc_value': exc_value, | |||
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||
'global_rank': getattr(self, "global_rank", None), | |||
'rank': self.get_local_rank(), | |||
'exc_type': str(exc_type.__name__), | |||
'exc_value': str(exc_value), | |||
'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')), | |||
'exc_global_rank': getattr(self, "global_rank", None), | |||
'exc_local_rank': self.get_local_rank(), | |||
} | |||
sys.stderr.write(str(_write_exc_info)+"\n") | |||
sys.stderr.write("\nException info:\n") | |||
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n") | |||
sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n") | |||
for pid in self._pids: | |||
@@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||
# world_size 和 rank | |||
if FASTNLP_BACKEND_LAUNCH in os.environ: | |||
if device is not None: | |||
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||
"up your script. And we will directly get the local device via " | |||
"`os.environ['LOCAL_RANK']`.") | |||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | |||
@@ -65,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic | |||
if not isinstance(device, List): | |||
return TorchSingleDriver(model, device, **kwargs) | |||
else: | |||
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||
logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " | |||
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" | |||
"`driver` as `TorchDDPDriver`.") | |||
return TorchDDPDriver(model, device, **kwargs) | |||
@@ -105,6 +105,20 @@ class TestDataSetMethods(unittest.TestCase): | |||
self.assertTrue(isinstance(field_array, FieldArray)) | |||
self.assertEqual(len(field_array), 40) | |||
def test_setitem(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
ds.add_field('i', list(range(len(ds)))) | |||
assert ds.get_field('i').content == list(range(len(ds))) | |||
import random | |||
random.shuffle(ds) | |||
import numpy as np | |||
np.random.shuffle(ds) | |||
assert ds.get_field('i').content != list(range(len(ds))) | |||
ins1 = ds[1] | |||
ds[2] = ds[1] | |||
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] | |||
def test_get_item_error(self): | |||
with self.assertRaises(RuntimeError): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) | |||