@@ -6,6 +6,7 @@ __all__ = [ | |||
'CheckpointCallback', | |||
'ProgressCallback', | |||
'RichCallback', | |||
'TqdmCallback', | |||
"LRSchedCallback", | |||
'LoadBestModelCallback', | |||
"EarlyStopCallback", | |||
@@ -4,8 +4,11 @@ __all__ = [ | |||
'Filter', | |||
'CheckpointCallback', | |||
'choose_progress_callback', | |||
'ProgressCallback', | |||
'RichCallback', | |||
'TqdmCallback', | |||
"LRSchedCallback", | |||
'LoadBestModelCallback', | |||
"EarlyStopCallback", | |||
@@ -26,7 +29,7 @@ from .callback import Callback | |||
from .callback_event import Event, Filter | |||
from .callback_manager import CallbackManager | |||
from .checkpoint_callback import CheckpointCallback | |||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback | |||
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback | |||
from .lr_scheduler_callback import LRSchedCallback | |||
from .load_best_model_callback import LoadBestModelCallback | |||
from .early_stop_callback import EarlyStopCallback | |||
@@ -5,11 +5,14 @@ from typing import Union | |||
__all__ = [ | |||
'choose_progress_callback', | |||
'ProgressCallback', | |||
'RichCallback' | |||
'RichCallback', | |||
'TqdmCallback' | |||
] | |||
from ...envs.imports import _module_available, _compare_version | |||
from .has_monitor_callback import HasMonitorCallback | |||
from fastNLP.core.utils import f_rich_progress | |||
from fastNLP.core.utils import f_rich_progress, f_tqdm_progress | |||
from fastNLP.core.log import logger | |||
@@ -24,7 +27,7 @@ class ProgressCallback(HasMonitorCallback): | |||
def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: | |||
if progress_bar == 'auto': | |||
if not f_rich_progress.dummy_rich: | |||
if not f_rich_progress.dummy: | |||
progress_bar = 'rich' | |||
else: | |||
progress_bar = 'raw' | |||
@@ -32,6 +35,8 @@ def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> Prog | |||
return RichCallback() | |||
elif progress_bar == 'raw': | |||
return RawTextCallback() | |||
elif progress_bar == 'tqdm': | |||
return TqdmCallback() | |||
elif isinstance(progress_bar, ProgressCallback): | |||
return progress_bar | |||
else: | |||
@@ -82,7 +87,9 @@ class RichCallback(ProgressCallback): | |||
if 'batch' in self.task2id: | |||
self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) | |||
else: | |||
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', total=trainer.num_batches_per_epoch) | |||
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', | |||
total=trainer.num_batches_per_epoch, | |||
completed=trainer.batch_idx_in_epoch) | |||
def on_train_epoch_end(self, trainer): | |||
self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', | |||
@@ -208,4 +215,95 @@ class RawTextCallback(ProgressCallback): | |||
@property | |||
def name(self): # progress bar的名称 | |||
return 'raw' | |||
return 'raw' | |||
class TqdmCallback(ProgressCallback): | |||
""" | |||
在训练过程中打印 tqdm progress bar 的 callback 。在 Trainer 中,默认就会使用这个 callback 来显示进度。如果需要定制这个 Callback 的 | |||
参数,请通过实例化本 Callback 并传入到 Trainer 中实现。 | |||
:param print_every: 多少个 batch 更新一次显示。 | |||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | |||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||
* 为 ``None`` | |||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||
* 为 ``str`` | |||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||
* 为 ``Callable`` | |||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||
的 ``monitor`` 值请返回 ``None`` 。 | |||
:param larger_better: 是否是 monitor 的结果越大越好。 | |||
:param format_json: 是否格式化 json 再打印 | |||
""" | |||
def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, | |||
format_json=True): | |||
super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) | |||
self.print_every = print_every | |||
self.progress_bar = f_tqdm_progress | |||
self.task2id = {} | |||
self.loss = 0 | |||
self.loss_round_ndigit = loss_round_ndigit | |||
self.format_json = format_json | |||
self.num_signs = 10 | |||
def on_train_begin(self, trainer): | |||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | |||
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', | |||
initial=trainer.global_forward_batches/(trainer.total_batches+1e-6)) | |||
def on_train_epoch_begin(self, trainer): | |||
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | |||
if 'batch' in self.task2id: | |||
self.progress_bar.reset(self.task2id['batch']) | |||
else: | |||
self.task2id['batch'] = self.progress_bar.add_task(description='Batch', total=trainer.num_batches_per_epoch, | |||
initial=trainer.batch_idx_in_epoch) | |||
self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | |||
def on_train_end(self, trainer): | |||
self.clear_tasks() | |||
def on_before_backward(self, trainer, outputs): | |||
loss = trainer.extract_loss_from_outputs(outputs) | |||
loss = trainer.driver.tensor_to_numeric(loss, reduce='sum') | |||
self.loss += loss | |||
def on_train_batch_end(self, trainer): | |||
if trainer.global_forward_batches % self.print_every == 0: | |||
loss = self.loss/self.print_every | |||
self.loss = 0 | |||
self.progress_bar.update(self.task2id['batch'], advance=self.print_every, refresh=True) | |||
self.progress_bar.set_postfix_str(self.task2id['batch'], f'Loss:{round(loss, self.loss_round_ndigit)}') | |||
self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True) | |||
def on_evaluate_end(self, trainer, results): | |||
if len(results)==0: | |||
return | |||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||
text = '' | |||
if self.monitor is not None: | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if abs(self.monitor_value) != float('inf'): | |||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||
if len(text) == 0: | |||
text = '-'*self.num_signs + base_text + '-'*self.num_signs | |||
logger.info(text) | |||
if self.format_json: | |||
logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) | |||
else: | |||
logger.info(results) | |||
def clear_tasks(self): | |||
for key, taskid in self.task2id.items(): | |||
self.progress_bar.destroy_task(taskid) | |||
self.task2id = {} | |||
self.loss = 0 | |||
@property | |||
def name(self): # progress bar的名称 | |||
return 'tqdm' |
@@ -19,7 +19,7 @@ from fastNLP.core.drivers import Driver, TorchDriver | |||
from ..drivers.choose_driver import choose_driver | |||
from .loops import Loop, EvaluateBatchLoop | |||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | |||
match_and_substitute_params, f_rich_progress, flat_nest_dict | |||
match_and_substitute_params, f_rich_progress, flat_nest_dict, f_tqdm_progress | |||
from fastNLP.core.metrics import Metric | |||
from fastNLP.core.metrics.utils import _is_torchmetrics_metric, _is_paddle_metric, _is_allennlp_metric | |||
from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader | |||
@@ -166,8 +166,9 @@ class Evaluator: | |||
self.dataloaders[name] = dl | |||
self.progress_bar = kwargs.get('progress_bar', 'auto') | |||
assert self.progress_bar in [None, 'rich', 'auto', 'tqdm', 'raw'] | |||
if self.progress_bar == 'auto': | |||
self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' | |||
self.progress_bar = 'raw' if f_rich_progress.dummy else 'rich' | |||
self.driver.barrier() | |||
@@ -226,12 +227,15 @@ class Evaluator: | |||
return metric_results | |||
def start_progress_bar(self, total: int, dataloader_name): | |||
if self.progress_bar == 'rich': | |||
if self.progress_bar in ('rich', 'tqdm'): | |||
if dataloader_name is None: | |||
desc = f'Eval. Batch:0' | |||
desc = f'Eval. Batch' | |||
else: | |||
desc = f'Eval. on {dataloader_name} Batch' | |||
if self.progress_bar == 'rich': | |||
self._task_id = f_rich_progress.add_task(description=desc, total=total) | |||
else: | |||
desc = f'Eval. on {dataloader_name} Batch:0' | |||
self._rich_task_id = f_rich_progress.add_task(description=desc, total=total) | |||
self._task_id = f_tqdm_progress.add_task(description=desc, total=total) | |||
elif self.progress_bar == 'raw': | |||
desc = 'Evaluation starts' | |||
if dataloader_name is not None: | |||
@@ -244,19 +248,26 @@ class Evaluator: | |||
else: | |||
desc = f'Eval. on {dataloader_name} Batch:{batch_idx}' | |||
if self.progress_bar == 'rich': | |||
assert hasattr(self, '_rich_task_id'), "You must first call `start_progress_bar()` before calling " \ | |||
assert hasattr(self, '_task_id'), "You must first call `start_progress_bar()` before calling " \ | |||
"update_progress_bar()" | |||
f_rich_progress.update(self._rich_task_id, description=desc, post_desc=kwargs.get('post_desc', ''), | |||
f_rich_progress.update(self._task_id, description=desc, post_desc=kwargs.get('post_desc', ''), | |||
advance=kwargs.get('advance', 1), refresh=kwargs.get('refresh', True), | |||
visible=kwargs.get('visible', True)) | |||
elif self.progress_bar == 'raw': | |||
if self.verbose > 1: | |||
logger.info(desc) | |||
elif self.progress_bar == 'tqdm': | |||
f_tqdm_progress.update(self._task_id, advance=1) | |||
def remove_progress_bar(self, dataloader_name): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
delattr(self, '_rich_task_id') | |||
if self.progress_bar == 'rich' and hasattr(self, '_task_id'): | |||
f_rich_progress.destroy_task(self._task_id) | |||
delattr(self, '_task_id') | |||
elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'): | |||
f_tqdm_progress.destroy_task(self._task_id) | |||
delattr(self, '_task_id') | |||
elif self.progress_bar == 'raw': | |||
desc = 'Evaluation ends' | |||
if dataloader_name is not None: | |||
@@ -264,9 +275,12 @@ class Evaluator: | |||
logger.info("*" * 10 + desc + '*' * 10 + '\n') | |||
def finally_progress_bar(self): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
delattr(self, '_rich_task_id') | |||
if self.progress_bar == 'rich' and hasattr(self, '_task_id'): | |||
f_rich_progress.destroy_task(self._task_id) | |||
delattr(self, '_task_id') | |||
elif self.progress_bar == 'tqdm' and hasattr(self, '_task_id'): | |||
f_tqdm_progress.destroy_task(self._task_id) | |||
delattr(self, '_task_id') | |||
@property | |||
def evaluate_batch_loop(self): | |||
@@ -294,9 +294,9 @@ class Trainer(TrainerEventTrigger): | |||
log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; | |||
注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; | |||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, | |||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback对象。如果 | |||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 | |||
* *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 RichCallback, RawTextCallback等对象, | |||
默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 RichCallback,否则使用 RawTextCallback 对象。如果 | |||
需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 等对象。 | |||
* *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 | |||
* *train_output_mapping* -- 与 output_mapping 一致,但是只用于 ``Trainer`` 中。与 output_mapping 互斥。 | |||
* *evaluate_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Evaluator`` 中。与 input_mapping 互斥。 | |||
@@ -573,7 +573,7 @@ class Trainer(TrainerEventTrigger): | |||
if resume_from is not None: | |||
if os.path.exists(resume_from): | |||
self.load(resume_from, resume_training=resume_training) | |||
self.load_checkpoint(resume_from, resume_training=resume_training) | |||
else: | |||
raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") | |||
@@ -732,6 +732,7 @@ class Trainer(TrainerEventTrigger): | |||
Trainer.__init__(): | |||
on_after_trainer_initialized(trainer, driver) | |||
Trainer.run(): | |||
# load checkpoint if resume_from is not None | |||
if num_eval_sanity_batch>0: | |||
on_sanity_check_begin(trainer) # 如果设置了num_eval_sanity_batch | |||
on_sanity_check_end(trainer, sanity_check_res) | |||
@@ -19,10 +19,17 @@ from .field import FieldArray | |||
from .instance import Instance | |||
from fastNLP.core.utils.utils import pretty_table_printer, deprecated | |||
from fastNLP.core.collators import Collator | |||
from fastNLP.core.utils.rich_progress import f_rich_progress | |||
from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress | |||
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress | |||
from ..log import logger | |||
progress_bars = { | |||
'rich': f_rich_progress, | |||
'tqdm': f_tqdm_progress | |||
} | |||
class ApplyResultException(Exception): | |||
def __init__(self, msg, index=None): | |||
super().__init__(msg) | |||
@@ -30,7 +37,7 @@ class ApplyResultException(Exception): | |||
self.index = index # 标示在哪个数据遭遇到问题了 | |||
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, show_progress_bar: bool = True, | |||
def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, progress_bar: str = 'rich', | |||
desc: str = None) -> list: | |||
""" | |||
对数据集进行处理封装函数,以便多进程使用 | |||
@@ -39,32 +46,29 @@ def _apply_single(ds=None, _apply_field=None, func: Optional[Callable] = None, s | |||
:param _apply_field: 需要处理数据集的field_name | |||
:param func: 用户自定义的func | |||
:param desc: 进度条的描述字符 | |||
:param show_progress_bar: 是否展示子进程进度条 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:return: | |||
""" | |||
if show_progress_bar: | |||
desc = desc if desc else f"Main" | |||
pg_main = f_rich_progress.add_task(description=desc, total=len(ds), visible=show_progress_bar) | |||
progress_bar = progress_bars.get(progress_bar, DummyFRichProgress()) | |||
desc = desc if desc else "Processing" | |||
task_id = progress_bar.add_task(description=desc, total=len(ds)) | |||
results = [] | |||
idx = -1 | |||
try: | |||
# for idx, ins in tqdm(enumerate(ds), total=len(ds), position=0, desc=desc, disable=not show_progress_bar): | |||
for idx, ins in enumerate(ds): | |||
if _apply_field is not None: | |||
results.append(func(ins[_apply_field])) | |||
else: | |||
results.append(func(ins)) | |||
if show_progress_bar: | |||
f_rich_progress.update(pg_main, advance=1) | |||
progress_bar.update(task_id, advance=1) | |||
except BaseException as e: | |||
if idx != -1: | |||
logger.error("Exception happens at the `{}`th instance.".format(idx)) | |||
raise e | |||
finally: | |||
if show_progress_bar: | |||
f_rich_progress.destroy_task(pg_main) | |||
progress_bar.destroy_task(task_id) | |||
return results | |||
@@ -398,7 +402,7 @@ class DataSet: | |||
def apply_field(self, func: Callable, field_name: str = None, | |||
new_field_name: str = None, num_proc: int = 0, | |||
progress_desc: str = None, show_progress_bar: bool = True): | |||
progress_desc: str = None, progress_bar: str = 'rich'): | |||
r""" | |||
将 :class:`~DataSet` 每个 ``instance`` 中为 ``field_name`` 的 ``field`` 传给函数 ``func``,并写入到 ``new_field_name`` | |||
中。 | |||
@@ -413,8 +417,8 @@ class DataSet: | |||
由于 ``python`` 语言的特性,设置该参数后会导致相应倍数的内存增长,这可能会对您程序的执行带来一定的影响。 | |||
:param progress_desc: 进度条的描述字符,默认为 ``Main``; | |||
:param show_progress_bar: 是否在处理过程中展示进度条; | |||
:param progress_desc: 进度条的描述字符,默认为 ``Processing``; | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:return: 从函数 ``func`` 中得到的返回值; | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
@@ -422,7 +426,7 @@ class DataSet: | |||
raise KeyError("DataSet has no field named `{}`.".format(field_name)) | |||
try: | |||
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, | |||
results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar, | |||
progress_desc=progress_desc, _apply_field=field_name) | |||
except BaseException as e: | |||
raise e | |||
@@ -433,7 +437,7 @@ class DataSet: | |||
def apply_field_more(self, func: Callable = None, field_name: str = None, | |||
modify_fields: bool = True, num_proc: int = 0, | |||
progress_desc: str = None, show_progress_bar: bool = True): | |||
progress_desc: str = None, progress_bar: str = 'rich'): | |||
r""" | |||
将 ``DataSet`` 中的每个 ``Instance`` 中的名为 `field_name` 的field 传给 func,并获取它的返回值。 | |||
func 可以返回一个或多个 field 上的结果。 | |||
@@ -446,8 +450,8 @@ class DataSet: | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param modify_fields: 是否用结果修改 `DataSet` 中的 `Field`, 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示进度条,默认展示 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条描述字符 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress_bar 时,显示当前正在处理的进度条描述字符 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
assert len(self) != 0, "Null DataSet cannot use apply_field()." | |||
@@ -456,7 +460,7 @@ class DataSet: | |||
idx = -1 | |||
results = {} | |||
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | |||
show_progress_bar=show_progress_bar, _apply_field=field_name) | |||
progress_bar=progress_bar, _apply_field=field_name) | |||
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | |||
if not isinstance(apply_out[0], Mapping): | |||
raise Exception(f"The result of func is not a Mapping, but a {type(apply_out[0])}") | |||
@@ -483,13 +487,13 @@ class DataSet: | |||
return results | |||
def _apply_process(self, num_proc: int = 0, func: Callable = None, | |||
show_progress_bar: bool = True, _apply_field: str = None, | |||
progress_bar: str = 'rich', _apply_field: str = None, | |||
progress_desc: str = 'Main') -> list: | |||
""" | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param func: 用户自定义处理函数,参数是 ``DataSet`` 中的 ``Instance`` | |||
:param _apply_field: 需要传进去func的数据集的field_name | |||
:param show_progress_bar: 是否展示progress进度条,默认为展示 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 进度条的描述字符,默认为'Main | |||
""" | |||
if isinstance(func, LambdaType) and num_proc>1 and func.__name__ == "<lambda>": | |||
@@ -499,7 +503,7 @@ class DataSet: | |||
if num_proc < 2: | |||
results = _apply_single(ds=self, _apply_field=_apply_field, func=func, | |||
desc=progress_desc, show_progress_bar=show_progress_bar) | |||
desc=progress_desc, progress_bar=progress_bar) | |||
else: | |||
# TODO 1. desc这个需要修改一下,应该把 subprocess 的 desc 修改一下。修改成Process 1 / Process 2 | |||
import multiprocessing as mp | |||
@@ -525,25 +529,25 @@ class DataSet: | |||
proc.start() | |||
pool.append(proc) | |||
queues.append(queue) | |||
progress_bar = progress_bars.get(progress_bar, DummyFRichProgress()) | |||
total_len = len(self) | |||
task_id = f_rich_progress.add_task(description=progress_desc, total=total_len, visible=show_progress_bar) | |||
task_id = progress_bar.add_task(description=progress_desc, total=total_len) | |||
last_count = -1 | |||
while counter.value < total_len or last_count == -1: | |||
while counter.value == last_count: | |||
time.sleep(0.1) | |||
advance = counter.value - last_count | |||
last_count = counter.value | |||
f_rich_progress.update(task_id, advance=advance, refresh=True) | |||
progress_bar.update(task_id, advance=advance, refresh=True) | |||
for idx, proc in enumerate(pool): | |||
results.extend(pickle.loads(queues[idx].get())) | |||
proc.join() | |||
f_rich_progress.destroy_task(task_id) | |||
progress_bar.destroy_task(task_id) | |||
return results | |||
def apply_more(self, func: Callable = None, modify_fields: bool = True, | |||
num_proc: int = 0, progress_desc: str = '', show_progress_bar: bool = True): | |||
num_proc: int = 0, progress_desc: str = '', progress_bar: str = 'rich'): | |||
r""" | |||
将 ``DataSet`` 中每个 ``Instance`` 传入到func中,并获取它的返回值。func可以返回一个或多个 field 上的结果。 | |||
@@ -558,9 +562,9 @@ class DataSet: | |||
:param modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param num_proc: 进程的数量 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前正在处理的进度条名称 | |||
:param progress_desc: 当 progress_bar 不为 None 时,可以显示当前正在处理的进度条名称 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:return Dict[str:Field]: 返回一个字典 | |||
""" | |||
assert callable(func), "The func is not callable." | |||
@@ -570,7 +574,7 @@ class DataSet: | |||
results = {} | |||
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | |||
show_progress_bar=show_progress_bar) | |||
progress_bar=progress_bar) | |||
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | |||
if not isinstance(apply_out[0], dict): | |||
raise Exception("The result of func is not a dict") | |||
@@ -597,21 +601,21 @@ class DataSet: | |||
return results | |||
def apply(self, func: Callable = None, new_field_name: str = None, | |||
num_proc: int = 0, show_progress_bar: bool = True, progress_desc: str = ''): | |||
num_proc: int = 0, progress_bar: str = 'rich', progress_desc: str = ''): | |||
""" | |||
:param func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示进度条。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: progress bar 显示的值,默认为空。 | |||
""" | |||
assert callable(func), "The func you provide is not callable." | |||
assert len(self) != 0, "Null DataSet cannot use apply()." | |||
assert num_proc >= 0, "num_proc must be an integer >= 0." | |||
try: | |||
results = self._apply_process(num_proc=num_proc, func=func, show_progress_bar=show_progress_bar, | |||
results = self._apply_process(num_proc=num_proc, func=func, progress_bar=progress_bar, | |||
progress_desc=progress_desc) | |||
except BaseException as e: | |||
raise e | |||
@@ -22,7 +22,8 @@ __all__ = [ | |||
'Option', | |||
'deprecated', | |||
'seq_len_to_mask', | |||
"flat_nest_dict" | |||
"flat_nest_dict", | |||
"f_tqdm_progress" | |||
] | |||
from .cache_results import cache_results | |||
@@ -32,5 +33,6 @@ from .paddle_utils import paddle_to, paddle_move_data_to_device, get_paddle_devi | |||
from .rich_progress import f_rich_progress | |||
from .torch_utils import torch_move_data_to_device | |||
from .utils import * | |||
from .tqdm_progress import f_tqdm_progress | |||
@@ -35,7 +35,7 @@ class DummyFRichProgress: | |||
return None | |||
@property | |||
def dummy_rich(self)->bool: | |||
def dummy(self)->bool: | |||
""" | |||
当前对象是否是 dummy 的 rich 对象。 | |||
@@ -122,6 +122,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
visible: bool = True, | |||
**fields: Any, | |||
) -> TaskID: | |||
from .tqdm_progress import f_tqdm_progress | |||
assert not f_tqdm_progress.not_empty(), "Cannot use rich before tqdm finish loop." | |||
if self.live._started is False: | |||
self.start() | |||
post_desc = fields.pop('post_desc', '') | |||
@@ -213,7 +216,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
self.refresh() | |||
@property | |||
def dummy_rich(self) -> bool: | |||
def dummy(self) -> bool: | |||
""" | |||
当前对象是否是 dummy 的 rich 对象。 | |||
@@ -221,6 +224,9 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
""" | |||
return False | |||
def not_empty(self): | |||
return len(self._tasks) != 0 | |||
class SpeedColumn(ProgressColumn): | |||
""" | |||
@@ -0,0 +1,160 @@ | |||
__all__ = [ | |||
'f_tqdm_progress' | |||
] | |||
import uuid | |||
import sys | |||
from ...envs.imports import _module_available, _compare_version | |||
from ...envs import get_global_rank | |||
from .utils import is_notebook | |||
from ..log import logger | |||
if _module_available('tqdm'): | |||
from tqdm.autonotebook import tqdm | |||
import operator | |||
class Singleton(type): | |||
_instances = {} | |||
def __call__(cls, *args, **kwargs): | |||
if cls not in cls._instances: | |||
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) | |||
return cls._instances[cls] | |||
# 如果不打印的时候,使得整个 progress 没有任何意义 | |||
class DummyFTqdmProgress: | |||
def __getattr__(self, item): | |||
return DummyFTqdmProgress() | |||
def __call__(self, *args, **kwargs): | |||
# 防止用户通过 DummyFRichProgress.console.print() 这种调用 | |||
return None | |||
@property | |||
def dummy(self)->bool: | |||
""" | |||
当前对象是否是 dummy 的 tqdm 对象。 | |||
:return: | |||
""" | |||
return True | |||
class TqdmProgress(metaclass=Singleton): | |||
def __init__(self): | |||
self.bars = {} | |||
def add_task(self, iterable=None, description=None, total=None, leave=False, | |||
ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None, | |||
ascii=None, visible=True, unit='it', unit_scale=False, | |||
dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0, | |||
postfix=None, unit_divisor=1000, write_bytes=None, | |||
lock_args=None, nrows=None, colour=None, gui=False, **kwargs): | |||
""" | |||
主要就模仿了 tqdm bar 的创建,为了和 FRichProgress 的接口尽量统一,将 desc 重名为了 description,以及 disable 专为了 | |||
visible 。 | |||
:param iterable: | |||
:param description: | |||
:param total: | |||
:param leave: | |||
:param ncols: | |||
:param mininterval: | |||
:param maxinterval: | |||
:param miniters: | |||
:param ascii: | |||
:param visible: | |||
:param unit: | |||
:param unit_scale: | |||
:param dynamic_ncols: | |||
:param smoothing: | |||
:param bar_format: | |||
:param initial: | |||
:param postfix: | |||
:param unit_divisor: | |||
:param write_bytes: | |||
:param lock_args: | |||
:param nrows: | |||
:param colour: | |||
:param gui: | |||
:param kwargs: | |||
:return: | |||
""" | |||
assert _module_available('tqdm') and _compare_version('tqdm', operator.ge, '4.57'), \ | |||
f"To use {self.__class__.__name__}, tqdm>=4.57 is needed." | |||
from .rich_progress import f_rich_progress | |||
assert not f_rich_progress.not_empty(), "Cannot use tqdm before rich finish loop." | |||
if hasattr(self, 'orig_out_err'): | |||
file = self.orig_out_err[0] | |||
else: | |||
file = sys.stdout | |||
bar = tqdm(iterable=iterable, desc=description, total=total, leave=leave, file=file, | |||
ncols=ncols, mininterval=mininterval, maxinterval=maxinterval, miniters=miniters, | |||
ascii=ascii, disable=not visible, unit=unit, unit_scale=unit_scale, | |||
dynamic_ncols=dynamic_ncols, smoothing=smoothing, bar_format=bar_format, initial=initial, | |||
position=len(self.bars), postfix=postfix, unit_divisor=unit_divisor, write_bytes=write_bytes, | |||
lock_args=lock_args, nrows=nrows, colour=colour, gui=gui, **kwargs) | |||
_uuid = str(uuid.uuid1()) | |||
self.bars[_uuid] = bar | |||
if not hasattr(self, 'orig_out_err') and not is_notebook(): | |||
from tqdm.contrib import DummyTqdmFile | |||
self.orig_out_err = sys.stdout, sys.stderr | |||
sys.stdout, sys.stderr = map(DummyTqdmFile, self.orig_out_err) | |||
return _uuid | |||
def update(self, task_id:str, advance:int, refresh=True): | |||
self.bars[task_id].update(advance) | |||
def set_postfix_str(self, task_id, s, refresh=True): | |||
self.bars[task_id].set_postfix_str(s=s, refresh=refresh) | |||
def set_description_str(self, task_id, desc, refresh=True): | |||
self.bars[task_id].set_description_str(desc=desc, refresh=refresh) | |||
def destroy_task(self, task_id): | |||
""" | |||
关闭 task_id 对应的 tqdm bar 。 | |||
:param task_id: | |||
:return: | |||
""" | |||
self.bars[task_id].close() | |||
self.bars.pop(task_id) | |||
if len(self.bars) == 0 and hasattr(self, 'orig_out_err'): | |||
# recover 成正常的 sys.stdout 与 sys.stderr | |||
sys.stdout, sys.stderr = self.orig_out_err | |||
delattr(self, 'orig_out_err') | |||
def reset(self, task_id): | |||
self.bars[task_id].reset() | |||
def print(self): | |||
tqdm.write('') | |||
def not_empty(self): | |||
return len(self.bars) != 0 | |||
@property | |||
def dummy(self) -> bool: | |||
""" | |||
当前对象是否是 dummy 的 tqdm 对象。 | |||
:return: | |||
""" | |||
return False | |||
if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and get_global_rank() == 0: | |||
f_tqdm_progress = TqdmProgress() | |||
else: | |||
f_tqdm_progress = DummyFTqdmProgress() | |||
logger.debug("Use dummy tqdm...") | |||
@@ -340,7 +340,7 @@ class Vocabulary(object): | |||
try: | |||
for f_n, n_f_n in zip(field_name, new_field_name): | |||
dataset.apply_field(index_instance, field_name=f_n, new_field_name=n_f_n, | |||
show_progress_bar=False) | |||
progress_bar=None) | |||
except Exception as e: | |||
logger.error("When processing the `{}` dataset, the following error occurred.".format(idx)) | |||
raise e | |||
@@ -396,7 +396,7 @@ class Vocabulary(object): | |||
for idx, dataset in enumerate(datasets): | |||
if isinstance(dataset, DataSet): | |||
try: | |||
dataset.apply(construct_vocab, show_progress_bar=False) | |||
dataset.apply(construct_vocab, progress_bar=None) | |||
except BaseException as e: | |||
logger.error("When processing the `{}` dataset, the following error occurred:".format(idx)) | |||
raise e | |||
@@ -406,12 +406,12 @@ class Vocabulary(object): | |||
if no_create_entry_dataset is not None: | |||
partial_construct_vocab = partial(construct_vocab, no_create_entry=True) | |||
if isinstance(no_create_entry_dataset, DataSet): | |||
no_create_entry_dataset.apply(partial_construct_vocab, show_progress_bar=False) | |||
no_create_entry_dataset.apply(partial_construct_vocab, progress_bar=None) | |||
elif isinstance(no_create_entry_dataset, list): | |||
for dataset in no_create_entry_dataset: | |||
if not isinstance(dataset, DataSet): | |||
raise TypeError("Only DataSet type is allowed.") | |||
dataset.apply(partial_construct_vocab, show_progress_bar=False) | |||
dataset.apply(partial_construct_vocab, progress_bar=None) | |||
return self | |||
def _is_word_no_create_entry(self, word:str): | |||
@@ -221,7 +221,7 @@ class DataBundle: | |||
yield field_name, vocab | |||
def apply_field(self, func: Callable, field_name: str, new_field_name: str, num_proc: int = 0, | |||
ignore_miss_dataset: bool = True, progress_desc: str = '', show_progress_bar: bool = True): | |||
ignore_miss_dataset: bool = True, progress_desc: str = '', progress_bar: str = 'rich'): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的dataset使用 :meth:`~fastNLP.DataSet.apply_field` 方法 | |||
@@ -233,8 +233,8 @@ class DataBundle: | |||
如果为False,则报错 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param ignore_miss_dataset: 如果 dataset 没有 {field_name} ,就直接跳过这个 dataset 。 | |||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||
:param show_progress_bar: 是否显示tqdm进度条 | |||
:param progress_desc: 当显示 progress 时,可以显示当前正在处理的名称 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
""" | |||
_progress_desc = progress_desc | |||
@@ -243,13 +243,13 @@ class DataBundle: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
if dataset.has_field(field_name=field_name): | |||
dataset.apply_field(func=func, field_name=field_name, new_field_name=new_field_name, num_proc=num_proc, | |||
progress_desc=progress_desc, show_progress_bar=show_progress_bar) | |||
progress_desc=progress_desc, progress_bar=progress_bar) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name}.") | |||
return self | |||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | |||
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''): | |||
ignore_miss_dataset=True, progress_bar: str = 'rich', progress_desc: str = ''): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | |||
@@ -263,8 +263,8 @@ class DataBundle: | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | |||
如果为False,则报错 | |||
:param show_progress_bar: 是否显示进度条 | |||
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress_bar 时,可以显示 ``progress`` 的名称。 | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
@@ -277,13 +277,13 @@ class DataBundle: | |||
if dataset.has_field(field_name=field_name): | |||
res[name] = dataset.apply_field_more(func=func, field_name=field_name, num_proc=num_proc, | |||
modify_fields=modify_fields, | |||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | |||
progress_bar=progress_bar, progress_desc=progress_desc) | |||
elif not ignore_miss_dataset: | |||
raise KeyError(f"{field_name} not found DataSet:{name} .") | |||
return res | |||
def apply(self, func: Callable, new_field_name: str, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True): | |||
progress_desc: str = '', progress_bar: bool = True): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply` 方法 | |||
@@ -293,20 +293,20 @@ class DataBundle: | |||
:param str new_field_name: 将func返回的内容放入到 `new_field_name` 这个field中,如果名称与已有的field相同,则覆 | |||
盖之前的field。如果为None则不创建新的field。 | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示tqd进度条 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress bar 时,可以显示当前正在处理的名称 | |||
""" | |||
_progress_desc = progress_desc | |||
for name, dataset in self.datasets.items(): | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, show_progress_bar=show_progress_bar, | |||
dataset.apply(func, new_field_name=new_field_name, num_proc=num_proc, progress_bar=progress_bar, | |||
progress_desc=progress_desc) | |||
return self | |||
def apply_more(self, func: Callable, modify_fields=True, num_proc: int = 0, | |||
progress_desc: str = '', show_progress_bar: bool = True): | |||
progress_desc: str = '', progress_bar: str = 'rich'): | |||
r""" | |||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_more` 方法 | |||
@@ -317,8 +317,8 @@ class DataBundle: | |||
:param callable func: 参数是 ``DataSet`` 中的 ``Instance`` ,返回值是一个字典,key 是field 的名字,value 是对应的结果 | |||
:param bool modify_fields: 是否用结果修改 ``DataSet`` 中的 ``Field`` , 默认为 True | |||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | |||
:param show_progress_bar: 是否显示tqd进度条 | |||
:param progress_desc: 当show_progress_bar为True时,可以显示当前tqd正在处理的名称 | |||
:param progress_bar: 显示 progress_bar 的方式,支持 `["rich", "tqdm", None]`。 | |||
:param progress_desc: 当显示 progress_bar 时,可以显示当前正在处理的名称 | |||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | |||
""" | |||
@@ -328,7 +328,7 @@ class DataBundle: | |||
if _progress_desc: | |||
progress_desc = _progress_desc + f' for `{name}`' | |||
res[name] = dataset.apply_more(func, modify_fields=modify_fields, num_proc=num_proc, | |||
show_progress_bar=show_progress_bar, progress_desc=progress_desc) | |||
progress_bar=progress_bar, progress_desc=progress_desc) | |||
return res | |||
def set_pad(self, field_name, pad_val=0, dtype=None, backend=None, pad_fn=None) -> "DataBundle": | |||
@@ -279,7 +279,7 @@ class _LazyConfigMapping(OrderedDict): | |||
value = self._mapping[key] | |||
module_name = model_type_to_module_name(key) | |||
if module_name not in self._modules: | |||
self._modules[module_name] = importlib.import_module(f".{module_name}", "transformers.models") | |||
self._modules[module_name] = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") | |||
return getattr(self._modules[module_name], value) | |||
def keys(self): | |||
@@ -318,15 +318,15 @@ class _LazyLoadAllMappings(OrderedDict): | |||
def _initialize(self): | |||
if self._initialized: | |||
return | |||
logger.warn( | |||
"ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " | |||
"It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", | |||
FutureWarning, | |||
) | |||
# logger.warn( | |||
# "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP is deprecated and will be removed in v5 of Transformers. " | |||
# "It does not contain all available model checkpoints, far from it. Checkout hf.co/models for that.", | |||
# FutureWarning, | |||
# ) | |||
for model_type, map_name in self._mapping.items(): | |||
module_name = model_type_to_module_name(model_type) | |||
module = importlib.import_module(f".{module_name}", "transformers.models") | |||
module = importlib.import_module(f".{module_name}", "fastNLP.transformers.torch.models") | |||
mapping = getattr(module, map_name) | |||
self._data.update(mapping) | |||
@@ -362,8 +362,8 @@ ALL_PRETRAINED_CONFIG_ARCHIVE_MAP = _LazyLoadAllMappings(CONFIG_ARCHIVE_MAP_MAPP | |||
def _get_class_name(model_class: Union[str, List[str]]): | |||
if isinstance(model_class, (list, tuple)): | |||
return " or ".join([f":class:`~transformers.{c}`" for c in model_class if c is not None]) | |||
return f":class:`~transformers.{model_class}`" | |||
return " or ".join([f":class:`~fastNLP.transformers.torch.{c}`" for c in model_class if c is not None]) | |||
return f":class:`~fastNLP.transformers.torch.{model_class}`" | |||
def _list_model_options(indent, config_to_class=None, use_model_types=True): | |||
@@ -372,7 +372,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): | |||
if use_model_types: | |||
if config_to_class is None: | |||
model_type_to_name = { | |||
model_type: f":class:`~transformers.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items() | |||
model_type: f":class:`~fastNLP.transformers.torch.{config}`" for model_type, config in CONFIG_MAPPING_NAMES.items() | |||
} | |||
else: | |||
model_type_to_name = { | |||
@@ -394,7 +394,7 @@ def _list_model_options(indent, config_to_class=None, use_model_types=True): | |||
config: MODEL_NAMES_MAPPING[model_type] for model_type, config in CONFIG_MAPPING_NAMES.items() | |||
} | |||
lines = [ | |||
f"{indent}- :class:`~transformers.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" | |||
f"{indent}- :class:`~fastNLP.transformers.torch.{config_name}` configuration class: {config_to_name[config_name]} ({config_to_model_name[config_name]} model)" | |||
for config_name in sorted(config_to_name.keys()) | |||
] | |||
return "\n".join(lines) | |||
@@ -4,17 +4,12 @@ | |||
(2) 能不能保存 topk 并load进来进行训练 | |||
""" | |||
import pytest | |||
import os | |||
import pytest | |||
from typing import Any | |||
from dataclasses import dataclass | |||
from pathlib import Path | |||
import re | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
@@ -25,7 +20,6 @@ from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||
from torchmetrics import Accuracy | |||
from fastNLP.core.metrics import Metric | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.callbacks import MoreEvaluateCallback | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
@@ -0,0 +1,123 @@ | |||
from typing import Any | |||
from dataclasses import dataclass | |||
import pytest | |||
from fastNLP import Metric, Accuracy | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP import Trainer, Evaluator | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
from torch.utils.data import DataLoader | |||
from torch.optim import SGD | |||
import torch.distributed as dist | |||
import torch | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.datasets.torch_data import TorchArgMaxDataset | |||
@dataclass | |||
class ArgMaxDatasetConfig: | |||
num_labels: int = 10 | |||
feature_dimension: int = 10 | |||
data_num: int = 100 | |||
seed: int = 0 | |||
batch_size: int = 4 | |||
shuffle: bool = True | |||
@dataclass | |||
class TrainerParameters: | |||
model: Any = None | |||
optimizers: Any = None | |||
train_dataloader: Any = None | |||
evaluate_dataloaders: Any = None | |||
input_mapping: Any = None | |||
output_mapping: Any = None | |||
metrics: Any = None | |||
more_metrics: Any = None | |||
@pytest.fixture(scope="module", params=[0], autouse=True) | |||
def model_and_optimizers(request): | |||
trainer_params = TrainerParameters() | |||
trainer_params.model = TorchNormalModel_Classification_1( | |||
num_labels=ArgMaxDatasetConfig.num_labels, | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension | |||
) | |||
trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) | |||
dataset = TorchArgMaxDataset( | |||
feature_dimension=ArgMaxDatasetConfig.feature_dimension, | |||
data_num=ArgMaxDatasetConfig.data_num, | |||
seed=ArgMaxDatasetConfig.seed | |||
) | |||
_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_size=ArgMaxDatasetConfig.batch_size, | |||
shuffle=True | |||
) | |||
class LossMetric(Metric): | |||
def __init__(self): | |||
super().__init__() | |||
self.register_element('loss') | |||
def update(self, loss): | |||
self.loss += loss.item() | |||
def get_metric(self) -> dict: | |||
return self.loss.item() | |||
trainer_params.train_dataloader = _dataloader | |||
trainer_params.evaluate_dataloaders = _dataloader | |||
trainer_params.metrics = {'loss': LossMetric()} | |||
trainer_params.more_metrics = {"acc": Accuracy()} | |||
return trainer_params | |||
@pytest.mark.torch | |||
@pytest.mark.parametrize('device', ['cpu', [0, 1]]) | |||
@pytest.mark.parametrize('progress_bar', ['rich', 'auto', None, 'raw', 'tqdm']) | |||
@magic_argv_env_context | |||
def test_run( model_and_optimizers: TrainerParameters, device, progress_bar): | |||
if device != 'cpu' and not torch.cuda.is_available(): | |||
pytest.skip(f"No cuda for device:{device}") | |||
n_epochs = 5 | |||
trainer = Trainer( | |||
model=model_and_optimizers.model, | |||
driver='torch', | |||
device=device, | |||
optimizers=model_and_optimizers.optimizers, | |||
train_dataloader=model_and_optimizers.train_dataloader, | |||
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders, | |||
input_mapping=model_and_optimizers.input_mapping, | |||
output_mapping=model_and_optimizers.output_mapping, | |||
metrics=model_and_optimizers.metrics, | |||
n_epochs=n_epochs, | |||
callbacks=None, | |||
progress_bar=progress_bar, | |||
output_from_new_proc="all", | |||
evaluate_fn='train_step', | |||
larger_better=False | |||
) | |||
trainer.run() | |||
evaluator = Evaluator(model=model_and_optimizers.model, dataloaders=model_and_optimizers.train_dataloader, | |||
driver=trainer.driver, metrics=model_and_optimizers.metrics, | |||
progress_bar=progress_bar, evaluate_fn='train_step') | |||
evaluator.run() | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -181,7 +181,7 @@ class TestDataSetMethods: | |||
assert ("rx" in ds.field_arrays) == True | |||
assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1] | |||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) | |||
ds.apply(lambda ins: len(ins["y"]), new_field_name="y", progress_bar=None) | |||
assert ds.field_arrays["y"].content[0] == 2 | |||
res = ds.apply(lambda ins: len(ins["x"]), num_proc=2, progress_desc="len") | |||
@@ -198,8 +198,8 @@ class TestDataSetMethods: | |||
def do_nothing(ins): | |||
time.sleep(0.01) | |||
ds.apply(do_nothing, show_progress_bar=True, num_proc=0) | |||
ds.apply_field(do_nothing, field_name='x', show_progress_bar=True) | |||
ds.apply(do_nothing, progress_bar='rich', num_proc=0) | |||
ds.apply_field(do_nothing, field_name='x', progress_bar='rich') | |||
def test_apply_cannot_modify_instance(self): | |||
ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) | |||
@@ -0,0 +1,16 @@ | |||
import pytest | |||
from fastNLP.envs.imports import _module_available | |||
from fastNLP.core.utils import f_tqdm_progress, f_rich_progress | |||
def test_raise(): | |||
if not _module_available('tqdm') or f_rich_progress.dummy or f_tqdm_progress.dummy: | |||
pytest.skip('No tqdm') | |||
t = f_rich_progress.add_task('test', total=10) | |||
with pytest.raises(AssertionError): | |||
f_tqdm_progress.add_task('test') | |||
f_rich_progress.destroy_task(t) | |||
t = f_tqdm_progress.add_task('test', total=10) | |||
with pytest.raises(AssertionError): | |||
f_rich_progress.add_task('test') |