Browse Source

update overfit_batches

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
399065ae04
6 changed files with 64 additions and 61 deletions
  1. +12
    -7
      fastNLP/core/callbacks/progress_callback.py
  2. +15
    -17
      fastNLP/core/controllers/trainer.py
  3. +4
    -11
      fastNLP/core/dataloaders/utils.py
  4. +0
    -3
      fastNLP/core/drivers/torch_driver/ddp.py
  5. +24
    -22
      fastNLP/core/drivers/torch_driver/utils.py
  6. +9
    -1
      fastNLP/core/metrics/metric.py

+ 12
- 7
fastNLP/core/callbacks/progress_callback.py View File

@@ -20,6 +20,7 @@ class ProgressCallback(HasMonitorCallback):
must_have_monitor=must_have_monitor) must_have_monitor=must_have_monitor)
self.best_monitor_epoch = -1 self.best_monitor_epoch = -1
self.best_monitor_step = -1 self.best_monitor_step = -1
self.best_results = None


def record_better_monitor(self, trainer): def record_better_monitor(self, trainer):
self.best_monitor_step = trainer.global_forward_batches self.best_monitor_step = trainer.global_forward_batches
@@ -29,6 +30,8 @@ class ProgressCallback(HasMonitorCallback):
if self.best_monitor_epoch != -1: if self.best_monitor_epoch != -1:
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}."
if self.best_results is not None:
msg = msg + ' The evaluation result: \n' + str(self.best_results)
logger.info(msg) logger.info(msg)


@property @property
@@ -147,9 +150,11 @@ class RichCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')} not key.startswith('_')}
if self.format_json: if self.format_json:
self.progress_bar.console.print_json(json.dumps(results))
results = json.dumps(results)
self.progress_bar.console.print_json(results)
else: else:
self.progress_bar.print(results) self.progress_bar.print(results)
self.best_results = results


def clear_tasks(self): def clear_tasks(self):
for key, taskid in self.task2id.items(): for key, taskid in self.task2id.items():
@@ -227,9 +232,9 @@ class RawTextCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')} not key.startswith('_')}
if self.format_json: if self.format_json:
logger.info(json.dumps(results))
else:
logger.info(results)
results = json.dumps(results)
logger.info(results)
self.best_results = results


@property @property
def name(self): # progress bar的名称 def name(self): # progress bar的名称
@@ -316,9 +321,9 @@ class TqdmCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')} not key.startswith('_')}
if self.format_json: if self.format_json:
logger.info(json.dumps(results))
else:
logger.info(results)
results = json.dumps(results)
logger.info(results)
self.best_results = results


def clear_tasks(self): def clear_tasks(self):
for key, taskid in self.task2id.items(): for key, taskid in self.task2id.items():


+ 15
- 17
fastNLP/core/controllers/trainer.py View File

@@ -119,19 +119,6 @@ class Trainer(TrainerEventTrigger):
对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。


:param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少 batch 的数据
来进行过拟合训练;其中 0 为 默认值表示不进行过拟合;-1 表示使用所有的数据进行训练;

.. note::

您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定大小的 batch,然后在之后的所有 epoch 中都是用这些数据来进行过拟合训练;

.. warning::

在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的;

:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认
为 None; 为 None;
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``,
@@ -258,7 +245,20 @@ class Trainer(TrainerEventTrigger):


注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;


:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。
:param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据
来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练;

.. note::

您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据
来进行训练;

.. warning::

在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的;


:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None;


@@ -370,7 +370,6 @@ class Trainer(TrainerEventTrigger):
optimizers, optimizers,
device: Optional[Union[int, List[int], str]] = "cpu", device: Optional[Union[int, List[int], str]] = "cpu",
n_epochs: int = 20, n_epochs: int = 20,
overfit_batches: int = 0,
evaluate_dataloaders=None, evaluate_dataloaders=None,
batch_step_fn: Optional[Callable] = None, batch_step_fn: Optional[Callable] = None,
evaluate_batch_step_fn: Optional[Callable] = None, evaluate_batch_step_fn: Optional[Callable] = None,
@@ -387,6 +386,7 @@ class Trainer(TrainerEventTrigger):
monitor: Union[str, Callable] = None, monitor: Union[str, Callable] = None,
larger_better: bool = True, larger_better: bool = True,
n_batches: int = -1, n_batches: int = -1,
overfit_batches: int = 0,
marker: Optional[str] = None, marker: Optional[str] = None,
**kwargs **kwargs
): ):
@@ -522,8 +522,6 @@ class Trainer(TrainerEventTrigger):
self.larger_better = larger_better self.larger_better = larger_better
if metrics is not None: if metrics is not None:
if overfit_batches != 0: if overfit_batches != 0:
logger.warning("Notice you are trying to 'overfit' the model and also using 'metrics', it may cause error "
"because 'metrics' are prepared for 'evaluate_dataloaders', but now 'train_dataloader'.")
evaluate_dataloaders = self.dataloader evaluate_dataloaders = self.dataloader
if evaluate_dataloaders is not None: if evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every) check_evaluate_every(evaluate_every)


+ 4
- 11
fastNLP/core/dataloaders/utils.py View File

@@ -120,20 +120,13 @@ class OverfitDataLoader:
def __init__(self, dataloader, overfit_batches: int): def __init__(self, dataloader, overfit_batches: int):
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作;
self.batches = [] self.batches = []
self.overfit_batches = int(overfit_batches)


if isinstance(overfit_batches, int):
if overfit_batches < 0 and overfit_batches != -1:
raise ValueError("Parameter 'overfit_batches' can only be '-1' when it is smaller than 0, and it means"
"that you use all the data to check whether it could be overfitted.")
else:
raise TypeError("Parameter 'overfit_batches' can only be 'int' type, check the parameter you input into 'Trainer'.")

if overfit_batches > len(dataloader):
logger.warning("Parameter 'overfit_batches' is bigger than the real length of 'train dataloader'.")
if self.overfit_batches > len(dataloader):
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.")


for idx, batch in enumerate(dataloader): for idx, batch in enumerate(dataloader):

if idx < overfit_batches or overfit_batches == -1:
if idx < self.overfit_batches or self.overfit_batches < -1:
self.batches.append(batch) self.batches.append(batch)


def __len__(self): def __len__(self):


+ 0
- 3
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -140,9 +140,6 @@ if _NEED_IMPORT_TORCH:
import torch.distributed as dist import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler
from torch.utils.data import BatchSampler as TorchBatchSampler


__all__ = [ __all__ = [
'TorchDDPDriver' 'TorchDDPDriver'


+ 24
- 22
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -181,18 +181,16 @@ def replace_sampler(dataloader: "DataLoader", sampler):
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}


# 'multiprocessing_context' 是 user-defined function; # 'multiprocessing_context' 是 user-defined function;
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context
if getattr(dataloader, 'multiprocessing_context', None) is not None:
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context


# 拿到 dataloader '__init__' 函数的默认函数签名; # 拿到 dataloader '__init__' 函数的默认函数签名;
init_params = dict(inspect.signature(dataloader.__init__).parameters) init_params = dict(inspect.signature(dataloader.__init__).parameters)


# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader
# 中寻找;
# 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
# 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了;
if has_variadic_kwargs and isinstance(dataloader, DataLoader):
# 防止用户写入了 super().__init__(**kwargs)
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
if key not in init_params and key != 'self': if key not in init_params and key != 'self':
init_params[key] = value init_params[key] = value
@@ -204,7 +202,8 @@ def replace_sampler(dataloader: "DataLoader", sampler):
non_default_params.add("dataset") non_default_params.add("dataset")


reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})
if isinstance(dataloader, DataLoader):
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})


batch_sampler = getattr(dataloader, "batch_sampler") batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
@@ -218,35 +217,31 @@ def replace_sampler(dataloader: "DataLoader", sampler):
and p.name not in reconstruct_args and p.name not in reconstruct_args
} }


# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上;
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化
if required_args: if required_args:
required_args = sorted(required_args) required_args = sorted(required_args)
dataloader_self_name = dataloader.__class__.__name__ dataloader_self_name = dataloader.__class__.__name__
raise Exception( raise Exception(
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
f"`{dataloader_self_name}`'s attribute."
) )


# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
if not has_variadic_kwargs: if not has_variadic_kwargs:

# the dataloader signature does not allow keyword arguments that need to be passed # the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = reconstruct_args.keys() - init_params.keys() missing_kwargs = reconstruct_args.keys() - init_params.keys()
if missing_kwargs: if missing_kwargs:
missing_kwargs = sorted(missing_kwargs) missing_kwargs = sorted(missing_kwargs)
dataloader_self_name = dataloader.__class__.__name__ dataloader_self_name = dataloader.__class__.__name__
raise Exception( raise Exception(
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
) )
# 如果没有kwargs,则保证一下只传入需要的参数
if not isinstance(dataloader, DataLoader):
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}

return type(dataloader)(**reconstruct_args) return type(dataloader)(**reconstruct_args)




@@ -260,6 +255,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler):
params_keys.remove(k) params_keys.remove(k)
params = {k: getattr(dataloader, k) for k in params_keys} params = {k: getattr(dataloader, k) for k in params_keys}
params["batch_sampler"] = new_batch_sampler params["batch_sampler"] = new_batch_sampler

if not isinstance(dataloader, DataLoader):
init_params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if not has_variadic_kwargs:
params = {key:value for key,value in params.items() if key in init_params}

return type(dataloader)(**params) return type(dataloader)(**params)






+ 9
- 1
fastNLP/core/metrics/metric.py View File

@@ -98,7 +98,7 @@ class Metric:
return _wrap_get_metric return _wrap_get_metric


def __setattr__(self, key, value): def __setattr__(self, key, value):
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
if getattr(self, '_cannot_change_element', False):
if key in self.elements and isinstance(value, (float, int, bool)): if key in self.elements and isinstance(value, (float, int, bool)):
self.elements[key].fill_value(value) self.elements[key].fill_value(value)
return return
@@ -109,6 +109,14 @@ class Metric:
raise RuntimeError("Please use register_element() function to add Element.") raise RuntimeError("Please use register_element() function to add Element.")
object.__setattr__(self, key, value) object.__setattr__(self, key, value)


# 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning
def __getattr__(self, name: str) -> Element:
if 'elements' in self.__dict__:
elements = self.__dict__['elements']
if name in elements:
return elements[name]
raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name))

def _wrap_update(self, update): def _wrap_update(self, update):
@functools.wraps(update) @functools.wraps(update)
def _wrap_update(*args, **kwargs): def _wrap_update(*args, **kwargs):


Loading…
Cancel
Save