Browse Source

新增TqdmProgressBar

tags/v1.0.0alpha
yh 2 years ago
parent
commit
2423641010
16 changed files with 522 additions and 100 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +4
    -1
      fastNLP/core/callbacks/__init__.py
  3. +103
    -5
      fastNLP/core/callbacks/progress_callback.py
  4. +28
    -14
      fastNLP/core/controllers/evaluator.py
  5. +5
    -4
      fastNLP/core/controllers/trainer.py
  6. +37
    -33
      fastNLP/core/dataset/dataset.py
  7. +3
    -1
      fastNLP/core/utils/__init__.py
  8. +8
    -2
      fastNLP/core/utils/rich_progress.py
  9. +160
    -0
      fastNLP/core/utils/tqdm_progress.py
  10. +4
    -4
      fastNLP/core/vocabulary.py
  11. +16
    -16
      fastNLP/io/data_bundle.py
  12. +11
    -11
      fastNLP/transformers/torch/models/auto/configuration_auto.py
  13. +0
    -6
      tests/core/callbacks/test_more_evaluate_callback.py
  14. +123
    -0
      tests/core/callbacks/test_progress_callback_torch.py
  15. +3
    -3
      tests/core/dataset/test_dataset.py
  16. +16
    -0
      tests/core/utils/test_progress.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -6,6 +6,7 @@ __all__ = [
'CheckpointCallback',
'ProgressCallback',
'RichCallback',
'TqdmCallback',
"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback",


+ 4
- 1
fastNLP/core/callbacks/__init__.py View File

@@ -4,8 +4,11 @@ __all__ = [
'Filter',
'CheckpointCallback',
'choose_progress_callback',

'ProgressCallback',
'RichCallback',
'TqdmCallback',

"LRSchedCallback",
'LoadBestModelCallback',
"EarlyStopCallback",
@@ -26,7 +29,7 @@ from .callback import Callback
from .callback_event import Event, Filter
from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback
from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback


+ 103
- 5
fastNLP/core/callbacks/progress_callback.py View File

@@ -5,11 +5,14 @@ from typing import Union
__all__ = [
'choose_progress_callback',
'ProgressCallback',
'RichCallback'
'RichCallback',
'TqdmCallback'
]

from ...envs.imports import _module_available, _compare_version

from .has_monitor_callback import HasMonitorCallback
from fastNLP.core.utils import f_rich_progress
from fastNLP.core.utils import f_rich_progress, f_tqdm_progress
from fastNLP.core.log import logger


@@ -24,7 +27,7 @@ class ProgressCallback(HasMonitorCallback):

def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback:
if progress_bar == 'auto':
if not f_rich_progress.dummy_rich:
if not f_rich_progress.dummy:
progress_bar = 'rich'
else:
progress_bar = 'raw'
@@ -32,6 +35,8 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog
return RichCallback()
elif progress_bar == 'raw':
return RawTextCallback()
elif progress_bar == 'tqdm':
return TqdmCallback()
elif isinstance(progress_bar, ProgressCallback):
return progress_bar
else:
@@ -82,7 +87,9 @@ class RichCallback(ProgressCallback):
if 'batch' in self.task2id:
self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch)
else:
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', total=trainer.num_batches_per_epoch)
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0',
total=trainer.num_batches_per_epoch,
completed=trainer.batch_idx_in_epoch)

def on_train_epoch_end(self, trainer):
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}',
@@ -208,4 +215,95 @@ class RawTextCallback(ProgressCallback):

@property
def name(self): # progress bar的名称
return 'raw'
return 'raw'


class TqdmCallback(ProgressCallback):
"""
在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。

:param print_every: 多少个 batch 更新一次显示。
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。

* 为 ``None``
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。
* 为 ``str``
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
* 为 ``Callable``
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
的 ``monitor`` 值请返回 ``None`` 。
:param larger_better: 是否是 monitor 的结果越大越好。
:param format_json: 是否格式化 json 再打印
"""
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True,
format_json=True):
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False)
self.print_every = print_every
self.progress_bar = f_tqdm_progress
self.task2id = {}
self.loss = 0
self.loss_round_ndigit = loss_round_ndigit
self.format_json = format_json
self.num_signs = 10

def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]',
initial=trainer.global_forward_batches/(trainer.total_batches+1e-6))

def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
if 'batch' in self.task2id:
self.progress_bar.reset(self.task2id['batch'])
else:
self.task2id['batch'] = self.progress_bar.add_task(description='Batch', total=trainer.num_batches_per_epoch,
initial=trainer.batch_idx_in_epoch)
self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True)

def on_train_end(self, trainer):
self.clear_tasks()

def on_before_backward(self, trainer, outputs):
loss = trainer.extract_loss_from_outputs(outputs)
loss = trainer.driver.tensor_to_numeric(loss, reduce='sum')
self.loss += loss

def on_train_batch_end(self, trainer):
if trainer.global_forward_batches % self.print_every == 0:
loss = self.loss/self.print_every
self.loss = 0
self.progress_bar.update(self.task2id['batch'], advance=self.print_every, refresh=True)
self.progress_bar.set_postfix_str(self.task2id['batch'], f'Loss:{round(loss, self.loss_round_ndigit)}')
self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True)

def on_evaluate_end(self, trainer, results):
if len(results)==0:
return
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
text = ''
if self.monitor is not None:
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs
if len(text) == 0:
text = '-'*self.num_signs + base_text + '-'*self.num_signs

logger.info(text)
if self.format_json:
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results)))
else:
logger.info(results)

def clear_tasks(self):
for key, taskid in self.task2id.items():
self.progress_bar.destroy_task(taskid)
self.task2id = {}
self.loss = 0

@property
def name(self): # progress bar的名称
return 'tqdm'

+ 28
- 14
fastNLP/core/controllers/evaluator.py View File

@@ -19,7 +19,7 @@ from fastNLP.core.drivers import Driver, TorchDriver
from ..drivers.choose_driver import choose_driver
from .loops import Loop, EvaluateBatchLoop
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \
match_and_substitute_params, f_rich_progress, flat_nest_dict
match_and_substitute_params, f_rich_progress, flat_nest_dict, f_tqdm_progress
from fastNLP.core.metrics import Metric
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader
@@ -166,8 +166,9 @@ class Evaluator:
self.dataloaders[name] = dl

self.progress_bar = kwargs.get('progress_bar', 'auto')
assert self.progress_bar in [None, 'rich', 'auto', 'tqdm', 'raw']
if self.progress_bar == 'auto':
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich'
self.progress_bar = 'raw' if f_rich_progress.dummy else 'rich'

self.driver.barrier()

@@ -226,12 +227,15 @@ class Evaluator:
return metric_results

def start_progress_bar(self, total: int, dataloader_name):
if self.progress_bar == 'rich':
if self.progress_bar in ('rich', 'tqdm'):
if dataloader_name is None:
desc = f'Eval. Batch:0'
desc = f'Eval. Batch'
else:
desc = f'Eval. on {dataloader_name} Batch'
if self.progress_bar == 'rich':
self._task_id = f_rich_progress.add_task(description=desc, total=total)
else:
desc = f'Eval. on {dataloader_name} Batch:0'
self._rich_task_id = f_rich_progress.add_task(description=desc, total=total)
self._task_id = f_tqdm_progress.add_task(description=desc, total=total)
elif self.progress_bar == 'raw':
desc = 'Evaluation starts'
if dataloader_name is not None:
@@ -244,19 +248,26 @@ class Evaluator:
else:
desc = f'Eval. on {dataloader_name} Batch:{batch_idx}'
if self.progress_bar == 'rich':
assert hasattr(self, '_rich_task_id'), "You must first call `start_progress_bar()` before calling " \
assert hasattr(self, '_task_id'), "You must first call `start_progress_bar()` before calling " \
"update_progress_bar()"
f_rich_progress.update(self._rich_task_id, description=desc, post_desc=kwargs.get('post_desc', ''),
f_rich_progress.update(self._task_id, description=desc, post_desc=kwargs.get('post_desc', ''),
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True),
visible=kwargs.get('visible', True))
elif self.progress_bar == 'raw':
if self.verbose > 1:
logger.info(desc)
elif self.progress_bar == 'tqdm':
f_tqdm_progress.update(self._task_id, advance=1)

def remove_progress_bar(self, dataloader_name):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id)
delattr(self, '_rich_task_id')
if self.progress_bar == 'rich' and hasattr(self, '_task_id'):
f_rich_progress.destroy_task(self._task_id)
delattr(self, '_task_id')

elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'):
f_tqdm_progress.destroy_task(self._task_id)
delattr(self, '_task_id')

elif self.progress_bar == 'raw':
desc = 'Evaluation ends'
if dataloader_name is not None:
@@ -264,9 +275,12 @@ class Evaluator:
logger.info("*" * 10 + desc + '*' * 10 + '\n')

def finally_progress_bar(self):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id)
delattr(self, '_rich_task_id')
if self.progress_bar == 'rich' and hasattr(self, '_task_id'):
f_rich_progress.destroy_task(self._task_id)
delattr(self, '_task_id')
elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'):
f_tqdm_progress.destroy_task(self._task_id)
delattr(self, '_task_id')

@property
def evaluate_batch_loop(self):


+ 5
- 4
fastNLP/core/controllers/trainer.py View File

@@ -294,9 +294,9 @@ class Trainer(TrainerEventTrigger):
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error";

注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``;
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 RichCallback, RawTextCallback对象,
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback 对象。如果
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。
@@ -573,7 +573,7 @@ class Trainer(TrainerEventTrigger):

if resume_from is not None:
if os.path.exists(resume_from):
self.load(resume_from, resume_training=resume_training)
self.load_checkpoint(resume_from, resume_training=resume_training)
else:
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.")

@@ -732,6 +732,7 @@ class Trainer(TrainerEventTrigger):
Trainer.__init__():
on_after_trainer_initialized(trainer, driver)
Trainer.run():
# load checkpoint if resume_from is not None
if num_eval_sanity_batch>0:
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch
on_sanity_check_end(trainer, sanity_check_res)


+ 37
- 33
fastNLP/core/dataset/dataset.py View File

@@ -19,10 +19,17 @@ from .field import FieldArray
from .instance import Instance
from fastNLP.core.utils.utils import pretty_table_printer, deprecated
from fastNLP.core.collators import Collator
from fastNLP.core.utils.rich_progress import f_rich_progress
from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress
from ..log import logger


progress_bars = {
'rich': f_rich_progress,
'tqdm': f_tqdm_progress
}


class ApplyResultException(Exception):
def __init__(self, msg, index=None):
super().__init__(msg)
@@ -30,7 +37,7 @@ class ApplyResultException(Exception):
self.index = index # 标示在哪个数据遭遇到问题了


def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True,
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, progress_bar: str = 'rich',
desc: str = None) -> list:
"""
对数据集进行处理封装函数,以便多进程使用
@@ -39,32 +46,29 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s
:param _apply_field: 需要处理数据集的field_name
:param func: 用户自定义的func
:param desc: 进度条的描述字符
:param show_progress_bar: 是否展示子进程进度条
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:return:
"""
if show_progress_bar:
desc = desc if desc else f"Main"
pg_main = f_rich_progress.add_task(description=desc, total=len(ds), visible=show_progress_bar)
progress_bar = progress_bars.get(progress_bar, DummyFRichProgress())
desc = desc if desc else "Processing"
task_id = progress_bar.add_task(description=desc, total=len(ds))
results = []
idx = -1

try:
# for idx, ins in tqdm(enumerate(ds), total=len(ds), position=0, desc=desc, disable=not show_progress_bar):
for idx, ins in enumerate(ds):
if _apply_field is not None:
results.append(func(ins[_apply_field]))
else:
results.append(func(ins))
if show_progress_bar:
f_rich_progress.update(pg_main, advance=1)
progress_bar.update(task_id, advance=1)

except BaseException as e:
if idx != -1:
logger.error("Exception happens at the `{}`th instance.".format(idx))
raise e
finally:
if show_progress_bar:
f_rich_progress.destroy_task(pg_main)
progress_bar.destroy_task(task_id)
return results


@@ -398,7 +402,7 @@ class DataSet:

def apply_field(self, func: Callable, field_name: str = None,
new_field_name: str = None, num_proc: int = 0,
progress_desc: str = None, show_progress_bar: bool = True):
progress_desc: str = None, progress_bar: str = 'rich'):
r"""
将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name``
中。
@@ -413,8 +417,8 @@ class DataSet:
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。

:param progress_desc: 进度条的描述字符,默认为 ``Main``;
:param show_progress_bar: 是否在处理过程中展示进度条;
:param progress_desc: 进度条的描述字符,默认为 ``Processing``;
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:return: 从函数 ``func`` 中得到的返回值;
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
@@ -422,7 +426,7 @@ class DataSet:
raise KeyError("DataSet has no field named `{}`.".format(field_name))

try:
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar,
results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar,
progress_desc=progress_desc, _apply_field=field_name)
except BaseException as e:
raise e
@@ -433,7 +437,7 @@ class DataSet:

def apply_field_more(self, func: Callable = None, field_name: str = None,
modify_fields: bool = True, num_proc: int = 0,
progress_desc: str = None, show_progress_bar: bool = True):
progress_desc: str = None, progress_bar: str = 'rich'):
r"""
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。
func 可以返回一个或多个 field 上的结果。
@@ -446,8 +450,8 @@ class DataSet:
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示进度条,默认展示
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符
:return Dict[str:Field]: 返回一个字典
"""
assert len(self) != 0, "Null DataSet cannot use apply_field()."
@@ -456,7 +460,7 @@ class DataSet:
idx = -1
results = {}
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
show_progress_bar=show_progress_bar, _apply_field=field_name)
progress_bar=progress_bar, _apply_field=field_name)
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
if not isinstance(apply_out[0], Mapping):
raise Exception(f"The result of func is not a Mapping, but a {type(apply_out[0])}")
@@ -483,13 +487,13 @@ class DataSet:
return results

def _apply_process(self, num_proc: int = 0, func: Callable = None,
show_progress_bar: bool = True, _apply_field: str = None,
progress_bar: str = 'rich', _apply_field: str = None,
progress_desc: str = 'Main') -> list:
"""
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance``
:param _apply_field: 需要传进去func的数据集的field_name
:param show_progress_bar: 是否展示progress进度条,默认为展示
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:param progress_desc: 进度条的描述字符,默认为'Main
"""
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>":
@@ -499,7 +503,7 @@ class DataSet:

if num_proc < 2:
results = _apply_single(ds=self, _apply_field=_apply_field, func=func,
desc=progress_desc, show_progress_bar=show_progress_bar)
desc=progress_desc, progress_bar=progress_bar)
else:
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2
import multiprocessing as mp
@@ -525,25 +529,25 @@ class DataSet:
proc.start()
pool.append(proc)
queues.append(queue)
progress_bar = progress_bars.get(progress_bar, DummyFRichProgress())
total_len = len(self)
task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar)
task_id = progress_bar.add_task(description=progress_desc, total=total_len)
last_count = -1
while counter.value < total_len or last_count == -1:
while counter.value == last_count:
time.sleep(0.1)
advance = counter.value - last_count
last_count = counter.value
f_rich_progress.update(task_id, advance=advance, refresh=True)
progress_bar.update(task_id, advance=advance, refresh=True)

for idx, proc in enumerate(pool):
results.extend(pickle.loads(queues[idx].get()))
proc.join()
f_rich_progress.destroy_task(task_id)
progress_bar.destroy_task(task_id)
return results

def apply_more(self, func: Callable = None, modify_fields: bool = True,
num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True):
num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'):
r"""
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。

@@ -558,9 +562,9 @@ class DataSet:

:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param num_proc: 进程的数量
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称
:param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:return Dict[str:Field]: 返回一个字典
"""
assert callable(func), "The func is not callable."
@@ -570,7 +574,7 @@ class DataSet:

results = {}
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc,
show_progress_bar=show_progress_bar)
progress_bar=progress_bar)
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。
if not isinstance(apply_out[0], dict):
raise Exception("The result of func is not a dict")
@@ -597,21 +601,21 @@ class DataSet:
return results

def apply(self, func: Callable = None, new_field_name: str = None,
num_proc: int = 0, show_progress_bar: bool = True, progress_desc: str = ''):
num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''):
"""

:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示进度条
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`
:param progress_desc: progress bar 显示的值,默认为空。
"""
assert callable(func), "The func you provide is not callable."
assert len(self) != 0, "Null DataSet cannot use apply()."
assert num_proc >= 0, "num_proc must be an integer >= 0."
try:
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar,
results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar,
progress_desc=progress_desc)
except BaseException as e:
raise e


+ 3
- 1
fastNLP/core/utils/__init__.py View File

@@ -22,7 +22,8 @@ __all__ = [
'Option',
'deprecated',
'seq_len_to_mask',
"flat_nest_dict"
"flat_nest_dict",
"f_tqdm_progress"
]

from .cache_results import cache_results
@@ -32,5 +33,6 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi
from .rich_progress import f_rich_progress
from .torch_utils import torch_move_data_to_device
from .utils import *
from .tqdm_progress import f_tqdm_progress



+ 8
- 2
fastNLP/core/utils/rich_progress.py View File

@@ -35,7 +35,7 @@ class DummyFRichProgress:
return None

@property
def dummy_rich(self)->bool:
def dummy(self)->bool:
"""
当前对象是否是 dummy 的 rich 对象。

@@ -122,6 +122,9 @@ class FRichProgress(Progress, metaclass=Singleton):
visible: bool = True,
**fields: Any,
) -> TaskID:
from .tqdm_progress import f_tqdm_progress
assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop."

if self.live._started is False:
self.start()
post_desc = fields.pop('post_desc', '')
@@ -213,7 +216,7 @@ class FRichProgress(Progress, metaclass=Singleton):
self.refresh()

@property
def dummy_rich(self) -> bool:
def dummy(self) -> bool:
"""
当前对象是否是 dummy 的 rich 对象。

@@ -221,6 +224,9 @@ class FRichProgress(Progress, metaclass=Singleton):
"""
return False

def not_empty(self):
return len(self._tasks) != 0


class SpeedColumn(ProgressColumn):
"""


+ 160
- 0
fastNLP/core/utils/tqdm_progress.py View File

@@ -0,0 +1,160 @@
__all__ = [
'f_tqdm_progress'
]

import uuid
import sys
from ...envs.imports import _module_available, _compare_version
from ...envs import get_global_rank
from .utils import is_notebook
from ..log import logger
if _module_available('tqdm'):
from tqdm.autonotebook import tqdm
import operator



class Singleton(type):
_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls._instances[cls]


# 如果不打印的时候,使得整个 progress 没有任何意义
class DummyFTqdmProgress:
def __getattr__(self, item):
return DummyFTqdmProgress()

def __call__(self, *args, **kwargs):
# 防止用户通过 DummyFRichProgress.console.print() 这种调用
return None

@property
def dummy(self)->bool:
"""
当前对象是否是 dummy 的 tqdm 对象。

:return:
"""
return True


class TqdmProgress(metaclass=Singleton):
def __init__(self):
self.bars = {}

def add_task(self, iterable=None, description=None, total=None, leave=False,
ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None,
ascii=None, visible=True, unit='it', unit_scale=False,
dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0,
postfix=None, unit_divisor=1000, write_bytes=None,
lock_args=None, nrows=None, colour=None, gui=False, **kwargs):
"""
主要就模仿了 tqdm bar 的创建,为了和 FRichProgress 的接口尽量统一,将 desc 重名为了 description,以及 disable 专为了
visible 。

:param iterable:
:param description:
:param total:
:param leave:
:param ncols:
:param mininterval:
:param maxinterval:
:param miniters:
:param ascii:
:param visible:
:param unit:
:param unit_scale:
:param dynamic_ncols:
:param smoothing:
:param bar_format:
:param initial:
:param postfix:
:param unit_divisor:
:param write_bytes:
:param lock_args:
:param nrows:
:param colour:
:param gui:
:param kwargs:
:return:
"""
assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \
f"To use {self.__class__.__name__}, tqdm>=4.57 is needed."

from .rich_progress import f_rich_progress
assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop."

if hasattr(self, 'orig_out_err'):
file = self.orig_out_err[0]
else:
file = sys.stdout

bar = tqdm(iterable=iterable, desc=description, total=total, leave=leave, file=file,
ncols=ncols, mininterval=mininterval, maxinterval=maxinterval, miniters=miniters,
ascii=ascii, disable=not visible, unit=unit, unit_scale=unit_scale,
dynamic_ncols=dynamic_ncols, smoothing=smoothing, bar_format=bar_format, initial=initial,
position=len(self.bars), postfix=postfix, unit_divisor=unit_divisor, write_bytes=write_bytes,
lock_args=lock_args, nrows=nrows, colour=colour, gui=gui, **kwargs)
_uuid = str(uuid.uuid1())
self.bars[_uuid] = bar
if not hasattr(self, 'orig_out_err') and not is_notebook():
from tqdm.contrib import DummyTqdmFile
self.orig_out_err = sys.stdout, sys.stderr
sys.stdout, sys.stderr = map(DummyTqdmFile, self.orig_out_err)

return _uuid

def update(self, task_id:str, advance:int, refresh=True):
self.bars[task_id].update(advance)

def set_postfix_str(self, task_id, s, refresh=True):
self.bars[task_id].set_postfix_str(s=s, refresh=refresh)

def set_description_str(self, task_id, desc, refresh=True):
self.bars[task_id].set_description_str(desc=desc, refresh=refresh)

def destroy_task(self, task_id):
"""
关闭 task_id 对应的 tqdm bar 。

:param task_id:
:return:
"""
self.bars[task_id].close()
self.bars.pop(task_id)
if len(self.bars) == 0 and hasattr(self, 'orig_out_err'):
# recover 成正常的 sys.stdout 与 sys.stderr
sys.stdout, sys.stderr = self.orig_out_err
delattr(self, 'orig_out_err')

def reset(self, task_id):
self.bars[task_id].reset()

def print(self):
tqdm.write('')

def not_empty(self):
return len(self.bars) != 0

@property
def dummy(self) -> bool:
"""
当前对象是否是 dummy 的 tqdm 对象。

:return:
"""
return False


if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and get_global_rank() == 0:
f_tqdm_progress = TqdmProgress()
else:
f_tqdm_progress = DummyFTqdmProgress()
logger.debug("Use dummy tqdm...")




+ 4
- 4
fastNLP/core/vocabulary.py View File

@@ -340,7 +340,7 @@ class Vocabulary(object):
try:
for f_n, n_f_n in zip(field_name, new_field_name):
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n,
show_progress_bar=False)
progress_bar=None)
except Exception as e:
logger.error("When processing the `{}` dataset, the following error occurred.".format(idx))
raise e
@@ -396,7 +396,7 @@ class Vocabulary(object):
for idx, dataset in enumerate(datasets):
if isinstance(dataset, DataSet):
try:
dataset.apply(construct_vocab, show_progress_bar=False)
dataset.apply(construct_vocab, progress_bar=None)
except BaseException as e:
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx))
raise e
@@ -406,12 +406,12 @@ class Vocabulary(object):
if no_create_entry_dataset is not None:
partial_construct_vocab = partial(construct_vocab, no_create_entry=True)
if isinstance(no_create_entry_dataset, DataSet):
no_create_entry_dataset.apply(partial_construct_vocab, show_progress_bar=False)
no_create_entry_dataset.apply(partial_construct_vocab, progress_bar=None)
elif isinstance(no_create_entry_dataset, list):
for dataset in no_create_entry_dataset:
if not isinstance(dataset, DataSet):
raise TypeError("Only DataSet type is allowed.")
dataset.apply(partial_construct_vocab, show_progress_bar=False)
dataset.apply(partial_construct_vocab, progress_bar=None)
return self
def _is_word_no_create_entry(self, word:str):


+ 16
- 16
fastNLP/io/data_bundle.py View File

@@ -221,7 +221,7 @@ class DataBundle:
yield field_name, vocab

def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0,
ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True):
ignore_miss_dataset: bool = True, progress_desc: str = '', progress_bar: str = 'rich'):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法

@@ -233,8 +233,8 @@ class DataBundle:
如果为False,则报错
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称
:param show_progress_bar: 是否显示tqdm进度条
:param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。

"""
_progress_desc = progress_desc
@@ -243,13 +243,13 @@ class DataBundle:
progress_desc = _progress_desc + f' for `{name}`'
if dataset.has_field(field_name=field_name):
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc,
progress_desc=progress_desc, show_progress_bar=show_progress_bar)
progress_desc=progress_desc, progress_bar=progress_bar)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name}.")
return self

def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True,
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''):
ignore_miss_dataset=True, progress_bar: str = 'rich', progress_desc: str = ''):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法

@@ -263,8 +263,8 @@ class DataBundle:
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet;
如果为False,则报错
:param show_progress_bar: 是否显示进度条
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。

:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字

@@ -277,13 +277,13 @@ class DataBundle:
if dataset.has_field(field_name=field_name):
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc,
modify_fields=modify_fields,
show_progress_bar=show_progress_bar, progress_desc=progress_desc)
progress_bar=progress_bar, progress_desc=progress_desc)
elif not ignore_miss_dataset:
raise KeyError(f"{field_name} not found DataSet:{name} .")
return res

def apply(self, func: Callable, new_field_name: str, num_proc: int = 0,
progress_desc: str = '', show_progress_bar: bool = True):
progress_desc: str = '', progress_bar: bool = True):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法

@@ -293,20 +293,20 @@ class DataBundle:
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆
盖之前的field。如果为None则不创建新的field。
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示tqd进度条
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称

"""
_progress_desc = progress_desc
for name, dataset in self.datasets.items():
if _progress_desc:
progress_desc = _progress_desc + f' for `{name}`'
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar,
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar,
progress_desc=progress_desc)
return self

def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0,
progress_desc: str = '', show_progress_bar: bool = True):
progress_desc: str = '', progress_bar: str = 'rich'):
r"""
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法

@@ -317,8 +317,8 @@ class DataBundle:
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。
:param show_progress_bar: 是否显示tqd进度条
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。
:param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称

:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字
"""
@@ -328,7 +328,7 @@ class DataBundle:
if _progress_desc:
progress_desc = _progress_desc + f' for `{name}`'
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc,
show_progress_bar=show_progress_bar, progress_desc=progress_desc)
progress_bar=progress_bar, progress_desc=progress_desc)
return res

def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle":


+ 11
- 11
fastNLP/transformers/torch/models/auto/configuration_auto.py View File

@@ -279,7 +279,7 @@ class _LazyConfigMapping(OrderedDict):
value = self._mapping[key]
module_name = model_type_to_module_name(key)
if module_name not in self._modules:
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models")
self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models")
return getattr(self._modules[module_name], value)

def keys(self):
@@ -318,15 +318,15 @@ class _LazyLoadAllMappings(OrderedDict):
def _initialize(self):
if self._initialized:
return
logger.warn(
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
FutureWarning,
)
# logger.warn(
# "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. "
# "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.",
# FutureWarning,
# )

for model_type, map_name in self._mapping.items():
module_name = model_type_to_module_name(model_type)
module = importlib.import_module(f".{module_name}", "transformers.models")
module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models")
mapping = getattr(module, map_name)
self._data.update(mapping)

@@ -362,8 +362,8 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPP

def _get_class_name(model_class: Union[str, List[str]]):
if isinstance(model_class, (list, tuple)):
return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None])
return f":class:`~transformers.{model_class}`"
return " or ".join([f":class:`~fastNLP.transformers.torch.{c}`" for c in model_class if c is not None])
return f":class:`~fastNLP.transformers.torch.{model_class}`"


def _list_model_options(indent, config_to_class=None, use_model_types=True):
@@ -372,7 +372,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
if use_model_types:
if config_to_class is None:
model_type_to_name = {
model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items()
model_type: f":class:`~fastNLP.transformers.torch.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items()
}
else:
model_type_to_name = {
@@ -394,7 +394,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True):
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items()
}
lines = [
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
f"{indent}- :class:`~fastNLP.transformers.torch.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)"
for config_name in sorted(config_to_name.keys())
]
return "\n".join(lines)


+ 0
- 6
tests/core/callbacks/test_more_evaluate_callback.py View File

@@ -4,17 +4,12 @@
(2) 能不能保存 topk 并load进来进行训练

"""
import pytest



import os
import pytest
from typing import Any
from dataclasses import dataclass

from pathlib import Path
import re

from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
@@ -25,7 +20,6 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset
from torchmetrics import Accuracy
from fastNLP.core.metrics import Metric
from fastNLP.core.log import logger
from fastNLP.core.callbacks import MoreEvaluateCallback
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:


+ 123
- 0
tests/core/callbacks/test_progress_callback_torch.py View File

@@ -0,0 +1,123 @@
from typing import Any
from dataclasses import dataclass

import pytest

from fastNLP import Metric, Accuracy
from tests.helpers.utils import magic_argv_env_context
from fastNLP import Trainer, Evaluator
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
from torch.utils.data import DataLoader
from torch.optim import SGD
import torch.distributed as dist
import torch

from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.datasets.torch_data import TorchArgMaxDataset


@dataclass
class ArgMaxDatasetConfig:
num_labels: int = 10
feature_dimension: int = 10
data_num: int = 100
seed: int = 0

batch_size: int = 4
shuffle: bool = True


@dataclass
class TrainerParameters:
model: Any = None
optimizers: Any = None
train_dataloader: Any = None
evaluate_dataloaders: Any = None
input_mapping: Any = None
output_mapping: Any = None
metrics: Any = None
more_metrics: Any = None


@pytest.fixture(scope="module", params=[0], autouse=True)
def model_and_optimizers(request):
trainer_params = TrainerParameters()

trainer_params.model = TorchNormalModel_Classification_1(
num_labels=ArgMaxDatasetConfig.num_labels,
feature_dimension=ArgMaxDatasetConfig.feature_dimension
)
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001)
dataset = TorchArgMaxDataset(
feature_dimension=ArgMaxDatasetConfig.feature_dimension,
data_num=ArgMaxDatasetConfig.data_num,
seed=ArgMaxDatasetConfig.seed
)
_dataloader = DataLoader(
dataset=dataset,
batch_size=ArgMaxDatasetConfig.batch_size,
shuffle=True
)

class LossMetric(Metric):
def __init__(self):
super().__init__()
self.register_element('loss')

def update(self, loss):
self.loss += loss.item()

def get_metric(self) -> dict:
return self.loss.item()

trainer_params.train_dataloader = _dataloader
trainer_params.evaluate_dataloaders = _dataloader
trainer_params.metrics = {'loss': LossMetric()}

trainer_params.more_metrics = {"acc": Accuracy()}

return trainer_params


@pytest.mark.torch
@pytest.mark.parametrize('device', ['cpu', [0, 1]])
@pytest.mark.parametrize('progress_bar', ['rich', 'auto', None, 'raw', 'tqdm'])
@magic_argv_env_context
def test_run( model_and_optimizers: TrainerParameters, device, progress_bar):

if device != 'cpu' and not torch.cuda.is_available():
pytest.skip(f"No cuda for device:{device}")
n_epochs = 5
trainer = Trainer(
model=model_and_optimizers.model,
driver='torch',
device=device,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=n_epochs,
callbacks=None,
progress_bar=progress_bar,
output_from_new_proc="all",
evaluate_fn='train_step',
larger_better=False
)

trainer.run()

evaluator = Evaluator(model=model_and_optimizers.model, dataloaders=model_and_optimizers.train_dataloader,
driver=trainer.driver, metrics=model_and_optimizers.metrics,
progress_bar=progress_bar, evaluate_fn='train_step')
evaluator.run()

if dist.is_initialized():
dist.destroy_process_group()






+ 3
- 3
tests/core/dataset/test_dataset.py View File

@@ -181,7 +181,7 @@ class TestDataSetMethods:
assert ("rx" in ds.field_arrays) == True
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1]

ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False)
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", progress_bar=None)
assert ds.field_arrays["y"].content[0] == 2

res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len")
@@ -198,8 +198,8 @@ class TestDataSetMethods:
def do_nothing(ins):
time.sleep(0.01)

ds.apply(do_nothing, show_progress_bar=True, num_proc=0)
ds.apply_field(do_nothing, field_name='x', show_progress_bar=True)
ds.apply(do_nothing, progress_bar='rich', num_proc=0)
ds.apply_field(do_nothing, field_name='x', progress_bar='rich')

def test_apply_cannot_modify_instance(self):
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})


+ 16
- 0
tests/core/utils/test_progress.py View File

@@ -0,0 +1,16 @@
import pytest
from fastNLP.envs.imports import _module_available
from fastNLP.core.utils import f_tqdm_progress, f_rich_progress

def test_raise():
if not _module_available('tqdm') or f_rich_progress.dummy or f_tqdm_progress.dummy:
pytest.skip('No tqdm')
t = f_rich_progress.add_task('test', total=10)
with pytest.raises(AssertionError):
f_tqdm_progress.add_task('test')

f_rich_progress.destroy_task(t)

t = f_tqdm_progress.add_task('test', total=10)
with pytest.raises(AssertionError):
f_rich_progress.add_task('test')

Loading…
Cancel
Save