@@ -14,7 +14,7 @@ __all__ = [ | |||||
'MoreEvaluateCallback', | 'MoreEvaluateCallback', | ||||
"TorchWarmupCallback", | "TorchWarmupCallback", | ||||
"TorchGradClipCallback", | "TorchGradClipCallback", | ||||
"MonitorUtility", | |||||
"ResultsMonitor", | |||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
# collators | # collators | ||||
@@ -16,7 +16,7 @@ __all__ = [ | |||||
"TorchWarmupCallback", | "TorchWarmupCallback", | ||||
"TorchGradClipCallback", | "TorchGradClipCallback", | ||||
"MonitorUtility", | |||||
"ResultsMonitor", | |||||
'HasMonitorCallback' | 'HasMonitorCallback' | ||||
] | ] | ||||
@@ -31,5 +31,5 @@ from .load_best_model_callback import LoadBestModelCallback | |||||
from .early_stop_callback import EarlyStopCallback | from .early_stop_callback import EarlyStopCallback | ||||
from .torch_callbacks import * | from .torch_callbacks import * | ||||
from .more_evaluate_callback import MoreEvaluateCallback | from .more_evaluate_callback import MoreEvaluateCallback | ||||
from .has_monitor_callback import MonitorUtility, HasMonitorCallback | |||||
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | |||||
@@ -57,7 +57,7 @@ def prepare_callbacks(callbacks, progress_bar): | |||||
if has_no_progress and progress_bar is not None: | if has_no_progress and progress_bar is not None: | ||||
callback = choose_progress_callback(progress_bar) | callback = choose_progress_callback(progress_bar) | ||||
if callback is not None: | if callback is not None: | ||||
_callbacks.append(callback) | |||||
_callbacks = [callback] + _callbacks # 放在最前面,方便分割不同 epoch | |||||
has_no_progress = False | has_no_progress = False | ||||
elif has_no_progress is False and progress_bar not in ('auto', None): | elif has_no_progress is False and progress_bar not in ('auto', None): | ||||
logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") | logger.rank_zero_warning(f"Since you have passed in ProgressCallback, progress_bar={progress_bar} will be ignored.") | ||||
@@ -1,7 +1,7 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
'ExecuteOnceBetterMonitor', | 'ExecuteOnceBetterMonitor', | ||||
'MonitorUtility' | |||||
'ResultsMonitor' | |||||
] | ] | ||||
from typing import Dict, Union, Any | from typing import Dict, Union, Any | ||||
@@ -29,12 +29,16 @@ class CanItemDataType(ABC): | |||||
return NotImplemented | return NotImplemented | ||||
class MonitorUtility: | |||||
""" | |||||
计算 monitor 的相关函数 | |||||
class ResultsMonitor: | |||||
def __init__(self, monitor:Union[Callback, str], larger_better:bool=True): | |||||
""" | |||||
可用于监控某个数值,并通过 is_better_results() 等接口实现检测结果是否变得更好了。 | |||||
""" | |||||
def __init__(self, monitor, larger_better): | |||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param larger_better: monitor 是否时越大越好 | |||||
""" | |||||
self.set_monitor(monitor, larger_better) | self.set_monitor(monitor, larger_better) | ||||
def set_monitor(self, monitor, larger_better): | def set_monitor(self, monitor, larger_better): | ||||
@@ -53,7 +57,7 @@ class MonitorUtility: | |||||
def itemize_results(self, results): | def itemize_results(self, results): | ||||
""" | """ | ||||
将结果中有 .item() 方法的都调用一下,使得可以结果可以保存 | |||||
将结果中有 .item() 方法的都调用一下,使得 tensor 类型的数据转为 python 内置类型。 | |||||
:param results: | :param results: | ||||
:return: | :return: | ||||
@@ -161,7 +165,7 @@ class MonitorUtility: | |||||
return monitor_name | return monitor_name | ||||
class HasMonitorCallback(MonitorUtility, Callback): | |||||
class HasMonitorCallback(ResultsMonitor, Callback): | |||||
def __init__(self, monitor, larger_better, must_have_monitor=False): | def __init__(self, monitor, larger_better, must_have_monitor=False): | ||||
""" | """ | ||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | ||||
@@ -12,7 +12,7 @@ from fastNLP.core.log import logger | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME | from fastNLP.envs.env import FASTNLP_EVALUATE_RESULT_FILENAME | ||||
from .has_monitor_callback import MonitorUtility | |||||
from .has_monitor_callback import ResultsMonitor | |||||
class Saver: | class Saver: | ||||
@@ -170,7 +170,7 @@ class TopkQueue: | |||||
return self.topk != 0 | return self.topk != 0 | ||||
class TopkSaver(MonitorUtility, Saver): | |||||
class TopkSaver(ResultsMonitor, Saver): | |||||
def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | def __init__(self, topk:int=0, monitor:str=None, larger_better:bool=True, folder:str=None, save_object:str='model', | ||||
only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | only_state_dict:bool=True, model_save_fn:Callable=None, save_evaluate_results:bool=True, | ||||
**kwargs): | **kwargs): | ||||
@@ -196,7 +196,7 @@ class TopkSaver(MonitorUtility, Saver): | |||||
fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | fastnlp_evaluate_results.json 文件,记录当前的 results。仅在设置了 topk 的场景下有用,默认为 True 。 | ||||
:param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | :param kwargs: 更多需要传递给 Trainer.save() 或者 Trainer.save_model() 接口的参数。 | ||||
""" | """ | ||||
MonitorUtility.__init__(self, monitor, larger_better) | |||||
ResultsMonitor.__init__(self, monitor, larger_better) | |||||
Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | Saver.__init__(self, folder, save_object, only_state_dict, model_save_fn, **kwargs) | ||||
if monitor is not None and topk == 0: | if monitor is not None and topk == 0: | ||||
@@ -8,10 +8,10 @@ __all__ = [ | |||||
] | ] | ||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | |||||
from ..drivers.choose_driver import choose_driver | |||||
from .loops import Loop, EvaluateBatchLoop | from .loops import Loop, EvaluateBatchLoop | ||||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | ||||
match_and_substitute_params, f_rich_progress | |||||
match_and_substitute_params, f_rich_progress, flat_nest_dict | |||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | ||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | ||||
@@ -162,13 +162,15 @@ class Evaluator: | |||||
self.cur_dataloader_name = dataloader_name | self.cur_dataloader_name = dataloader_name | ||||
results = self.evaluate_batch_loop.run(self, dataloader) | results = self.evaluate_batch_loop.run(self, dataloader) | ||||
self.remove_progress_bar(dataloader_name) | self.remove_progress_bar(dataloader_name) | ||||
metric_results.update(results) | |||||
metric_results[dataloader_name] = results | |||||
self.reset() | self.reset() | ||||
self.driver.barrier() | self.driver.barrier() | ||||
except BaseException as e: | except BaseException as e: | ||||
raise e | raise e | ||||
finally: | finally: | ||||
self.finally_progress_bar() | self.finally_progress_bar() | ||||
metric_results = flat_nest_dict(metric_results, separator=self.separator, compress_none_key=True, top_down=False) | |||||
self.driver.set_model_mode(mode='train') | self.driver.set_model_mode(mode='train') | ||||
if self.verbose: | if self.verbose: | ||||
if self.progress_bar == 'rich': | if self.progress_bar == 'rich': | ||||
@@ -251,14 +253,13 @@ class Evaluator: | |||||
""" | """ | ||||
self.metrics_wrapper.update(batch, outputs) | self.metrics_wrapper.update(batch, outputs) | ||||
def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: | |||||
def get_metric(self) -> Dict: | |||||
""" | """ | ||||
获取当前dataloader的metric结果 | |||||
调用所有 metric 的 get_metric 方法,并返回结果。其中 key 为 metric 的名称,value 是各个 metric 的结果。 | |||||
:param str dataloader_name: 当前dataloader的名字 | |||||
:return: | :return: | ||||
""" | """ | ||||
return self.metrics_wrapper.get_metric(dataloader_name=dataloader_name, separator=self.separator) | |||||
return self.metrics_wrapper.get_metric() | |||||
@property | @property | ||||
def metrics_wrapper(self): | def metrics_wrapper(self): | ||||
@@ -366,15 +367,12 @@ class _MetricsWrapper: | |||||
elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): | elif _is_torchmetrics_metric(metric) or _is_paddle_metric(metric) or isinstance(metric, Metric): | ||||
metric.reset() | metric.reset() | ||||
def get_metric(self, dataloader_name: str, separator: str) -> Dict: | |||||
def get_metric(self) -> Dict: | |||||
""" | """ | ||||
将所有 metric 结果展平到一个一级的字典中,这个字典中 key 的命名规则是 | |||||
indicator_name{separator}metric_name{separator}dataloader_name | |||||
例如: f1#F1PreRec#dev | |||||
调用各个 metric 得到 metric 的结果。并使用 {'metric_name1': metric_results, 'metric_name2': metric_results} 的形式 | |||||
返回。 | |||||
:param dataloader_name: 当前metric对应的dataloader的名字。若为空,则不显示在最终的key上面。 | |||||
:param separator: 用于间隔不同称呼。 | |||||
:return: 返回一个一级结构的字典,其中 key 为区别一个 metric 的名字,value 为该 metric 的值; | |||||
:return: | |||||
""" | """ | ||||
results = {} | results = {} | ||||
for metric_name, metric in zip(self._metric_names, self._metrics): | for metric_name, metric in zip(self._metric_names, self._metrics): | ||||
@@ -384,37 +382,9 @@ class _MetricsWrapper: | |||||
_results = metric.get_metric(reset=False) | _results = metric.get_metric(reset=False) | ||||
elif _is_torchmetrics_metric(metric): | elif _is_torchmetrics_metric(metric): | ||||
_results = metric.compute() | _results = metric.compute() | ||||
# 我们规定了 evaluator 中的 metrics 的输入只能是一个 dict,这样如果 metric 是一个 torchmetrics 时,如果 evaluator | |||||
# 没有传入 func_post_proc,那么我们就自动使用该 metric 的 metric name 当做其的 indicator name 将其自动转换成一个字典; | |||||
elif _is_paddle_metric(metric): | elif _is_paddle_metric(metric): | ||||
_results = metric.accumulate() | _results = metric.accumulate() | ||||
if not isinstance(_results, Dict): | |||||
name = _get_metric_res_name(dataloader_name, metric_name, '', separator) | |||||
results[name] = _results | |||||
else: | else: | ||||
for indicator_name, value in _results.items(): | |||||
name = _get_metric_res_name(dataloader_name, metric_name, indicator_name, separator) | |||||
results[name] = value | |||||
raise RuntimeError(f"Not support `{type(metric)}` for now.") | |||||
results[metric_name] = _results | |||||
return results | return results | ||||
def _get_metric_res_name(dataloader_name: Optional[str], metric_name: str, indicator_name: str, separator='#') -> str: | |||||
""" | |||||
:param dataloader_name: dataloder的名字 | |||||
:param metric_name: metric的名字 | |||||
:param indicator_name: metric中的各项metric名称,例如f, precision, recall | |||||
:param separator: 用以间隔不同对象的间隔符 | |||||
:return: | |||||
""" | |||||
names = [] | |||||
if indicator_name: | |||||
names.append(indicator_name) | |||||
if metric_name: | |||||
names.append(metric_name) | |||||
if dataloader_name: | |||||
names.append(dataloader_name) | |||||
if len(names) == 0: | |||||
raise RuntimeError("You cannot use empty `dataloader_name`, `metric_name`, and `monitor` simultaneously.") | |||||
return separator.join(names) |
@@ -40,8 +40,8 @@ class EvaluateBatchLoop(Loop): | |||||
self.batch_step_fn(evaluator, batch) | self.batch_step_fn(evaluator, batch) | ||||
batch_idx += 1 | batch_idx += 1 | ||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | ||||
# 获取metric结果。返回的dict内容示例为{'f1#F1Metric#dl1': 0.93, 'pre#F1Metric#dl1': 0.95, ...} | |||||
results = evaluator.get_dataloader_metric(dataloader_name=evaluator.cur_dataloader_name) | |||||
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | |||||
results = evaluator.get_metric() | |||||
return results | return results | ||||
@staticmethod | @staticmethod | ||||
@@ -23,7 +23,7 @@ from fastNLP.core.callbacks.callback import _CallbackWrapper | |||||
from fastNLP.core.callbacks.callback_manager import prepare_callbacks | from fastNLP.core.callbacks.callback_manager import prepare_callbacks | ||||
from fastNLP.core.callbacks.callback_event import Event | from fastNLP.core.callbacks.callback_event import Event | ||||
from fastNLP.core.drivers import Driver | from fastNLP.core.drivers import Driver | ||||
from fastNLP.core.drivers.utils import choose_driver | |||||
from ..drivers.choose_driver import choose_driver | |||||
from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext | ||||
from fastNLP.core.utils.utils import _check_valid_parameters_number | from fastNLP.core.utils.utils import _check_valid_parameters_number | ||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
@@ -0,0 +1,31 @@ | |||||
from typing import Union, Optional, List | |||||
from .driver import Driver | |||||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | |||||
r""" | |||||
根据输入的参数 'gpus' 的格式来决定具体的工作模式; | |||||
:param model: 运行过程中使用的具体的最原始的模型; | |||||
:param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式; | |||||
:param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释; | |||||
:param kwargs: 其余的传给 `Driver` 的参数; | |||||
""" | |||||
# 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性; | |||||
if isinstance(driver, Driver): | |||||
return driver | |||||
if driver in {"torch", "torch_ddp", "fairscale"}: | |||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||||
return initialize_torch_driver(driver, device, model, **kwargs) | |||||
elif driver in {"jittor"}: | |||||
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver | |||||
return initialize_jittor_driver(driver, device, model, **kwargs) | |||||
elif driver in {"paddle", "fleet"}: | |||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | |||||
return initialize_paddle_driver(driver, device, model, **kwargs) | |||||
else: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " | |||||
"'jittor', 'paddle', 'fleet'].") |
@@ -33,11 +33,12 @@ class JittorDriver(Driver): | |||||
f"`jittor.Module` type.") | f"`jittor.Module` type.") | ||||
super(JittorDriver, self).__init__(model) | super(JittorDriver, self).__init__(model) | ||||
self.model = model | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | ||||
self.grad_scaler = _grad_scaler() | self.grad_scaler = _grad_scaler() | ||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
@staticmethod | @staticmethod | ||||
def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): | ||||
# 在fastnlp中实现了JittorDataLoader | # 在fastnlp中实现了JittorDataLoader | ||||
@@ -152,4 +153,4 @@ class JittorDriver(Driver): | |||||
# def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | # def set_sampler_epoch(self, dataloader: JittorDataLoader, cur_epoch_idx): | ||||
# # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | # # 保证 ddp 训练时的 shuffle=True 时的正确性,因为需要保证每一个进程上的 sampler 的shuffle 的随机数种子是一样的; | ||||
# if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | # if callable(getattr(dataloader.batch_sampler, "set_epoch", None)): | ||||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) | |||||
# dataloader.batch_sampler.set_epoch(cur_epoch_idx) |
@@ -60,8 +60,8 @@ class JittorSingleDriver(JittorDriver): | |||||
logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | logger.debug(f'Use {_get_fun_msg(fn, with_fp=False)}...') | ||||
return fn, None | return fn, None | ||||
elif fn in {"train_step", "evaluate_step"}: | elif fn in {"train_step", "evaluate_step"}: | ||||
logger.debug(f'Use {_get_fun_msg(self.model.forward, with_fp=False)}...') | |||||
return self.model, self.model.forward | |||||
logger.debug(f'Use {_get_fun_msg(self.model.execute, with_fp=False)}...') | |||||
return self.model, self.model.execute | |||||
else: | else: | ||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | ||||
@@ -98,3 +98,9 @@ class JittorSingleDriver(JittorDriver): | |||||
return dataloader | return dataloader | ||||
else: | else: | ||||
return dataloader | return dataloader | ||||
def setup(self): | |||||
""" | |||||
使用单个 GPU 时,jittor 底层自动实现调配,无需额外操作 | |||||
""" | |||||
pass |
@@ -1,38 +1,5 @@ | |||||
from typing import Optional | |||||
from typing import Union, List | |||||
from typing import List | |||||
import subprocess | import subprocess | ||||
from pathlib import Path | |||||
from fastNLP.core.drivers.driver import Driver | |||||
__all__ = [] | |||||
def choose_driver(model, driver: Union[str, Driver], device: Optional[Union[int, List[int], str]], **kwargs) -> Driver: | |||||
r""" | |||||
根据输入的参数 'gpus' 的格式来决定具体的工作模式; | |||||
:param model: 运行过程中使用的具体的最原始的模型; | |||||
:param driver: 应当为字符串或者 `Driver` 实例,表示运行中具体使用的训练/评测模式; | |||||
:param device: 具体的形式请参见 `fastNLP.core.drivers.torch_driver.utils.initialize_torch_dirver` 的注释; | |||||
:param kwargs: 其余的传给 `Driver` 的参数; | |||||
""" | |||||
# 如果用户直接传进来一个 driver 实例,我们就直接返回回去,目前用户需要自己保证传进来的 driver 的正确性; | |||||
if isinstance(driver, Driver): | |||||
return driver | |||||
if driver in {"torch", "torch_ddp", "fairscale"}: | |||||
from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver | |||||
return initialize_torch_driver(driver, device, model, **kwargs) | |||||
elif driver in {"jittor"}: | |||||
from fastNLP.core.drivers.jittor_driver.initialize_jittor_driver import initialize_jittor_driver | |||||
return initialize_jittor_driver(driver, device, model, **kwargs) | |||||
elif driver in {"paddle", "fleet"}: | |||||
from fastNLP.core.drivers.paddle_driver.initialize_paddle_driver import initialize_paddle_driver | |||||
return initialize_paddle_driver(driver, device, model, **kwargs) | |||||
else: | |||||
raise ValueError("Parameter `driver` can only be one of these values: ['torch', 'torch_ddp', 'fairscale', " | |||||
"'jittor', 'paddle', 'fleet'].") | |||||
def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | def distributed_open_proc(output_from_new_proc:str, command:List[str], env_copy:dict, rank:int=None): | ||||
@@ -24,6 +24,7 @@ __all__ = [ | |||||
'Option', | 'Option', | ||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
"flat_nest_dict" | |||||
] | ] | ||||
from .cache_results import cache_results | from .cache_results import cache_results | ||||
@@ -33,8 +34,6 @@ from .paddle_utils import get_device_from_visible, paddle_to, paddle_move_data_t | |||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
from .torch_paddle_utils import torch_paddle_move_data_to_device | from .torch_paddle_utils import torch_paddle_move_data_to_device | ||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import get_fn_arg_names, auto_param_call, check_user_specific_params, \ | |||||
dataclass_to_dict, match_and_substitute_params, apply_to_collection, nullcontext, pretty_table_printer, Option, \ | |||||
deprecated, seq_len_to_mask | |||||
from .utils import * | |||||
@@ -35,6 +35,7 @@ __all__ = [ | |||||
'Option', | 'Option', | ||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
"flat_nest_dict" | |||||
] | ] | ||||
@@ -645,4 +646,55 @@ def is_notebook(): | |||||
except: | except: | ||||
return False | return False | ||||
else: # pragma: no cover | else: # pragma: no cover | ||||
return True | |||||
return True | |||||
def flat_nest_dict(d:Dict, separator:str='#', compress_none_key:bool=True, top_down:bool=False) -> Dict: | |||||
""" | |||||
讲一个 nested 的 dict 转成 flat 的 dict,例如 | |||||
ex:: | |||||
d = {'test': {'f1': {'f': 0.2, 'rec': 0.1}}} -> {'f#f1#test':0.2, 'rec#f1#test':0.1} | |||||
:param d: 需要展平的 dict 对象。 | |||||
:param separator: 不同层级之间的 key 之间的连接符号。 | |||||
:param compress_none_key: 如果有 key 为 None ,则忽略这一层连接。 | |||||
:param top_down: 新的 key 的是否按照从最底层往最底层的顺序连接。 | |||||
:return: | |||||
""" | |||||
assert isinstance(d, Dict) | |||||
assert isinstance(separator, str) | |||||
flat_d = {} | |||||
for key, value in d.items(): | |||||
if key is None: | |||||
key = () | |||||
else: | |||||
key = (key, ) | |||||
if isinstance(value, Mapping): | |||||
flat_d.update(_flat_nest_dict(value, parent_key=key, compress_none_key=compress_none_key)) | |||||
else: | |||||
flat_d[key] = value | |||||
str_flat_d = {} | |||||
for key, value in flat_d.items(): | |||||
if top_down: | |||||
key = map(str, key) | |||||
else: | |||||
key = map(str, key[::-1]) | |||||
key = separator.join(key) | |||||
str_flat_d[key] = value | |||||
return str_flat_d | |||||
def _flat_nest_dict(d:Mapping, parent_key:Tuple, compress_none_key:bool): | |||||
flat_d = {} | |||||
for k, v in d.items(): | |||||
_key = parent_key | |||||
if k is not None: | |||||
_key = _key + (k,) | |||||
if isinstance(v, Mapping): | |||||
_d = _flat_nest_dict(v, parent_key=_key, compress_none_key=compress_none_key) | |||||
flat_d.update(_d) | |||||
else: | |||||
flat_d[_key] = v | |||||
return flat_d |
@@ -0,0 +1,133 @@ | |||||
import pytest | |||||
from fastNLP.core.controllers.trainer import Trainer | |||||
from fastNLP.core.controllers.trainer import Evaluator | |||||
from fastNLP.core.metrics.accuracy import Accuracy | |||||
from fastNLP.core.callbacks.progress_callback import RichCallback | |||||
from fastNLP.core.dataloaders.jittor_dataloader.fdl import JittorDataLoader | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor as jt | |||||
from jittor import nn, Module | |||||
from jittor.dataset import Dataset | |||||
class JittorNormalModel_Classification(Module): | |||||
""" | |||||
基础的 Jittor 分类模型 | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(JittorNormalModel_Classification, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=64) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=64, out_features=32) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=32, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def execute(self, x): | |||||
# It's similar to forward function in Pytorch | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
return x | |||||
def train_step(self, x, y): | |||||
x = self(x) | |||||
return {"loss": self.loss_fn(x, y)} | |||||
def evaluate_step(self, x, y): | |||||
x = self(x) | |||||
return {"pred": x, "target": y.reshape((-1,))} | |||||
class JittorRandomMaxDataset(Dataset): | |||||
def __init__(self, num_samples, num_features): | |||||
super(JittorRandomMaxDataset, self).__init__() | |||||
self.x = jt.randn((num_samples, num_features)) | |||||
self.y = self.x.argmax(dim=1)[0] | |||||
def __len__(self): | |||||
return len(self.y) | |||||
def __getitem__(self, item): | |||||
return {"x": self.x[item], "y": self.y[item]} | |||||
class TrainJittorConfig: | |||||
num_labels: int = 5 | |||||
feature_dimension: int = 5 | |||||
lr = 1e-1 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@pytest.mark.parametrize("driver,device", [("jittor", None)]) | |||||
@pytest.mark.parametrize("callbacks", [[RichCallback(100)]]) | |||||
def test_trainer_jittor( | |||||
driver, | |||||
device, | |||||
callbacks, | |||||
n_epochs=3, | |||||
): | |||||
model = JittorNormalModel_Classification( | |||||
num_labels=TrainJittorConfig.num_labels, | |||||
feature_dimension=TrainJittorConfig.feature_dimension | |||||
) | |||||
optimizer = nn.SGD(model.parameters(), lr=TrainJittorConfig.lr) | |||||
train_dataloader = JittorDataLoader( | |||||
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), | |||||
batch_size=TrainJittorConfig.batch_size, | |||||
shuffle=True, | |||||
# num_workers=4, | |||||
) | |||||
val_dataloader = JittorDataLoader( | |||||
dataset=JittorRandomMaxDataset(500, TrainJittorConfig.feature_dimension), | |||||
batch_size=TrainJittorConfig.batch_size, | |||||
shuffle=True, | |||||
# num_workers=4, | |||||
) | |||||
test_dataloader = JittorDataLoader( | |||||
dataset=JittorRandomMaxDataset(1000, TrainJittorConfig.feature_dimension), | |||||
batch_size=TrainJittorConfig.batch_size, | |||||
shuffle=True, | |||||
# num_workers=4, | |||||
) | |||||
metrics = {"acc": Accuracy()} | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizer, | |||||
train_dataloader=train_dataloader, | |||||
evaluate_dataloaders=val_dataloader, | |||||
validate_every=-1, | |||||
evaluate_fn="evaluate_step", | |||||
input_mapping=None, | |||||
output_mapping=None, | |||||
metrics=metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=callbacks, | |||||
# progress_bar="rich" | |||||
) | |||||
trainer.run() | |||||
evaluator = Evaluator( | |||||
model=model, | |||||
driver=driver, | |||||
dataloaders=test_dataloader, | |||||
evaluate_fn="evaluate_step", | |||||
metrics=metrics, | |||||
) | |||||
metric_results = evaluator.run() | |||||
assert metric_results["acc#acc"] > 0.80 | |||||
if __name__ == "__main__": | |||||
# test_trainer_jittor("jittor", None, [RichCallback(100)]) | |||||
pytest.main(['test_trainer_jittor.py']) # 只运行此模块 |
@@ -174,7 +174,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( | |||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@pytest.mark.torch | @pytest.mark.torch | ||||
@pytest.mark.parametrize("driver,device", [("torch", 1)]) # ("torch", [0, 1]),("torch", 1) | |||||
@pytest.mark.parametrize("driver,device", [("torch", 'cpu')]) # ("torch", [0, 1]),("torch", 1) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_trainer_validate_every( | def test_trainer_validate_every( | ||||
model_and_optimizers: TrainerParameters, | model_and_optimizers: TrainerParameters, | ||||
@@ -234,7 +234,7 @@ def test_trainer_on( | |||||
device=device, | device=device, | ||||
optimizers=model_and_optimizers.optimizers, | optimizers=model_and_optimizers.optimizers, | ||||
train_dataloader=model_and_optimizers.train_dataloader, | train_dataloader=model_and_optimizers.train_dataloader, | ||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
evaluate_dataloaders={"dl":model_and_optimizers.evaluate_dataloaders}, | |||||
input_mapping=model_and_optimizers.input_mapping, | input_mapping=model_and_optimizers.input_mapping, | ||||
output_mapping=model_and_optimizers.output_mapping, | output_mapping=model_and_optimizers.output_mapping, | ||||
metrics=model_and_optimizers.metrics, | metrics=model_and_optimizers.metrics, | ||||