Browse Source

1.增加DataSet的__setitem__方法,使得其可以直接random.shuffle(dataset); 2.优化部分log输出显示

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
607367588c
9 changed files with 55 additions and 29 deletions
  1. +3
    -2
      fastNLP/core/callbacks/__init__.py
  2. +2
    -1
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +9
    -9
      fastNLP/core/controllers/evaluator.py
  4. +1
    -0
      fastNLP/core/controllers/trainer.py
  5. +15
    -5
      fastNLP/core/dataset/dataset.py
  6. +1
    -4
      fastNLP/core/dataset/field.py
  7. +8
    -6
      fastNLP/core/drivers/driver.py
  8. +2
    -2
      fastNLP/core/drivers/torch_driver/initialize_torch_driver.py
  9. +14
    -0
      tests/core/dataset/test_dataset.py

+ 3
- 2
fastNLP/core/callbacks/__init__.py View File

@@ -4,7 +4,8 @@ __all__ = [
'EventsList',
'Filter',
'CallbackManager',
'CheckpointCallback',
'ModelCheckpointCallback',
'TrainerCheckpointCallback',
'choose_progress_callback',
'ProgressCallback',
'RichCallback',
@@ -16,7 +17,7 @@ __all__ = [
from .callback import Callback
from .callback_events import EventsList, Events, Filter
from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback
from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback


+ 2
- 1
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -1,5 +1,6 @@
__all__ = [
'CheckpointCallback'
'ModelCheckpointCallback',
'TrainerCheckpointCallback'
]
import os
from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping


+ 9
- 9
fastNLP/core/controllers/evaluator.py View File

@@ -133,17 +133,18 @@ class Evaluator:

self.driver.barrier()

def run(self, num_eval_batch_per_dl: int = -1) -> Dict:
def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
"""
返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。
如果存在多个metric,一个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name
如果存在多个数据集,一个metric的情况,key的命名规则是
metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
如果存在多个metric,多个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name#dataloader_name
:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据
如果存在多个metric,一个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name
如果存在多个数据集,一个metric的情况,key的命名规则是
metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。
如果存在多个metric,多个dataloader的情况,key的命名规则是
metric_indicator_name#metric_name#dataloader_name
其中 metric_indicator_name 可能不存在

:param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。
:return:
"""
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
@@ -157,7 +158,6 @@ class Evaluator:
assert self.driver.has_test_dataloaders()

metric_results = {}

self.reset()
evaluate_context = self.driver.get_evaluate_context()
self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train')


+ 1
- 0
fastNLP/core/controllers/trainer.py View File

@@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger):
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")

if self.evaluator is not None and num_eval_sanity_batch > 0:
logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.")
self.on_sanity_check_begin()
sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch)
self.on_sanity_check_end(sanity_check_res)


+ 15
- 5
fastNLP/core/dataset/dataset.py View File

@@ -8,9 +8,8 @@ __all__ = [

import _pickle as pickle
from copy import deepcopy
from typing import Optional, List, Callable, Union, Dict, Any
from typing import Optional, List, Callable, Union, Dict, Any, Mapping
from functools import partial
import warnings

import numpy as np
from threading import Thread
@@ -197,6 +196,20 @@ class DataSet:
else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))

def __setitem__(self, key, value):
assert isinstance(key, int) and key<len(self)
assert isinstance(value, Instance) or isinstance(value, Mapping)
ins_keys = set(value.keys())
ds_keys = set(self.get_field_names())

if len(ins_keys - ds_keys) != 0:
raise KeyError(f"The following keys are not found in the Dataset:{list(ins_keys - ds_keys)}.")
if len(ds_keys - ins_keys) != 0:
raise KeyError(f"The following keys are not found in the Instance:{list(ds_keys - ins_keys)}.")

for field_name, field in self.field_arrays.items():
field[key] = value[field_name]

def __getattribute__(self, item):
return object.__getattribute__(self, item)

@@ -813,6 +826,3 @@ class DataSet:
self.collate_fns.set_input(*field_names)


class IterableDataset:
pass


+ 1
- 4
fastNLP/core/dataset/field.py View File

@@ -46,9 +46,6 @@ class FieldArray:

def __setitem__(self, idx: int, val: Any):
assert isinstance(idx, int)
if idx == -1:
idx = len(self) - 1
assert 0 <= idx < len(self), f"0<= idx <{len(self)}, but idx is {idx}"
self.content[idx] = val

def get(self, indices: Union[int, List[int]]):
@@ -79,7 +76,7 @@ class FieldArray:

def split(self, sep: str = None, inplace: bool = True):
r"""
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。将返回值
依次对自身的元素使用.split()方法,应该只有当本field的元素为str时,该方法才有用。

:param sep: 分割符,如果为None则直接调用str.split()。
:param inplace: 如果为True,则将新生成值替换本field。否则返回list。


+ 8
- 6
fastNLP/core/drivers/driver.py View File

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from datetime import datetime
from pathlib import Path
from io import BytesIO
import json

__all__ = [
'Driver'
@@ -447,13 +448,14 @@ class Driver(ABC):

exc_type, exc_value, exc_traceback_obj = sys.exc_info()
_write_exc_info = {
'exc_type': exc_type,
'exc_value': exc_value,
'time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
'global_rank': getattr(self, "global_rank", None),
'rank': self.get_local_rank(),
'exc_type': str(exc_type.__name__),
'exc_value': str(exc_value),
'exc_time': str(datetime.now().strftime('%Y-%m-%d-%H:%M:%S')),
'exc_global_rank': getattr(self, "global_rank", None),
'exc_local_rank': self.get_local_rank(),
}
sys.stderr.write(str(_write_exc_info)+"\n")
sys.stderr.write("\nException info:\n")
sys.stderr.write(json.dumps(_write_exc_info, indent=2)+"\n")

sys.stderr.write(f"Start to stop these pids:{self._pids}, please wait several seconds.\n")
for pid in self._pids:


+ 2
- 2
fastNLP/core/drivers/torch_driver/initialize_torch_driver.py View File

@@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
# world_size 和 rank
if FASTNLP_BACKEND_LAUNCH in os.environ:
if device is not None:
logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull "
"up your script. And we will directly get the local device via "
"`os.environ['LOCAL_RANK']`.")
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs)
@@ -65,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic
if not isinstance(device, List):
return TorchSingleDriver(model, device, **kwargs)
else:
logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use "
"`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter"
"`driver` as `TorchDDPDriver`.")
return TorchDDPDriver(model, device, **kwargs)


+ 14
- 0
tests/core/dataset/test_dataset.py View File

@@ -105,6 +105,20 @@ class TestDataSetMethods(unittest.TestCase):
self.assertTrue(isinstance(field_array, FieldArray))
self.assertEqual(len(field_array), 40)

def test_setitem(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
ds.add_field('i', list(range(len(ds))))
assert ds.get_field('i').content == list(range(len(ds)))
import random
random.shuffle(ds)
import numpy as np
np.random.shuffle(ds)
assert ds.get_field('i').content != list(range(len(ds)))

ins1 = ds[1]
ds[2] = ds[1]
assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y']

def test_get_item_error(self):
with self.assertRaises(RuntimeError):
ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})


Loading…
Cancel
Save