Browse Source

1.montior允许传入callable的对象进行选择; 2.解决Sampler中存在的循环引用问题

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
16a467393c
30 changed files with 505 additions and 222 deletions
  1. +31
    -9
      fastNLP/core/callbacks/callback.py
  2. +7
    -4
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +6
    -5
      fastNLP/core/callbacks/early_stop_callback.py
  4. +6
    -5
      fastNLP/core/callbacks/load_best_model_callback.py
  5. +3
    -1
      fastNLP/core/callbacks/progress_callback.py
  6. +14
    -4
      fastNLP/core/callbacks/utils.py
  7. +15
    -8
      fastNLP/core/collators/collator.py
  8. +10
    -9
      fastNLP/core/controllers/evaluator.py
  9. +26
    -38
      fastNLP/core/controllers/trainer.py
  10. +7
    -3
      fastNLP/core/controllers/utils/utils.py
  11. +8
    -4
      fastNLP/core/dataset/dataset.py
  12. +2
    -2
      fastNLP/core/drivers/jittor_driver/jittor_driver.py
  13. +1
    -1
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  14. +5
    -5
      fastNLP/core/drivers/torch_driver/dist_utils.py
  15. +1
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  16. +3
    -3
      fastNLP/core/log/logger.py
  17. +2
    -3
      fastNLP/core/metrics/accuracy.py
  18. +3
    -3
      fastNLP/core/samplers/__init__.py
  19. +33
    -0
      fastNLP/core/samplers/conversion_utils.py
  20. +6
    -13
      fastNLP/core/samplers/reproducible_batch_sampler.py
  21. +7
    -7
      fastNLP/core/samplers/reproducible_sampler.py
  22. +1
    -33
      fastNLP/core/samplers/utils.py
  23. +1
    -2
      fastNLP/core/utils/__init__.py
  24. +111
    -32
      fastNLP/core/utils/utils.py
  25. +5
    -5
      fastNLP/io/data_bundle.py
  26. +1
    -1
      fastNLP/io/pipe/classification.py
  27. +1
    -1
      fastNLP/io/pipe/construct_graph.py
  28. +1
    -1
      fastNLP/io/pipe/pipe.py
  29. +187
    -0
      tests/core/utils/test_utils.py
  30. +1
    -19
      tests/helpers/utils.py

+ 31
- 9
fastNLP/core/callbacks/callback.py View File

@@ -10,6 +10,7 @@ from .utils import _get_monitor_value
from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.core.utils import apply_to_collection from fastNLP.core.utils import apply_to_collection
from fastNLP.core.utils.utils import _check_valid_parameters_number




class Callback: class Callback:
@@ -299,7 +300,11 @@ class HasMonitorCallback(Callback):
self.must_have_moinitor = must_have_monitor self.must_have_moinitor = must_have_monitor


def set_monitor(self, monitor, larger_better): def set_monitor(self, monitor, larger_better):
self.monitor = str(monitor) if monitor is not None else None
if callable(monitor): # 检查是否能够接受一个参数
_check_valid_parameters_number(monitor, expected_params=['results'], fn_name='monitor')
self.monitor = monitor
else:
self.monitor = str(monitor) if monitor is not None else None
self.larger_better = bool(larger_better) self.larger_better = bool(larger_better)
if larger_better: if larger_better:
self.monitor_value = float('-inf') self.monitor_value = float('-inf')
@@ -322,24 +327,33 @@ class HasMonitorCallback(Callback):
raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. " raise RuntimeError(f"No `monitor` is set for {self.__class__.__name__}. "
f"You can set it in the initialization or through Trainer.") f"You can set it in the initialization or through Trainer.")


def get_monitor_value(self, results:Dict)->float:
def get_monitor_value(self, results:Dict)->Union[float, None]:
""" """
获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。 获取 monitor 的值,如果 monitor 没有直接找到,会尝试使用匹配的方式寻找,并把匹配到的设置到 self._real_monitor 属性上。


:param results: :param results:
:return:
:return: 如果为 None ,表明此次没有找到合适的monitor
""" """
if len(results)==0: if len(results)==0:
return 0
return None
# 保证所有的 tensor 都被转换为了 python 特定的类型 # 保证所有的 tensor 都被转换为了 python 特定的类型
results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item()) results = apply_to_collection(results, dtype=CanItemDataType, function=lambda x: x.item())
use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor, use_monitor, monitor_value = _get_monitor_value(monitor=self.monitor,
real_monitor=self._real_monitor, real_monitor=self._real_monitor,
res=results) res=results)
if self._real_monitor != use_monitor: # 发生了替换需要打印
logger.warning(
f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
if monitor_value is None:
return monitor_value
# 第一次运行
if isinstance(self.monitor, str) and self._real_monitor == self.monitor and use_monitor != self.monitor:
logger.warning(f"We can not find `{self.monitor}` in the evaluation result (with keys as {list(results.keys())}), "
f"we use the `{use_monitor}` as the monitor for `{self.__class__.__name__}`.")
# 检测到此次和上次不同。
elif isinstance(self.monitor, str) and self._real_monitor != self.monitor and use_monitor != self._real_monitor:
logger.warning(f"Change of monitor detected for `{self.__class__.__name__}`. "
f"The expected monitor is:`{self.monitor}`, last used monitor is:"
f"`{self._real_monitor}` and current monitor is:`{use_monitor}`. Please consider using a "
f"customized monitor function when the evaluation results are varying between validation.")

self._real_monitor = use_monitor self._real_monitor = use_monitor
return monitor_value return monitor_value


@@ -347,10 +361,12 @@ class HasMonitorCallback(Callback):
""" """
检测 monitor_value 是否是更好的 检测 monitor_value 是否是更好的


:param monitor_value:
:param monitor_value: 待检查的 monitor_value 。如果为 None ,返回 False
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
:return: :return:
""" """
if monitor_value is None:
return False
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better: if keep_if_better and better:
self.monitor_value = monitor_value self.monitor_value = monitor_value
@@ -364,6 +380,12 @@ class HasMonitorCallback(Callback):
:param monitor_value2: :param monitor_value2:
:return: :return:
""" """
if monitor_value1 is None and monitor_value2 is None:
return True
if monitor_value1 is None:
return False
if monitor_value2 is None:
return True
better = False better = False
if (self.larger_better and monitor_value1 > monitor_value2) or \ if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2): (not self.larger_better and monitor_value1 < monitor_value2):


+ 7
- 4
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -10,8 +10,7 @@ from copy import deepcopy




import fastNLP import fastNLP
from .callback import Callback, HasMonitorCallback
from fastNLP.core.callbacks.utils import _get_monitor_value
from .callback import HasMonitorCallback
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
@@ -166,6 +165,8 @@ class CheckpointCallback(HasMonitorCallback):
""" """
if self.save_topk is not None: if self.save_topk is not None:
monitor_value = self.get_monitor_value(results=results) monitor_value = self.get_monitor_value(results=results)
if monitor_value is None:
return
folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \
f"-{self._real_monitor}_{monitor_value}" f"-{self._real_monitor}_{monitor_value}"


@@ -231,7 +232,8 @@ class ModelCheckpointCallback(CheckpointCallback):
若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。


:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。 :param save_every_n_epochs: 多少个 epoch 保存一次。
@@ -278,7 +280,8 @@ class TrainerCheckpointCallback(CheckpointCallback):
若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。


:param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。
的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),
返回一个 float 值作为 monitor 的结果。
:param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 时间戳文件夹中。如果为 None ,默认使用当前文件夹。
:param save_every_n_epochs: 多少个 epoch 保存一次。 :param save_every_n_epochs: 多少个 epoch 保存一次。


+ 6
- 5
fastNLP/core/callbacks/early_stop_callback.py View File

@@ -2,17 +2,18 @@ __all__ = [
'EarlyStopCallback' 'EarlyStopCallback'
] ]


from typing import Dict
from typing import Dict, Union, Callable


from .callback import HasMonitorCallback from .callback import HasMonitorCallback
from fastNLP.core.utils.exceptions import EarlyStopException from fastNLP.core.utils.exceptions import EarlyStopException




class EarlyStopCallback(HasMonitorCallback): class EarlyStopCallback(HasMonitorCallback):
def __init__(self, monitor:str=None, larger_better:bool=True, patience:int=10):
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10):
""" """


:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。 :param larger_better: monitor 的值是否是越大越好。
:param patience: 多少次 validate 不没有提升就停止。 :param patience: 多少次 validate 不没有提升就停止。
""" """
@@ -21,9 +22,9 @@ class EarlyStopCallback(HasMonitorCallback):
self.patience = patience self.patience = patience


def on_validate_end(self, trainer, results): def on_validate_end(self, trainer, results):
if len(results)==0:
return
monitor_value = self.get_monitor_value(results) monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
if self.is_better_monitor_value(monitor_value, keep_if_better=True): if self.is_better_monitor_value(monitor_value, keep_if_better=True):
self.wait = 0 self.wait = 0
else: else:


+ 6
- 5
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -3,7 +3,7 @@ __all__ = [
] ]


import os import os
from typing import Optional, Callable
from typing import Optional, Callable, Union
from .callback import HasMonitorCallback from .callback import HasMonitorCallback
from io import BytesIO from io import BytesIO
import shutil import shutil
@@ -14,14 +14,15 @@ from fastNLP.envs import all_rank_call




class LoadBestModelCallback(HasMonitorCallback): class LoadBestModelCallback(HasMonitorCallback):
def __init__(self, monitor:str=None, larger_better:bool = True, only_state_dict:bool = True,
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool = True, only_state_dict:bool = True,
save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None, save_folder:Optional[str] = None, model_save_fn:Optional[Callable] = None,
model_load_fn:Optional[Callable] = None, model_load_fn:Optional[Callable] = None,
delete_after_train:bool = True): delete_after_train:bool = True):
""" """
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。


:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。
:param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为
evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 该 metric 值是否是越大越好。 :param larger_better: 该 metric 值是否是越大越好。
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。
@@ -78,9 +79,9 @@ class LoadBestModelCallback(HasMonitorCallback):
self.get_monitor_value(sanity_check_res) self.get_monitor_value(sanity_check_res)


def on_validate_end(self, trainer, results): def on_validate_end(self, trainer, results):
if len(results)==0:
return
monitor_value = self.get_monitor_value(results) monitor_value = self.get_monitor_value(results)
if monitor_value is None:
return
if self.is_better_monitor_value(monitor_value, keep_if_better=True): if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.real_save_folder: if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,


+ 3
- 1
fastNLP/core/callbacks/progress_callback.py View File

@@ -45,6 +45,7 @@ class RichCallback(ProgressCallback):
:param print_every: 多少个 batch 更新一次显示。 :param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。
也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。 :param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印 :param format_json: 是否format json再打印
""" """
@@ -135,7 +136,8 @@ class RawTextCallback(ProgressCallback):


:param print_every: 多少个 batch 更新一次显示。 :param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果(
字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: 是否是monitor的结果越大越好。 :param larger_better: 是否是monitor的结果越大越好。
:param format_json: 是否format json再打印 :param format_json: 是否format json再打印
""" """


+ 14
- 4
fastNLP/core/callbacks/utils.py View File

@@ -1,9 +1,10 @@
from typing import Optional
from typing import Optional, Union
from fastNLP.core.log.logger import logger from fastNLP.core.log.logger import logger
from difflib import SequenceMatcher from difflib import SequenceMatcher
from fastNLP.core.utils.utils import _get_fun_msg




def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(str, float):
def _get_monitor_value(monitor: Union[callable, str], real_monitor: Optional[str], res: dict) ->(str, float):
""" """
从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行 从res中寻找 monitor 并返回。如果 monitor 没找到则尝试用 _real_monitor ,若 _real_monitor 为 None 则尝试使用 monitor 的值进行
匹配。 匹配。
@@ -11,10 +12,19 @@ def _get_monitor_value(monitor: str, real_monitor: Optional[str], res: dict) ->(
:param monitor: :param monitor:
:param real_monitor: :param real_monitor:
:param res: :param res:
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value
:return: 返回两个值(str, value),其中str就是最终要到的key,value就是这个key对应的value。如果value为None说明当前results中没有
找到对应的 monitor
""" """
if len(res)==0: if len(res)==0:
return monitor, 0
return monitor, None

if callable(monitor):
try:
monitor_value = monitor(res)
except BaseException as e:
logger.error(f"Exception happens when calling customized monitor function:{_get_fun_msg(monitor)}.")
raise e
return monitor, monitor_value


if monitor in res: if monitor in res:
return monitor, res[monitor] return monitor, res[monitor]


+ 15
- 8
fastNLP/core/collators/collator.py View File

@@ -5,7 +5,7 @@ __all__ = [




from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, Dict, List, Callable, Union
from typing import Any, Dict, List, Callable, Union, Tuple
from numbers import Number from numbers import Number
import warnings import warnings


@@ -35,7 +35,7 @@ class SetInputOrTargetException(Exception):
self.field_name = field_name # 标示当前 field 的名称 self.field_name = field_name # 标示当前 field 的名称




def _get_ele_type_and_dim(cell: Any, dim=0):
def _get_ele_type_and_dim(cell: Any, dim=0) -> Tuple[Any, int]:
r""" r"""
识别cell的类别与dimension的数量 识别cell的类别与dimension的数量


@@ -206,7 +206,7 @@ class AutoCollator(Collator):
def __init__(self, as_numpy: bool): def __init__(self, as_numpy: bool):
super(AutoCollator, self).__init__() super(AutoCollator, self).__init__()
self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0 self.pad_field_value = {} # field padding 自定义的 padding 值, 默认为0
self.need_inputs = [] # 需要的 field name
self.need_inputs = set() # 需要的 field name
self.field_dtypes = None # 每列数据单元的 dtype 类型 self.field_dtypes = None # 每列数据单元的 dtype 类型
self.field_dims = None # 每列数据单元维度 self.field_dims = None # 每列数据单元维度
self.as_numpy = as_numpy self.as_numpy = as_numpy
@@ -214,10 +214,17 @@ class AutoCollator(Collator):
def __call__(self, ins_lst: List[Dict]) -> dict: def __call__(self, ins_lst: List[Dict]) -> dict:
if len(self.need_inputs) == 0: if len(self.need_inputs) == 0:
raise ValueError({"set_inputs is None, you should use set_inputs method first!!"}) raise ValueError({"set_inputs is None, you should use set_inputs method first!!"})
# TODO 这里应该是先 check 有哪些需要 padding,然后check这些是否是可以pad的

# 第一种情况,设置了 set_input 的值 # 第一种情况,设置了 set_input 的值
# 第二种情况, 根据数据的类型的判断是否 padding # 第二种情况, 根据数据的类型的判断是否 padding
if self.field_dtypes is None and self.field_dims is None: if self.field_dtypes is None and self.field_dims is None:
self.field_dtypes, self.field_dims = _get_ds_type_dim(ins_lst[0])
field_dtypes, field_dims = {}, {}
for key, value in ins_lst[0].items():
if key in self.need_inputs and self.pad_field_value.get(key, 0) is not None:
field_dtypes[key], field_dims[key] = _get_ele_type_and_dim(value)
self.field_dtypes = field_dtypes
self.field_dims = field_dims


pack_ins_lst, pad_ins_lst = {field_name: [] pack_ins_lst, pad_ins_lst = {field_name: []
for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {} for field_name in ins_lst[0].keys() if field_name in self.need_inputs}, {}
@@ -233,13 +240,13 @@ class AutoCollator(Collator):


if len(self.pad_field_value.keys()) > 0: if len(self.pad_field_value.keys()) > 0:
# 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略 # 去掉不需要 pad 的列,如果 set_input 的列不存在则忽略
drop_field_names = []
non_pad_field_names = []
for k, v in self.pad_field_value.items(): for k, v in self.pad_field_value.items():
if v is None: if v is None:
drop_field_names.append(k)
non_pad_field_names.append(k)


# drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields)) # drop_field_names = list(set(list(ins_lst[0].keys())) - set(drop_fields))
for field_name in drop_field_names:
for field_name in non_pad_field_names:
field_array = pack_ins_lst.pop(field_name) field_array = pack_ins_lst.pop(field_name)
pad_ins_lst[field_name] = np.array(field_array) pad_ins_lst[field_name] = np.array(field_array)


@@ -269,7 +276,7 @@ class AutoCollator(Collator):


def set_input(self, *field_names): def set_input(self, *field_names):
for field_name in field_names: for field_name in field_names:
self.need_inputs.append(field_name)
self.need_inputs.add(field_name)




def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool): def pad_content(content, field_name: str, field_type, field_dim: int, pad_val: int, as_numpy: bool):


+ 10
- 9
fastNLP/core/controllers/evaluator.py View File

@@ -11,11 +11,12 @@ __all__ = [
from fastNLP.core.drivers import Driver from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.drivers.utils import choose_driver
from .loops import Loop, EvaluateBatchLoop from .loops import Loop, EvaluateBatchLoop
from fastNLP.core.utils import check_fn_not_empty_params, auto_param_call, dataclass_to_dict, \
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \
match_and_substitute_params, f_rich_progress match_and_substitute_params, f_rich_progress
from fastNLP.core.metrics import Metric from fastNLP.core.metrics import Metric
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader
from fastNLP.core.utils.utils import _check_valid_parameters_number
from fastNLP.core.log import logger from fastNLP.core.log import logger




@@ -38,11 +39,11 @@ class Evaluator:
driver: Union[str, Driver] = 'single', driver: Union[str, Driver] = 'single',
device: Optional[Union[int, List[int], str]] = None, device: Optional[Union[int, List[int], str]] = None,
batch_step_fn: Optional[callable] = None, batch_step_fn: Optional[callable] = None,
mode: str = "validate",
mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable
input_mapping: Optional[Union[Callable, Dict]] = None, input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False, model_wo_auto_param_call: bool = False,
fp16: Optional[bool] = False,
fp16: bool = False,
verbose: int = 1, verbose: int = 1,
**kwargs **kwargs
): ):
@@ -92,8 +93,8 @@ class Evaluator:
self.device = device self.device = device
self.verbose = verbose self.verbose = verbose


assert check_fn_not_empty_params(batch_step_fn, 2), "Parameter `batch_step_fn` should be a callable object with " \
"two parameters."
if batch_step_fn is not None:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.batch_step_fn = batch_step_fn self.batch_step_fn = batch_step_fn


self.mode = mode self.mode = mode
@@ -135,6 +136,7 @@ class Evaluator:
if self.progress_bar == 'auto': if self.progress_bar == 'auto':
self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw'


self.driver.check_evaluator_mode(self.mode)
self.driver.barrier() self.driver.barrier()


def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict:
@@ -154,8 +156,6 @@ class Evaluator:
assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type."
assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0."


self.driver.check_evaluator_mode(self.mode)

if self.mode == 'validate': if self.mode == 'validate':
assert self.driver.has_validate_dataloaders() assert self.driver.has_validate_dataloaders()
else: else:
@@ -367,9 +367,10 @@ class _MetricsWrapper:
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly"
f" return a dict from your model or use `output_mapping` to convert it into dict type.") f" return a dict from your model or use `output_mapping` to convert it into dict type.")
if isinstance(metric, Metric): if isinstance(metric, Metric):
auto_param_call(metric.update, outputs, *args)
# 这样在 auto_param_call 报错的时候才清晰。
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
elif _is_torchmetrics_metric(metric): elif _is_torchmetrics_metric(metric):
auto_param_call(metric.update, outputs, *args)
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__)
elif _is_allennlp_metric(metric): elif _is_allennlp_metric(metric):
auto_param_call(metric.__call__, outputs, *args) auto_param_call(metric.__call__, outputs, *args)
elif _is_paddle_metric(metric): elif _is_paddle_metric(metric):


+ 26
- 38
fastNLP/core/controllers/trainer.py View File

@@ -14,6 +14,7 @@ __all__ = [


from .loops import Loop, TrainBatchLoop from .loops import Loop, TrainBatchLoop
from .utils import State, TrainerState from .utils import State, TrainerState
from .utils.utils import check_validate_every
from .evaluator import Evaluator from .evaluator import Evaluator
from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _TruncatedDataLoader
from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList, Filter
@@ -21,7 +22,8 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper
from fastNLP.core.callbacks.callback_events import _SingleEventState from fastNLP.core.callbacks.callback_events import _SingleEventState
from fastNLP.core.drivers import Driver from fastNLP.core.drivers import Driver
from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.drivers.utils import choose_driver
from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext
from fastNLP.core.utils.utils import _check_valid_parameters_number
from fastNLP.envs import rank_zero_call from fastNLP.envs import rank_zero_call
from fastNLP.core.log import logger from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME from fastNLP.envs import FASTNLP_MODEL_FILENAME
@@ -42,7 +44,7 @@ class Trainer(TrainerEventTrigger):
validate_dataloaders=None, validate_dataloaders=None,
batch_step_fn: Optional[Callable] = None, batch_step_fn: Optional[Callable] = None,
validate_batch_step_fn: Optional[Callable] = None, validate_batch_step_fn: Optional[Callable] = None,
validate_mode: str = "validate",
validate_mode: Union[str, callable] = 'validate',
callbacks: Union[List[Callback], Callback, None] = None, callbacks: Union[List[Callback], Callback, None] = None,
metrics: Optional[dict] = None, metrics: Optional[dict] = None,
validate_every: Optional[Union[int, callable]] = -1, validate_every: Optional[Union[int, callable]] = -1,
@@ -51,7 +53,7 @@ class Trainer(TrainerEventTrigger):
model_wo_auto_param_call: bool = False, model_wo_auto_param_call: bool = False,
accumulation_steps: int = 1, accumulation_steps: int = 1,
fp16: bool = False, fp16: bool = False,
monitor: str = None,
monitor: Union[str, callable] = None,
larger_better: bool = True, larger_better: bool = True,
marker: Optional[str] = None, marker: Optional[str] = None,
**kwargs **kwargs
@@ -90,11 +92,8 @@ class Trainer(TrainerEventTrigger):
:param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类;
:param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()};
:param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; :param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次;
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的参数应该为 (filter, trainer) , 其中的 filter 对象
中自动记录了两个变量: filter.num_called 表示有多少次尝试 validate (实际等同于到当前时刻 batch 的总数), filter.num_executed
表示 validate 实际被执行了多少次;trainer 参数即为 Trainer 对象。 函数返回值应为 bool ,返回为 True 说明需要进行 validate 。
例如: (filter.num_called % trainer.num_batches_per_epoch == 0 and trainer.cur_epoch_idx > 10) 表示在第 10 个 epoch
之后,每个 epoch 结束进行一次 validate 。
为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并
返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。
:param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是
一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的 一个字典,并且 batch 也是一个 `Dict`,那么我们会把 batch 中同样在 input_mapping 中的 key 修改为 input_mapping 的对应 key 的
value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它
@@ -111,7 +110,7 @@ class Trainer(TrainerEventTrigger):
:param fp16: 是否开启混合精度训练;默认为 False; :param fp16: 是否开启混合精度训练;默认为 False;
:param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
:param larger_better: monitor 的值是否是越大越好。 :param larger_better: monitor 的值是否是越大越好。
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
:param kwargs: 一些其它的可能需要的参数; :param kwargs: 一些其它的可能需要的参数;
@@ -142,10 +141,9 @@ class Trainer(TrainerEventTrigger):
self.input_mapping = input_mapping self.input_mapping = input_mapping
self.output_mapping = output_mapping self.output_mapping = output_mapping


assert check_fn_not_empty_params(batch_step_fn, 2), "`batch_step_fn` should be a callable object with " \
"two parameters."
self.batch_step_fn = batch_step_fn self.batch_step_fn = batch_step_fn
if batch_step_fn is not None: if batch_step_fn is not None:
_check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn')
self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True) self.check_batch_step_fn = partial(self._check_callback_called_legality, check_mode=True)
else: else:
self.check_batch_step_fn = lambda *args, **kwargs: ... self.check_batch_step_fn = lambda *args, **kwargs: ...
@@ -221,18 +219,11 @@ class Trainer(TrainerEventTrigger):
if metrics is not None and validate_dataloaders is None: if metrics is not None and validate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.")


# 为了在 train 的循环中每次都检查是否需要进行 validate,这里我们提前在 trainer 初始化的时候就将对应时间点需要运行的函数确定下来;
# _epoch_validate 表示每隔几个 epoch validate 一次;_step_validate 表示每隔几个 step validate 一次;
self.evaluator = None self.evaluator = None
self.monitor = monitor self.monitor = monitor
self.larger_better = larger_better self.larger_better = larger_better
if metrics is not None and validate_dataloaders is not None: if metrics is not None and validate_dataloaders is not None:
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):
logger.info("Notice you are using a 'filter function' as the value of parameter `validate_every`, "
"and in this way, the kind of controlling frequency is depending on the 'step'.")

check_validate_every(validate_every)
self.evaluator = Evaluator( self.evaluator = Evaluator(
model=model, model=model,
dataloaders=validate_dataloaders, dataloaders=validate_dataloaders,
@@ -352,33 +343,32 @@ class Trainer(TrainerEventTrigger):
_validate_res: dict = validate_fn() _validate_res: dict = validate_fn()
trainer.on_validate_end(_validate_res) trainer.on_validate_end(_validate_res)


self.validate_fn = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))
self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl))


def step_validate(self): def step_validate(self):
if self.evaluator is not None:
should_run_validate = False
"""
在每个 batch 结束后调用,根据设置执行 evaluate 。


:return:
"""
if self.evaluator is not None:
if callable(self.validate_every): if callable(self.validate_every):
if self.validate_every(self): if self.validate_every(self):
should_run_validate = True
elif self.validate_every > 0:
if self.global_forward_batches % self.validate_every == 0:
should_run_validate = True

if should_run_validate:
self.validate_fn()
self.run_evaluate()
elif self.validate_every > 0 and self.global_forward_batches % self.validate_every == 0:
self.run_evaluate()


def epoch_validate(self): def epoch_validate(self):
if self.evaluator is not None:
should_run_validate = False
"""
在每个 epoch 结束后调用,根据设置执行 evaluate 。


:return:
"""
if self.evaluator is not None:
if isinstance(self.validate_every, int) and self.validate_every < 0: if isinstance(self.validate_every, int) and self.validate_every < 0:
validate_every = -self.validate_every validate_every = -self.validate_every
if self.cur_epoch_idx % validate_every == 0: if self.cur_epoch_idx % validate_every == 0:
should_run_validate = True

if should_run_validate:
self.validate_fn()
self.run_evaluate()


def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable):
r""" r"""
@@ -410,9 +400,7 @@ class Trainer(TrainerEventTrigger):
def wrapper(fn: Callable) -> Callable: def wrapper(fn: Callable) -> Callable:
cls._custom_callbacks[marker].append((event, fn)) cls._custom_callbacks[marker].append((event, fn))
callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:] callback_fn_args = get_fn_arg_names(getattr(Callback, event.value))[1:]
assert check_fn_not_empty_params(fn, len(callback_fn_args)), \
f"The callback function at `{event.value.lower()}`'s parameters should be {callback_fn_args}, but your "\
f"function {fn.__name__} only has these parameters: {get_fn_arg_names(fn)}."
_check_valid_parameters_number(fn, callback_fn_args)
return fn return fn


return wrapper return wrapper


+ 7
- 3
fastNLP/core/controllers/utils/utils.py View File

@@ -1,8 +1,9 @@
from collections.abc import Iterator
import inspect
from typing import Dict from typing import Dict


from fastNLP.core.callbacks import CallbackManager from fastNLP.core.callbacks import CallbackManager
from .state import TrainerState from .state import TrainerState
from fastNLP.core.utils.utils import _check_valid_parameters_number




class TrainerEventTrigger: class TrainerEventTrigger:
@@ -125,5 +126,8 @@ class _TruncatedDataLoader:
return getattr(self.dataloader, item) return getattr(self.dataloader, item)






def check_validate_every(validate_every):
if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0):
raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.")
if callable(validate_every):
_check_valid_parameters_number(validate_every, expected_params=['trainer'])

+ 8
- 4
fastNLP/core/dataset/dataset.py View File

@@ -178,10 +178,11 @@ class DataSet:
elif isinstance(idx, slice): elif isinstance(idx, slice):
if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)): if idx.start is not None and (idx.start >= len(self) or idx.start <= -len(self)):
raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}") raise RuntimeError(f"Start index {idx.start} out of range 0-{len(self) - 1}")
data_set = DataSet()
dataset = DataSet()
for field_name, field in self.field_arrays.items(): for field_name, field in self.field_arrays.items():
data_set.add_field(field_name=field_name, fields=field.content[idx])
return data_set
dataset.add_field(field_name=field_name, fields=field.content[idx])
dataset.collate_fns = deepcopy(self.collate_fns)
return dataset
elif isinstance(idx, str): elif isinstance(idx, str):
if idx not in self: if idx not in self:
raise KeyError("No such field called {} in DataSet.".format(idx)) raise KeyError("No such field called {} in DataSet.".format(idx))
@@ -192,6 +193,7 @@ class DataSet:
assert isinstance(i, int), "Only int index allowed." assert isinstance(i, int), "Only int index allowed."
instance = self[i] instance = self[i]
dataset.append(instance) dataset.append(instance)
dataset.collate_fns = deepcopy(self.collate_fns)
return dataset return dataset
else: else:
raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx)))
@@ -674,6 +676,8 @@ class DataSet:
dev_set.append(self[idx]) dev_set.append(self[idx])
for idx in train_indices: for idx in train_indices:
train_set.append(self[idx]) train_set.append(self[idx])
dev_set.collate_fns = deepcopy(self.collate_fns)
train_set.collate_fns = deepcopy(self.collate_fns)


return dev_set, train_set return dev_set, train_set


@@ -795,7 +799,7 @@ class DataSet:
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。
:return: :return:
""" """
# TODO 需要去重复
# TODO 不能为空
for field_name in field_names: for field_name in field_names:
self.collate_fns.set_pad_val(field_name, val=val) self.collate_fns.set_pad_val(field_name, val=val)




+ 2
- 2
fastNLP/core/drivers/jittor_driver/jittor_driver.py View File

@@ -66,7 +66,7 @@ class JittorDriver(Driver):
if mode == "validate": if mode == "validate":
if not hasattr(model, "validate_step"): if not hasattr(model, "validate_step"):
if hasattr(model, "test_step"): if hasattr(model, "test_step"):
logger.warning(
logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you" "Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for" "are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.") "'validate_step'.")
@@ -74,7 +74,7 @@ class JittorDriver(Driver):
else: else:
if not hasattr(model, "test_step"): if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"): if hasattr(model, "validate_step"):
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you"
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'mode=test', we are going to use 'validate_step' to substitute for" "are using 'mode=test', we are going to use 'validate_step' to substitute for"
"'test_step'.") "'test_step'.")




+ 1
- 1
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -133,7 +133,7 @@ class PaddleDriver(Driver):
else: else:
if not hasattr(model, "test_step"): if not hasattr(model, "test_step"):
if hasattr(model, "validate_step"): if hasattr(model, "validate_step"):
logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you"
logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you"
"are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" "are using 'Evaluator.test', we are going to use 'validate_step' to substitute for"
"'test_step'.") "'test_step'.")




+ 5
- 5
fastNLP/core/drivers/torch_driver/dist_utils.py View File

@@ -333,10 +333,8 @@ def all_gather_object(object_list, obj, group=None):
>>> output >>> output
['foo', 12, {1: 2}] ['foo', 12, {1: 2}]
""" """
if dist._rank_not_in_group(group):
if dist.distributed_c10d._rank_not_in_group(group):
return return

input_tensor, local_size = _object_to_tensor(obj)
if _TORCH_GREATER_EQUAL_1_8: if _TORCH_GREATER_EQUAL_1_8:
current_device = torch.device("cpu") current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group) is_nccl_backend = _check_for_nccl_backend(group)
@@ -345,10 +343,11 @@ def all_gather_object(object_list, obj, group=None):
# We cannot simply use my_rank since rank == device is not necessarily # We cannot simply use my_rank since rank == device is not necessarily
# true. # true.
current_device = torch.device("cuda", torch.cuda.current_device()) current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
else: else:
current_device = torch.cuda.current_device() current_device = torch.cuda.current_device()

input_tensor, local_size = _object_to_tensor(obj, device=current_device)

# Gather all local sizes. This is so that we can find the max size, and index # Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors. # until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group) group_size = dist.get_world_size(group=group)
@@ -379,3 +378,4 @@ def all_gather_object(object_list, obj, group=None):
tensor = tensor.cpu() tensor = tensor.cpu()
tensor_size = object_size_list[i] tensor_size = object_size_list[i]
object_list[i] = _tensor_to_object(tensor, tensor_size) object_list[i] = _tensor_to_object(tensor, tensor_size)
return object_list

+ 1
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -113,7 +113,7 @@ class TorchDriver(Driver):
if mode == "validate": if mode == "validate":
if not hasattr(model, "validate_step"): if not hasattr(model, "validate_step"):
if hasattr(model, "test_step"): if hasattr(model, "test_step"):
logger.warning(
logger.warning_once(
"Your model does not have 'validate_step' method but has 'test_step' method, but you" "Your model does not have 'validate_step' method but has 'test_step' method, but you"
"are using 'mode=validate', we are going to use 'test_step' to substitute for" "are using 'mode=validate', we are going to use 'test_step' to substitute for"
"'validate_step'.") "'validate_step'.")


+ 3
- 3
fastNLP/core/log/logger.py View File

@@ -125,9 +125,9 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
self._warning_msgs.add(msg) self._warning_msgs.add(msg)


def warn(self, msg, *args, **kwargs): def warn(self, msg, *args, **kwargs):
warnings.warn("The 'warn' method is deprecated, "
"use 'warning' instead", DeprecationWarning, 2)
self.warning(msg, *args, **kwargs)
if self.isEnabledFor(WARNING):
kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)


def error(self, msg, *args, **kwargs): def error(self, msg, *args, **kwargs):
""" """


+ 2
- 3
fastNLP/core/metrics/accuracy.py View File

@@ -14,8 +14,7 @@ from fastNLP.core.utils.utils import seq_len_to_mask


class Accuracy(Metric): class Accuracy(Metric):


def __init__(self, backend: Union[str, Backend, None] = 'auto',
aggregate_when_get_metric: bool = True):
def __init__(self, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True):
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric)
self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend) self.register_element(name='correct', value=0, aggregate_method='sum', backend=backend)
self.register_element(name='total', value=0, aggregate_method="sum", backend=backend) self.register_element(name='total', value=0, aggregate_method="sum", backend=backend)
@@ -64,7 +63,7 @@ class Accuracy(Metric):
warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.") warnings.warn("You are not passing `seq_len` to exclude pad when calculate accuracy.")


else: else:
raise RuntimeError(f"when pred havesize:{pred.shape}, target should have size: {pred.shape} or "
raise RuntimeError(f"when pred have size:{pred.shape}, target should have size: {pred.shape} or "
f"{pred.shape[:-1]}, got {target.shape}.") f"{pred.shape[:-1]}, got {target.shape}.")


if masks is not None: if masks is not None:


+ 3
- 3
fastNLP/core/samplers/__init__.py View File

@@ -23,14 +23,14 @@ __all__ = [
"BucketedBatchSampler", "BucketedBatchSampler",
"ReproducibleBatchSampler", "ReproducibleBatchSampler",


"re_instantiate_sampler",
"conversion_between_reproducible_and_unrepeated_sampler"
"re_instantiate_sampler"
] ]


from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler
from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler
from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler
from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler
from .utils import re_instantiate_sampler
from .conversion_utils import conversion_between_reproducible_and_unrepeated_sampler
from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler, ReproducibleBatchSampler



+ 33
- 0
fastNLP/core/samplers/conversion_utils.py View File

@@ -0,0 +1,33 @@
from fastNLP.core.samplers import re_instantiate_sampler
from fastNLP.core.samplers.reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, \
SortedSampler
from fastNLP.core.samplers.unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, \
UnrepeatedSequentialSampler, UnrepeatedSortedSampler


def conversion_between_reproducible_and_unrepeated_sampler(sampler):
"""
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的
ReproducibleSampler,

:param sampler:
:return:
"""
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
"The sampler must be UnrepeatedSampler or ReproducibleSampler"
if isinstance(sampler, UnrepeatedSampler):
if isinstance(sampler, UnrepeatedRandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
elif isinstance(sampler, UnrepeatedSequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
elif isinstance(sampler, UnrepeatedSortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
raise TypeError(f"{sampler.__class__} has no unrepeated version.")
else:
if isinstance(sampler, RandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
elif isinstance(sampler, SequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
elif isinstance(sampler, SortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
raise TypeError(f"{sampler.__class__} has no reproducible version.")

+ 6
- 13
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -378,7 +378,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
batch_indices = list(batch_indices[:-1]) batch_indices = list(batch_indices[:-1])
rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响 rng = np.random.default_rng(abs(seed)) # 这里防止由于bucket长度不同,对随机数状态有影响
rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。 rng.shuffle(batch_indices) # 不同的 batch 也 shuffle ,当前这种可以保证每张卡上每个 batch 长度都接近的。
batches = (np.array(batches)[batch_indices]).tolist()
batches = (np.array(batches, dtype=object)[batch_indices]).tolist()
if last_batches: if last_batches:
batches = batches + last_batches batches = batches + last_batches
return batches return batches
@@ -387,19 +387,12 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket: if self.old_batch_size != self.batch_size or self.old_num_batch_per_bucket != self.num_batch_per_bucket:
raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been" raise RuntimeError("BucketedBatchSampler does not support saving before last checkpoint states have been"
" consumed. ") " consumed. ")
states = {
'seed': self.seed,
'epoch': self.epoch,
'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据;
'sampler_type': self.__class__.__name__,
'length': len(self.dataset),
'shuffle': self.shuffle,
'batch_size': self.batch_size,
'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas
}
states = {'seed': self.seed, 'epoch': self.epoch, 'num_consumed_samples': self.num_consumed_samples,
'sampler_type': self.__class__.__name__, 'length': len(self.dataset), 'shuffle': self.shuffle,
'batch_size': self.batch_size, 'num_batch_per_bucket': self.num_batch_per_bucket,
'num_replicas': self.num_replicas,
'num_consumed_samples_array': getattr(self, 'num_consumed_samples_array', None)}


states['num_consumed_samples_array'] = getattr(self, 'num_consumed_samples_array', None)
return states return states


def load_state_dict(self, states: Dict): def load_state_dict(self, states: Dict):


+ 7
- 7
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -1,3 +1,10 @@
__all__ = [
'ReproducibleSampler',
'RandomSampler',
"SortedSampler",
"SequentialSampler"
]

from typing import Dict, List, Union from typing import Dict, List, Union
import math import math
import os import os
@@ -10,13 +17,6 @@ from fastNLP.envs.env import FASTNLP_DEQUE_SIZE
from .utils import NumConsumedSamplesArray from .utils import NumConsumedSamplesArray




__all__ = [
'ReproducibleSampler',
'RandomSampler',
"SortedSampler",
"SequentialSampler"
]



class ReproducibleSampler: class ReproducibleSampler:
""" """


+ 1
- 33
fastNLP/core/samplers/utils.py View File

@@ -1,42 +1,10 @@
__all__ = [ __all__ = [
're_instantiate_sampler',
'conversion_between_reproducible_and_unrepeated_sampler'
're_instantiate_sampler'
] ]
from array import array from array import array
from typing import Sequence from typing import Sequence
from collections import deque from collections import deque


from fastNLP.core.samplers.unrepeated_sampler import *
from fastNLP.core.samplers.reproducible_sampler import *


def conversion_between_reproducible_and_unrepeated_sampler(sampler):
"""
将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的
ReproducibleSampler,

:param sampler:
:return:
"""
assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \
"The sampler must be UnrepeatedSampler or ReproducibleSampler"
if isinstance(sampler, UnrepeatedSampler):
if isinstance(sampler, UnrepeatedRandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler)
elif isinstance(sampler, UnrepeatedSequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler)
elif isinstance(sampler, UnrepeatedSortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler)
raise TypeError(f"{sampler.__class__} has no unrepeated version.")
else:
if isinstance(sampler, RandomSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler)
elif isinstance(sampler, SequentialSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler)
elif isinstance(sampler, SortedSampler):
return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler)
raise TypeError(f"{sampler.__class__} has no reproducible version.")



def re_instantiate_sampler(sampler, new_sampler_class=None): def re_instantiate_sampler(sampler, new_sampler_class=None):
all_attributes = vars(sampler) all_attributes = vars(sampler)


+ 1
- 2
fastNLP/core/utils/__init__.py View File

@@ -13,7 +13,6 @@ __all__ = [
'torch_paddle_move_data_to_device', 'torch_paddle_move_data_to_device',
'torch_move_data_to_device', 'torch_move_data_to_device',
'get_fn_arg_names', 'get_fn_arg_names',
'check_fn_not_empty_params',
'auto_param_call', 'auto_param_call',
'check_user_specific_params', 'check_user_specific_params',
'dataclass_to_dict', 'dataclass_to_dict',
@@ -36,7 +35,7 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi
from .rich_progress import f_rich_progress from .rich_progress import f_rich_progress
from .torch_paddle_utils import torch_paddle_move_data_to_device from .torch_paddle_utils import torch_paddle_move_data_to_device
from .torch_utils import torch_move_data_to_device from .torch_utils import torch_move_data_to_device
from .utils import get_fn_arg_names, check_fn_not_empty_params, auto_param_call, check_user_specific_params, \
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \
indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir indice_collate_wrapper, deprecated, seq_len_to_mask, synchronize_safe_rm, synchronize_mkdir




+ 111
- 32
fastNLP/core/utils/utils.py View File

@@ -1,3 +1,4 @@
import functools
import inspect import inspect
from inspect import Parameter from inspect import Parameter
import dataclasses import dataclasses
@@ -24,10 +25,8 @@ from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_GLOBAL_RANK from fastNLP.envs import FASTNLP_GLOBAL_RANK





__all__ = [ __all__ = [
'get_fn_arg_names', 'get_fn_arg_names',
'check_fn_not_empty_params',
'auto_param_call', 'auto_param_call',
'check_user_specific_params', 'check_user_specific_params',
'dataclass_to_dict', 'dataclass_to_dict',
@@ -54,30 +53,6 @@ def get_fn_arg_names(fn: Callable) -> List[str]:
return list(inspect.signature(fn).parameters) return list(inspect.signature(fn).parameters)




def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional[int] = None) -> bool:
r"""
检查传入的batch_step_fn是否是合法的:(1) 是否是 callable 的; (2) 没有默认值的参数是否只有指定个数;
用户也可以传进一个 partial 的函数进来,只要其保证留有 `trainer` 和 `batch` 的参数位置即可;

:param fn: 传入的用以代替 Loop 中 'step' 函数的函数;
:param param_num: 检测的函数的应当的没有默认值的参数的个数;

:return: bool,表示传入的 `batch_step_fn` 是否正确;
"""

if fn is None:
return True
if not callable(fn):
return False
else:
params = inspect.signature(fn).parameters
not_default_params = {}
for _name, _param in params.items():
if _param.default == Parameter.empty:
not_default_params[_name] = _param
return len(not_default_params) == param_num


def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
r""" r"""
@@ -95,7 +70,6 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取
参数值后,再传给 `fn` 进行实际的运算; 参数值后,再传给 `fn` 进行实际的运算;
:param mapping: 一个字典,用来更改其前面的字典的键值; :param mapping: 一个字典,用来更改其前面的字典的键值;
:param wo_auto_param_call: 是否关闭默认的参数匹配行为;


:return: 返回 `fn` 运行的结果; :return: 返回 `fn` 运行的结果;


@@ -123,7 +97,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
_kwargs = None _kwargs = None
for _name, _param in _need_params.items(): for _name, _param in _need_params.items():
if _param.kind == Parameter.VAR_POSITIONAL: if _param.kind == Parameter.VAR_POSITIONAL:
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn.__name__}.")
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise ValueError(f"It is not allowed to have parameter `*args` in your function:{fn_msg}.")
if _param.kind == Parameter.VAR_KEYWORD: if _param.kind == Parameter.VAR_KEYWORD:
_kwargs = (_name, _param) _kwargs = (_name, _param)


@@ -136,12 +111,17 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
_default_params[_name] = _param.default _default_params[_name] = _param.default


if mapping is not None: if mapping is not None:
assert isinstance(mapping, Dict), f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}."
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
assert isinstance(mapping, Dict), f"Exception happens when calling {fn_msg}. " \
f"Parameter `mapping` should be of 'Dict' type, instead of {type(mapping)}."


_has_params = {} _has_params = {}
duplicate_names = [] duplicate_names = []
for arg in args: for arg in args:
assert isinstance(arg, Dict), "The input part of function `auto_param_call` can only be `Dict` type."
if not isinstance(arg, Dict):
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise TypeError(f"Exception happens when calling {fn_msg}. "
f"The input part of function `auto_param_call` must be `Dict` type, instead of {type(arg)}.")
for _name, _value in arg.items(): for _name, _value in arg.items():
if mapping is not None and _name in mapping: if mapping is not None and _name in mapping:
_name = mapping[_name] _name = mapping[_name]
@@ -153,7 +133,8 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
elif _name in _need_params and not (_has_params[_name] is _value): elif _name in _need_params and not (_has_params[_name] is _value):
duplicate_names.append(_name) duplicate_names.append(_name)
if duplicate_names: if duplicate_names:
raise ValueError(f"The following key present in several inputs:{duplicate_names}")
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
raise ValueError(f"The following key present in several inputs:{duplicate_names} when calling {fn_msg}.")


# 将具有默认值但是没有被输入修改过的参数值传进去; # 将具有默认值但是没有被输入修改过的参数值传进去;
for _name, _value in _default_params.items(): for _name, _value in _default_params.items():
@@ -162,11 +143,89 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None


if len(_has_params)<len(_need_params): if len(_has_params)<len(_need_params):
miss_params = list(set(_need_params.keys()) - set(_has_params.keys())) miss_params = list(set(_need_params.keys()) - set(_has_params.keys()))
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn.__name__} are not found in the input.")
fn_msg = _get_fun_msg(fn if signature_fn is None else signature_fn)
_provided_keys = _get_keys(args)
raise ValueError(f"The parameters:`{miss_params}` needed by function:{fn_msg} "
f"are not found in the input keys({_provided_keys}).")


return fn(**_has_params) return fn(**_has_params)




def _get_keys(args:List[Dict]) -> List[List[str]]:
"""
返回每个 dict 的 keys

:param args:
:return:
"""
_provided_keys = []
for arg in args:
_provided_keys.append(list(arg.keys()))
return _provided_keys


def _get_fun_msg(fn)->str:
"""
获取函数的基本信息,帮助报错。
ex:
print(_get_fun_msg(_get_fun_msg))
# `_get_fun_msg(fn) -> str`(In file:/Users/hnyan/Desktop/projects/fastNLP/fastNLP/fastNLP/core/utils/utils.py)

:param callable fn:
:return:
"""
if isinstance(fn, functools.partial):
return _get_fun_msg(fn.func)
try:
fn_name = fn.__qualname__ + str(inspect.signature(fn))
except:
fn_name = str(fn)
try:
fp = '(In file:' + os.path.abspath(inspect.getfile(fn)) + ')'
except:
fp = ''
msg = f'`{fn_name}`' + fp
return msg


def _check_valid_parameters_number(fn, expected_params:List[str], fn_name=None):
"""
检查一个函数是否需要 expected_params 参数(检测数量是否匹配)。除掉 self (如果是method),给定默认值的参数等。如果匹配不上,就会
进行报错。

:param fn: 需要检测的函数,可以是 method 或者 function 。
:param expected_params: 期待应该支持的参数。
:param fn_name: fn 的名字,当传入的 fn 不是 callable 的时候方便报错。
:return:
"""
if fn_name is not None:
assert callable(fn), f"{fn_name} should be callable, instead of {type(fn)}."

parameters = list(inspect.signature(fn).parameters.values())
if inspect.ismethod(fn):
if len(parameters)>0 and parameters[0].name == 'self':
parameters = parameters[1:] # 去掉self

no_var_param = True # 没有 * 这种参数
number_param_need_value = 0
for param in parameters:
if param.kind is param.VAR_POSITIONAL:
no_var_param = False
elif param.kind is param.VAR_KEYWORD:
no_var_param = False
else:
if param.default is param.empty:
number_param_need_value += 1

if len(parameters)<len(expected_params) and no_var_param:
raise RuntimeError(f"The function:{_get_fun_msg(fn)} accepts {len(parameters)} parameters, "
f"but {len(expected_params)} parameters:{expected_params} will be provided.")

if number_param_need_value>len(expected_params):
raise RuntimeError(f"The function:{_get_fun_msg(fn)} expects {len(parameters)} parameters, but only"
f" {len(expected_params)} parameters:{expected_params} will be provided.")


def check_user_specific_params(user_params: Dict, fn: Callable): def check_user_specific_params(user_params: Dict, fn: Callable):
""" """
该函数使用用户的输入来对指定函数的参数进行赋值; 该函数使用用户的输入来对指定函数的参数进行赋值;
@@ -592,4 +651,24 @@ def synchronize_mkdir(path: Optional[Union[str, Path]]):
wait_to_success(path.exists) wait_to_success(path.exists)




def get_class_that_defined_method(method):
"""
给定一个method,返回这个 method 的 class 的对象


:param method:
:return:
"""
if isinstance(method, functools.partial):
return get_class_that_defined_method(method.func)
if inspect.ismethod(method) or (inspect.isbuiltin(method) and getattr(method, '__self__', None) is not None and getattr(method.__self__, '__class__', None)):
for cls in inspect.getmro(method.__self__.__class__):
if method.__name__ in cls.__dict__:
return cls
method = getattr(method, '__func__', method) # fallback to __qualname__ parsing
if inspect.isfunction(method):
cls = getattr(inspect.getmodule(method),
method.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0],
None)
if isinstance(cls, type):
return cls
return getattr(method, '__objclass__', None) # handle special descriptor objects

+ 5
- 5
fastNLP/io/data_bundle.py View File

@@ -251,10 +251,10 @@ class DataBundle:
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True,
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True):
r""" r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_field_more` 方法


.. note:: .. note::
``apply_field_more`` 与 ``apply_field`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply_field_more`` 与 ``apply_field`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。 ``apply`` 区别的介绍。


:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
@@ -285,7 +285,7 @@ class DataBundle:
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, def apply(self, func: Callable, new_field_name: str, num_proc: int = 0,
progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None): progress_desc: str = '', show_progress_bar: bool = True, _apply_field: str = None):
r""" r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply` 方法


对DataBundle中所有的dataset使用apply方法 对DataBundle中所有的dataset使用apply方法


@@ -309,10 +309,10 @@ class DataBundle:
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0,
progress_desc: str = '', show_progress_bar: bool = True): progress_desc: str = '', show_progress_bar: bool = True):
r""" r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :method:`~fastNLP.DataSet.apply_more` 方法


.. note:: .. note::
``apply_more`` 与 ``apply`` 的区别参考 :meth:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply_more`` 与 ``apply`` 的区别参考 :method:`fastNLP.DataSet.apply_more` 中关于 ``apply_more`` 与
``apply`` 区别的介绍。 ``apply`` 区别的介绍。


:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果


+ 1
- 1
fastNLP/io/pipe/classification.py View File

@@ -87,7 +87,7 @@ class CLSBasePipe(Pipe):


def process_from_file(self, paths) -> DataBundle: def process_from_file(self, paths) -> DataBundle:
r""" r"""
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`


:param paths: :param paths:
:return: DataBundle :return: DataBundle


+ 1
- 1
fastNLP/io/pipe/construct_graph.py View File

@@ -164,7 +164,7 @@ class GraphBuilderBase:


def build_graph_from_file(self, path: str): def build_graph_from_file(self, path: str):
r""" r"""
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
传入文件路径,生成处理好的scipy_sparse_matrix对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`


:param path: :param path:
:return: scipy_sparse_matrix :return: scipy_sparse_matrix


+ 1
- 1
fastNLP/io/pipe/pipe.py View File

@@ -33,7 +33,7 @@ class Pipe:


def process_from_file(self, paths: str) -> DataBundle: def process_from_file(self, paths: str) -> DataBundle:
r""" r"""
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::meth:`fastNLP.io.Loader.load()`
传入文件路径,生成处理好的DataBundle对象。paths支持的路径形式可以参考 ::method:`fastNLP.io.Loader.load()`


:param str paths: :param str paths:
:return: DataBundle :return: DataBundle


+ 187
- 0
tests/core/utils/test_utils.py View File

@@ -0,0 +1,187 @@
from functools import partial

import pytest

from fastNLP.core.utils.utils import auto_param_call, _check_valid_parameters_number, _get_fun_msg
from fastNLP.core.metrics import Metric



class TestAutoParamCall:
def test_basic(self):
def fn(x):
return x
x = {'x': 3, 'y': 4}
r = auto_param_call(fn, x)
assert r==3

xs = []
for i in range(10):
xs.append({f'x{i}': i})
def fn(x0, x1, x2, x3):
return x0 + x1 + x2 + x3
r = auto_param_call(fn, *xs)
assert r == 0 + 1+ 2+ 3

def fn(chongfu1, chongfu2, buChongFu):
pass
with pytest.raises(BaseException) as exc_info:
auto_param_call(fn, {'chongfu1': 3, "chongfu2":4, 'buChongFu':2},
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2})
assert 'The following key present in several inputs' in exc_info.value.args[0]
assert 'chongfu1' in exc_info.value.args[0] and 'chongfu2' in exc_info.value.args[0]

# 没用到不报错
def fn(chongfu1, buChongFu):
pass
auto_param_call(fn, {'chongfu1': 1, "chongfu2":4, 'buChongFu':2},
{'chongfu1': 1, 'chongfu2':2, 'buChongFu':2})

# 可以定制signature_fn
def fn1(**kwargs):
kwargs.pop('x')
kwargs.pop('y')
assert len(kwargs)==0
def fn(x, y):
pass
x = {'x': 3, 'y': 4}
r = auto_param_call(fn1, x, signature_fn=fn)

# 没提供的时候报错
def fn(meiti1, meiti2, tigong):
pass
with pytest.raises(BaseException) as exc_info:
auto_param_call(fn, {'tigong':1})
assert 'meiti1' in exc_info.value.args[0] and 'meiti2' in exc_info.value.args[0]

# 默认值替换
def fn(x, y=100):
return x + y
r = auto_param_call(fn, {'x': 10, 'y': 20})
assert r==30
assert auto_param_call(fn, {'x': 10, 'z': 20})==110

# 测试mapping的使用
def fn(x, y=100):
return x + y
r = auto_param_call(fn, {'x1': 10, 'y1': 20}, mapping={'x1': 'x', 'y1': 'y', 'meiyong': 'meiyong'})
assert r==30

# 测试不需要任何参数
def fn():
return 1
assert 1 == auto_param_call(fn, {'x':1})

# 测试调用类的方法没问题
assert 2==auto_param_call(self.call_this, {'x':1 ,'y':1})
assert 2==auto_param_call(self.call_this, {'x':1,'y':1, 'z':1},mapping={'z': 'self'})

def test_msg(self):
with pytest.raises(BaseException) as exc_info:
auto_param_call(self.call_this, {'x':1})
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0]

with pytest.raises(BaseException) as exc_info:
auto_param_call(call_this_for_auto_param_call, {'x':1})
assert __file__ in exc_info.value.args[0]
assert 'call_this_for_auto_param_call' in exc_info.value.args[0]

with pytest.raises(BaseException) as exc_info:
auto_param_call(self.call_this_two, {'x':1})
assert __file__ in exc_info.value.args[0]

with pytest.raises(BaseException) as exc_info:
auto_param_call(call_this_for_auto_param_call, {'x':1}, signature_fn=self.call_this)
assert 'TestAutoParamCall.call_this' in exc_info.value.args[0] # 应该是signature的信息

def call_this(self, x, y):
return x + y

def call_this_two(self, x, y, z=pytest, **kwargs):
return x + y

def test_metric_auto_param_call(self):
metric = AutoParamCallMetric()
with pytest.raises(BaseException):
auto_param_call(metric.update, {'y':1}, signature_fn=metric.update.__wrapped__)


class AutoParamCallMetric(Metric):
def update(self, x):
pass


def call_this_for_auto_param_call(x, y):
return x + y


class TestCheckNumberOfParameters:
def test_validate_every(self):
def validate_every(trainer):
pass
_check_valid_parameters_number(validate_every, expected_params=['trainer'])

# 无默认值,多了报错
def validate_every(trainer, other):
pass
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(validate_every, expected_params=['trainer'])
assert "2 parameters" in exc_info.value.args[0]
print(exc_info.value.args[0])

# 有默认值ok
def validate_every(trainer, other=1):
pass
_check_valid_parameters_number(validate_every, expected_params=['trainer'])

# 参数多了
def validate_every(trainer):
pass
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other'])
assert "accepts 1 parameters" in exc_info.value.args[0]
print(exc_info.value.args[0])

# 使用partial
def validate_every(trainer, other):
pass
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer'])
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other'])
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(partial(validate_every, other=1), expected_params=['trainer', 'other', 'more'])
assert 'accepts 2 parameters' in exc_info.value.args[0]
print(exc_info.value.args[0])

# 如果存在 *args 或 *kwargs 不报错多的
def validate_every(trainer, *args):
pass
_check_valid_parameters_number(validate_every, expected_params=['trainer', 'other', 'more'])

def validate_every(trainer, **kwargs):
pass
_check_valid_parameters_number(partial(validate_every, trainer=1), expected_params=['trainer', 'other', 'more'])

# class 的方法删掉self
class InnerClass:
def demo(self, x):
pass

def no_param(self):
pass

def param_kwargs(self, **kwargs):
pass

inner = InnerClass()
with pytest.raises(RuntimeError) as exc_info:
_check_valid_parameters_number(inner.demo, expected_params=['trainer', 'other', 'more'])
assert 'accepts 1 parameters' in exc_info.value.args[0]

_check_valid_parameters_number(inner.demo, expected_params=['trainer'])


def test_get_fun_msg():
def demo(x):
pass

print(_get_fun_msg(_get_fun_msg))

+ 1
- 19
tests/helpers/utils.py View File

@@ -2,37 +2,19 @@ import os
import sys import sys
import __main__ import __main__
from functools import wraps from functools import wraps
import inspect
from inspect import ismethod from inspect import ismethod
import functools
from copy import deepcopy from copy import deepcopy
from io import StringIO from io import StringIO
import time import time


import numpy as np import numpy as np


from fastNLP.core.utils.utils import get_class_that_defined_method
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
from fastNLP.core.drivers.utils import distributed_open_proc from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.log import logger from fastNLP.core.log import logger




def get_class_that_defined_method(meth):
if isinstance(meth, functools.partial):
return get_class_that_defined_method(meth.func)
if inspect.ismethod(meth) or (inspect.isbuiltin(meth) and getattr(meth, '__self__', None) is not None and getattr(meth.__self__, '__class__', None)):
for cls in inspect.getmro(meth.__self__.__class__):
if meth.__name__ in cls.__dict__:
return cls
meth = getattr(meth, '__func__', meth) # fallback to __qualname__ parsing
if inspect.isfunction(meth):
cls = getattr(inspect.getmodule(meth),
meth.__qualname__.split('.<locals>', 1)[0].rsplit('.', 1)[0],
None)
if isinstance(cls, type):
return cls
return getattr(meth, '__objclass__', None) # handle special descriptor objects


def recover_logger(fn): def recover_logger(fn):
@wraps(fn) @wraps(fn)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):


Loading…
Cancel
Save