From 2423641010b6af2d21a92ade11225ac50a02ee27 Mon Sep 17 00:00:00 2001 From: yh Date: Sat, 14 May 2022 15:49:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9ETqdmProgressBar?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 1 + fastNLP/core/callbacks/__init__.py | 5 +- fastNLP/core/callbacks/progress_callback.py | 108 +++++++++++- fastNLP/core/controllers/evaluator.py | 42 +++-- fastNLP/core/controllers/trainer.py | 9 +- fastNLP/core/dataset/dataset.py | 70 ++++---- fastNLP/core/utils/__init__.py | 4 +- fastNLP/core/utils/rich_progress.py | 10 +- fastNLP/core/utils/tqdm_progress.py | 160 ++++++++++++++++++ fastNLP/core/vocabulary.py | 8 +- fastNLP/io/data_bundle.py | 32 ++-- .../torch/models/auto/configuration_auto.py | 22 +-- .../callbacks/test_more_evaluate_callback.py | 6 - .../callbacks/test_progress_callback_torch.py | 123 ++++++++++++++ tests/core/dataset/test_dataset.py | 6 +- tests/core/utils/test_progress.py | 16 ++ 16 files changed, 522 insertions(+), 100 deletions(-) create mode 100644 fastNLP/core/utils/tqdm_progress.py create mode 100644 tests/core/callbacks/test_progress_callback_torch.py create mode 100644 tests/core/utils/test_progress.py diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 343313a6..095c314c 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -6,6 +6,7 @@ __all__ = [ 'CheckpointCallback', 'ProgressCallback', 'RichCallback', + 'TqdmCallback', "LRSchedCallback", 'LoadBestModelCallback', "EarlyStopCallback", diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index caf96af7..48699b68 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -4,8 +4,11 @@ __all__ = [ 'Filter', 'CheckpointCallback', 'choose_progress_callback', + 'ProgressCallback', 'RichCallback', + 'TqdmCallback', + "LRSchedCallback", 'LoadBestModelCallback', "EarlyStopCallback", @@ -26,7 +29,7 @@ from .callback import Callback from .callback_event import Event, Filter from .callback_manager import CallbackManager from .checkpoint_callback import CheckpointCallback -from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback +from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback from .early_stop_callback import EarlyStopCallback diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 2618431f..9fab4dbd 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -5,11 +5,14 @@ from typing import Union __all__ = [ 'choose_progress_callback', 'ProgressCallback', - 'RichCallback' + 'RichCallback', + 'TqdmCallback' ] +from ...envs.imports import _module_available, _compare_version + from .has_monitor_callback import HasMonitorCallback -from fastNLP.core.utils import f_rich_progress +from fastNLP.core.utils import f_rich_progress, f_tqdm_progress from fastNLP.core.log import logger @@ -24,7 +27,7 @@ class ProgressCallback(HasMonitorCallback): def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: if progress_bar == 'auto': - if not f_rich_progress.dummy_rich: + if not f_rich_progress.dummy: progress_bar = 'rich' else: progress_bar = 'raw' @@ -32,6 +35,8 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog return RichCallback() elif progress_bar == 'raw': return RawTextCallback() + elif progress_bar == 'tqdm': + return TqdmCallback() elif isinstance(progress_bar, ProgressCallback): return progress_bar else: @@ -82,7 +87,9 @@ class RichCallback(ProgressCallback): if 'batch' in self.task2id: self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) else: - self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', total=trainer.num_batches_per_epoch) + self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', + total=trainer.num_batches_per_epoch, + completed=trainer.batch_idx_in_epoch) def on_train_epoch_end(self, trainer): self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', @@ -208,4 +215,95 @@ class RawTextCallback(ProgressCallback): @property def name(self): # progress bar的名称 - return 'raw' \ No newline at end of file + return 'raw' + + +class TqdmCallback(ProgressCallback): + """ + 在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 + 参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 + + :param print_every: 多少个 batch 更新一次显示。 + :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 + :param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 + + * 为 ``None`` + 将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 + * 为 ``str`` + 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 + 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 + * 为 ``Callable`` + 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 + 的 ``monitor`` 值请返回 ``None`` 。 + :param larger_better: 是否是 monitor 的结果越大越好。 + :param format_json: 是否格式化 json 再打印 + """ + def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, + format_json=True): + super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) + self.print_every = print_every + self.progress_bar = f_tqdm_progress + self.task2id = {} + self.loss = 0 + self.loss_round_ndigit = loss_round_ndigit + self.format_json = format_json + self.num_signs = 10 + + def on_train_begin(self, trainer): + self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, + bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', + initial=trainer.global_forward_batches/(trainer.total_batches+1e-6)) + + def on_train_epoch_begin(self, trainer): + self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) + if 'batch' in self.task2id: + self.progress_bar.reset(self.task2id['batch']) + else: + self.task2id['batch'] = self.progress_bar.add_task(description='Batch', total=trainer.num_batches_per_epoch, + initial=trainer.batch_idx_in_epoch) + self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) + + def on_train_end(self, trainer): + self.clear_tasks() + + def on_before_backward(self, trainer, outputs): + loss = trainer.extract_loss_from_outputs(outputs) + loss = trainer.driver.tensor_to_numeric(loss, reduce='sum') + self.loss += loss + + def on_train_batch_end(self, trainer): + if trainer.global_forward_batches % self.print_every == 0: + loss = self.loss/self.print_every + self.loss = 0 + self.progress_bar.update(self.task2id['batch'], advance=self.print_every, refresh=True) + self.progress_bar.set_postfix_str(self.task2id['batch'], f'Loss:{round(loss, self.loss_round_ndigit)}') + self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True) + + def on_evaluate_end(self, trainer, results): + if len(results)==0: + return + base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' + text = '' + if self.monitor is not None: + monitor_value = self.get_monitor_value(results) + if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if abs(self.monitor_value) != float('inf'): + text = '+'*self.num_signs + base_text + '+'*self.num_signs + if len(text) == 0: + text = '-'*self.num_signs + base_text + '-'*self.num_signs + + logger.info(text) + if self.format_json: + logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) + else: + logger.info(results) + + def clear_tasks(self): + for key, taskid in self.task2id.items(): + self.progress_bar.destroy_task(taskid) + self.task2id = {} + self.loss = 0 + + @property + def name(self): # progress bar的名称 + return 'tqdm' diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index fd7cb533..908c3564 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -19,7 +19,7 @@ from fastNLP.core.drivers import Driver, TorchDriver from ..drivers.choose_driver import choose_driver from .loops import Loop, EvaluateBatchLoop from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ - match_and_substitute_params, f_rich_progress, flat_nest_dict + match_and_substitute_params, f_rich_progress, flat_nest_dict, f_tqdm_progress from fastNLP.core.metrics import Metric from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader @@ -166,8 +166,9 @@ class Evaluator: self.dataloaders[name] = dl self.progress_bar = kwargs.get('progress_bar', 'auto') + assert self.progress_bar in [None, 'rich', 'auto', 'tqdm', 'raw'] if self.progress_bar == 'auto': - self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' + self.progress_bar = 'raw' if f_rich_progress.dummy else 'rich' self.driver.barrier() @@ -226,12 +227,15 @@ class Evaluator: return metric_results def start_progress_bar(self, total: int, dataloader_name): - if self.progress_bar == 'rich': + if self.progress_bar in ('rich', 'tqdm'): if dataloader_name is None: - desc = f'Eval. Batch:0' + desc = f'Eval. Batch' + else: + desc = f'Eval. on {dataloader_name} Batch' + if self.progress_bar == 'rich': + self._task_id = f_rich_progress.add_task(description=desc, total=total) else: - desc = f'Eval. on {dataloader_name} Batch:0' - self._rich_task_id = f_rich_progress.add_task(description=desc, total=total) + self._task_id = f_tqdm_progress.add_task(description=desc, total=total) elif self.progress_bar == 'raw': desc = 'Evaluation starts' if dataloader_name is not None: @@ -244,19 +248,26 @@ class Evaluator: else: desc = f'Eval. on {dataloader_name} Batch:{batch_idx}' if self.progress_bar == 'rich': - assert hasattr(self, '_rich_task_id'), "You must first call `start_progress_bar()` before calling " \ + assert hasattr(self, '_task_id'), "You must first call `start_progress_bar()` before calling " \ "update_progress_bar()" - f_rich_progress.update(self._rich_task_id, description=desc, post_desc=kwargs.get('post_desc', ''), + f_rich_progress.update(self._task_id, description=desc, post_desc=kwargs.get('post_desc', ''), advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), visible=kwargs.get('visible', True)) elif self.progress_bar == 'raw': if self.verbose > 1: logger.info(desc) + elif self.progress_bar == 'tqdm': + f_tqdm_progress.update(self._task_id, advance=1) def remove_progress_bar(self, dataloader_name): - if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): - f_rich_progress.destroy_task(self._rich_task_id) - delattr(self, '_rich_task_id') + if self.progress_bar == 'rich' and hasattr(self, '_task_id'): + f_rich_progress.destroy_task(self._task_id) + delattr(self, '_task_id') + + elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'): + f_tqdm_progress.destroy_task(self._task_id) + delattr(self, '_task_id') + elif self.progress_bar == 'raw': desc = 'Evaluation ends' if dataloader_name is not None: @@ -264,9 +275,12 @@ class Evaluator: logger.info("*" * 10 + desc + '*' * 10 + '\n') def finally_progress_bar(self): - if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): - f_rich_progress.destroy_task(self._rich_task_id) - delattr(self, '_rich_task_id') + if self.progress_bar == 'rich' and hasattr(self, '_task_id'): + f_rich_progress.destroy_task(self._task_id) + delattr(self, '_task_id') + elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'): + f_tqdm_progress.destroy_task(self._task_id) + delattr(self, '_task_id') @property def evaluate_batch_loop(self): diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 44a74c69..df6bf176 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -294,9 +294,9 @@ class Trainer(TrainerEventTrigger): log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; - * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, - 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 - 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 + * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 RichCallback, RawTextCallback等对象, + 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback 对象。如果 + 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 等对象。 * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 @@ -573,7 +573,7 @@ class Trainer(TrainerEventTrigger): if resume_from is not None: if os.path.exists(resume_from): - self.load(resume_from, resume_training=resume_training) + self.load_checkpoint(resume_from, resume_training=resume_training) else: raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") @@ -732,6 +732,7 @@ class Trainer(TrainerEventTrigger): Trainer.__init__(): on_after_trainer_initialized(trainer, driver) Trainer.run(): + # load checkpoint if resume_from is not None if num_eval_sanity_batch>0: on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch on_sanity_check_end(trainer, sanity_check_res) diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index c592984f..6bec175b 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -19,10 +19,17 @@ from .field import FieldArray from .instance import Instance from fastNLP.core.utils.utils import pretty_table_printer, deprecated from fastNLP.core.collators import Collator -from fastNLP.core.utils.rich_progress import f_rich_progress +from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress +from fastNLP.core.utils.tqdm_progress import f_tqdm_progress from ..log import logger +progress_bars = { + 'rich': f_rich_progress, + 'tqdm': f_tqdm_progress +} + + class ApplyResultException(Exception): def __init__(self, msg, index=None): super().__init__(msg) @@ -30,7 +37,7 @@ class ApplyResultException(Exception): self.index = index # 标示在哪个数据遭遇到问题了 -def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, +def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, progress_bar: str = 'rich', desc: str = None) -> list: """ 对数据集进行处理封装函数,以便多进程使用 @@ -39,32 +46,29 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s :param _apply_field: 需要处理数据集的field_name :param func: 用户自定义的func :param desc: 进度条的描述字符 - :param show_progress_bar: 是否展示子进程进度条 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :return: """ - if show_progress_bar: - desc = desc if desc else f"Main" - pg_main = f_rich_progress.add_task(description=desc, total=len(ds), visible=show_progress_bar) + progress_bar = progress_bars.get(progress_bar, DummyFRichProgress()) + desc = desc if desc else "Processing" + task_id = progress_bar.add_task(description=desc, total=len(ds)) results = [] idx = -1 try: - # for idx, ins in tqdm(enumerate(ds), total=len(ds), position=0, desc=desc, disable=not show_progress_bar): for idx, ins in enumerate(ds): if _apply_field is not None: results.append(func(ins[_apply_field])) else: results.append(func(ins)) - if show_progress_bar: - f_rich_progress.update(pg_main, advance=1) + progress_bar.update(task_id, advance=1) except BaseException as e: if idx != -1: logger.error("Exception happens at the `{}`th instance.".format(idx)) raise e finally: - if show_progress_bar: - f_rich_progress.destroy_task(pg_main) + progress_bar.destroy_task(task_id) return results @@ -398,7 +402,7 @@ class DataSet: def apply_field(self, func: Callable, field_name: str = None, new_field_name: str = None, num_proc: int = 0, - progress_desc: str = None, show_progress_bar: bool = True): + progress_desc: str = None, progress_bar: str = 'rich'): r""" 将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name`` 中。 @@ -413,8 +417,8 @@ class DataSet: 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。 - :param progress_desc: 进度条的描述字符,默认为 ``Main``; - :param show_progress_bar: 是否在处理过程中展示进度条; + :param progress_desc: 进度条的描述字符,默认为 ``Processing``; + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :return: 从函数 ``func`` 中得到的返回值; """ assert len(self) != 0, "Null DataSet cannot use apply_field()." @@ -422,7 +426,7 @@ class DataSet: raise KeyError("DataSet has no field named `{}`.".format(field_name)) try: - results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, + results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar, progress_desc=progress_desc, _apply_field=field_name) except BaseException as e: raise e @@ -433,7 +437,7 @@ class DataSet: def apply_field_more(self, func: Callable = None, field_name: str = None, modify_fields: bool = True, num_proc: int = 0, - progress_desc: str = None, show_progress_bar: bool = True): + progress_desc: str = None, progress_bar: str = 'rich'): r""" 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 func 可以返回一个或多个 field 上的结果。 @@ -446,8 +450,8 @@ class DataSet: :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param show_progress_bar: 是否显示进度条,默认展示 - :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 :return Dict[str:Field]: 返回一个字典 """ assert len(self) != 0, "Null DataSet cannot use apply_field()." @@ -456,7 +460,7 @@ class DataSet: idx = -1 results = {} apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, - show_progress_bar=show_progress_bar, _apply_field=field_name) + progress_bar=progress_bar, _apply_field=field_name) # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 if not isinstance(apply_out[0], Mapping): raise Exception(f"The result of func is not a Mapping, but a {type(apply_out[0])}") @@ -483,13 +487,13 @@ class DataSet: return results def _apply_process(self, num_proc: int = 0, func: Callable = None, - show_progress_bar: bool = True, _apply_field: str = None, + progress_bar: str = 'rich', _apply_field: str = None, progress_desc: str = 'Main') -> list: """ :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` :param _apply_field: 需要传进去func的数据集的field_name - :param show_progress_bar: 是否展示progress进度条,默认为展示 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :param progress_desc: 进度条的描述字符,默认为'Main """ if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "": @@ -499,7 +503,7 @@ class DataSet: if num_proc < 2: results = _apply_single(ds=self, _apply_field=_apply_field, func=func, - desc=progress_desc, show_progress_bar=show_progress_bar) + desc=progress_desc, progress_bar=progress_bar) else: # TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 import multiprocessing as mp @@ -525,25 +529,25 @@ class DataSet: proc.start() pool.append(proc) queues.append(queue) - + progress_bar = progress_bars.get(progress_bar, DummyFRichProgress()) total_len = len(self) - task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar) + task_id = progress_bar.add_task(description=progress_desc, total=total_len) last_count = -1 while counter.value < total_len or last_count == -1: while counter.value == last_count: time.sleep(0.1) advance = counter.value - last_count last_count = counter.value - f_rich_progress.update(task_id, advance=advance, refresh=True) + progress_bar.update(task_id, advance=advance, refresh=True) for idx, proc in enumerate(pool): results.extend(pickle.loads(queues[idx].get())) proc.join() - f_rich_progress.destroy_task(task_id) + progress_bar.destroy_task(task_id) return results def apply_more(self, func: Callable = None, modify_fields: bool = True, - num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True): + num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): r""" 将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 @@ -558,9 +562,9 @@ class DataSet: :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 - :param num_proc: 进程的数量 :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 + :param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :return Dict[str:Field]: 返回一个字典 """ assert callable(func), "The func is not callable." @@ -570,7 +574,7 @@ class DataSet: results = {} apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, - show_progress_bar=show_progress_bar) + progress_bar=progress_bar) # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 if not isinstance(apply_out[0], dict): raise Exception("The result of func is not a dict") @@ -597,21 +601,21 @@ class DataSet: return results def apply(self, func: Callable = None, new_field_name: str = None, - num_proc: int = 0, show_progress_bar: bool = True, progress_desc: str = ''): + num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''): """ :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param show_progress_bar: 是否显示进度条。 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 :param progress_desc: progress bar 显示的值,默认为空。 """ assert callable(func), "The func you provide is not callable." assert len(self) != 0, "Null DataSet cannot use apply()." assert num_proc >= 0, "num_proc must be an integer >= 0." try: - results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, + results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar, progress_desc=progress_desc) except BaseException as e: raise e diff --git a/fastNLP/core/utils/__init__.py b/fastNLP/core/utils/__init__.py index 6c65c8a5..62b4cb7e 100644 --- a/fastNLP/core/utils/__init__.py +++ b/fastNLP/core/utils/__init__.py @@ -22,7 +22,8 @@ __all__ = [ 'Option', 'deprecated', 'seq_len_to_mask', - "flat_nest_dict" + "flat_nest_dict", + "f_tqdm_progress" ] from .cache_results import cache_results @@ -32,5 +33,6 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi from .rich_progress import f_rich_progress from .torch_utils import torch_move_data_to_device from .utils import * +from .tqdm_progress import f_tqdm_progress diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 432efd85..0e6d5a01 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -35,7 +35,7 @@ class DummyFRichProgress: return None @property - def dummy_rich(self)->bool: + def dummy(self)->bool: """ 当前对象是否是 dummy 的 rich 对象。 @@ -122,6 +122,9 @@ class FRichProgress(Progress, metaclass=Singleton): visible: bool = True, **fields: Any, ) -> TaskID: + from .tqdm_progress import f_tqdm_progress + assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop." + if self.live._started is False: self.start() post_desc = fields.pop('post_desc', '') @@ -213,7 +216,7 @@ class FRichProgress(Progress, metaclass=Singleton): self.refresh() @property - def dummy_rich(self) -> bool: + def dummy(self) -> bool: """ 当前对象是否是 dummy 的 rich 对象。 @@ -221,6 +224,9 @@ class FRichProgress(Progress, metaclass=Singleton): """ return False + def not_empty(self): + return len(self._tasks) != 0 + class SpeedColumn(ProgressColumn): """ diff --git a/fastNLP/core/utils/tqdm_progress.py b/fastNLP/core/utils/tqdm_progress.py new file mode 100644 index 00000000..9fcfac94 --- /dev/null +++ b/fastNLP/core/utils/tqdm_progress.py @@ -0,0 +1,160 @@ +__all__ = [ + 'f_tqdm_progress' +] + +import uuid +import sys +from ...envs.imports import _module_available, _compare_version +from ...envs import get_global_rank +from .utils import is_notebook +from ..log import logger +if _module_available('tqdm'): + from tqdm.autonotebook import tqdm +import operator + + + +class Singleton(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +# 如果不打印的时候,使得整个 progress 没有任何意义 +class DummyFTqdmProgress: + def __getattr__(self, item): + return DummyFTqdmProgress() + + def __call__(self, *args, **kwargs): + # 防止用户通过 DummyFRichProgress.console.print() 这种调用 + return None + + @property + def dummy(self)->bool: + """ + 当前对象是否是 dummy 的 tqdm 对象。 + + :return: + """ + return True + + +class TqdmProgress(metaclass=Singleton): + def __init__(self): + self.bars = {} + + def add_task(self, iterable=None, description=None, total=None, leave=False, + ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None, + ascii=None, visible=True, unit='it', unit_scale=False, + dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0, + postfix=None, unit_divisor=1000, write_bytes=None, + lock_args=None, nrows=None, colour=None, gui=False, **kwargs): + """ + 主要就模仿了 tqdm bar 的创建,为了和 FRichProgress 的接口尽量统一,将 desc 重名为了 description,以及 disable 专为了 + visible 。 + + :param iterable: + :param description: + :param total: + :param leave: + :param ncols: + :param mininterval: + :param maxinterval: + :param miniters: + :param ascii: + :param visible: + :param unit: + :param unit_scale: + :param dynamic_ncols: + :param smoothing: + :param bar_format: + :param initial: + :param postfix: + :param unit_divisor: + :param write_bytes: + :param lock_args: + :param nrows: + :param colour: + :param gui: + :param kwargs: + :return: + """ + assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ + f"To use {self.__class__.__name__}, tqdm>=4.57 is needed." + + from .rich_progress import f_rich_progress + assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." + + if hasattr(self, 'orig_out_err'): + file = self.orig_out_err[0] + else: + file = sys.stdout + + bar = tqdm(iterable=iterable, desc=description, total=total, leave=leave, file=file, + ncols=ncols, mininterval=mininterval, maxinterval=maxinterval, miniters=miniters, + ascii=ascii, disable=not visible, unit=unit, unit_scale=unit_scale, + dynamic_ncols=dynamic_ncols, smoothing=smoothing, bar_format=bar_format, initial=initial, + position=len(self.bars), postfix=postfix, unit_divisor=unit_divisor, write_bytes=write_bytes, + lock_args=lock_args, nrows=nrows, colour=colour, gui=gui, **kwargs) + _uuid = str(uuid.uuid1()) + self.bars[_uuid] = bar + if not hasattr(self, 'orig_out_err') and not is_notebook(): + from tqdm.contrib import DummyTqdmFile + self.orig_out_err = sys.stdout, sys.stderr + sys.stdout, sys.stderr = map(DummyTqdmFile, self.orig_out_err) + + return _uuid + + def update(self, task_id:str, advance:int, refresh=True): + self.bars[task_id].update(advance) + + def set_postfix_str(self, task_id, s, refresh=True): + self.bars[task_id].set_postfix_str(s=s, refresh=refresh) + + def set_description_str(self, task_id, desc, refresh=True): + self.bars[task_id].set_description_str(desc=desc, refresh=refresh) + + def destroy_task(self, task_id): + """ + 关闭 task_id 对应的 tqdm bar 。 + + :param task_id: + :return: + """ + self.bars[task_id].close() + self.bars.pop(task_id) + if len(self.bars) == 0 and hasattr(self, 'orig_out_err'): + # recover 成正常的 sys.stdout 与 sys.stderr + sys.stdout, sys.stderr = self.orig_out_err + delattr(self, 'orig_out_err') + + def reset(self, task_id): + self.bars[task_id].reset() + + def print(self): + tqdm.write('') + + def not_empty(self): + return len(self.bars) != 0 + + @property + def dummy(self) -> bool: + """ + 当前对象是否是 dummy 的 tqdm 对象。 + + :return: + """ + return False + + +if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and get_global_rank() == 0: + f_tqdm_progress = TqdmProgress() +else: + f_tqdm_progress = DummyFTqdmProgress() + logger.debug("Use dummy tqdm...") + + + diff --git a/fastNLP/core/vocabulary.py b/fastNLP/core/vocabulary.py index a8fd11d9..7cb281c2 100644 --- a/fastNLP/core/vocabulary.py +++ b/fastNLP/core/vocabulary.py @@ -340,7 +340,7 @@ class Vocabulary(object): try: for f_n, n_f_n in zip(field_name, new_field_name): dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n, - show_progress_bar=False) + progress_bar=None) except Exception as e: logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) raise e @@ -396,7 +396,7 @@ class Vocabulary(object): for idx, dataset in enumerate(datasets): if isinstance(dataset, DataSet): try: - dataset.apply(construct_vocab, show_progress_bar=False) + dataset.apply(construct_vocab, progress_bar=None) except BaseException as e: logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) raise e @@ -406,12 +406,12 @@ class Vocabulary(object): if no_create_entry_dataset is not None: partial_construct_vocab = partial(construct_vocab, no_create_entry=True) if isinstance(no_create_entry_dataset, DataSet): - no_create_entry_dataset.apply(partial_construct_vocab, show_progress_bar=False) + no_create_entry_dataset.apply(partial_construct_vocab, progress_bar=None) elif isinstance(no_create_entry_dataset, list): for dataset in no_create_entry_dataset: if not isinstance(dataset, DataSet): raise TypeError("Only DataSet type is allowed.") - dataset.apply(partial_construct_vocab, show_progress_bar=False) + dataset.apply(partial_construct_vocab, progress_bar=None) return self def _is_word_no_create_entry(self, word:str): diff --git a/fastNLP/io/data_bundle.py b/fastNLP/io/data_bundle.py index df194df2..81ff3b84 100644 --- a/fastNLP/io/data_bundle.py +++ b/fastNLP/io/data_bundle.py @@ -221,7 +221,7 @@ class DataBundle: yield field_name, vocab def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, - ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True): + ignore_miss_dataset: bool = True, progress_desc: str = '', progress_bar: str = 'rich'): r""" 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 @@ -233,8 +233,8 @@ class DataBundle: 如果为False,则报错 :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 - :param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 - :param show_progress_bar: 是否显示tqdm进度条 + :param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 """ _progress_desc = progress_desc @@ -243,13 +243,13 @@ class DataBundle: progress_desc = _progress_desc + f' for `{name}`' if dataset.has_field(field_name=field_name): dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, - progress_desc=progress_desc, show_progress_bar=show_progress_bar) + progress_desc=progress_desc, progress_bar=progress_bar) elif not ignore_miss_dataset: raise KeyError(f"{field_name} not found DataSet:{name}.") return self def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, - ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''): + ignore_miss_dataset=True, progress_bar: str = 'rich', progress_desc: str = ''): r""" 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 @@ -263,8 +263,8 @@ class DataBundle: :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; 如果为False,则报错 - :param show_progress_bar: 是否显示进度条 - :param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 @@ -277,13 +277,13 @@ class DataBundle: if dataset.has_field(field_name=field_name): res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, modify_fields=modify_fields, - show_progress_bar=show_progress_bar, progress_desc=progress_desc) + progress_bar=progress_bar, progress_desc=progress_desc) elif not ignore_miss_dataset: raise KeyError(f"{field_name} not found DataSet:{name} .") return res def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, - progress_desc: str = '', show_progress_bar: bool = True): + progress_desc: str = '', progress_bar: bool = True): r""" 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 @@ -293,20 +293,20 @@ class DataBundle: :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 盖之前的field。如果为None则不创建新的field。 :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param show_progress_bar: 是否显示tqd进度条 - :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称 """ _progress_desc = progress_desc for name, dataset in self.datasets.items(): if _progress_desc: progress_desc = _progress_desc + f' for `{name}`' - dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, + dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, progress_desc=progress_desc) return self def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, - progress_desc: str = '', show_progress_bar: bool = True): + progress_desc: str = '', progress_bar: str = 'rich'): r""" 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 @@ -317,8 +317,8 @@ class DataBundle: :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 - :param show_progress_bar: 是否显示tqd进度条 - :param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 + :param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 + :param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称 :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 """ @@ -328,7 +328,7 @@ class DataBundle: if _progress_desc: progress_desc = _progress_desc + f' for `{name}`' res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, - show_progress_bar=show_progress_bar, progress_desc=progress_desc) + progress_bar=progress_bar, progress_desc=progress_desc) return res def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": diff --git a/fastNLP/transformers/torch/models/auto/configuration_auto.py b/fastNLP/transformers/torch/models/auto/configuration_auto.py index bcd7576c..0138aec7 100644 --- a/fastNLP/transformers/torch/models/auto/configuration_auto.py +++ b/fastNLP/transformers/torch/models/auto/configuration_auto.py @@ -279,7 +279,7 @@ class _LazyConfigMapping(OrderedDict): value = self._mapping[key] module_name = model_type_to_module_name(key) if module_name not in self._modules: - self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") + self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") return getattr(self._modules[module_name], value) def keys(self): @@ -318,15 +318,15 @@ class _LazyLoadAllMappings(OrderedDict): def _initialize(self): if self._initialized: return - logger.warn( - "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " - "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", - FutureWarning, - ) + # logger.warn( + # "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " + # "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", + # FutureWarning, + # ) for model_type, map_name in self._mapping.items(): module_name = model_type_to_module_name(model_type) - module = importlib.import_module(f".{module_name}", "transformers.models") + module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") mapping = getattr(module, map_name) self._data.update(mapping) @@ -362,8 +362,8 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPP def _get_class_name(model_class: Union[str, List[str]]): if isinstance(model_class, (list, tuple)): - return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None]) - return f":class:`~transformers.{model_class}`" + return " or ".join([f":class:`~fastNLP.transformers.torch.{c}`" for c in model_class if c is not None]) + return f":class:`~fastNLP.transformers.torch.{model_class}`" def _list_model_options(indent, config_to_class=None, use_model_types=True): @@ -372,7 +372,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): if use_model_types: if config_to_class is None: model_type_to_name = { - model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items() + model_type: f":class:`~fastNLP.transformers.torch.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items() } else: model_type_to_name = { @@ -394,7 +394,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() } lines = [ - f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" + f"{indent}- :class:`~fastNLP.transformers.torch.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" for config_name in sorted(config_to_name.keys()) ] return "\n".join(lines) diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 925be172..4fd9d0d3 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -4,17 +4,12 @@ (2) 能不能保存 topk 并load进来进行训练 """ -import pytest - - - import os import pytest from typing import Any from dataclasses import dataclass from pathlib import Path -import re from fastNLP.core.controllers.trainer import Trainer from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK @@ -25,7 +20,6 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy from fastNLP.core.metrics import Metric -from fastNLP.core.log import logger from fastNLP.core.callbacks import MoreEvaluateCallback from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: diff --git a/tests/core/callbacks/test_progress_callback_torch.py b/tests/core/callbacks/test_progress_callback_torch.py new file mode 100644 index 00000000..75d3dbda --- /dev/null +++ b/tests/core/callbacks/test_progress_callback_torch.py @@ -0,0 +1,123 @@ +from typing import Any +from dataclasses import dataclass + +import pytest + +from fastNLP import Metric, Accuracy +from tests.helpers.utils import magic_argv_env_context +from fastNLP import Trainer, Evaluator +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +if _NEED_IMPORT_TORCH: + from torch.utils.data import DataLoader + from torch.optim import SGD + import torch.distributed as dist + import torch + +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchArgMaxDataset + + +@dataclass +class ArgMaxDatasetConfig: + num_labels: int = 10 + feature_dimension: int = 10 + data_num: int = 100 + seed: int = 0 + + batch_size: int = 4 + shuffle: bool = True + + +@dataclass +class TrainerParameters: + model: Any = None + optimizers: Any = None + train_dataloader: Any = None + evaluate_dataloaders: Any = None + input_mapping: Any = None + output_mapping: Any = None + metrics: Any = None + more_metrics: Any = None + + +@pytest.fixture(scope="module", params=[0], autouse=True) +def model_and_optimizers(request): + trainer_params = TrainerParameters() + + trainer_params.model = TorchNormalModel_Classification_1( + num_labels=ArgMaxDatasetConfig.num_labels, + feature_dimension=ArgMaxDatasetConfig.feature_dimension + ) + trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) + dataset = TorchArgMaxDataset( + feature_dimension=ArgMaxDatasetConfig.feature_dimension, + data_num=ArgMaxDatasetConfig.data_num, + seed=ArgMaxDatasetConfig.seed + ) + _dataloader = DataLoader( + dataset=dataset, + batch_size=ArgMaxDatasetConfig.batch_size, + shuffle=True + ) + + class LossMetric(Metric): + def __init__(self): + super().__init__() + self.register_element('loss') + + def update(self, loss): + self.loss += loss.item() + + def get_metric(self) -> dict: + return self.loss.item() + + trainer_params.train_dataloader = _dataloader + trainer_params.evaluate_dataloaders = _dataloader + trainer_params.metrics = {'loss': LossMetric()} + + trainer_params.more_metrics = {"acc": Accuracy()} + + return trainer_params + + +@pytest.mark.torch +@pytest.mark.parametrize('device', ['cpu', [0, 1]]) +@pytest.mark.parametrize('progress_bar', ['rich', 'auto', None, 'raw', 'tqdm']) +@magic_argv_env_context +def test_run( model_and_optimizers: TrainerParameters, device, progress_bar): + + if device != 'cpu' and not torch.cuda.is_available(): + pytest.skip(f"No cuda for device:{device}") + n_epochs = 5 + trainer = Trainer( + model=model_and_optimizers.model, + driver='torch', + device=device, + optimizers=model_and_optimizers.optimizers, + train_dataloader=model_and_optimizers.train_dataloader, + evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, + input_mapping=model_and_optimizers.input_mapping, + output_mapping=model_and_optimizers.output_mapping, + metrics=model_and_optimizers.metrics, + n_epochs=n_epochs, + callbacks=None, + progress_bar=progress_bar, + output_from_new_proc="all", + evaluate_fn='train_step', + larger_better=False + ) + + trainer.run() + + evaluator = Evaluator(model=model_and_optimizers.model, dataloaders=model_and_optimizers.train_dataloader, + driver=trainer.driver, metrics=model_and_optimizers.metrics, + progress_bar=progress_bar, evaluate_fn='train_step') + evaluator.run() + + if dist.is_initialized(): + dist.destroy_process_group() + + + + + diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index ded60465..95b2d17c 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -181,7 +181,7 @@ class TestDataSetMethods: assert ("rx" in ds.field_arrays) == True assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1] - ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) + ds.apply(lambda ins: len(ins["y"]), new_field_name="y", progress_bar=None) assert ds.field_arrays["y"].content[0] == 2 res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") @@ -198,8 +198,8 @@ class TestDataSetMethods: def do_nothing(ins): time.sleep(0.01) - ds.apply(do_nothing, show_progress_bar=True, num_proc=0) - ds.apply_field(do_nothing, field_name='x', show_progress_bar=True) + ds.apply(do_nothing, progress_bar='rich', num_proc=0) + ds.apply_field(do_nothing, field_name='x', progress_bar='rich') def test_apply_cannot_modify_instance(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) diff --git a/tests/core/utils/test_progress.py b/tests/core/utils/test_progress.py new file mode 100644 index 00000000..ac81a839 --- /dev/null +++ b/tests/core/utils/test_progress.py @@ -0,0 +1,16 @@ +import pytest +from fastNLP.envs.imports import _module_available +from fastNLP.core.utils import f_tqdm_progress, f_rich_progress + +def test_raise(): + if not _module_available('tqdm') or f_rich_progress.dummy or f_tqdm_progress.dummy: + pytest.skip('No tqdm') + t = f_rich_progress.add_task('test', total=10) + with pytest.raises(AssertionError): + f_tqdm_progress.add_task('test') + + f_rich_progress.destroy_task(t) + + t = f_tqdm_progress.add_task('test', total=10) + with pytest.raises(AssertionError): + f_rich_progress.add_task('test') \ No newline at end of file