@@ -146,11 +146,13 @@ class CallbackManager: | |||||
r""" | r""" | ||||
用于断点重训的 callback 的保存函数; | 用于断点重训的 callback 的保存函数; | ||||
该函数主要涉及两个方面: | 该函数主要涉及两个方面: | ||||
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 | |||||
断点重训应当保存的状态; | |||||
2. 每一个具体的 callback 函数的 filter 的状态; | |||||
:return: 一个包含上述内容的字典:: | |||||
1. callback 的状态的保存;我们会调用每一个 callback 的 `on_save_checkpoint` 方法,该方法应当返回一个字典,其中包含着 | |||||
断点重训应当保存的状态; | |||||
2. 每一个具体的 callback 函数的 filter 的状态; | |||||
:return: 一个包含上述内容的字典: | |||||
.. code-block:: | |||||
{ | { | ||||
"callback_name_1": { | "callback_name_1": { | ||||
@@ -158,6 +160,7 @@ class CallbackManager: | |||||
"filter_states": {"on_train_begin": filter1.state_dict(), ...} | "filter_states": {"on_train_begin": filter1.state_dict(), ...} | ||||
} | } | ||||
} | } | ||||
""" | """ | ||||
states = {} | states = {} | ||||
@@ -39,7 +39,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 | ||||
取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 | 取值: (1) str 类型,监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最 | ||||
匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | 匹配的那个作为 monitor ; (2) 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor | ||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||||
的结果,如果当前结果中没有相关的monitor 值请返回 None 。 | |||||
:param watch_monitor_larger_better: watch_monitor 是否越大越好。 | :param watch_monitor_larger_better: watch_monitor 是否越大越好。 | ||||
:param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 | ||||
`model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 | ||||
@@ -10,13 +10,13 @@ class TorchGradClipCallback(Callback): | |||||
在每次 optimizer update 之前将 parameter 进行 clip | 在每次 optimizer update 之前将 parameter 进行 clip | ||||
:param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | :param float clip_value: 将gradient 限制到[-clip_value, clip_value]。clip_value应该为正数 | ||||
:param str clip_type: 支持'norm', 'value'两种:: | |||||
:param str clip_type: 支持'norm', 'value'两种: | |||||
1 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
1. 'norm', 将gradient的norm rescale到[-clip_value, clip_value] | |||||
2. 'value', 将gradient限制在[-clip_value, clip_value], | |||||
小于-clip_value的gradient被赋值为-clip_value; | |||||
大于clip_value的gradient被赋值为clip_value. | |||||
2 'value', 将gradient限制在[-clip_value, clip_value], | |||||
小于-clip_value的gradient被赋值为-clip_value; | |||||
大于clip_value的gradient被赋值为clip_value. | |||||
:param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | :param None,torch.Tensor,List[torch.Tensor] parameters: 一般通过model.parameters()获得。 | ||||
如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 | 如果为None则默认对 Trainer 的 optimizers 中所有参数进行梯度裁剪。 | ||||
""" | """ | ||||
@@ -9,6 +9,7 @@ from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPad | |||||
from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder | ||||
from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | from .raw_padder import RawNumberPadder, RawSequencePadder, RawTensorPadder | ||||
from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | from .paddle_padder import PaddleTensorPadder, PaddleSequencePadder, PaddleNumberPadder | ||||
from .jittor_padder import JittorTensorPadder, JittorSequencePadder, JittorNumberPadder | |||||
from .exceptions import * | from .exceptions import * | ||||
@@ -91,6 +92,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | |||||
return JittorNumberPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for list(Field:{field_name}).") | ||||
@@ -103,6 +106,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return TorchSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | return PaddleSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | ||||
elif backend == 'jittor': | |||||
return JittorSequencePadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for nested list(Field:{field_name}).") | ||||
@@ -116,6 +121,8 @@ def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)-> | |||||
return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | return TorchTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | ||||
elif backend == 'paddle': | elif backend == 'paddle': | ||||
return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | return PaddleTensorPadder(pad_val=pad_val, ele_dtype=None, dtype=dtype) | ||||
elif backend == 'jittor': | |||||
return JittorTensorPadder(pad_val=pad_val, ele_dtype=ele_dtype, dtype=dtype) | |||||
else: | else: | ||||
raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | raise ValueError(f"backend={backend} is not supported for tensors(Field:{field_name}).") | ||||
@@ -0,0 +1,195 @@ | |||||
__all__ = [ | |||||
'JittorNumberPadder', | |||||
'JittorSequencePadder', | |||||
'JittorTensorPadder' | |||||
] | |||||
from inspect import isclass | |||||
import numpy as np | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
numpy_to_jittor_dtype_dict = { | |||||
np.bool_: 'bool', | |||||
np.uint8: 'uint8', | |||||
np.int8: "int8", | |||||
np.int16: "int16", | |||||
np.int32: "int32", | |||||
np.int64: "int64", | |||||
np.float16: "float16", | |||||
np.float32: 'float32', | |||||
np.float64: 'float32', # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 | |||||
} | |||||
# number_to_jittor_dtype_dict = { | |||||
# float: 'float32', # 因为 paddle.tensor([1], dtype=float)是paddle.float64 | |||||
# int: 'int64', | |||||
# bool: 'bool' | |||||
# } | |||||
from .padder import Padder | |||||
from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class | |||||
from .exceptions import * | |||||
def is_jittor_tensor(dtype): | |||||
if not isclass(dtype) and isinstance(dtype, jittor.jittor_core.Var): | |||||
return True | |||||
return False | |||||
def is_jittor_dtype_str(dtype): | |||||
try: | |||||
if isinstance(dtype, str) and dtype in {'bool', 'float16', 'uint16', 'float32', 'float64', 'int8', | |||||
'int16', 'int32', 'int64', 'uint8', 'complex64', 'complex128', | |||||
u'bool', u'float16', u'uint16', u'float32', u'float64', u'int8', | |||||
u'int16', u'int32', u'int64', u'uint8'}: | |||||
return True | |||||
except: | |||||
pass | |||||
return False | |||||
def _get_dtype(ele_dtype, dtype, class_name): | |||||
if not (ele_dtype is None or ( | |||||
is_number_or_numpy_number(ele_dtype) or is_jittor_tensor(ele_dtype) or is_jittor_dtype_str(dtype))): | |||||
raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " | |||||
f"or numpy numbers or jittor.Var but get `{ele_dtype}`.") | |||||
if dtype is not None: | |||||
if not (is_jittor_tensor(dtype) or is_number(dtype) or is_jittor_dtype_str(dtype)): | |||||
raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " | |||||
f"or jittor.dtype but get `{dtype}`.") | |||||
# dtype = number_to_jittor_dtype_dict.get(dtype, dtype) | |||||
else: | |||||
# if (is_number(ele_dtype) or is_jittor_tensor(ele_dtype)): | |||||
# # ele_dtype = number_to_jittor_dtype_dict.get(ele_dtype, ele_dtype) | |||||
# dtype = ele_dtype | |||||
# elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 | |||||
# dtype = numpy_to_jittor_dtype_dict.get(ele_dtype.type) | |||||
if is_numpy_generic_class(ele_dtype): | |||||
dtype = numpy_to_jittor_dtype_dict.get(ele_dtype) | |||||
else: | |||||
dtype = ele_dtype | |||||
return dtype | |||||
class JittorNumberPadder(Padder): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
可以将形如 [1, 2, 3] 这类的数据转为 jittor.Var([1, 2, 3]) | |||||
:param pad_val: 该值无意义 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
return jittor.Var(np.array(batch_field, dtype=dtype)) | |||||
class JittorSequencePadder(Padder): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
将类似于 [[1], [1, 2]] 的内容 pad 为 jittor.Var([[1, 0], [1, 2]]) 可以 pad 多重嵌套的数据。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
tensor = get_padded_jittor_tensor(batch_field, dtype=dtype, pad_val=pad_val) | |||||
return tensor | |||||
class JittorTensorPadder(Padder): | |||||
def __init__(self, pad_val=0, ele_dtype=None, dtype=None): | |||||
""" | |||||
目前支持 [jittor.Var([3, 2], jittor.Var([1])] 类似的。若内部元素不为 jittor.Var ,则必须含有 tolist() 方法。 | |||||
:param pad_val: 需要 pad 的值。 | |||||
:param ele_dtype: 用于检测当前 field 的元素类型是否可以转换为 jittor.Var 类型。 | |||||
:param dtype: 输出的数据的 dtype 是什么。如 jittor.long, jittor.float32, int, float 等 | |||||
""" | |||||
dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) | |||||
super().__init__(pad_val=pad_val, dtype=dtype) | |||||
@staticmethod | |||||
def pad(batch_field, pad_val, dtype): | |||||
try: | |||||
if not isinstance(batch_field[0], jittor.Var): | |||||
batch_field = [jittor.Var(np.array(field.tolist(), dtype=dtype)) for field in batch_field] | |||||
except AttributeError: | |||||
raise RuntimeError(f"If the field is not a jittor.Var (it is {type(batch_field[0])}), " | |||||
f"it must have tolist() method.") | |||||
shapes = [field.shape for field in batch_field] | |||||
max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] | |||||
# if dtype is not None: | |||||
# tensor = jittor.full(max_shape, pad_val, dtype=dtype) | |||||
# else: | |||||
tensor = jittor.full(max_shape, pad_val, dtype=dtype) | |||||
for i, field in enumerate(batch_field): | |||||
slices = (i,) + tuple(slice(0, s) for s in shapes[i]) | |||||
tensor[slices] = field | |||||
return tensor | |||||
def fill_tensor(batch_field, padded_batch, dtype): | |||||
""" | |||||
将 batch_field 中的值填入到 tensor 中。 | |||||
:param batch_field: 需要填充进入 array 中的内容 | |||||
:param padded_batch: 待填充的 tensor | |||||
:param dtype: 数据的类别 | |||||
:return: | |||||
""" | |||||
if padded_batch.ndim == 2: | |||||
for i, content_i in enumerate(batch_field): | |||||
padded_batch[i, :len(content_i)] = jittor.Var(np.array(content_i, dtype=dtype)) | |||||
elif padded_batch.ndim == 3: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
padded_batch[i, j, :len(content_ii)] = jittor.Var(np.array(content_ii, dtype=dtype)) | |||||
elif padded_batch.ndim == 4: | |||||
try: # 应该是图像,所以直接应该就 ok 了。 | |||||
padded_batch = np.array(batch_field) | |||||
except: | |||||
for i, content_i in enumerate(batch_field): | |||||
for j, content_ii in enumerate(content_i): | |||||
for k, content_iii in enumerate(content_ii): | |||||
padded_batch[i, j, k, :len(content_iii)] = jittor.Var(np.array(content_iii, dtype=dtype)) | |||||
elif padded_batch.ndim == 1: | |||||
padded_batch[:] = jittor.Var(np.array(batch_field, dtype=dtype)) | |||||
else: | |||||
raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " | |||||
"report.") | |||||
return padded_batch | |||||
def get_padded_jittor_tensor(batch_field, dtype=None, pad_val=0): | |||||
""" | |||||
例如: | |||||
[[1,2], [3]] -> jittor.LongTensor([[1, 2], [3, 0]]) | |||||
:param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) | |||||
/4d(多为图片)。 | |||||
:param dtype: 目标类别是什么 | |||||
:param pad_val: pad 的 value | |||||
:return: | |||||
""" | |||||
shapes = get_shape(batch_field) | |||||
tensor = jittor.full(shapes, pad_val, dtype=dtype) | |||||
tensor = fill_tensor(batch_field, tensor, dtype=dtype) | |||||
return tensor |
@@ -51,23 +51,30 @@ class Evaluator: | |||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; | 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; | ||||
:param fp16: 是否使用 fp16 。 | :param fp16: 是否使用 fp16 。 | ||||
:param verbose: 是否打印 evaluate 的结果。 | :param verbose: 是否打印 evaluate 的结果。 | ||||
:param kwargs: | |||||
bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout | |||||
与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 | |||||
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 | |||||
TODO 还没完成。 | |||||
Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 | |||||
tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, | |||||
当 auto_tensor_conversion_for_metric 为True时,fastNLP 将自动将输出中 paddle 的 tensor (其它非 tensor 的参数 | |||||
不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的输出 tensor 类型通过 driver 来决定, | |||||
metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换,请使用 input_mapping、output_mapping 参数进行。 | |||||
use_dist_sampler: 是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 | |||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
progress_bar: evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 | |||||
到当前terminal为交互型则使用 rich,否则使用 raw。 | |||||
:param \**kwargs: | |||||
See below | |||||
:kwargs: | |||||
* *model_use_eval_mode* (``bool``) -- | |||||
是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的 | |||||
dropout 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 | |||||
该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 | |||||
TODO 还没完成。 | |||||
* *auto_tensor_conversion_for_metric* (``Union[bool]``) -- | |||||
是否自动将输出中的 tensor 适配到 metrics 支持的。例如 model 输出是 | |||||
paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象,当 auto_tensor_conversion_for_metric 为True时,fastNLP 将 | |||||
自动将输出中 paddle 的 tensor (其它非 tensor 的参数不做任何处理)转换为 pytorch 的 tensor 再输入到 metrics 中进行评测。 model 的 | |||||
输出 tensor 类型通过 driver 来决定,metrics 支持的输入类型由 metrics 决定。如果需要更复杂的转换, | |||||
请使用 input_mapping、output_mapping 参数进行。 | |||||
* *use_dist_sampler* -- | |||||
是否使用分布式evaluate的方式。仅当 driver 为分布式类型时,该参数才有效。默认为根据 driver 是否支持 | |||||
分布式进行设置。如果为True,将使得每个进程上的 dataloader 自动使用不同数据,所有进程的数据并集是整个数据集。 | |||||
* *output_from_new_proc* -- | |||||
应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||||
* *progress_bar* -- | |||||
evaluate 的时候显示的 progress bar 。目前支持三种 [None, 'raw', 'rich', 'auto'], auto 表示如果检测 | |||||
到当前terminal为交互型则使用 rich,否则使用 raw。 | |||||
""" | """ | ||||
self.model = model | self.model = model | ||||
@@ -67,20 +67,21 @@ class Trainer(TrainerEventTrigger): | |||||
要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; | 要自己实现模型部分,而将训练层面的逻辑完全地交给 fastNLP; | ||||
:param model: 训练所需要的模型,目前支持 pytorch; | :param model: 训练所需要的模型,目前支持 pytorch; | ||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle | |||||
等国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 | |||||
:param driver: 训练模型所使用的具体的驱动模式,应当为以下选择中的一个:["torch", "torch_ddp", ],之后我们会加入 jittor、paddle 等 | |||||
国产框架的训练模式;其中 "torch" 表示使用 cpu 或者单张 gpu 进行训练 | |||||
:param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; | ||||
:param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; | ||||
:param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你 | :param device: 该参数用来指定具体训练时使用的机器;注意当该参数为 None 时,fastNLP 不会将模型和数据进行设备之间的移动处理,但是你 | ||||
可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 | |||||
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 | |||||
自己构造 DDP 的多进程场景); | |||||
可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间数据迁移的工作(通过这两个参数传入两个处理数据的函数);同时你也 | |||||
可以通过在 kwargs 添加参数 "data_device" 来让我们帮助您将数据迁移到指定的机器上(注意这种情况理应只出现在用户在 Trainer 实例化前 | |||||
自己构造 DDP 的多进程场景); | |||||
device 的可选输入如下所示: | device 的可选输入如下所示: | ||||
1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中; | 1. 可选输入:str: ['cpu', 'cuda', 'cuda:0', 'cuda:1', ...] 依次为'cpu'中, 可见的第一个GPU中, 可见的第一个GPU中, 可见的第二个GPU中; | ||||
2. torch.device:将模型装载到torch.device上; | 2. torch.device:将模型装载到torch.device上; | ||||
3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`; | 3. int: 将使用device_id为该值的gpu进行训练;如果值为 -1,那么默认使用全部的显卡,此时是 `TorchDDPDriver`; | ||||
4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; | 4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; | ||||
5. None: 为None则不对模型进行任何处理; | 5. None: 为None则不对模型进行任何处理; | ||||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; | :param n_epochs: 训练总共的 epoch 的数量,默认为 20; | ||||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | ||||
为 None; | 为 None; | ||||
@@ -121,26 +122,27 @@ class Trainer(TrainerEventTrigger): | |||||
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。 | 如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。 | ||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
:param kwargs: 一些其它的可能需要的参数; | |||||
torch_non_blocking: 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
data_device: 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; | |||||
注意如果 model_device 为 None,那么 data_device 不会起作用; | |||||
torch_ddp_kwargs: 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入 | |||||
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。 | |||||
set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | |||||
:param kwargs: 一些其它的可能需要的参数,见下方的说明 | |||||
:kwargs: | |||||
* *torch_non_blocking* -- 表示用于 pytorch 的 tensor 的 to 方法的参数 non_blocking; | |||||
* *data_device* -- 表示如果用户的模型 device (在 Driver 中对应为参数 model_device)为 None 时,我们会将数据迁移到 data_device 上; | |||||
注意如果 model_device 为 None,那么 data_device 不会起作用; | |||||
* *torch_ddp_kwargs* -- 用于配置 pytorch 的 DistributedDataParallel 初始化时的参数;仅用于 pytorch ddp 训练。例如传入 | |||||
{'find_unused_parameters': True} 来解决有有参数不参与前向运算导致的报错等。 | |||||
* *set_grad_to_none* -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | |||||
* *use_dist_sampler* -- 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch | |||||
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 | ||||
evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
* *evaluate_use_dist_sampler* -- 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; | |||||
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: | |||||
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 | ||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | |||||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 | 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 | ||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | ||||
train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 | |||||
train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 | |||||
evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 | |||||
evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 | |||||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 | |||||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 | |||||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 | |||||
* *evaluate_output_mapping* -- 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 | |||||
""" | """ | ||||
self.model = model | self.model = model | ||||
self.marker = marker | self.marker = marker | ||||
@@ -290,14 +292,14 @@ class Trainer(TrainerEventTrigger): | |||||
catch_KeyboardInterrupt=None): | catch_KeyboardInterrupt=None): | ||||
""" | """ | ||||
注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ModelCheckpoint | ||||
去保存断点重训的文件; | |||||
去保存断点重训的文件; | |||||
:param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。 | :param num_train_batch_per_epoch: 每个 epoch 运行多少个 batch 即停止,-1 为根据 dataloader 有多少个 batch 决定。 | ||||
:param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止,-1 为根据 dataloader 有多少个 batch 决定。 | :param num_eval_batch_per_dl: 每个 evaluate dataloader 运行多少个 batch 停止,-1 为根据 dataloader 有多少个 batch 决定。 | ||||
:param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误。为 0 表示不检测。 | :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 是否有错误。为 0 表示不检测。 | ||||
:param resume_from: 从哪个路径下恢复 trainer 的状态 | :param resume_from: 从哪个路径下恢复 trainer 的状态 | ||||
:param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态。 | ||||
:param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | :param catch_KeyboardInterrupt: 是否捕获KeyboardInterrupt, 如果捕获的话,不会抛出一场,trainer.run()之后的代码会继续运 | ||||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | |||||
行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) | |||||
:return: | :return: | ||||
""" | """ | ||||
@@ -417,39 +419,42 @@ class Trainer(TrainerEventTrigger): | |||||
def on(cls, event: Event, marker: Optional[str] = None): | def on(cls, event: Event, marker: Optional[str] = None): | ||||
r""" | r""" | ||||
函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; | 函数修饰器,用户可以使用该函数来方便地将一个函数转变为 callback 函数,从而进行训练流程中的控制; | ||||
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如 | |||||
Trainer.__init__(): | |||||
on_after_trainer_initialized(trainer, driver) | |||||
Trainer.run(): | |||||
if num_eval_sanity_batch>0: | |||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end(trainer, sanity_check_res) | |||||
try: | |||||
on_train_begin(trainer) | |||||
while cur_epoch_idx < n_epochs: | |||||
on_train_epoch_begin(trainer) | |||||
while batch_idx_in_epoch<=num_batches_per_epoch: | |||||
on_fetch_data_begin(trainer) | |||||
batch = next(dataloader) | |||||
on_fetch_data_end(trainer) | |||||
on_train_batch_begin(trainer, batch, indices) | |||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||||
on_after_backward(trainer) | |||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end(trainer) | |||||
on_train_epoch_end(trainer) | |||||
except BaseException: | |||||
self.on_exception(trainer, exception) | |||||
finally: | |||||
on_train_end(trainer) | |||||
支持的 event 时机有以下这些,其执行的时机顺序也如下所示。每个时机装饰的函数应该接受的参数列表也如下所示,例如:: | |||||
Trainer.__init__(): | |||||
on_after_trainer_initialized(trainer, driver) | |||||
Trainer.run(): | |||||
if num_eval_sanity_batch>0: | |||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||||
on_sanity_check_end(trainer, sanity_check_res) | |||||
try: | |||||
on_train_begin(trainer) | |||||
while cur_epoch_idx < n_epochs: | |||||
on_train_epoch_begin(trainer) | |||||
while batch_idx_in_epoch<=num_batches_per_epoch: | |||||
on_fetch_data_begin(trainer) | |||||
batch = next(dataloader) | |||||
on_fetch_data_end(trainer) | |||||
on_train_batch_begin(trainer, batch, indices) | |||||
on_before_backward(trainer, outputs) # 其中 outputs 是经过 output_mapping(如果设置了) 后的,否则即为 model 的输出。 | |||||
on_after_backward(trainer) | |||||
on_before_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_zero_grad(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_before_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_after_optimizers_step(trainer, optimizers) # 实际调用受到 accumulation_steps 影响 | |||||
on_train_batch_end(trainer) | |||||
on_train_epoch_end(trainer) | |||||
except BaseException: | |||||
self.on_exception(trainer, exception) | |||||
finally: | |||||
on_train_end(trainer) | |||||
其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | 其它 callback 例如 on_evaluate_begin(trainer)/on_evaluate_end(trainer, results)/on_save_model(trainer)/ | ||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 | |||||
特定的时间调用。 | |||||
on_load_model(trainer)/on_save_checkpoint(trainer)/on_load_checkpoint(trainer)将根据需要在Trainer.run()中 | |||||
特定的时间调用。 | |||||
Example:: | Example:: | ||||
from fastNLP import Event | from fastNLP import Event | ||||
@Trainer.on(Event.on_save_model()) | @Trainer.on(Event.on_save_model()) | ||||
def do_something_1(trainer): | def do_something_1(trainer): | ||||
@@ -696,7 +701,7 @@ class Trainer(TrainerEventTrigger): | |||||
r""" | r""" | ||||
用于断点重训的加载函数; | 用于断点重训的加载函数; | ||||
注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 | ||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; | |||||
保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; | |||||
注意我们目前不支持单卡到多卡的断点重训; | 注意我们目前不支持单卡到多卡的断点重训; | ||||
@@ -26,7 +26,8 @@ class State(dict): | |||||
为了实现断点重训,用户应当保证其保存的信息都是可序列化的; | 为了实现断点重训,用户应当保证其保存的信息都是可序列化的; | ||||
推荐的使用方式: | |||||
推荐的使用方式:: | |||||
>>> state = State() | >>> state = State() | ||||
>>> state["best_accuracy"] = 0.9 | >>> state["best_accuracy"] = 0.9 | ||||
>>> print(state["best_accuracy"]) | >>> print(state["best_accuracy"]) | ||||
@@ -64,38 +64,40 @@ class JittorDataLoader: | |||||
:param collate_fn: 对取得到的数据进行打包的callable函数 | :param collate_fn: 对取得到的数据进行打包的callable函数 | ||||
:param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | :param as_numpy: 返回数据是否设置为numpy类型,否则为torch.tensor类型 | ||||
""" | """ | ||||
# TODO 支持fastnlp dataset | |||||
# TODO 验证支持replacesampler (以后完成) | # TODO 验证支持replacesampler (以后完成) | ||||
# 是否为 jittor 类型的 dataset | |||||
# FastNLP Datset, collate_fn not None | |||||
if isinstance(dataset, FDataSet) and collate_fn is None: | |||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||||
if not isinstance(dataset, _JittorDataset): | |||||
self.dataset = _JittorDataset(dataset) | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == "auto": | if collate_fn == "auto": | ||||
if isinstance(dataset, FDataSet): | |||||
self._collate_fn = dataset.collator | |||||
self._collate_fn.set_backend(backend="jittor") | |||||
if isinstance(self.dataset.dataset, FDataSet): | |||||
self.collate_fn = self.dataset.dataset.collator | |||||
self.collate_fn.set_backend(backend="jittor") | |||||
else: | else: | ||||
self._collate_fn = Collator(backend="jittor") | |||||
self.collate_fn = Collator(backend="jittor") | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
elif isinstance(collate_fn, Callable): | elif isinstance(collate_fn, Callable): | ||||
if collate_fn is not collate_batch: | if collate_fn is not collate_batch: | ||||
self._collate_fn = collate_fn | |||||
self.collate_fn = collate_fn | |||||
else: | else: | ||||
self._collate_fn = collate_batch | |||||
self.dataset = _JittorDataset(dataset) | |||||
self.collate_fn = collate_batch | |||||
self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | self.dataset.set_attrs(batch_size=batch_size, shuffle=shuffle, drop_last=drop_last, | ||||
num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | num_workers=num_workers, buffer_size=buffer_size, stop_grad=stop_grad, | ||||
keep_numpy_array=keep_numpy_array, endless=endless) | keep_numpy_array=keep_numpy_array, endless=endless) | ||||
# 将内部dataset批次设置为1 | |||||
if isinstance(self.dataset.dataset, Dataset): | if isinstance(self.dataset.dataset, Dataset): | ||||
self.dataset.dataset.set_attrs(batch_size=1) | self.dataset.dataset.set_attrs(batch_size=1) | ||||
# 用户提供了 collate_fn,则会自动代替 jittor 提供 collate_batch 函数 | |||||
# self._collate_fn = _collate_fn | |||||
self.cur_batch_indices = None | self.cur_batch_indices = None | ||||
def __iter__(self): | def __iter__(self): | ||||
# TODO 第一次迭代后不能设置collate_fn,设置是无效的 | # TODO 第一次迭代后不能设置collate_fn,设置是无效的 | ||||
self.collate_fn = self._collate_fn | |||||
if self.cur_batch_indices is None: | if self.cur_batch_indices is None: | ||||
self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) | self.dataset.set_attrs(collate_batch=indice_collate_wrapper(self.collate_fn)) | ||||
for indices, data in self.dataset.__iter__(): | for indices, data in self.dataset.__iter__(): | ||||
@@ -107,8 +109,8 @@ class JittorDataLoader: | |||||
return len(self.dataset) // self.dataset.batch_size | return len(self.dataset) // self.dataset.batch_size | ||||
return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | return (len(self.dataset) - 1) // self.dataset.batch_size + 1 | ||||
def set_pad(self, field_name:Union[str, tuple], pad_val:Union[int, float, None]=0, dtype=None, backend=None, | |||||
pad_fn:Callable=None) -> Collator: | |||||
def set_pad(self, field_name: Union[str, tuple], pad_val: Union[int, float, None] = 0, dtype=None, backend=None, | |||||
pad_fn: Callable = None) -> "JittorDataLoader": | |||||
""" | """ | ||||
如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 | ||||
@@ -127,16 +129,18 @@ class JittorDataLoader: | |||||
形式,输出将被直接作为结果输出。 | 形式,输出将被直接作为结果输出。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, backend=backend) | |||||
return self._collate_fn | |||||
if isinstance(self.collate_fn, Collator): | |||||
self.collate_fn.set_pad(field_name=field_name, pad_val=pad_val, dtype=dtype, pad_fn=pad_fn, | |||||
backend=backend) | |||||
return self | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_pad() is allowed.") | ||||
def set_ignore(self, *field_names) -> Collator: | |||||
def set_ignore(self, *field_names) -> "JittorDataLoader": | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Example:: | Example:: | ||||
collator.set_ignore('field1', 'field2') | collator.set_ignore('field1', 'field2') | ||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | ||||
@@ -144,9 +148,9 @@ class JittorDataLoader: | |||||
__getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 | ||||
:return: 返回 Collator 自身 | :return: 返回 Collator 自身 | ||||
""" | """ | ||||
if isinstance(self._collate_fn, Collator): | |||||
self._collate_fn.set_ignore(*field_names) | |||||
return self._collate_fn | |||||
if isinstance(self.collate_fn, Collator): | |||||
self.collate_fn.set_ignore(*field_names) | |||||
return self | |||||
else: | else: | ||||
raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | raise ValueError(f"Only when the collate_fn is a fastNLP Collator, set_ignore() is allowed.") | ||||
@@ -158,5 +162,6 @@ class JittorDataLoader: | |||||
""" | """ | ||||
return self.cur_batch_indices | return self.cur_batch_indices | ||||
def prepare_jittor_dataloader(): | def prepare_jittor_dataloader(): | ||||
... | ... |
@@ -9,7 +9,6 @@ from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.io import DataLoader, Dataset, Sampler | from paddle.io import DataLoader, Dataset, Sampler | ||||
from paddle.fluid.dataloader.collate import default_collate_fn | |||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | from fastNLP.core.utils.dummy_class import DummyClass as DataLoader | ||||
@@ -52,6 +51,9 @@ class PaddleDataLoader(DataLoader): | |||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
worker_init_fn: Callable = None, persistent_workers=False) -> None: | worker_init_fn: Callable = None, persistent_workers=False) -> None: | ||||
# FastNLP Datset, collate_fn not None | |||||
if isinstance(dataset, FDataSet) and collate_fn is None: | |||||
raise ValueError("When use FastNLP DataSet, collate_fn must be not None") | |||||
if not isinstance(dataset, _PaddleDataset): | if not isinstance(dataset, _PaddleDataset): | ||||
dataset = _PaddleDataset(dataset) | dataset = _PaddleDataset(dataset) | ||||
@@ -66,10 +68,10 @@ class PaddleDataLoader(DataLoader): | |||||
if isinstance(collate_fn, str): | if isinstance(collate_fn, str): | ||||
if collate_fn == 'auto': | if collate_fn == 'auto': | ||||
if isinstance(dataset.dataset, FDataSet): | if isinstance(dataset.dataset, FDataSet): | ||||
self._collate_fn = dataset.dataset.collator | |||||
self._collate_fn.set_backend(backend="paddle") | |||||
collate_fn = dataset.dataset.collator | |||||
collate_fn.set_backend(backend="paddle") | |||||
else: | else: | ||||
self._collate_fn = Collator(backend="paddle") | |||||
collate_fn = Collator(backend="paddle") | |||||
else: | else: | ||||
raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | raise ValueError(f"collate_fn: {collate_fn} must be 'auto'") | ||||
@@ -142,6 +144,7 @@ class PaddleDataLoader(DataLoader): | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Example:: | Example:: | ||||
collator.set_ignore('field1', 'field2') | collator.set_ignore('field1', 'field2') | ||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | ||||
@@ -187,7 +190,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
dl_bundle = {} | dl_bundle = {} | ||||
for name, ds in ds_or_db.iter_datasets(): | for name, ds in ds_or_db.iter_datasets(): | ||||
if 'train' in name: | if 'train' in name: | ||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | |||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=train_batch_size, | batch_sampler=batch_sampler, batch_size=train_batch_size, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
@@ -197,7 +200,7 @@ def prepare_paddle_dataloader(ds_or_db, feed_list=None, places=None, | |||||
timeout=timeout, worker_init_fn=worker_init_fn, | timeout=timeout, worker_init_fn=worker_init_fn, | ||||
persistent_workers=persistent_workers) | persistent_workers=persistent_workers) | ||||
else: | else: | ||||
dl_bundle[name] = PaddleDataLoader(ds_or_db, feed_list=feed_list, places=places, | |||||
dl_bundle[name] = PaddleDataLoader(ds, feed_list=feed_list, places=places, | |||||
return_list=return_list, | return_list=return_list, | ||||
batch_sampler=batch_sampler, batch_size=non_train_batch_size, | batch_sampler=batch_sampler, batch_size=non_train_batch_size, | ||||
shuffle=shuffle, | shuffle=shuffle, | ||||
@@ -153,6 +153,7 @@ class TorchDataLoader(DataLoader): | |||||
""" | """ | ||||
如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 | ||||
Example:: | Example:: | ||||
collator.set_ignore('field1', 'field2') | collator.set_ignore('field1', 'field2') | ||||
:param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 | ||||
@@ -706,8 +706,8 @@ class DataSet: | |||||
def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | def concat(self, dataset: 'DataSet', inplace:bool=True, field_mapping:Dict=None) -> 'DataSet': | ||||
""" | """ | ||||
将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target | 将当前dataset与输入的dataset结合成一个更大的dataset,需要保证两个dataset都包含了相同的field。结合后的dataset的input,target | ||||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 | |||||
当前dataset含有field,则会报错。 | |||||
以及collate_fn以当前dataset为准。当dataset中包含的field多于当前的dataset,则多余的field会被忽略;若dataset中未包含所有 | |||||
当前dataset含有field,则会报错。 | |||||
:param DataSet, dataset: 需要和当前dataset concat的dataset | :param DataSet, dataset: 需要和当前dataset concat的dataset | ||||
:param bool, inplace: 是否直接将dataset组合到当前dataset中 | :param bool, inplace: 是否直接将dataset组合到当前dataset中 | ||||
@@ -87,8 +87,8 @@ class Driver(ABC): | |||||
:param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; | ||||
:param fn: 调用该函数进行一次计算。 | :param fn: 调用该函数进行一次计算。 | ||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call | |||||
函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函 | |||||
数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; | |||||
:return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); | ||||
""" | """ | ||||
raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") | raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") | ||||
@@ -106,9 +106,10 @@ class Driver(ABC): | |||||
`evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; | ||||
这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: | ||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
函数,然后给出 warning; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` | |||||
函数,然后给出 warning; | |||||
2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; | |||||
注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 | ||||
forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 | ||||
可能需要额外标记最初传入 driver 的模型是哪种形式的; | 可能需要额外标记最初传入 driver 的模型是哪种形式的; | ||||
@@ -376,7 +377,7 @@ class Driver(ABC): | |||||
的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; | ||||
因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 | 因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 | ||||
pid 的信息; | |||||
pid 的信息; | |||||
""" | """ | ||||
# 单卡 driver 不需要这个函数; | # 单卡 driver 不需要这个函数; | ||||
if self._pids is not None: | if self._pids is not None: | ||||
@@ -33,11 +33,12 @@ class JittorDriver(Driver): | |||||
f"`jittor.Module` type.") | f"`jittor.Module` type.") | ||||
super(JittorDriver, self).__init__(model) | super(JittorDriver, self).__init__(model) | ||||
self.model = model | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | ||||
self.grad_scaler = _grad_scaler() | self.grad_scaler = _grad_scaler() | ||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
@staticmethod | @staticmethod | ||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
# 在fastnlp中实现了JittorDataLoader | # 在fastnlp中实现了JittorDataLoader | ||||
@@ -152,4 +153,4 @@ class JittorDriver(Driver): | |||||
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | ||||
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | ||||
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | ||||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) |
@@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver): | |||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | ||||
return fn, None | return fn, None | ||||
elif fn in {"train_step", "evaluate_step"}: | elif fn in {"train_step", "evaluate_step"}: | ||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||||
return self.model, self.model.forward | |||||
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') | |||||
return self.model, self.model.execute | |||||
else: | else: | ||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | ||||
@@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver): | |||||
return dataloader | return dataloader | ||||
else: | else: | ||||
return dataloader | return dataloader | ||||
def setup(self): | |||||
""" | |||||
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 | |||||
""" | |||||
pass |
@@ -172,6 +172,7 @@ def fastnlp_paddle_all_gather(obj: Any, device=None, group=None) ->List: | |||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | ||||
example:: | example:: | ||||
obj = { | obj = { | ||||
'a': [1, 1], | 'a': [1, 1], | ||||
'b': [[1, 2], [1, 2]], | 'b': [[1, 2], [1, 2]], | ||||
@@ -534,7 +534,7 @@ class TorchDDPDriver(TorchDriver): | |||||
def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | def broadcast_object(self, obj, src:int=0, group=None, **kwargs): | ||||
""" | """ | ||||
从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | 从 src 端将 obj 对象(可能是 tensor ,可能是 object )发送到 dst 处。如果是非 tensor 的对象会尝试使用 pickle 进行打包进行 | ||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
传输,然后再 dst 处再加载回来。仅在分布式的 driver 中有实际意义。 | |||||
:param obj: obj,可能是 Tensor 或 嵌套类型的数据 | :param obj: obj,可能是 Tensor 或 嵌套类型的数据 | ||||
:param int src: source 的 global rank 。 | :param int src: source 的 global rank 。 | ||||
@@ -551,9 +551,10 @@ class TorchDDPDriver(TorchDriver): | |||||
def all_gather(self, obj, group) -> List: | def all_gather(self, obj, group) -> List: | ||||
""" | """ | ||||
将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | 将 obj 互相传送到其它所有的 rank 上,其中 obj 可能是 Tensor,也可能是嵌套结构的 object 。如果不是基础类型的数据,尝试通过 | ||||
pickle 进行序列化,接收到之后再反序列化。 | |||||
pickle 进行序列化,接收到之后再反序列化。 | |||||
example:: | |||||
example: | |||||
obj = { | obj = { | ||||
'a': [1, 1], | 'a': [1, 1], | ||||
'b': [[1, 2], [1, 2]], | 'b': [[1, 2], [1, 2]], | ||||
@@ -175,7 +175,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) - | |||||
""" | """ | ||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | ||||
example: | |||||
example:: | |||||
obj = { | obj = { | ||||
'a': [1, 1], | 'a': [1, 1], | ||||
'b': [[1, 2], [1, 2]], | 'b': [[1, 2], [1, 2]], | ||||
@@ -175,16 +175,18 @@ def _build_fp16_env(dummy=False): | |||||
def replace_sampler(dataloader: "DataLoader", sampler): | def replace_sampler(dataloader: "DataLoader", sampler): | ||||
""" | """ | ||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于): | |||||
替换 sampler (初始化一个新的 dataloader 的逻辑在于): | |||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | |||||
`inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader | |||||
的类,而不是直接的 DataLoader; | |||||
用户可能继承了 dataloader,定制了自己的 dataloader 类,这也是我们为什么先 `inspect.signature(dataloader)` 而不是直接 | |||||
`inspect.signature(DataLoader)` 的原因,因此同时注意到我们在外层重新初始化一个 dataloader 时也是使用的用户传进来的 dataloader | |||||
的类,而不是直接的 DataLoader; | |||||
如果需要定制自己的 dataloader,保证以下两点: | |||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; | |||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 | |||||
来获取实际的参数的值; | |||||
如果需要定制自己的 dataloader,保证以下两点: | |||||
1. 在 __init__ 方法中加入 **kwargs,这是为了方便我们将 sampler 插入到具体的 DataLoader 的构造中; | |||||
2. 在 __init__ 方法中出现的参数,请务必挂为同样名字的实例属性,例如 self.one_arg_name = one_arg_name,这是因为我们只能通过属性 | |||||
来获取实际的参数的值; | |||||
""" | """ | ||||
# 拿到实例属性; | # 拿到实例属性; | ||||
@@ -1,18 +1,20 @@ | |||||
r""" | r""" | ||||
Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, | Logger 是fastNLP中记录日志的模块,logger封装了logging模块的Logger, | ||||
具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API | 具体使用方式与直接使用logging.Logger相同,同时也新增一些简单好用的API | ||||
使用方式: | |||||
from fastNLP import _logger | |||||
# | |||||
# _logger 可以和 logging.Logger 一样使用 | |||||
_logger.info('your msg') | |||||
_logger.error('your msg') | |||||
# _logger 新增的API | |||||
# 将日志输出到文件,以及输出的日志等级 | |||||
_logger.add_file('/path/to/log', level='INFO') | |||||
# 定义在命令行中的显示格式和日志等级 | |||||
_logger.set_stdout('tqdm', level='WARN') | |||||
使用方式:: | |||||
from fastNLP import _logger | |||||
# | |||||
# _logger 可以和 logging.Logger 一样使用 | |||||
_logger.info('your msg') | |||||
_logger.error('your msg') | |||||
# _logger 新增的API | |||||
# 将日志输出到文件,以及输出的日志等级 | |||||
_logger.add_file('/path/to/log', level='INFO') | |||||
# 定义在命令行中的显示格式和日志等级 | |||||
_logger.set_stdout('tqdm', level='WARN') | |||||
""" | """ | ||||
@@ -10,12 +10,13 @@ def print(*args, sep=' ', end='\n', file=None, flush=False): | |||||
用来重定向 print 函数至 logger.info 的函数。 | 用来重定向 print 函数至 logger.info 的函数。 | ||||
Example:: | Example:: | ||||
from fastNLP import print | from fastNLP import print | ||||
print("This is a test") # 等价于调用了 logger.info("This is a test") | print("This is a test") # 等价于调用了 logger.info("This is a test") | ||||
:param args: 需要打印的内容 | :param args: 需要打印的内容 | ||||
:param sep: 存在多个输入时,使用的间隔。 | :param sep: 存在多个输入时,使用的间隔。 | ||||
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 \n 。 | |||||
:param end: 该参数在当前设置无意义,因为结尾一定会被加入 '\\\\n' 。 | |||||
:param file: 该参数无意义。 | :param file: 该参数无意义。 | ||||
:param flush: 该参数无意义。 | :param flush: 该参数无意义。 | ||||
:return: | :return: | ||||
@@ -38,7 +38,7 @@ class Metric: | |||||
def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: | def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: | ||||
""" | """ | ||||
注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 | ||||
tensor 直接进行加减乘除计算即可。 | |||||
tensor 直接进行加减乘除计算即可。 | |||||
注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 | 注意:如果想使得该 metric 可自动扩展到多卡的情况,请一定申明 aggregate_method 。 | ||||
:param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | :param name: 当前 element 的名字,注册后,在 Metric 中可以通过 self.{name} 访问该变量。 | ||||
@@ -48,7 +48,7 @@ class Metric: | |||||
Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | Torch.tensor ; 如果backend 为 paddle 则该对象为 paddle.tensor ;如果 backend 为 jittor , 则该对象为 jittor.Var 。 | ||||
一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | 一般情况下直接默认为 auto 就行了,fastNLP 会根据实际调用 Metric.update() 函数时传入的参数进行合理的初始化,例如当传入 | ||||
的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 | 的参数中只包含 torch.Tensor 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 torch ;只包含 | ||||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 | |||||
jittor.Var 则认为 backend 这一种 tensor 时(可以有其它非 tensor 类型的输入)则认为 backend 为 jittor 。如果没有检测 | |||||
到任何一种 tensor ,就默认使用 float 类型作为 element 。 | 到任何一种 tensor ,就默认使用 float 类型作为 element 。 | ||||
:return: 注册的 Element 对象 | :return: 注册的 Element 对象 | ||||
""" | """ | ||||
@@ -496,7 +496,7 @@ class PollingSampler(MixSampler): | |||||
:param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | :param sampler: 实例化好的sampler,每个dataset对应一个sampler对象 | ||||
:param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | :param drop_last: 是否去掉最后一个batch的数据,其长度小于batch_size | ||||
:param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时, | :param ds_ratio: 当ds_ratio=None时候, 轮流采样dataset列表直至所有的数据集采样完;当ds_ratio='truncate_to_least'时, | ||||
以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样 | |||||
以dataset列表最短的ds为基准,长的数据集会被截断;当ds_ratio='pad_to_most'时,以dataset列表最长ds为基准,短的数据集会被重采样 | |||||
""" | """ | ||||
super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size, | super(PollingSampler, self).__init__(dataset=dataset, batch_size=batch_size, | ||||
sampler=sampler, ds_ratio=ds_ratio, | sampler=sampler, ds_ratio=ds_ratio, | ||||
@@ -35,7 +35,9 @@ class NumConsumedSamplesArray: | |||||
def __init__(self, buffer_size=2000, num_consumed_samples=0): | def __init__(self, buffer_size=2000, num_consumed_samples=0): | ||||
""" | """ | ||||
保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | 保留 buffer_size 个 num_consumed_samples 数据,可以索引得到某个 index 下的 num_consumed_samples 多少 | ||||
Example:: | Example:: | ||||
array = NumConsumedSamplesArray(buffer_size=3) | array = NumConsumedSamplesArray(buffer_size=3) | ||||
for i in range(10): | for i in range(10): | ||||
array.push(i) | array.push(i) | ||||
@@ -222,7 +222,7 @@ def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _chec | |||||
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。 | 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。 | ||||
如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose, | 如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose, | ||||
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称:: | |||||
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称。 | |||||
:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | ||||
函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到 | 函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到 | ||||
@@ -257,12 +257,13 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, | |||||
对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; | 对于 `output_mapping`,该函数会在 `Trainer.train_step` 以及 `Evaluator.train_step` 中得到结果后立刻被调用; | ||||
转换的逻辑按优先级依次为: | 转换的逻辑按优先级依次为: | ||||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; | |||||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; | |||||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; | |||||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; | |||||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 | |||||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 | |||||
1. 如果 `mapping` 是一个函数,那么会直接返回 `mapping(data)`; | |||||
2. 如果 `mapping` 是一个 `Dict`,那么 `data` 的类型只能为以下三种: [`Dict`, `dataclass`, `Sequence`]; | |||||
如果 `data` 是 `Dict`,那么该函数会将 `data` 的 key 替换为 mapping[key]; | |||||
如果 `data` 是 `dataclass`,那么该函数会先使用 `dataclasses.asdict` 函数将其转换为 `Dict`,然后进行转换; | |||||
如果 `data` 是 `Sequence`,那么该函数会先将其转换成一个对应的 `Dict`:{"_0": list[0], "_1": list[1], ...},然后使用 | |||||
mapping对这个 `Dict` 进行转换,如果没有匹配上mapping中的key则保持"_number"这个形式。 | |||||
:param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 | :param mapping: 用于转换的字典或者函数;mapping是函数时,返回值必须为字典类型。 | ||||
:param data: 需要被转换的对象; | :param data: 需要被转换的对象; | ||||
@@ -440,12 +441,16 @@ def _is_iterable(value): | |||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | def pretty_table_printer(dataset_or_ins) -> PrettyTable: | ||||
r""" | r""" | ||||
:param dataset_or_ins: 传入一个dataSet或者instance | :param dataset_or_ins: 传入一个dataSet或者instance | ||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||||
+-----------+-----------+-----------------+ | |||||
| field_1 | field_2 | field_3 | | |||||
+-----------+-----------+-----------------+ | |||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||||
+-----------+-----------+-----------------+ | |||||
.. code-block:: | |||||
ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | |||||
+-----------+-----------+-----------------+ | |||||
| field_1 | field_2 | field_3 | | |||||
+-----------+-----------+-----------------+ | |||||
| [1, 1, 1] | [2, 2, 2] | ['a', 'b', 'c'] | | |||||
+-----------+-----------+-----------------+ | |||||
:return: 以 pretty table的形式返回根据terminal大小进行自动截断 | :return: 以 pretty table的形式返回根据terminal大小进行自动截断 | ||||
""" | """ | ||||
x = PrettyTable() | x = PrettyTable() | ||||
@@ -47,7 +47,7 @@ def rank_zero_call(fn: Callable): | |||||
rank_zero_call(add)(1, 2) | rank_zero_call(add)(1, 2) | ||||
同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何 | 同时,该函数还会设置 FASTNLP_NO_SYNC 为 2,在这个环境下,所有的 fastNLP 内置的 barrier 接口,gather/broadcast 操作都没有任何 | ||||
意义。 | |||||
意义。 | |||||
:param fn: 需要包裹的可执行的函数。 | :param fn: 需要包裹的可执行的函数。 | ||||
:return: | :return: | ||||
@@ -65,7 +65,7 @@ def rank_zero_call(fn: Callable): | |||||
def fastnlp_no_sync_context(level=2): | def fastnlp_no_sync_context(level=2): | ||||
""" | """ | ||||
用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; | 用于让 fastNLP 的 barrier 以及 gather/broadcast等操作等同于只有1卡的多卡程序。如果为 1 表示 fastNLP 里的barrier 操作失效; | ||||
如果为 2 表示 barrier 与 gather/broadcast 都失效。 | |||||
如果为 2 表示 barrier 与 gather/broadcast 都失效。 | |||||
:param int level: 可选 [0, 1, 2] | :param int level: 可选 [0, 1, 2] | ||||
:return: | :return: | ||||
@@ -84,9 +84,10 @@ def all_rank_call_context(): | |||||
""" | """ | ||||
在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 | 在多卡模式下,该环境内,会暂时地将 FASTNLP_GLOBAL_RANK 设置为 "0",使得 rank_zero_call 函数失效,使得每个进程都会运行该函数。 | ||||
# 使用方式 | |||||
with all_rank_call_context(): | |||||
do_something # all rank will do | |||||
使用方式:: | |||||
with all_rank_call_context(): | |||||
do_something # all rank will do | |||||
:param fn: | :param fn: | ||||
:return: | :return: | ||||
@@ -233,8 +233,8 @@ class DataBundle: | |||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | ||||
:param progress_desc 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||||
:param show_progress_bar 是否显示tqdm进度条 | |||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||||
:param show_progress_bar: 是否显示tqdm进度条 | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
@@ -0,0 +1,133 @@ | |||||
import pytest | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.controllers.trainer import Evaluator | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt | |||||
from jittor import nn, Module | |||||
from jittor.dataset import Dataset | |||||
class JittorNormalModel_Classification(Module): | |||||
""" | |||||
基础的 Jittor 分类模型 | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(JittorNormalModel_Classification, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def execute(self, x): | |||||
# It's similar to forward function in Pytorch | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return x | |||||
def train_step(self, x, y): | |||||
x = self(x) | |||||
return {"loss": self.loss_fn(x, y)} | |||||
def evaluate_step(self, x, y): | |||||
x = self(x) | |||||
return {"pred": x, "target": y.reshape((-1,))} | |||||
class JittorRandomMaxDataset(Dataset): | |||||
def __init__(self, num_samples, num_features): | |||||
super(JittorRandomMaxDataset, self).__init__() | |||||
self.x = jt.randn((num_samples, num_features)) | |||||
self.y = self.x.argmax(dim=1)[0] | |||||
def __len__(self): | |||||
return len(self.y) | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} | |||||
class TrainJittorConfig: | |||||
num_labels: int = 5 | |||||
feature_dimension: int = 5 | |||||
lr = 1e-1 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@pytest.mark.parametrize("driver,device", [("jittor", None)]) | |||||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | |||||
def test_trainer_jittor( | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs=3, | |||||
): | |||||
model = JittorNormalModel_Classification( | |||||
num_labels=TrainJittorConfig.num_labels, | |||||
feature_dimension=TrainJittorConfig.feature_dimension | |||||
) | |||||
optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr) | |||||
train_dataloader = JittorDataLoader( | |||||
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), | |||||
batch_size=TrainJittorConfig.batch_size, | |||||
shuffle=True, | |||||
# num_workers=4, | |||||
) | |||||
val_dataloader = JittorDataLoader( | |||||
dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension), | |||||
batch_size=TrainJittorConfig.batch_size, | |||||
shuffle=True, | |||||
# num_workers=4, | |||||
) | |||||
test_dataloader = JittorDataLoader( | |||||
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), | |||||
batch_size=TrainJittorConfig.batch_size, | |||||
shuffle=True, | |||||
# num_workers=4, | |||||
) | |||||
metrics = {"acc": Accuracy()} | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizer, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders=val_dataloader, | |||||
validate_every=-1, | |||||
evaluate_fn="evaluate_step", | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
# progress_bar="rich" | |||||
) | |||||
trainer.run() | |||||
evaluator = Evaluator( | |||||
model=model, | |||||
driver=driver, | |||||
dataloaders=test_dataloader, | |||||
evaluate_fn="evaluate_step", | |||||
metrics=metrics, | |||||
) | |||||
metric_results = evaluator.run() | |||||
assert metric_results["acc#acc"] > 0.80 | |||||
if __name__ == "__main__": | |||||
# test_trainer_jittor("jittor", None, [RichCallback(100)]) | |||||
pytest.main(['test_trainer_jittor.py']) # 只运行此模块 |
@@ -1,7 +1,6 @@ | |||||
import pytest | import pytest | ||||
import numpy as np | import numpy as np | ||||
from datasets import Dataset as HfDataset | from datasets import Dataset as HfDataset | ||||
from datasets import load_dataset | |||||
from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | from fastNLP.core.dataloaders.jittor_dataloader import JittorDataLoader | ||||
from fastNLP.core.dataset import DataSet as Fdataset | from fastNLP.core.dataset import DataSet as Fdataset | ||||
@@ -23,16 +22,12 @@ class MyDataset(Dataset): | |||||
def __getitem__(self, item): | def __getitem__(self, item): | ||||
return self.data[item] | return self.data[item] | ||||
# return {'x': [[1, 0], [2, 0, 1]]} | |||||
# return np.random.randn(3, 10) | |||||
# def __len__(self): | |||||
# return self.dataset_len | |||||
@pytest.mark.jittor | @pytest.mark.jittor | ||||
class TestJittor: | class TestJittor: | ||||
def test_v1(self): | |||||
def test_jittor_dataset(self): | |||||
""" | """ | ||||
测试jittor类型的dataset使用fdl | 测试jittor类型的dataset使用fdl | ||||
@@ -40,13 +35,13 @@ class TestJittor: | |||||
""" | """ | ||||
dataset = MyDataset() | dataset = MyDataset() | ||||
jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | jtl = JittorDataLoader(dataset, keep_numpy_array=True, batch_size=4) | ||||
# jtl.set_pad_val('x', 'y') | |||||
# jtl.set_input('x') | |||||
for batch in jtl: | for batch in jtl: | ||||
print(batch) | |||||
print(jtl.get_batch_indices()) | |||||
assert batch.size() == [4, 3, 4] | |||||
jtl1 = JittorDataLoader(dataset, keep_numpy_array=False, batch_size=4, num_workers=2) | |||||
for batch in jtl1: | |||||
assert batch.size() == [4, 3, 4] | |||||
def test_v2(self): | |||||
def test_fastnlp_Dataset(self): | |||||
""" | """ | ||||
测试fastnlp的dataset | 测试fastnlp的dataset | ||||
@@ -56,26 +51,27 @@ class TestJittor: | |||||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True) | ||||
jtl.set_pad("x", -1) | jtl.set_pad("x", -1) | ||||
jtl.set_ignore("y") | jtl.set_ignore("y") | ||||
# jtl.set_pad_val('x', val=-1) | |||||
# jtl.set_input('x', 'y') | |||||
for batch in jtl: | for batch in jtl: | ||||
assert batch['x'].size() == (16, 4) | assert batch['x'].size() == (16, 4) | ||||
jtl = JittorDataLoader(dataset, batch_size=16, drop_last=True, num_workers=2) | |||||
def test_v3(self): | |||||
def test_huggingface_datasets(self): | |||||
dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | dataset = HfDataset.from_dict({'x': [[1, 2], [0], [2, 3, 4, 5]] * 100, 'y': [0, 1, 2] * 100}) | ||||
jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | jtl = JittorDataLoader(dataset, batch_size=4, drop_last=True) | ||||
# jtl.set_input('x', 'y') | |||||
for batch in jtl: | for batch in jtl: | ||||
print(batch) | |||||
assert batch['x'].size() == [4, 4] | |||||
assert len(batch['y']) == 4 | |||||
def test_v4(self): | |||||
def test_num_workers(self): | |||||
dataset = MyDataset() | dataset = MyDataset() | ||||
dl = JittorDataLoader(dataset, batch_size=4, num_workers=2) | dl = JittorDataLoader(dataset, batch_size=4, num_workers=2) | ||||
print(len(dl)) | |||||
for idx, batch in enumerate(dl): | for idx, batch in enumerate(dl): | ||||
print(batch.shape, idx) | |||||
assert batch.shape == [4, 3, 4] | |||||
for idx, batch in enumerate(dl): | for idx, batch in enumerate(dl): | ||||
print(batch.shape, idx) | |||||
assert batch.shape == [4, 3, 4] | |||||
def test_v5(self): | def test_v5(self): | ||||
dataset = MyDataset() | dataset = MyDataset() | ||||
@@ -6,19 +6,19 @@ from fastNLP.core.dataset import DataSet | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | from fastNLP.envs.imports import _NEED_IMPORT_PADDLE | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
from paddle.io import Dataset, DataLoader | |||||
from paddle.io import Dataset | |||||
import paddle | import paddle | ||||
else: | else: | ||||
from fastNLP.core.utils.dummy_class import DummyClass as Dataset | from fastNLP.core.utils.dummy_class import DummyClass as Dataset | ||||
class RandomDataset(Dataset): | class RandomDataset(Dataset): | ||||
def __getitem__(self, idx): | def __getitem__(self, idx): | ||||
image = np.random.random((10, 5)).astype('float32') | image = np.random.random((10, 5)).astype('float32') | ||||
return {'image': image, 'label': [[0, 1], [1, 2, 3, 4]]} | |||||
return {'image': paddle.to_tensor(image), 'label': [[0, 1], [1, 2, 3, 4]]} | |||||
def __len__(self): | def __len__(self): | ||||
return 10 | return 10 | ||||
@@ -33,16 +33,22 @@ class TestPaddle: | |||||
fdl = PaddleDataLoader(ds, batch_size=2) | fdl = PaddleDataLoader(ds, batch_size=2) | ||||
# fdl = DataLoader(ds, batch_size=2, shuffle=True) | # fdl = DataLoader(ds, batch_size=2, shuffle=True) | ||||
for batch in fdl: | for batch in fdl: | ||||
print(batch) | |||||
assert batch['image'].shape == [2, 10, 5] | |||||
assert batch['label'].shape == [2, 2, 4] | |||||
# print(fdl.get_batch_indices()) | # print(fdl.get_batch_indices()) | ||||
def test_fdl_batch_indices(self): | |||||
def test_fdl_fastnlp_dataset(self): | |||||
ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) | ||||
fdl = PaddleDataLoader(ds, batch_size=4, shuffle=True, drop_last=True) | |||||
fdl = PaddleDataLoader(ds, batch_size=3, shuffle=False, drop_last=True) | |||||
fdl.set_ignore('y') | |||||
fdl.set_pad('x', -1) | |||||
for batch in fdl: | for batch in fdl: | ||||
assert len(fdl.get_batch_indices()) == 4 | |||||
print(batch) | |||||
print(fdl.get_batch_indices()) | |||||
assert len(fdl.get_batch_indices()) == 3 | |||||
assert 'y' not in batch | |||||
assert batch['x'].shape == [3, 3] | |||||
with pytest.raises(ValueError): | |||||
PaddleDataLoader(ds, batch_size=3, collate_fn=None) | |||||
def test_set_inputs_and_set_pad_val(self): | def test_set_inputs_and_set_pad_val(self): | ||||
logger.setLevel("DEBUG") | logger.setLevel("DEBUG") | ||||
@@ -50,11 +56,8 @@ class TestPaddle: | |||||
fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | fdl = PaddleDataLoader(ds, batch_size=2, drop_last=True) | ||||
fdl.set_pad('label', -1) | fdl.set_pad('label', -1) | ||||
for batch in fdl: | for batch in fdl: | ||||
print(batch['image']) | |||||
assert batch['image'].shape == [2, 10, 5] | assert batch['image'].shape == [2, 10, 5] | ||||
print(batch) | |||||
fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | fdl1 = PaddleDataLoader(ds, batch_size=4, drop_last=True) | ||||
fdl1.set_ignore('label') | fdl1.set_ignore('label') | ||||
for batch in fdl1: | for batch in fdl1: | ||||
assert batch['image'].shape == [4, 10, 5] | assert batch['image'].shape == [4, 10, 5] | ||||
print(batch) |
@@ -4,6 +4,7 @@ from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_t | |||||
from fastNLP.core.dataset import DataSet | from fastNLP.core.dataset import DataSet | ||||
from fastNLP.io.data_bundle import DataBundle | from fastNLP.io.data_bundle import DataBundle | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core import Trainer | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||