@@ -6,6 +6,7 @@ __all__ = [ | |||||
'CheckpointCallback', | 'CheckpointCallback', | ||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
'TqdmCallback', | |||||
"LRSchedCallback", | "LRSchedCallback", | ||||
'LoadBestModelCallback', | 'LoadBestModelCallback', | ||||
"EarlyStopCallback", | "EarlyStopCallback", | ||||
@@ -4,8 +4,11 @@ __all__ = [ | |||||
'Filter', | 'Filter', | ||||
'CheckpointCallback', | 'CheckpointCallback', | ||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback', | 'RichCallback', | ||||
'TqdmCallback', | |||||
"LRSchedCallback", | "LRSchedCallback", | ||||
'LoadBestModelCallback', | 'LoadBestModelCallback', | ||||
"EarlyStopCallback", | "EarlyStopCallback", | ||||
@@ -26,7 +29,7 @@ from .callback import Callback | |||||
from .callback_event import Event, Filter | from .callback_event import Event, Filter | ||||
from .callback_manager import CallbackManager | from .callback_manager import CallbackManager | ||||
from .checkpoint_callback import CheckpointCallback | from .checkpoint_callback import CheckpointCallback | ||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | |||||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback | |||||
from .lr_scheduler_callback import LRSchedCallback | from .lr_scheduler_callback import LRSchedCallback | ||||
from .load_best_model_callback import LoadBestModelCallback | from .load_best_model_callback import LoadBestModelCallback | ||||
from .early_stop_callback import EarlyStopCallback | from .early_stop_callback import EarlyStopCallback | ||||
@@ -5,11 +5,14 @@ from typing import Union | |||||
__all__ = [ | __all__ = [ | ||||
'choose_progress_callback', | 'choose_progress_callback', | ||||
'ProgressCallback', | 'ProgressCallback', | ||||
'RichCallback' | |||||
'RichCallback', | |||||
'TqdmCallback' | |||||
] | ] | ||||
from ...envs.imports import _module_available, _compare_version | |||||
from .has_monitor_callback import HasMonitorCallback | from .has_monitor_callback import HasMonitorCallback | ||||
from fastNLP.core.utils import f_rich_progress | |||||
from fastNLP.core.utils import f_rich_progress, f_tqdm_progress | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -24,7 +27,7 @@ class ProgressCallback(HasMonitorCallback): | |||||
def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: | def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: | ||||
if progress_bar == 'auto': | if progress_bar == 'auto': | ||||
if not f_rich_progress.dummy_rich: | |||||
if not f_rich_progress.dummy: | |||||
progress_bar = 'rich' | progress_bar = 'rich' | ||||
else: | else: | ||||
progress_bar = 'raw' | progress_bar = 'raw' | ||||
@@ -32,6 +35,8 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog | |||||
return RichCallback() | return RichCallback() | ||||
elif progress_bar == 'raw': | elif progress_bar == 'raw': | ||||
return RawTextCallback() | return RawTextCallback() | ||||
elif progress_bar == 'tqdm': | |||||
return TqdmCallback() | |||||
elif isinstance(progress_bar, ProgressCallback): | elif isinstance(progress_bar, ProgressCallback): | ||||
return progress_bar | return progress_bar | ||||
else: | else: | ||||
@@ -82,7 +87,9 @@ class RichCallback(ProgressCallback): | |||||
if 'batch' in self.task2id: | if 'batch' in self.task2id: | ||||
self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) | self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) | ||||
else: | else: | ||||
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', total=trainer.num_batches_per_epoch) | |||||
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', | |||||
total=trainer.num_batches_per_epoch, | |||||
completed=trainer.batch_idx_in_epoch) | |||||
def on_train_epoch_end(self, trainer): | def on_train_epoch_end(self, trainer): | ||||
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | ||||
@@ -208,4 +215,95 @@ class RawTextCallback(ProgressCallback): | |||||
@property | @property | ||||
def name(self): # progress bar的名称 | def name(self): # progress bar的名称 | ||||
return 'raw' | |||||
return 'raw' | |||||
class TqdmCallback(ProgressCallback): | |||||
""" | |||||
在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||||
:param print_every: 多少个 batch 更新一次显示。 | |||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||||
:param format_json: 是否格式化 json 再打印 | |||||
""" | |||||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||||
format_json=True): | |||||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||||
self.print_every = print_every | |||||
self.progress_bar = f_tqdm_progress | |||||
self.task2id = {} | |||||
self.loss = 0 | |||||
self.loss_round_ndigit = loss_round_ndigit | |||||
self.format_json = format_json | |||||
self.num_signs = 10 | |||||
def on_train_begin(self, trainer): | |||||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | |||||
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', | |||||
initial=trainer.global_forward_batches/(trainer.total_batches+1e-6)) | |||||
def on_train_epoch_begin(self, trainer): | |||||
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | |||||
if 'batch' in self.task2id: | |||||
self.progress_bar.reset(self.task2id['batch']) | |||||
else: | |||||
self.task2id['batch'] = self.progress_bar.add_task(description='Batch', total=trainer.num_batches_per_epoch, | |||||
initial=trainer.batch_idx_in_epoch) | |||||
self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | |||||
def on_train_end(self, trainer): | |||||
self.clear_tasks() | |||||
def on_before_backward(self, trainer, outputs): | |||||
loss = trainer.extract_loss_from_outputs(outputs) | |||||
loss = trainer.driver.tensor_to_numeric(loss, reduce='sum') | |||||
self.loss += loss | |||||
def on_train_batch_end(self, trainer): | |||||
if trainer.global_forward_batches % self.print_every == 0: | |||||
loss = self.loss/self.print_every | |||||
self.loss = 0 | |||||
self.progress_bar.update(self.task2id['batch'], advance=self.print_every, refresh=True) | |||||
self.progress_bar.set_postfix_str(self.task2id['batch'], f'Loss:{round(loss, self.loss_round_ndigit)}') | |||||
self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True) | |||||
def on_evaluate_end(self, trainer, results): | |||||
if len(results)==0: | |||||
return | |||||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||||
text = '' | |||||
if self.monitor is not None: | |||||
monitor_value = self.get_monitor_value(results) | |||||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||||
if abs(self.monitor_value) != float('inf'): | |||||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||||
if len(text) == 0: | |||||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||||
logger.info(text) | |||||
if self.format_json: | |||||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||||
else: | |||||
logger.info(results) | |||||
def clear_tasks(self): | |||||
for key, taskid in self.task2id.items(): | |||||
self.progress_bar.destroy_task(taskid) | |||||
self.task2id = {} | |||||
self.loss = 0 | |||||
@property | |||||
def name(self): # progress bar的名称 | |||||
return 'tqdm' |
@@ -19,7 +19,7 @@ from fastNLP.core.drivers import Driver, TorchDriver | |||||
from ..drivers.choose_driver import choose_driver | from ..drivers.choose_driver import choose_driver | ||||
from .loops import Loop, EvaluateBatchLoop | from .loops import Loop, EvaluateBatchLoop | ||||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | ||||
match_and_substitute_params, f_rich_progress, flat_nest_dict | |||||
match_and_substitute_params, f_rich_progress, flat_nest_dict, f_tqdm_progress | |||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | ||||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | ||||
@@ -166,8 +166,9 @@ class Evaluator: | |||||
self.dataloaders[name] = dl | self.dataloaders[name] = dl | ||||
self.progress_bar = kwargs.get('progress_bar', 'auto') | self.progress_bar = kwargs.get('progress_bar', 'auto') | ||||
assert self.progress_bar in [None, 'rich', 'auto', 'tqdm', 'raw'] | |||||
if self.progress_bar == 'auto': | if self.progress_bar == 'auto': | ||||
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' | |||||
self.progress_bar = 'raw' if f_rich_progress.dummy else 'rich' | |||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -226,12 +227,15 @@ class Evaluator: | |||||
return metric_results | return metric_results | ||||
def start_progress_bar(self, total: int, dataloader_name): | def start_progress_bar(self, total: int, dataloader_name): | ||||
if self.progress_bar == 'rich': | |||||
if self.progress_bar in ('rich', 'tqdm'): | |||||
if dataloader_name is None: | if dataloader_name is None: | ||||
desc = f'Eval. Batch:0' | |||||
desc = f'Eval. Batch' | |||||
else: | |||||
desc = f'Eval. on {dataloader_name} Batch' | |||||
if self.progress_bar == 'rich': | |||||
self._task_id = f_rich_progress.add_task(description=desc, total=total) | |||||
else: | else: | ||||
desc = f'Eval. on {dataloader_name} Batch:0' | |||||
self._rich_task_id = f_rich_progress.add_task(description=desc, total=total) | |||||
self._task_id = f_tqdm_progress.add_task(description=desc, total=total) | |||||
elif self.progress_bar == 'raw': | elif self.progress_bar == 'raw': | ||||
desc = 'Evaluation starts' | desc = 'Evaluation starts' | ||||
if dataloader_name is not None: | if dataloader_name is not None: | ||||
@@ -244,19 +248,26 @@ class Evaluator: | |||||
else: | else: | ||||
desc = f'Eval. on {dataloader_name} Batch:{batch_idx}' | desc = f'Eval. on {dataloader_name} Batch:{batch_idx}' | ||||
if self.progress_bar == 'rich': | if self.progress_bar == 'rich': | ||||
assert hasattr(self, '_rich_task_id'), "You must first call `start_progress_bar()` before calling " \ | |||||
assert hasattr(self, '_task_id'), "You must first call `start_progress_bar()` before calling " \ | |||||
"update_progress_bar()" | "update_progress_bar()" | ||||
f_rich_progress.update(self._rich_task_id, description=desc, post_desc=kwargs.get('post_desc', ''), | |||||
f_rich_progress.update(self._task_id, description=desc, post_desc=kwargs.get('post_desc', ''), | |||||
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), | advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), | ||||
visible=kwargs.get('visible', True)) | visible=kwargs.get('visible', True)) | ||||
elif self.progress_bar == 'raw': | elif self.progress_bar == 'raw': | ||||
if self.verbose > 1: | if self.verbose > 1: | ||||
logger.info(desc) | logger.info(desc) | ||||
elif self.progress_bar == 'tqdm': | |||||
f_tqdm_progress.update(self._task_id, advance=1) | |||||
def remove_progress_bar(self, dataloader_name): | def remove_progress_bar(self, dataloader_name): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||||
f_rich_progress.destroy_task(self._rich_task_id) | |||||
delattr(self, '_rich_task_id') | |||||
if self.progress_bar == 'rich' and hasattr(self, '_task_id'): | |||||
f_rich_progress.destroy_task(self._task_id) | |||||
delattr(self, '_task_id') | |||||
elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'): | |||||
f_tqdm_progress.destroy_task(self._task_id) | |||||
delattr(self, '_task_id') | |||||
elif self.progress_bar == 'raw': | elif self.progress_bar == 'raw': | ||||
desc = 'Evaluation ends' | desc = 'Evaluation ends' | ||||
if dataloader_name is not None: | if dataloader_name is not None: | ||||
@@ -264,9 +275,12 @@ class Evaluator: | |||||
logger.info("*" * 10 + desc + '*' * 10 + '\n') | logger.info("*" * 10 + desc + '*' * 10 + '\n') | ||||
def finally_progress_bar(self): | def finally_progress_bar(self): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||||
f_rich_progress.destroy_task(self._rich_task_id) | |||||
delattr(self, '_rich_task_id') | |||||
if self.progress_bar == 'rich' and hasattr(self, '_task_id'): | |||||
f_rich_progress.destroy_task(self._task_id) | |||||
delattr(self, '_task_id') | |||||
elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'): | |||||
f_tqdm_progress.destroy_task(self._task_id) | |||||
delattr(self, '_task_id') | |||||
@property | @property | ||||
def evaluate_batch_loop(self): | def evaluate_batch_loop(self): | ||||
@@ -294,9 +294,9 @@ class Trainer(TrainerEventTrigger): | |||||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | ||||
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | ||||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | |||||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 RichCallback, RawTextCallback等对象, | |||||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback 对象。如果 | |||||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 等对象。 | |||||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | ||||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | * *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | ||||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | * *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | ||||
@@ -573,7 +573,7 @@ class Trainer(TrainerEventTrigger): | |||||
if resume_from is not None: | if resume_from is not None: | ||||
if os.path.exists(resume_from): | if os.path.exists(resume_from): | ||||
self.load(resume_from, resume_training=resume_training) | |||||
self.load_checkpoint(resume_from, resume_training=resume_training) | |||||
else: | else: | ||||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | ||||
@@ -732,6 +732,7 @@ class Trainer(TrainerEventTrigger): | |||||
Trainer.__init__(): | Trainer.__init__(): | ||||
on_after_trainer_initialized(trainer, driver) | on_after_trainer_initialized(trainer, driver) | ||||
Trainer.run(): | Trainer.run(): | ||||
# load checkpoint if resume_from is not None | |||||
if num_eval_sanity_batch>0: | if num_eval_sanity_batch>0: | ||||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | ||||
on_sanity_check_end(trainer, sanity_check_res) | on_sanity_check_end(trainer, sanity_check_res) | ||||
@@ -19,10 +19,17 @@ from .field import FieldArray | |||||
from .instance import Instance | from .instance import Instance | ||||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | from fastNLP.core.utils.utils import pretty_table_printer, deprecated | ||||
from fastNLP.core.collators import Collator | from fastNLP.core.collators import Collator | ||||
from fastNLP.core.utils.rich_progress import f_rich_progress | |||||
from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress | |||||
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress | |||||
from ..log import logger | from ..log import logger | ||||
progress_bars = { | |||||
'rich': f_rich_progress, | |||||
'tqdm': f_tqdm_progress | |||||
} | |||||
class ApplyResultException(Exception): | class ApplyResultException(Exception): | ||||
def __init__(self, msg, index=None): | def __init__(self, msg, index=None): | ||||
super().__init__(msg) | super().__init__(msg) | ||||
@@ -30,7 +37,7 @@ class ApplyResultException(Exception): | |||||
self.index = index # 标示在哪个数据遭遇到问题了 | self.index = index # 标示在哪个数据遭遇到问题了 | ||||
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, | |||||
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, progress_bar: str = 'rich', | |||||
desc: str = None) -> list: | desc: str = None) -> list: | ||||
""" | """ | ||||
对数据集进行处理封装函数,以便多进程使用 | 对数据集进行处理封装函数,以便多进程使用 | ||||
@@ -39,32 +46,29 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||||
:param _apply_field: 需要处理数据集的field_name | :param _apply_field: 需要处理数据集的field_name | ||||
:param func: 用户自定义的func | :param func: 用户自定义的func | ||||
:param desc: 进度条的描述字符 | :param desc: 进度条的描述字符 | ||||
:param show_progress_bar: 是否展示子进程进度条 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:return: | :return: | ||||
""" | """ | ||||
if show_progress_bar: | |||||
desc = desc if desc else f"Main" | |||||
pg_main = f_rich_progress.add_task(description=desc, total=len(ds), visible=show_progress_bar) | |||||
progress_bar = progress_bars.get(progress_bar, DummyFRichProgress()) | |||||
desc = desc if desc else "Processing" | |||||
task_id = progress_bar.add_task(description=desc, total=len(ds)) | |||||
results = [] | results = [] | ||||
idx = -1 | idx = -1 | ||||
try: | try: | ||||
# for idx, ins in tqdm(enumerate(ds), total=len(ds), position=0, desc=desc, disable=not show_progress_bar): | |||||
for idx, ins in enumerate(ds): | for idx, ins in enumerate(ds): | ||||
if _apply_field is not None: | if _apply_field is not None: | ||||
results.append(func(ins[_apply_field])) | results.append(func(ins[_apply_field])) | ||||
else: | else: | ||||
results.append(func(ins)) | results.append(func(ins)) | ||||
if show_progress_bar: | |||||
f_rich_progress.update(pg_main, advance=1) | |||||
progress_bar.update(task_id, advance=1) | |||||
except BaseException as e: | except BaseException as e: | ||||
if idx != -1: | if idx != -1: | ||||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | logger.error("Exception happens at the `{}`th instance.".format(idx)) | ||||
raise e | raise e | ||||
finally: | finally: | ||||
if show_progress_bar: | |||||
f_rich_progress.destroy_task(pg_main) | |||||
progress_bar.destroy_task(task_id) | |||||
return results | return results | ||||
@@ -398,7 +402,7 @@ class DataSet: | |||||
def apply_field(self, func: Callable, field_name: str = None, | def apply_field(self, func: Callable, field_name: str = None, | ||||
new_field_name: str = None, num_proc: int = 0, | new_field_name: str = None, num_proc: int = 0, | ||||
progress_desc: str = None, show_progress_bar: bool = True): | |||||
progress_desc: str = None, progress_bar: str = 'rich'): | |||||
r""" | r""" | ||||
将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name`` | 将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name`` | ||||
中。 | 中。 | ||||
@@ -413,8 +417,8 @@ class DataSet: | |||||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。 | 由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。 | ||||
:param progress_desc: 进度条的描述字符,默认为 ``Main``; | |||||
:param show_progress_bar: 是否在处理过程中展示进度条; | |||||
:param progress_desc: 进度条的描述字符,默认为 ``Processing``; | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:return: 从函数 ``func`` 中得到的返回值; | :return: 从函数 ``func`` 中得到的返回值; | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
@@ -422,7 +426,7 @@ class DataSet: | |||||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | raise KeyError("DataSet has no field named `{}`.".format(field_name)) | ||||
try: | try: | ||||
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, | |||||
results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar, | |||||
progress_desc=progress_desc, _apply_field=field_name) | progress_desc=progress_desc, _apply_field=field_name) | ||||
except BaseException as e: | except BaseException as e: | ||||
raise e | raise e | ||||
@@ -433,7 +437,7 @@ class DataSet: | |||||
def apply_field_more(self, func: Callable = None, field_name: str = None, | def apply_field_more(self, func: Callable = None, field_name: str = None, | ||||
modify_fields: bool = True, num_proc: int = 0, | modify_fields: bool = True, num_proc: int = 0, | ||||
progress_desc: str = None, show_progress_bar: bool = True): | |||||
progress_desc: str = None, progress_bar: str = 'rich'): | |||||
r""" | r""" | ||||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | 将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | ||||
func 可以返回一个或多个 field 上的结果。 | func 可以返回一个或多个 field 上的结果。 | ||||
@@ -446,8 +450,8 @@ class DataSet: | |||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | :param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param show_progress_bar: 是否显示进度条,默认展示 | |||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 | |||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
""" | """ | ||||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | assert len(self) != 0, "Null DataSet cannot use apply_field()." | ||||
@@ -456,7 +460,7 @@ class DataSet: | |||||
idx = -1 | idx = -1 | ||||
results = {} | results = {} | ||||
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | ||||
show_progress_bar=show_progress_bar, _apply_field=field_name) | |||||
progress_bar=progress_bar, _apply_field=field_name) | |||||
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | ||||
if not isinstance(apply_out[0], Mapping): | if not isinstance(apply_out[0], Mapping): | ||||
raise Exception(f"The result of func is not a Mapping, but a {type(apply_out[0])}") | raise Exception(f"The result of func is not a Mapping, but a {type(apply_out[0])}") | ||||
@@ -483,13 +487,13 @@ class DataSet: | |||||
return results | return results | ||||
def _apply_process(self, num_proc: int = 0, func: Callable = None, | def _apply_process(self, num_proc: int = 0, func: Callable = None, | ||||
show_progress_bar: bool = True, _apply_field: str = None, | |||||
progress_bar: str = 'rich', _apply_field: str = None, | |||||
progress_desc: str = 'Main') -> list: | progress_desc: str = 'Main') -> list: | ||||
""" | """ | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | :param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | ||||
:param _apply_field: 需要传进去func的数据集的field_name | :param _apply_field: 需要传进去func的数据集的field_name | ||||
:param show_progress_bar: 是否展示progress进度条,默认为展示 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: 进度条的描述字符,默认为'Main | :param progress_desc: 进度条的描述字符,默认为'Main | ||||
""" | """ | ||||
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>": | if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>": | ||||
@@ -499,7 +503,7 @@ class DataSet: | |||||
if num_proc < 2: | if num_proc < 2: | ||||
results = _apply_single(ds=self, _apply_field=_apply_field, func=func, | results = _apply_single(ds=self, _apply_field=_apply_field, func=func, | ||||
desc=progress_desc, show_progress_bar=show_progress_bar) | |||||
desc=progress_desc, progress_bar=progress_bar) | |||||
else: | else: | ||||
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 | # TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 | ||||
import multiprocessing as mp | import multiprocessing as mp | ||||
@@ -525,25 +529,25 @@ class DataSet: | |||||
proc.start() | proc.start() | ||||
pool.append(proc) | pool.append(proc) | ||||
queues.append(queue) | queues.append(queue) | ||||
progress_bar = progress_bars.get(progress_bar, DummyFRichProgress()) | |||||
total_len = len(self) | total_len = len(self) | ||||
task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar) | |||||
task_id = progress_bar.add_task(description=progress_desc, total=total_len) | |||||
last_count = -1 | last_count = -1 | ||||
while counter.value < total_len or last_count == -1: | while counter.value < total_len or last_count == -1: | ||||
while counter.value == last_count: | while counter.value == last_count: | ||||
time.sleep(0.1) | time.sleep(0.1) | ||||
advance = counter.value - last_count | advance = counter.value - last_count | ||||
last_count = counter.value | last_count = counter.value | ||||
f_rich_progress.update(task_id, advance=advance, refresh=True) | |||||
progress_bar.update(task_id, advance=advance, refresh=True) | |||||
for idx, proc in enumerate(pool): | for idx, proc in enumerate(pool): | ||||
results.extend(pickle.loads(queues[idx].get())) | results.extend(pickle.loads(queues[idx].get())) | ||||
proc.join() | proc.join() | ||||
f_rich_progress.destroy_task(task_id) | |||||
progress_bar.destroy_task(task_id) | |||||
return results | return results | ||||
def apply_more(self, func: Callable = None, modify_fields: bool = True, | def apply_more(self, func: Callable = None, modify_fields: bool = True, | ||||
num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True): | |||||
num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): | |||||
r""" | r""" | ||||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | 将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | ||||
@@ -558,9 +562,9 @@ class DataSet: | |||||
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | :param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | ||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param num_proc: 进程的数量 | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | |||||
:param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:return Dict[str:Field]: 返回一个字典 | :return Dict[str:Field]: 返回一个字典 | ||||
""" | """ | ||||
assert callable(func), "The func is not callable." | assert callable(func), "The func is not callable." | ||||
@@ -570,7 +574,7 @@ class DataSet: | |||||
results = {} | results = {} | ||||
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | ||||
show_progress_bar=show_progress_bar) | |||||
progress_bar=progress_bar) | |||||
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | # 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | ||||
if not isinstance(apply_out[0], dict): | if not isinstance(apply_out[0], dict): | ||||
raise Exception("The result of func is not a dict") | raise Exception("The result of func is not a dict") | ||||
@@ -597,21 +601,21 @@ class DataSet: | |||||
return results | return results | ||||
def apply(self, func: Callable = None, new_field_name: str = None, | def apply(self, func: Callable = None, new_field_name: str = None, | ||||
num_proc: int = 0, show_progress_bar: bool = True, progress_desc: str = ''): | |||||
num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''): | |||||
""" | """ | ||||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | :param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | ||||
盖之前的field。如果为None则不创建新的field。 | 盖之前的field。如果为None则不创建新的field。 | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param show_progress_bar: 是否显示进度条。 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: progress bar 显示的值,默认为空。 | :param progress_desc: progress bar 显示的值,默认为空。 | ||||
""" | """ | ||||
assert callable(func), "The func you provide is not callable." | assert callable(func), "The func you provide is not callable." | ||||
assert len(self) != 0, "Null DataSet cannot use apply()." | assert len(self) != 0, "Null DataSet cannot use apply()." | ||||
assert num_proc >= 0, "num_proc must be an integer >= 0." | assert num_proc >= 0, "num_proc must be an integer >= 0." | ||||
try: | try: | ||||
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, | |||||
results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar, | |||||
progress_desc=progress_desc) | progress_desc=progress_desc) | ||||
except BaseException as e: | except BaseException as e: | ||||
raise e | raise e | ||||
@@ -22,7 +22,8 @@ __all__ = [ | |||||
'Option', | 'Option', | ||||
'deprecated', | 'deprecated', | ||||
'seq_len_to_mask', | 'seq_len_to_mask', | ||||
"flat_nest_dict" | |||||
"flat_nest_dict", | |||||
"f_tqdm_progress" | |||||
] | ] | ||||
from .cache_results import cache_results | from .cache_results import cache_results | ||||
@@ -32,5 +33,6 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||||
from .rich_progress import f_rich_progress | from .rich_progress import f_rich_progress | ||||
from .torch_utils import torch_move_data_to_device | from .torch_utils import torch_move_data_to_device | ||||
from .utils import * | from .utils import * | ||||
from .tqdm_progress import f_tqdm_progress | |||||
@@ -35,7 +35,7 @@ class DummyFRichProgress: | |||||
return None | return None | ||||
@property | @property | ||||
def dummy_rich(self)->bool: | |||||
def dummy(self)->bool: | |||||
""" | """ | ||||
当前对象是否是 dummy 的 rich 对象。 | 当前对象是否是 dummy 的 rich 对象。 | ||||
@@ -122,6 +122,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
visible: bool = True, | visible: bool = True, | ||||
**fields: Any, | **fields: Any, | ||||
) -> TaskID: | ) -> TaskID: | ||||
from .tqdm_progress import f_tqdm_progress | |||||
assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop." | |||||
if self.live._started is False: | if self.live._started is False: | ||||
self.start() | self.start() | ||||
post_desc = fields.pop('post_desc', '') | post_desc = fields.pop('post_desc', '') | ||||
@@ -213,7 +216,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
self.refresh() | self.refresh() | ||||
@property | @property | ||||
def dummy_rich(self) -> bool: | |||||
def dummy(self) -> bool: | |||||
""" | """ | ||||
当前对象是否是 dummy 的 rich 对象。 | 当前对象是否是 dummy 的 rich 对象。 | ||||
@@ -221,6 +224,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
""" | """ | ||||
return False | return False | ||||
def not_empty(self): | |||||
return len(self._tasks) != 0 | |||||
class SpeedColumn(ProgressColumn): | class SpeedColumn(ProgressColumn): | ||||
""" | """ | ||||
@@ -0,0 +1,160 @@ | |||||
__all__ = [ | |||||
'f_tqdm_progress' | |||||
] | |||||
import uuid | |||||
import sys | |||||
from ...envs.imports import _module_available, _compare_version | |||||
from ...envs import get_global_rank | |||||
from .utils import is_notebook | |||||
from ..log import logger | |||||
if _module_available('tqdm'): | |||||
from tqdm.autonotebook import tqdm | |||||
import operator | |||||
class Singleton(type): | |||||
_instances = {} | |||||
def __call__(cls, *args, **kwargs): | |||||
if cls not in cls._instances: | |||||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||||
return cls._instances[cls] | |||||
# 如果不打印的时候,使得整个 progress 没有任何意义 | |||||
class DummyFTqdmProgress: | |||||
def __getattr__(self, item): | |||||
return DummyFTqdmProgress() | |||||
def __call__(self, *args, **kwargs): | |||||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | |||||
return None | |||||
@property | |||||
def dummy(self)->bool: | |||||
""" | |||||
当前对象是否是 dummy 的 tqdm 对象。 | |||||
:return: | |||||
""" | |||||
return True | |||||
class TqdmProgress(metaclass=Singleton): | |||||
def __init__(self): | |||||
self.bars = {} | |||||
def add_task(self, iterable=None, description=None, total=None, leave=False, | |||||
ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None, | |||||
ascii=None, visible=True, unit='it', unit_scale=False, | |||||
dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0, | |||||
postfix=None, unit_divisor=1000, write_bytes=None, | |||||
lock_args=None, nrows=None, colour=None, gui=False, **kwargs): | |||||
""" | |||||
主要就模仿了 tqdm bar 的创建,为了和 FRichProgress 的接口尽量统一,将 desc 重名为了 description,以及 disable 专为了 | |||||
visible 。 | |||||
:param iterable: | |||||
:param description: | |||||
:param total: | |||||
:param leave: | |||||
:param ncols: | |||||
:param mininterval: | |||||
:param maxinterval: | |||||
:param miniters: | |||||
:param ascii: | |||||
:param visible: | |||||
:param unit: | |||||
:param unit_scale: | |||||
:param dynamic_ncols: | |||||
:param smoothing: | |||||
:param bar_format: | |||||
:param initial: | |||||
:param postfix: | |||||
:param unit_divisor: | |||||
:param write_bytes: | |||||
:param lock_args: | |||||
:param nrows: | |||||
:param colour: | |||||
:param gui: | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ | |||||
f"To use {self.__class__.__name__}, tqdm>=4.57 is needed." | |||||
from .rich_progress import f_rich_progress | |||||
assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." | |||||
if hasattr(self, 'orig_out_err'): | |||||
file = self.orig_out_err[0] | |||||
else: | |||||
file = sys.stdout | |||||
bar = tqdm(iterable=iterable, desc=description, total=total, leave=leave, file=file, | |||||
ncols=ncols, mininterval=mininterval, maxinterval=maxinterval, miniters=miniters, | |||||
ascii=ascii, disable=not visible, unit=unit, unit_scale=unit_scale, | |||||
dynamic_ncols=dynamic_ncols, smoothing=smoothing, bar_format=bar_format, initial=initial, | |||||
position=len(self.bars), postfix=postfix, unit_divisor=unit_divisor, write_bytes=write_bytes, | |||||
lock_args=lock_args, nrows=nrows, colour=colour, gui=gui, **kwargs) | |||||
_uuid = str(uuid.uuid1()) | |||||
self.bars[_uuid] = bar | |||||
if not hasattr(self, 'orig_out_err') and not is_notebook(): | |||||
from tqdm.contrib import DummyTqdmFile | |||||
self.orig_out_err = sys.stdout, sys.stderr | |||||
sys.stdout, sys.stderr = map(DummyTqdmFile, self.orig_out_err) | |||||
return _uuid | |||||
def update(self, task_id:str, advance:int, refresh=True): | |||||
self.bars[task_id].update(advance) | |||||
def set_postfix_str(self, task_id, s, refresh=True): | |||||
self.bars[task_id].set_postfix_str(s=s, refresh=refresh) | |||||
def set_description_str(self, task_id, desc, refresh=True): | |||||
self.bars[task_id].set_description_str(desc=desc, refresh=refresh) | |||||
def destroy_task(self, task_id): | |||||
""" | |||||
关闭 task_id 对应的 tqdm bar 。 | |||||
:param task_id: | |||||
:return: | |||||
""" | |||||
self.bars[task_id].close() | |||||
self.bars.pop(task_id) | |||||
if len(self.bars) == 0 and hasattr(self, 'orig_out_err'): | |||||
# recover 成正常的 sys.stdout 与 sys.stderr | |||||
sys.stdout, sys.stderr = self.orig_out_err | |||||
delattr(self, 'orig_out_err') | |||||
def reset(self, task_id): | |||||
self.bars[task_id].reset() | |||||
def print(self): | |||||
tqdm.write('') | |||||
def not_empty(self): | |||||
return len(self.bars) != 0 | |||||
@property | |||||
def dummy(self) -> bool: | |||||
""" | |||||
当前对象是否是 dummy 的 tqdm 对象。 | |||||
:return: | |||||
""" | |||||
return False | |||||
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and get_global_rank() == 0: | |||||
f_tqdm_progress = TqdmProgress() | |||||
else: | |||||
f_tqdm_progress = DummyFTqdmProgress() | |||||
logger.debug("Use dummy tqdm...") | |||||
@@ -340,7 +340,7 @@ class Vocabulary(object): | |||||
try: | try: | ||||
for f_n, n_f_n in zip(field_name, new_field_name): | for f_n, n_f_n in zip(field_name, new_field_name): | ||||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n, | dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n, | ||||
show_progress_bar=False) | |||||
progress_bar=None) | |||||
except Exception as e: | except Exception as e: | ||||
logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) | logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) | ||||
raise e | raise e | ||||
@@ -396,7 +396,7 @@ class Vocabulary(object): | |||||
for idx, dataset in enumerate(datasets): | for idx, dataset in enumerate(datasets): | ||||
if isinstance(dataset, DataSet): | if isinstance(dataset, DataSet): | ||||
try: | try: | ||||
dataset.apply(construct_vocab, show_progress_bar=False) | |||||
dataset.apply(construct_vocab, progress_bar=None) | |||||
except BaseException as e: | except BaseException as e: | ||||
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | ||||
raise e | raise e | ||||
@@ -406,12 +406,12 @@ class Vocabulary(object): | |||||
if no_create_entry_dataset is not None: | if no_create_entry_dataset is not None: | ||||
partial_construct_vocab = partial(construct_vocab, no_create_entry=True) | partial_construct_vocab = partial(construct_vocab, no_create_entry=True) | ||||
if isinstance(no_create_entry_dataset, DataSet): | if isinstance(no_create_entry_dataset, DataSet): | ||||
no_create_entry_dataset.apply(partial_construct_vocab, show_progress_bar=False) | |||||
no_create_entry_dataset.apply(partial_construct_vocab, progress_bar=None) | |||||
elif isinstance(no_create_entry_dataset, list): | elif isinstance(no_create_entry_dataset, list): | ||||
for dataset in no_create_entry_dataset: | for dataset in no_create_entry_dataset: | ||||
if not isinstance(dataset, DataSet): | if not isinstance(dataset, DataSet): | ||||
raise TypeError("Only DataSet type is allowed.") | raise TypeError("Only DataSet type is allowed.") | ||||
dataset.apply(partial_construct_vocab, show_progress_bar=False) | |||||
dataset.apply(partial_construct_vocab, progress_bar=None) | |||||
return self | return self | ||||
def _is_word_no_create_entry(self, word:str): | def _is_word_no_create_entry(self, word:str): | ||||
@@ -221,7 +221,7 @@ class DataBundle: | |||||
yield field_name, vocab | yield field_name, vocab | ||||
def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, | def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, | ||||
ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True): | |||||
ignore_miss_dataset: bool = True, progress_desc: str = '', progress_bar: str = 'rich'): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 | ||||
@@ -233,8 +233,8 @@ class DataBundle: | |||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | :param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | ||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||||
:param show_progress_bar: 是否显示tqdm进度条 | |||||
:param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
@@ -243,13 +243,13 @@ class DataBundle: | |||||
progress_desc = _progress_desc + f' for `{name}`' | progress_desc = _progress_desc + f' for `{name}`' | ||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | ||||
progress_desc=progress_desc, show_progress_bar=show_progress_bar) | |||||
progress_desc=progress_desc, progress_bar=progress_bar) | |||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
raise KeyError(f"{field_name} not found DataSet:{name}.") | raise KeyError(f"{field_name} not found DataSet:{name}.") | ||||
return self | return self | ||||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | ||||
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''): | |||||
ignore_miss_dataset=True, progress_bar: str = 'rich', progress_desc: str = ''): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | ||||
@@ -263,8 +263,8 @@ class DataBundle: | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param show_progress_bar: 是否显示进度条 | |||||
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
@@ -277,13 +277,13 @@ class DataBundle: | |||||
if dataset.has_field(field_name=field_name): | if dataset.has_field(field_name=field_name): | ||||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | ||||
modify_fields=modify_fields, | modify_fields=modify_fields, | ||||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | |||||
progress_bar=progress_bar, progress_desc=progress_desc) | |||||
elif not ignore_miss_dataset: | elif not ignore_miss_dataset: | ||||
raise KeyError(f"{field_name} not found DataSet:{name} .") | raise KeyError(f"{field_name} not found DataSet:{name} .") | ||||
return res | return res | ||||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | ||||
progress_desc: str = '', show_progress_bar: bool = True): | |||||
progress_desc: str = '', progress_bar: bool = True): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | ||||
@@ -293,20 +293,20 @@ class DataBundle: | |||||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | :param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | ||||
盖之前的field。如果为None则不创建新的field。 | 盖之前的field。如果为None则不创建新的field。 | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param show_progress_bar: 是否显示tqd进度条 | |||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称 | |||||
""" | """ | ||||
_progress_desc = progress_desc | _progress_desc = progress_desc | ||||
for name, dataset in self.datasets.items(): | for name, dataset in self.datasets.items(): | ||||
if _progress_desc: | if _progress_desc: | ||||
progress_desc = _progress_desc + f' for `{name}`' | progress_desc = _progress_desc + f' for `{name}`' | ||||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | |||||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | |||||
progress_desc=progress_desc) | progress_desc=progress_desc) | ||||
return self | return self | ||||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | ||||
progress_desc: str = '', show_progress_bar: bool = True): | |||||
progress_desc: str = '', progress_bar: str = 'rich'): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | ||||
@@ -317,8 +317,8 @@ class DataBundle: | |||||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | :param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | ||||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | :param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | ||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param show_progress_bar: 是否显示tqd进度条 | |||||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||||
:param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
""" | """ | ||||
@@ -328,7 +328,7 @@ class DataBundle: | |||||
if _progress_desc: | if _progress_desc: | ||||
progress_desc = _progress_desc + f' for `{name}`' | progress_desc = _progress_desc + f' for `{name}`' | ||||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | ||||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | |||||
progress_bar=progress_bar, progress_desc=progress_desc) | |||||
return res | return res | ||||
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | ||||
@@ -279,7 +279,7 @@ class _LazyConfigMapping(OrderedDict): | |||||
value = self._mapping[key] | value = self._mapping[key] | ||||
module_name = model_type_to_module_name(key) | module_name = model_type_to_module_name(key) | ||||
if module_name not in self._modules: | if module_name not in self._modules: | ||||
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") | |||||
self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") | |||||
return getattr(self._modules[module_name], value) | return getattr(self._modules[module_name], value) | ||||
def keys(self): | def keys(self): | ||||
@@ -318,15 +318,15 @@ class _LazyLoadAllMappings(OrderedDict): | |||||
def _initialize(self): | def _initialize(self): | ||||
if self._initialized: | if self._initialized: | ||||
return | return | ||||
logger.warn( | |||||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " | |||||
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", | |||||
FutureWarning, | |||||
) | |||||
# logger.warn( | |||||
# "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " | |||||
# "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", | |||||
# FutureWarning, | |||||
# ) | |||||
for model_type, map_name in self._mapping.items(): | for model_type, map_name in self._mapping.items(): | ||||
module_name = model_type_to_module_name(model_type) | module_name = model_type_to_module_name(model_type) | ||||
module = importlib.import_module(f".{module_name}", "transformers.models") | |||||
module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") | |||||
mapping = getattr(module, map_name) | mapping = getattr(module, map_name) | ||||
self._data.update(mapping) | self._data.update(mapping) | ||||
@@ -362,8 +362,8 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPP | |||||
def _get_class_name(model_class: Union[str, List[str]]): | def _get_class_name(model_class: Union[str, List[str]]): | ||||
if isinstance(model_class, (list, tuple)): | if isinstance(model_class, (list, tuple)): | ||||
return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None]) | |||||
return f":class:`~transformers.{model_class}`" | |||||
return " or ".join([f":class:`~fastNLP.transformers.torch.{c}`" for c in model_class if c is not None]) | |||||
return f":class:`~fastNLP.transformers.torch.{model_class}`" | |||||
def _list_model_options(indent, config_to_class=None, use_model_types=True): | def _list_model_options(indent, config_to_class=None, use_model_types=True): | ||||
@@ -372,7 +372,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): | |||||
if use_model_types: | if use_model_types: | ||||
if config_to_class is None: | if config_to_class is None: | ||||
model_type_to_name = { | model_type_to_name = { | ||||
model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items() | |||||
model_type: f":class:`~fastNLP.transformers.torch.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items() | |||||
} | } | ||||
else: | else: | ||||
model_type_to_name = { | model_type_to_name = { | ||||
@@ -394,7 +394,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): | |||||
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() | config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() | ||||
} | } | ||||
lines = [ | lines = [ | ||||
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" | |||||
f"{indent}- :class:`~fastNLP.transformers.torch.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" | |||||
for config_name in sorted(config_to_name.keys()) | for config_name in sorted(config_to_name.keys()) | ||||
] | ] | ||||
return "\n".join(lines) | return "\n".join(lines) | ||||
@@ -4,17 +4,12 @@ | |||||
(2) 能不能保存 topk 并load进来进行训练 | (2) 能不能保存 topk 并load进来进行训练 | ||||
""" | """ | ||||
import pytest | |||||
import os | import os | ||||
import pytest | import pytest | ||||
from typing import Any | from typing import Any | ||||
from dataclasses import dataclass | from dataclasses import dataclass | ||||
from pathlib import Path | from pathlib import Path | ||||
import re | |||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | ||||
@@ -25,7 +20,6 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | from tests.helpers.datasets.torch_data import TorchArgMaxDataset | ||||
from torchmetrics import Accuracy | from torchmetrics import Accuracy | ||||
from fastNLP.core.metrics import Metric | from fastNLP.core.metrics import Metric | ||||
from fastNLP.core.log import logger | |||||
from fastNLP.core.callbacks import MoreEvaluateCallback | from fastNLP.core.callbacks import MoreEvaluateCallback | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
@@ -0,0 +1,123 @@ | |||||
from typing import Any | |||||
from dataclasses import dataclass | |||||
import pytest | |||||
from fastNLP import Metric, Accuracy | |||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP import Trainer, Evaluator | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
if _NEED_IMPORT_TORCH: | |||||
from torch.utils.data import DataLoader | |||||
from torch.optim import SGD | |||||
import torch.distributed as dist | |||||
import torch | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||||
@dataclass | |||||
class ArgMaxDatasetConfig: | |||||
num_labels: int = 10 | |||||
feature_dimension: int = 10 | |||||
data_num: int = 100 | |||||
seed: int = 0 | |||||
batch_size: int = 4 | |||||
shuffle: bool = True | |||||
@dataclass | |||||
class TrainerParameters: | |||||
model: Any = None | |||||
optimizers: Any = None | |||||
train_dataloader: Any = None | |||||
evaluate_dataloaders: Any = None | |||||
input_mapping: Any = None | |||||
output_mapping: Any = None | |||||
metrics: Any = None | |||||
more_metrics: Any = None | |||||
@pytest.fixture(scope="module", params=[0], autouse=True) | |||||
def model_and_optimizers(request): | |||||
trainer_params = TrainerParameters() | |||||
trainer_params.model = TorchNormalModel_Classification_1( | |||||
num_labels=ArgMaxDatasetConfig.num_labels, | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||||
) | |||||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||||
dataset = TorchArgMaxDataset( | |||||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||||
data_num=ArgMaxDatasetConfig.data_num, | |||||
seed=ArgMaxDatasetConfig.seed | |||||
) | |||||
_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=ArgMaxDatasetConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
class LossMetric(Metric): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.register_element('loss') | |||||
def update(self, loss): | |||||
self.loss += loss.item() | |||||
def get_metric(self) -> dict: | |||||
return self.loss.item() | |||||
trainer_params.train_dataloader = _dataloader | |||||
trainer_params.evaluate_dataloaders = _dataloader | |||||
trainer_params.metrics = {'loss': LossMetric()} | |||||
trainer_params.more_metrics = {"acc": Accuracy()} | |||||
return trainer_params | |||||
@pytest.mark.torch | |||||
@pytest.mark.parametrize('device', ['cpu', [0, 1]]) | |||||
@pytest.mark.parametrize('progress_bar', ['rich', 'auto', None, 'raw', 'tqdm']) | |||||
@magic_argv_env_context | |||||
def test_run( model_and_optimizers: TrainerParameters, device, progress_bar): | |||||
if device != 'cpu' and not torch.cuda.is_available(): | |||||
pytest.skip(f"No cuda for device:{device}") | |||||
n_epochs = 5 | |||||
trainer = Trainer( | |||||
model=model_and_optimizers.model, | |||||
driver='torch', | |||||
device=device, | |||||
optimizers=model_and_optimizers.optimizers, | |||||
train_dataloader=model_and_optimizers.train_dataloader, | |||||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||||
input_mapping=model_and_optimizers.input_mapping, | |||||
output_mapping=model_and_optimizers.output_mapping, | |||||
metrics=model_and_optimizers.metrics, | |||||
n_epochs=n_epochs, | |||||
callbacks=None, | |||||
progress_bar=progress_bar, | |||||
output_from_new_proc="all", | |||||
evaluate_fn='train_step', | |||||
larger_better=False | |||||
) | |||||
trainer.run() | |||||
evaluator = Evaluator(model=model_and_optimizers.model, dataloaders=model_and_optimizers.train_dataloader, | |||||
driver=trainer.driver, metrics=model_and_optimizers.metrics, | |||||
progress_bar=progress_bar, evaluate_fn='train_step') | |||||
evaluator.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -181,7 +181,7 @@ class TestDataSetMethods: | |||||
assert ("rx" in ds.field_arrays) == True | assert ("rx" in ds.field_arrays) == True | ||||
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1] | assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1] | ||||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | |||||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", progress_bar=None) | |||||
assert ds.field_arrays["y"].content[0] == 2 | assert ds.field_arrays["y"].content[0] == 2 | ||||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") | res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") | ||||
@@ -198,8 +198,8 @@ class TestDataSetMethods: | |||||
def do_nothing(ins): | def do_nothing(ins): | ||||
time.sleep(0.01) | time.sleep(0.01) | ||||
ds.apply(do_nothing, show_progress_bar=True, num_proc=0) | |||||
ds.apply_field(do_nothing, field_name='x', show_progress_bar=True) | |||||
ds.apply(do_nothing, progress_bar='rich', num_proc=0) | |||||
ds.apply_field(do_nothing, field_name='x', progress_bar='rich') | |||||
def test_apply_cannot_modify_instance(self): | def test_apply_cannot_modify_instance(self): | ||||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | ||||
@@ -0,0 +1,16 @@ | |||||
import pytest | |||||
from fastNLP.envs.imports import _module_available | |||||
from fastNLP.core.utils import f_tqdm_progress, f_rich_progress | |||||
def test_raise(): | |||||
if not _module_available('tqdm') or f_rich_progress.dummy or f_tqdm_progress.dummy: | |||||
pytest.skip('No tqdm') | |||||
t = f_rich_progress.add_task('test', total=10) | |||||
with pytest.raises(AssertionError): | |||||
f_tqdm_progress.add_task('test') | |||||
f_rich_progress.destroy_task(t) | |||||
t = f_tqdm_progress.add_task('test', total=10) | |||||
with pytest.raises(AssertionError): | |||||
f_rich_progress.add_task('test') |