@@ -20,6 +20,7 @@ class ProgressCallback(HasMonitorCallback): | |||||
must_have_monitor=must_have_monitor) | must_have_monitor=must_have_monitor) | ||||
self.best_monitor_epoch = -1 | self.best_monitor_epoch = -1 | ||||
self.best_monitor_step = -1 | self.best_monitor_step = -1 | ||||
self.best_results = None | |||||
def record_better_monitor(self, trainer): | def record_better_monitor(self, trainer): | ||||
self.best_monitor_step = trainer.global_forward_batches | self.best_monitor_step = trainer.global_forward_batches | ||||
@@ -29,6 +30,8 @@ class ProgressCallback(HasMonitorCallback): | |||||
if self.best_monitor_epoch != -1: | if self.best_monitor_epoch != -1: | ||||
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | ||||
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | ||||
if self.best_results is not None: | |||||
msg = msg + ' The evaluation result: \n' + str(self.best_results) | |||||
logger.info(msg) | logger.info(msg) | ||||
@property | @property | ||||
@@ -147,9 +150,11 @@ class RichCallback(ProgressCallback): | |||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | ||||
not key.startswith('_')} | not key.startswith('_')} | ||||
if self.format_json: | if self.format_json: | ||||
self.progress_bar.console.print_json(json.dumps(results)) | |||||
results = json.dumps(results) | |||||
self.progress_bar.console.print_json(results) | |||||
else: | else: | ||||
self.progress_bar.print(results) | self.progress_bar.print(results) | ||||
self.best_results = results | |||||
def clear_tasks(self): | def clear_tasks(self): | ||||
for key, taskid in self.task2id.items(): | for key, taskid in self.task2id.items(): | ||||
@@ -227,9 +232,9 @@ class RawTextCallback(ProgressCallback): | |||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | ||||
not key.startswith('_')} | not key.startswith('_')} | ||||
if self.format_json: | if self.format_json: | ||||
logger.info(json.dumps(results)) | |||||
else: | |||||
logger.info(results) | |||||
results = json.dumps(results) | |||||
logger.info(results) | |||||
self.best_results = results | |||||
@property | @property | ||||
def name(self): # progress bar的名称 | def name(self): # progress bar的名称 | ||||
@@ -316,9 +321,9 @@ class TqdmCallback(ProgressCallback): | |||||
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if | ||||
not key.startswith('_')} | not key.startswith('_')} | ||||
if self.format_json: | if self.format_json: | ||||
logger.info(json.dumps(results)) | |||||
else: | |||||
logger.info(results) | |||||
results = json.dumps(results) | |||||
logger.info(results) | |||||
self.best_results = results | |||||
def clear_tasks(self): | def clear_tasks(self): | ||||
for key, taskid in self.task2id.items(): | for key, taskid in self.task2id.items(): | ||||
@@ -119,19 +119,6 @@ class Trainer(TrainerEventTrigger): | |||||
对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 | 对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 | ||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 | :param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 | ||||
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少 batch 的数据 | |||||
来进行过拟合训练;其中 0 为 默认值表示不进行过拟合;-1 表示使用所有的数据进行训练; | |||||
.. note:: | |||||
您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 | |||||
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定大小的 batch,然后在之后的所有 epoch 中都是用这些数据来进行过拟合训练; | |||||
.. warning:: | |||||
在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders | |||||
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; | |||||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | ||||
为 None; | 为 None; | ||||
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | :param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | ||||
@@ -258,7 +245,20 @@ class Trainer(TrainerEventTrigger): | |||||
注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | 注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | ||||
:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||||
:param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||||
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据 | |||||
来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练; | |||||
.. note:: | |||||
您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等 | |||||
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据 | |||||
来进行训练; | |||||
.. warning:: | |||||
在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders | |||||
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的; | |||||
:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | :param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | ||||
@@ -370,7 +370,6 @@ class Trainer(TrainerEventTrigger): | |||||
optimizers, | optimizers, | ||||
device: Optional[Union[int, List[int], str]] = "cpu", | device: Optional[Union[int, List[int], str]] = "cpu", | ||||
n_epochs: int = 20, | n_epochs: int = 20, | ||||
overfit_batches: int = 0, | |||||
evaluate_dataloaders=None, | evaluate_dataloaders=None, | ||||
batch_step_fn: Optional[Callable] = None, | batch_step_fn: Optional[Callable] = None, | ||||
evaluate_batch_step_fn: Optional[Callable] = None, | evaluate_batch_step_fn: Optional[Callable] = None, | ||||
@@ -387,6 +386,7 @@ class Trainer(TrainerEventTrigger): | |||||
monitor: Union[str, Callable] = None, | monitor: Union[str, Callable] = None, | ||||
larger_better: bool = True, | larger_better: bool = True, | ||||
n_batches: int = -1, | n_batches: int = -1, | ||||
overfit_batches: int = 0, | |||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
**kwargs | **kwargs | ||||
): | ): | ||||
@@ -522,8 +522,6 @@ class Trainer(TrainerEventTrigger): | |||||
self.larger_better = larger_better | self.larger_better = larger_better | ||||
if metrics is not None: | if metrics is not None: | ||||
if overfit_batches != 0: | if overfit_batches != 0: | ||||
logger.warning("Notice you are trying to 'overfit' the model and also using 'metrics', it may cause error " | |||||
"because 'metrics' are prepared for 'evaluate_dataloaders', but now 'train_dataloader'.") | |||||
evaluate_dataloaders = self.dataloader | evaluate_dataloaders = self.dataloader | ||||
if evaluate_dataloaders is not None: | if evaluate_dataloaders is not None: | ||||
check_evaluate_every(evaluate_every) | check_evaluate_every(evaluate_every) | ||||
@@ -120,20 +120,13 @@ class OverfitDataLoader: | |||||
def __init__(self, dataloader, overfit_batches: int): | def __init__(self, dataloader, overfit_batches: int): | ||||
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; | self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作; | ||||
self.batches = [] | self.batches = [] | ||||
self.overfit_batches = int(overfit_batches) | |||||
if isinstance(overfit_batches, int): | |||||
if overfit_batches < 0 and overfit_batches != -1: | |||||
raise ValueError("Parameter 'overfit_batches' can only be '-1' when it is smaller than 0, and it means" | |||||
"that you use all the data to check whether it could be overfitted.") | |||||
else: | |||||
raise TypeError("Parameter 'overfit_batches' can only be 'int' type, check the parameter you input into 'Trainer'.") | |||||
if overfit_batches > len(dataloader): | |||||
logger.warning("Parameter 'overfit_batches' is bigger than the real length of 'train dataloader'.") | |||||
if self.overfit_batches > len(dataloader): | |||||
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.") | |||||
for idx, batch in enumerate(dataloader): | for idx, batch in enumerate(dataloader): | ||||
if idx < overfit_batches or overfit_batches == -1: | |||||
if idx < self.overfit_batches or self.overfit_batches < -1: | |||||
self.batches.append(batch) | self.batches.append(batch) | ||||
def __len__(self): | def __len__(self): | ||||
@@ -140,9 +140,6 @@ if _NEED_IMPORT_TORCH: | |||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import BatchSampler | from torch.utils.data import BatchSampler | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | |||||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||||
from torch.utils.data import BatchSampler as TorchBatchSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchDDPDriver' | 'TorchDDPDriver' | ||||
@@ -181,18 +181,16 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')} | ||||
# 'multiprocessing_context' 是 user-defined function; | # 'multiprocessing_context' 是 user-defined function; | ||||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||||
if getattr(dataloader, 'multiprocessing_context', None) is not None: | |||||
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context | |||||
# 拿到 dataloader '__init__' 函数的默认函数签名; | # 拿到 dataloader '__init__' 函数的默认函数签名; | ||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | init_params = dict(inspect.signature(dataloader.__init__).parameters) | ||||
# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果 | |||||
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户 | |||||
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader | |||||
# 中寻找; | |||||
# 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数 | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | ||||
if has_variadic_kwargs: | |||||
# 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了; | |||||
if has_variadic_kwargs and isinstance(dataloader, DataLoader): | |||||
# 防止用户写入了 super().__init__(**kwargs) | |||||
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items(): | ||||
if key not in init_params and key != 'self': | if key not in init_params and key != 'self': | ||||
init_params[key] = value | init_params[key] = value | ||||
@@ -204,7 +202,8 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
non_default_params.add("dataset") | non_default_params.add("dataset") | ||||
reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} | ||||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||||
if isinstance(dataloader, DataLoader): | |||||
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) | |||||
batch_sampler = getattr(dataloader, "batch_sampler") | batch_sampler = getattr(dataloader, "batch_sampler") | ||||
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): | ||||
@@ -218,35 +217,31 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
and p.name not in reconstruct_args | and p.name not in reconstruct_args | ||||
} | } | ||||
# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上; | |||||
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化 | |||||
if required_args: | if required_args: | ||||
required_args = sorted(required_args) | required_args = sorted(required_args) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as some of the `__init__` arguments are not available as instance attributes. " | |||||
f"The missing attributes are {required_args}. " | |||||
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or " | |||||
"manually add the `DistributedSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||||
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. " | |||||
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its " | |||||
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be " | |||||
f"`{dataloader_self_name}`'s attribute." | |||||
) | ) | ||||
# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | # 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs; | ||||
if not has_variadic_kwargs: | if not has_variadic_kwargs: | ||||
# the dataloader signature does not allow keyword arguments that need to be passed | # the dataloader signature does not allow keyword arguments that need to be passed | ||||
missing_kwargs = reconstruct_args.keys() - init_params.keys() | missing_kwargs = reconstruct_args.keys() - init_params.keys() | ||||
if missing_kwargs: | if missing_kwargs: | ||||
missing_kwargs = sorted(missing_kwargs) | missing_kwargs = sorted(missing_kwargs) | ||||
dataloader_self_name = dataloader.__class__.__name__ | dataloader_self_name = dataloader.__class__.__name__ | ||||
raise Exception( | raise Exception( | ||||
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. " | |||||
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " | |||||
f"The missing arguments are {missing_kwargs}. " | |||||
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or " | |||||
"manually add the `DistributedSampler` as: " | |||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||||
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found." | |||||
) | ) | ||||
# 如果没有kwargs,则保证一下只传入需要的参数 | |||||
if not isinstance(dataloader, DataLoader): | |||||
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params} | |||||
return type(dataloader)(**reconstruct_args) | return type(dataloader)(**reconstruct_args) | ||||
@@ -260,6 +255,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
params_keys.remove(k) | params_keys.remove(k) | ||||
params = {k: getattr(dataloader, k) for k in params_keys} | params = {k: getattr(dataloader, k) for k in params_keys} | ||||
params["batch_sampler"] = new_batch_sampler | params["batch_sampler"] = new_batch_sampler | ||||
if not isinstance(dataloader, DataLoader): | |||||
init_params = dict(inspect.signature(dataloader.__init__).parameters) | |||||
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items()) | |||||
if not has_variadic_kwargs: | |||||
params = {key:value for key,value in params.items() if key in init_params} | |||||
return type(dataloader)(**params) | return type(dataloader)(**params) | ||||
@@ -98,7 +98,7 @@ class Metric: | |||||
return _wrap_get_metric | return _wrap_get_metric | ||||
def __setattr__(self, key, value): | def __setattr__(self, key, value): | ||||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | |||||
if getattr(self, '_cannot_change_element', False): | |||||
if key in self.elements and isinstance(value, (float, int, bool)): | if key in self.elements and isinstance(value, (float, int, bool)): | ||||
self.elements[key].fill_value(value) | self.elements[key].fill_value(value) | ||||
return | return | ||||
@@ -109,6 +109,14 @@ class Metric: | |||||
raise RuntimeError("Please use register_element() function to add Element.") | raise RuntimeError("Please use register_element() function to add Element.") | ||||
object.__setattr__(self, key, value) | object.__setattr__(self, key, value) | ||||
# 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning | |||||
def __getattr__(self, name: str) -> Element: | |||||
if 'elements' in self.__dict__: | |||||
elements = self.__dict__['elements'] | |||||
if name in elements: | |||||
return elements[name] | |||||
raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name)) | |||||
def _wrap_update(self, update): | def _wrap_update(self, update): | ||||
@functools.wraps(update) | @functools.wraps(update) | ||||
def _wrap_update(*args, **kwargs): | def _wrap_update(*args, **kwargs): | ||||