@@ -0,0 +1,7 @@ | |||||
fastNLP.core.callbacks.fitlog\_callback module | |||||
============================================== | |||||
.. automodule:: fastNLP.core.callbacks.fitlog_callback | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -25,6 +25,7 @@ Submodules | |||||
fastNLP.core.callbacks.callback_manager | fastNLP.core.callbacks.callback_manager | ||||
fastNLP.core.callbacks.checkpoint_callback | fastNLP.core.callbacks.checkpoint_callback | ||||
fastNLP.core.callbacks.early_stop_callback | fastNLP.core.callbacks.early_stop_callback | ||||
fastNLP.core.callbacks.fitlog_callback | |||||
fastNLP.core.callbacks.has_monitor_callback | fastNLP.core.callbacks.has_monitor_callback | ||||
fastNLP.core.callbacks.load_best_model_callback | fastNLP.core.callbacks.load_best_model_callback | ||||
fastNLP.core.callbacks.lr_scheduler_callback | fastNLP.core.callbacks.lr_scheduler_callback | ||||
@@ -0,0 +1,15 @@ | |||||
fastNLP.modules.mix\_modules package | |||||
==================================== | |||||
.. automodule:: fastNLP.modules.mix_modules | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
Submodules | |||||
---------- | |||||
.. toctree:: | |||||
:maxdepth: 4 | |||||
fastNLP.modules.mix_modules.utils |
@@ -0,0 +1,7 @@ | |||||
fastNLP.modules.mix\_modules.utils module | |||||
========================================= | |||||
.. automodule:: fastNLP.modules.mix_modules.utils | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: |
@@ -0,0 +1,15 @@ | |||||
fastNLP.modules package | |||||
======================= | |||||
.. automodule:: fastNLP.modules | |||||
:members: | |||||
:undoc-members: | |||||
:show-inheritance: | |||||
Subpackages | |||||
----------- | |||||
.. toctree:: | |||||
:maxdepth: 4 | |||||
fastNLP.modules.mix_modules |
@@ -15,3 +15,4 @@ Subpackages | |||||
fastNLP.core | fastNLP.core | ||||
fastNLP.envs | fastNLP.envs | ||||
fastNLP.io | fastNLP.io | ||||
fastNLP.modules |
@@ -14,6 +14,7 @@ __all__ = [ | |||||
"TorchGradClipCallback", | "TorchGradClipCallback", | ||||
"ResultsMonitor", | "ResultsMonitor", | ||||
'HasMonitorCallback', | 'HasMonitorCallback', | ||||
"FitlogCallback", | |||||
# collators | # collators | ||||
'Collator', | 'Collator', | ||||
@@ -68,6 +69,7 @@ __all__ = [ | |||||
# metrics | # metrics | ||||
"Metric", | "Metric", | ||||
"Accuracy", | "Accuracy", | ||||
"TransformersAccuracy", | |||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
@@ -17,7 +17,9 @@ __all__ = [ | |||||
"TorchGradClipCallback", | "TorchGradClipCallback", | ||||
"ResultsMonitor", | "ResultsMonitor", | ||||
'HasMonitorCallback' | |||||
'HasMonitorCallback', | |||||
"FitlogCallback" | |||||
] | ] | ||||
@@ -32,4 +34,5 @@ from .early_stop_callback import EarlyStopCallback | |||||
from .torch_callbacks import * | from .torch_callbacks import * | ||||
from .more_evaluate_callback import MoreEvaluateCallback | from .more_evaluate_callback import MoreEvaluateCallback | ||||
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | ||||
from .fitlog_callback import FitlogCallback | |||||
@@ -25,7 +25,7 @@ def _transfer(func): | |||||
for callback_fn in manager.callback_fns[func.__name__]: | for callback_fn in manager.callback_fns[func.__name__]: | ||||
try: | try: | ||||
callback_fn(*arg, **kwargs) | callback_fn(*arg, **kwargs) | ||||
except EarlyStopException as e: | |||||
except (EarlyStopException, KeyboardInterrupt) as e: | |||||
raise e | raise e | ||||
except BaseException as e: | except BaseException as e: | ||||
logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") | logger.error(f"The following callback_fn raise exception:{_get_fun_msg(callback_fn)}.") | ||||
@@ -9,12 +9,13 @@ import sys | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from .topk_saver import TopkSaver | from .topk_saver import TopkSaver | ||||
from .callback import Callback | from .callback import Callback | ||||
from ..utils.exceptions import EarlyStopException | |||||
class CheckpointCallback(Callback): | class CheckpointCallback(Callback): | ||||
def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | def __init__(self, folder: Optional[Union[str, Path]] = None, every_n_epochs: Optional[int] = None, | ||||
every_n_batches: Optional[int] = None, last: bool = False, | |||||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = None, topk: int = 0, | |||||
every_n_batches: Optional[int] = None, last: bool = False, topk: int = 0, | |||||
on_exceptions: Optional[Union[BaseException, Sequence[BaseException]]] = [EarlyStopException], | |||||
monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | monitor: Optional[Union[str, Callable]] = None, larger_better: bool = True, | ||||
only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, save_object: str = 'model', | ||||
save_evaluate_results=True, **kwargs): | save_evaluate_results=True, **kwargs): | ||||
@@ -33,16 +34,23 @@ class CheckpointCallback(Callback): | |||||
则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | 则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 在该 folder 下不进行模型保存。默认情况下,本 checkpoint 只保存了 model | ||||
的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | 的状态;如还需保存 Trainer 的状态以断点重训的话,请使用 ``save_object='trainer'`` 。 | ||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | :param folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 | ||||
时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 | ||||
:param every_n_epochs: 多少个 epoch 保存一次。 | :param every_n_epochs: 多少个 epoch 保存一次。 | ||||
:param every_n_batches: 多少个 batch 保存一次。 | :param every_n_batches: 多少个 batch 保存一次。 | ||||
:param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | :param last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 | ||||
:param topk: 保存 monitor 结果 topK 个。 | :param topk: 保存 monitor 结果 topK 个。 | ||||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。 | |||||
:param on_exceptions: 在出异常信息时,是否保存。传入需要捕获的异常的类。默认将捕获 EarlyStopException 。 | |||||
:param larger_better: monitor 的值是否时越大越好。 | :param larger_better: monitor 的值是否时越大越好。 | ||||
:param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 | ||||
:param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 | ||||
@@ -12,9 +12,16 @@ class EarlyStopCallback(HasMonitorCallback): | |||||
def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): | ||||
""" | """ | ||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
:param patience: 多少次 evaluate 不没有提升就停止。 | :param patience: 多少次 evaluate 不没有提升就停止。 | ||||
""" | """ | ||||
@@ -0,0 +1,66 @@ | |||||
__all__ = [ | |||||
'FitlogCallback' | |||||
] | |||||
from .has_monitor_callback import HasMonitorCallback | |||||
from ...envs import _module_available | |||||
from ...envs import get_global_rank | |||||
if _module_available('fitlog'): | |||||
import fitlog | |||||
class FitlogCallback(HasMonitorCallback): | |||||
""" | |||||
自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据 | |||||
``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是越大越好。 | |||||
:param log_exception: 是否记录 ``exception`` 。 | |||||
:param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。 | |||||
""" | |||||
def __init__(self, monitor=None, larger_better: bool = True, log_exception:bool=True, log_loss_every:int=0): | |||||
assert _module_available('fitlog'), "fitlog is not installed." | |||||
super().__init__(monitor=monitor, larger_better=larger_better) | |||||
self.log_exception = log_exception | |||||
self.log_loss_every = log_loss_every | |||||
self.avg_loss = 0 | |||||
def on_after_trainer_initialized(self, trainer, driver): | |||||
if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog | |||||
fitlog.debug() | |||||
def on_evaluate_end(self, trainer, results): | |||||
results = self.itemize_results(results) | |||||
fitlog.add_metric(results, step=trainer.global_forward_batches, epoch=trainer.cur_epoch_idx) | |||||
if self.is_better_results(results, keep_if_better=True): | |||||
results['step'] = trainer.global_forward_batches | |||||
results['epoch'] = trainer.cur_epoch_idx | |||||
fitlog.add_best_metric(results) | |||||
def on_before_backward(self, trainer, outputs): | |||||
if self.log_loss_every > 0: | |||||
loss = trainer.extract_loss_from_outputs(outputs) | |||||
self.avg_loss += loss.item() | |||||
if trainer.global_forward_batches % self.log_loss_every == 0: | |||||
fitlog.add_loss(self.avg_loss / self.log_loss_every * trainer.accumulation_steps, name='loss', | |||||
step=trainer.global_forward_batches, | |||||
epoch=trainer.cur_epoch_idx) | |||||
self.avg_loss = 0 | |||||
def on_train_end(self, trainer): | |||||
fitlog.finish() | |||||
def on_exception(self, trainer, exception): | |||||
fitlog.finish(status=1) | |||||
if self.log_exception: | |||||
fitlog.add_other(repr(exception), name='except_info') |
@@ -171,9 +171,16 @@ class HasMonitorCallback(ResultsMonitor, Callback): | |||||
该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | 该 callback 不直接进行使用,作为其它相关 callback 的父类使用,如果 callback 有使用 monitor 可以继承该函数里面实现了 | ||||
(1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | (1)判断monitor合法性;(2)在需要时, 根据trainer的monitor设置自己的monitor名称。 | ||||
:param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: monitor 是否时越大越好 | :param larger_better: monitor 是否时越大越好 | ||||
:param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | :param must_have_monitor: 这个 callback 是否必须有 monitor 设置。如果设置为 True ,且没检测到设置 monitor 会报错。 | ||||
""" | """ | ||||
@@ -22,9 +22,16 @@ class LoadBestModelCallback(HasMonitorCallback): | |||||
保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型,默认会在加载之后删除权重文件。仅在训练正常结束的时候才能加载 | ||||
最好的模型。 | 最好的模型。 | ||||
:param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最长公共字符串算法 找到最匹配 | |||||
的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 | |||||
果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 该 metric 值是否是越大越好。 | :param larger_better: 该 metric 值是否是越大越好。 | ||||
:param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 | ||||
不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 | ||||
@@ -45,10 +45,16 @@ class RichCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是 monitor 的结果越大越好。 | :param larger_better: 是否是 monitor 的结果越大越好。 | ||||
:param format_json: 是否格式化 json 再打印 | :param format_json: 是否格式化 json 再打印 | ||||
""" | """ | ||||
@@ -140,10 +146,16 @@ class RawTextCallback(ProgressCallback): | |||||
:param print_every: 多少个 batch 更新一次显示。 | :param print_every: 多少个 batch 更新一次显示。 | ||||
:param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 | ||||
:param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 | |||||
完全一致的名称,将使用 最长公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor | |||||
。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果,如果当前结果中没有 | |||||
相关的 monitor 值请返回 None 。 | |||||
:param monitor: 监控的 metric 值。当检测到这个key的结果更好时,会打印出不同的颜色进行提示。 | |||||
* 为 ``None`` | |||||
将尝试使用 :class:`~fastNLP.Trainer` 中设置 `monitor` 值(如果有设置)。 | |||||
* 为 ``str`` | |||||
尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将 | |||||
使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。 | |||||
* 为 ``Callable`` | |||||
接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关 | |||||
的 ``monitor`` 值请返回 ``None`` 。 | |||||
:param larger_better: 是否是monitor的结果越大越好。 | :param larger_better: 是否是monitor的结果越大越好。 | ||||
:param format_json: 是否format json再打印 | :param format_json: 是否format json再打印 | ||||
""" | """ | ||||
@@ -7,7 +7,7 @@ __all__ = [ | |||||
'Evaluator' | 'Evaluator' | ||||
] | ] | ||||
from fastNLP.core.drivers import Driver | |||||
from fastNLP.core.drivers import Driver, TorchDriver | |||||
from ..drivers.choose_driver import choose_driver | from ..drivers.choose_driver import choose_driver | ||||
from .loops import Loop, EvaluateBatchLoop | from .loops import Loop, EvaluateBatchLoop | ||||
from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | from fastNLP.core.utils import auto_param_call, dataclass_to_dict, \ | ||||
@@ -23,7 +23,7 @@ class Evaluator: | |||||
driver: Driver | driver: Driver | ||||
_evaluate_batch_loop: Loop | _evaluate_batch_loop: Loop | ||||
def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, | |||||
def __init__(self, model, dataloaders, metrics: Optional[Dict] = None, | |||||
driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, | driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, | ||||
evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, | evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, | ||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
@@ -316,15 +316,15 @@ class _MetricsWrapper: | |||||
raise TypeError("Parameter `metrics` can only be `Dict` type.") | raise TypeError("Parameter `metrics` can only be `Dict` type.") | ||||
for metric_name, metric in metrics.items(): | for metric_name, metric in metrics.items(): | ||||
# 因为 torchmetrics 是一个 nn.Module,因此我们需要先将其移到对应的机器上; | # 因为 torchmetrics 是一个 nn.Module,因此我们需要先将其移到对应的机器上; | ||||
if _is_torchmetrics_metric(metric): | |||||
if _is_torchmetrics_metric(metric) and isinstance(evaluator.driver, TorchDriver): | |||||
# torchmetrics 是默认自动开启了多卡的 | # torchmetrics 是默认自动开启了多卡的 | ||||
evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | evaluator.driver.move_model_to_device(metric, evaluator.driver.data_device) | ||||
elif isinstance(metric, Metric): | elif isinstance(metric, Metric): | ||||
# 如果数据是分布式的,但是不aggregate的话可能有问题 | # 如果数据是分布式的,但是不aggregate的话可能有问题 | ||||
if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: | if evaluator._dist_sampler is not None and metric.aggregate_when_get_metric is False: | ||||
logger.warning_once( | |||||
"You have replace the sampler as distributed sampler when evaluation, but your " | |||||
f"metric {metric_name}:{metric.__class__.__name__}' `aggregate_when_get_metric` is False.") | |||||
logger.rank_zero_warning( | |||||
"You have replace the sampler as distributed sampler when evaluation, but your metric " | |||||
f"{metric_name}:{metric.__class__.__name__}'s `aggregate_when_get_metric` is False.", once=True) | |||||
if metric.aggregate_when_get_metric is None: | if metric.aggregate_when_get_metric is None: | ||||
metric.aggregate_when_get_metric = evaluator._dist_sampler is not None | metric.aggregate_when_get_metric = evaluator._dist_sampler is not None | ||||
@@ -388,5 +388,8 @@ class _MetricsWrapper: | |||||
_results = metric.accumulate() | _results = metric.accumulate() | ||||
else: | else: | ||||
raise RuntimeError(f"Not support `{type(metric)}` for now.") | raise RuntimeError(f"Not support `{type(metric)}` for now.") | ||||
results[metric_name] = _results | |||||
if _results is not None: | |||||
results[metric_name] = _results | |||||
else: | |||||
logger.warning_once(f"Metric:{metric_name} returns None when getting metric results.") | |||||
return results | return results |
@@ -27,19 +27,21 @@ class EvaluateBatchLoop(Loop): | |||||
while True: | while True: | ||||
try: | try: | ||||
batch = next(iterator) | batch = next(iterator) | ||||
batch = match_and_substitute_params(evaluator.input_mapping, batch) | |||||
batch = evaluator.move_data_to_device(batch) | |||||
except StopIteration: | except StopIteration: | ||||
break | break | ||||
try: | |||||
batch = match_and_substitute_params(evaluator.input_mapping, batch) | |||||
batch = evaluator.move_data_to_device(batch) | |||||
self.batch_step_fn(evaluator, batch) | |||||
batch_idx += 1 | |||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | |||||
except BaseException as e: | except BaseException as e: | ||||
if callable(getattr(dataloader, 'get_batch_indices', None)): | if callable(getattr(dataloader, 'get_batch_indices', None)): | ||||
indices = dataloader.get_batch_indices() | indices = dataloader.get_batch_indices() | ||||
logger.error(f"Exception happens when evaluating on samples: {indices}") | logger.error(f"Exception happens when evaluating on samples: {indices}") | ||||
raise e | raise e | ||||
self.batch_step_fn(evaluator, batch) | |||||
batch_idx += 1 | |||||
evaluator.update_progress_bar(batch_idx, evaluator.cur_dataloader_name) | |||||
# 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | # 获取metric结果。返回的dict内容示例为{'metric_name1': metric_results, 'metric_name2': metric_results, ...} | ||||
results = evaluator.get_metric() | results = evaluator.get_metric() | ||||
return results | return results | ||||
@@ -19,30 +19,31 @@ class TrainBatchLoop(Loop): | |||||
get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | get_batch_indices = dataloader.get_batch_indices if callable(getattr(dataloader, 'get_batch_indices', None))\ | ||||
else lambda *args, **kwargs: None | else lambda *args, **kwargs: None | ||||
dataloader = iter(dataloader) | dataloader = iter(dataloader) | ||||
indices = None | |||||
while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: | ||||
try: | try: | ||||
trainer.on_fetch_data_begin() | trainer.on_fetch_data_begin() | ||||
batch = next(dataloader) | batch = next(dataloader) | ||||
indices = get_batch_indices() | indices = get_batch_indices() | ||||
except StopIteration: | |||||
break | |||||
try: | |||||
trainer.on_fetch_data_end() | trainer.on_fetch_data_end() | ||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | batch = match_and_substitute_params(trainer.input_mapping, batch) | ||||
batch = trainer.move_data_to_device(batch) | batch = trainer.move_data_to_device(batch) | ||||
except StopIteration: | |||||
break | |||||
trainer.on_train_batch_begin(batch, indices) | |||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | |||||
trainer.batch_idx_in_epoch += 1 | |||||
trainer.check_batch_step_fn() | |||||
trainer.on_train_batch_end() | |||||
except BaseException as e: | except BaseException as e: | ||||
if indices and not isinstance(e, EarlyStopException): | |||||
if indices is not None and not isinstance(e, (EarlyStopException, KeyboardInterrupt)): | |||||
logger.error(f"Exception happens when running on samples: {indices}") | logger.error(f"Exception happens when running on samples: {indices}") | ||||
raise e | raise e | ||||
trainer.on_train_batch_begin(batch, indices) | |||||
with trainer.get_no_sync_context(): # 在多卡的时候可能需要关闭 sync | |||||
self.batch_step_fn(trainer, batch) | |||||
trainer.global_forward_batches += 1 | |||||
trainer.batch_idx_in_epoch += 1 | |||||
trainer.check_batch_step_fn() | |||||
trainer.on_train_batch_end() | |||||
trainer.step_evaluate() | trainer.step_evaluate() | ||||
trainer.batch_idx_in_epoch = 0 | trainer.batch_idx_in_epoch = 0 | ||||
@@ -256,7 +256,6 @@ class Trainer(TrainerEventTrigger): | |||||
:kwargs: | :kwargs: | ||||
* *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: | * *torch_kwargs* -- 用于在指定 ``driver`` 为 'torch' 时设定具体 driver 实例的一些参数: | ||||
* ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | * ddp_kwargs -- 用于在使用 ``TorchDDPDriver`` 时指定 ``DistributedDataParallel`` 初始化时的参数;例如传入 | ||||
{'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | {'find_unused_parameters': True} 来解决有参数不参与前向运算导致的报错等; | ||||
* set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | * set_grad_to_none -- 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; | ||||
@@ -47,7 +47,7 @@ class JittorDataLoader: | |||||
提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | 提供给使用jittor框架的DataLoader函数,提供了auto_collate的功能, 支持实现了__getitem__和__len__的dataset | ||||
""" | """ | ||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = False, | |||||
def __init__(self, dataset, batch_size: int = 16, shuffle: bool = True, | |||||
drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | drop_last: bool = False, num_workers: int = 0, buffer_size: int = 512 * 1024 * 1024, | ||||
stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | stop_grad: bool = True, keep_numpy_array: bool = False, endless: bool = False, | ||||
collate_fn: Union[None, str, Callable] = "auto") -> None: | collate_fn: Union[None, str, Callable] = "auto") -> None: | ||||
@@ -47,7 +47,7 @@ class PaddleDataLoader(DataLoader): | |||||
def __init__(self, dataset, feed_list=None, places=None, | def __init__(self, dataset, feed_list=None, places=None, | ||||
return_list: bool = True, batch_sampler=None, | return_list: bool = True, batch_sampler=None, | ||||
batch_size: int = 1, shuffle: bool = False, | |||||
batch_size: int = 1, shuffle: bool = True, | |||||
drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | drop_last: bool = False, collate_fn: Union[str, Callable, None] = 'auto', | ||||
num_workers: int = 0, use_buffer_reader: bool = True, | num_workers: int = 0, use_buffer_reader: bool = True, | ||||
use_shared_memory: bool = True, timeout: int = 0, | use_shared_memory: bool = True, timeout: int = 0, | ||||
@@ -14,7 +14,7 @@ from ...envs import FASTNLP_BACKEND, SUPPORT_BACKENDS | |||||
from ..log import logger | from ..log import logger | ||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = False, drop_last: bool = False, | |||||
def prepare_dataloader(dataset, batch_size: int = 16, shuffle: bool = True, drop_last: bool = False, | |||||
collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | collate_fn: Union[Callable, str, None] = 'auto', num_workers: int = 0, | ||||
seed: int = 0, backend: str = 'auto'): | seed: int = 0, backend: str = 'auto'): | ||||
""" | """ | ||||
@@ -179,7 +179,7 @@ class TorchDataLoader(DataLoader): | |||||
def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping[str, DataSet]], | ||||
batch_size: int = 1, | batch_size: int = 1, | ||||
shuffle: bool = False, | |||||
shuffle: bool = True, | |||||
sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | ||||
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | num_workers: int = 0, collate_fn: Union[str, Callable, None] = 'auto', | ||||
@@ -236,8 +236,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, Sequence[DataSet], Mapping | |||||
persistent_workers=persistent_workers, | persistent_workers=persistent_workers, | ||||
) | ) | ||||
else: | else: | ||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler, | |||||
dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size if non_train_batch_size else batch_size, | |||||
shuffle=shuffle, sampler=non_train_sampler if non_train_sampler else sampler, | |||||
batch_sampler=batch_sampler, | batch_sampler=batch_sampler, | ||||
num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, | ||||
drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, | ||||
@@ -1,9 +1,10 @@ | |||||
from typing import Callable | |||||
__all__ = [ | __all__ = [ | ||||
"indice_collate_wrapper" | "indice_collate_wrapper" | ||||
] | ] | ||||
def indice_collate_wrapper(func): | |||||
def indice_collate_wrapper(func:Callable): | |||||
""" | """ | ||||
其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | 其功能是封装一层collate_fn,将dataset取到的tuple数据分离开,将idx打包为indices。 | ||||
@@ -40,8 +40,8 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ | |||||
if user_visible_devices is None: | if user_visible_devices is None: | ||||
raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | raise RuntimeError("To run paddle distributed training, please set `FASTNLP_BACKEND` to 'paddle' before using FastNLP.") | ||||
if device is not None: | if device is not None: | ||||
logger.warning_once("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | |||||
"up your script. And we will directly get the local device via environment variables.") | |||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `paddle.distributed.launch` to pull " | |||||
"up your script. And we will directly get the local device via environment variables.", once=True) | |||||
_visible_list = user_visible_devices.split(",") | _visible_list = user_visible_devices.split(",") | ||||
device = [ f"gpu:{_visible_list.index(g) }" for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] | device = [ f"gpu:{_visible_list.index(g) }" for g in os.environ["CUDA_VISIBLE_DEVICES"].split(",")] | ||||
# TODO 目前一个进程仅对应一个卡,所以暂时传入单个 | # TODO 目前一个进程仅对应一个卡,所以暂时传入单个 | ||||
@@ -26,9 +26,9 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, "torch.devi | |||||
# world_size 和 rank | # world_size 和 rank | ||||
if FASTNLP_BACKEND_LAUNCH in os.environ: | if FASTNLP_BACKEND_LAUNCH in os.environ: | ||||
if device is not None: | if device is not None: | ||||
logger.warning_once("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
logger.rank_zero_warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " | |||||
"up your script. And we will directly get the local device via " | "up your script. And we will directly get the local device via " | ||||
"`os.environ['LOCAL_RANK']`.") | |||||
"`os.environ['LOCAL_RANK']`.", once=True) | |||||
return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) | ||||
if driver not in {"torch", "fairscale"}: | if driver not in {"torch", "fairscale"}: | ||||
@@ -1,11 +1,13 @@ | |||||
import os | import os | ||||
from typing import Dict, Union, Callable, Tuple, Optional | from typing import Dict, Union, Callable, Tuple, Optional | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch.nn import DataParallel | from torch.nn import DataParallel | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import RandomSampler as TorchRandomSampler | from torch.utils.data import RandomSampler as TorchRandomSampler | ||||
from torch.utils.data import SequentialSampler as TorchSequentialSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchSingleDriver' | 'TorchSingleDriver' | ||||
@@ -15,7 +17,8 @@ from .torch_driver import TorchDriver | |||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.utils.utils import _get_fun_msg | from fastNLP.core.utils.utils import _get_fun_msg | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, ReproduceBatchSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, \ | |||||
ReproduceBatchSampler | |||||
from fastNLP.core.samplers import RandomSampler | from fastNLP.core.samplers import RandomSampler | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -24,6 +27,7 @@ class TorchSingleDriver(TorchDriver): | |||||
r""" | r""" | ||||
用于 cpu 和 单卡 gpu 运算; | 用于 cpu 和 单卡 gpu 运算; | ||||
""" | """ | ||||
def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | def __init__(self, model, device: "torch.device", fp16: bool = False, **kwargs): | ||||
if isinstance(model, DistributedDataParallel): | if isinstance(model, DistributedDataParallel): | ||||
raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | raise ValueError("`DistributedDataParallel` is not supported in `TorchSingleDriver`") | ||||
@@ -88,7 +92,8 @@ class TorchSingleDriver(TorchDriver): | |||||
else: | else: | ||||
raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | raise RuntimeError(f"There is no `{fn}` method in your {type(self.model)}.") | ||||
def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, | |||||
def set_dist_repro_dataloader(self, dataloader, | |||||
dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler] = None, | |||||
reproducible: bool = False): | reproducible: bool = False): | ||||
# 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | # 如果 dist 为 ReproducibleBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; | ||||
@@ -108,17 +113,24 @@ class TorchSingleDriver(TorchDriver): | |||||
if reproducible: | if reproducible: | ||||
if isinstance(args.sampler, TorchRandomSampler): | if isinstance(args.sampler, TorchRandomSampler): | ||||
# 如果本来就是随机的,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
if getattr(args.sampler, '_num_samples', None) is None \ | |||||
and getattr(args.sampler, 'replacements', False) is False \ | |||||
and getattr(args.sampler, 'generator', None) is None: | |||||
# 如果本来就是随机的,并且没有定制,直接替换掉吧。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=True) | |||||
logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | |||||
elif isinstance(args.sampler, TorchSequentialSampler): | |||||
# 需要替换为不要 shuffle 的。 | |||||
sampler = RandomSampler(args.sampler.data_source, shuffle=False) | |||||
logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
else: | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
batch_sampler = ReproduceBatchSampler( | |||||
batch_sampler=args.batch_sampler, | |||||
batch_size=args.batch_size, | |||||
drop_last=args.drop_last | |||||
) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
return dataloader | return dataloader | ||||
@@ -138,9 +150,3 @@ class TorchSingleDriver(TorchDriver): | |||||
def is_distributed(self): | def is_distributed(self): | ||||
return False | return False | ||||
@@ -24,4 +24,4 @@ def print(*args, sep=' ', end='\n', file=None, flush=False): | |||||
line = sep.join(map(str, args)) | line = sep.join(map(str, args)) | ||||
if logger.isEnabledFor(INFO): | if logger.isEnabledFor(INFO): | ||||
kwargs = logger._add_rank_info({}) | kwargs = logger._add_rank_info({}) | ||||
logger._log(INFO, line, args, **kwargs) | |||||
logger._log(INFO, line, None, **kwargs) |
@@ -1,11 +1,12 @@ | |||||
__all__ = [ | __all__ = [ | ||||
"Metric", | "Metric", | ||||
"Accuracy", | "Accuracy", | ||||
"TransformersAccuracy", | |||||
'SpanFPreRecMetric', | 'SpanFPreRecMetric', | ||||
'ClassifyFPreRecMetric', | 'ClassifyFPreRecMetric', | ||||
] | ] | ||||
from .metric import Metric | from .metric import Metric | ||||
from .accuracy import Accuracy | |||||
from .accuracy import Accuracy, TransformersAccuracy | |||||
from .span_f1_pre_rec_metric import SpanFPreRecMetric | from .span_f1_pre_rec_metric import SpanFPreRecMetric | ||||
from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric | from .classify_f1_pre_rec_metric import ClassifyFPreRecMetric |
@@ -1,5 +1,6 @@ | |||||
__all__ = [ | __all__ = [ | ||||
'Accuracy' | |||||
'Accuracy', | |||||
"TransformersAccuracy" | |||||
] | ] | ||||
from typing import Union | from typing import Union | ||||
@@ -17,9 +18,9 @@ class Accuracy(Metric): | |||||
""" | """ | ||||
计算 准确率 的 metric 。 | 计算 准确率 的 metric 。 | ||||
:param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
:param backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() | |||||
函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 | ||||
:param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
:param aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到 metric, | |||||
当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | 当 backend 不支持分布式时,该参数无意义。如果为 None ,将在 Evaluator 中根据 sampler 是否使用分布式进行自动设置。 | ||||
""" | """ | ||||
super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | super(Accuracy, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) | ||||
@@ -39,11 +40,11 @@ class Accuracy(Metric): | |||||
r""" | r""" | ||||
update 函数将针对一个批次的预测结果做评价指标的累计 | update 函数将针对一个批次的预测结果做评价指标的累计 | ||||
:param torch.Tensor pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
:param pred: 预测的tensor, tensor的形状可以是torch.Size([B,]), torch.Size([B, n_classes]), | |||||
torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | torch.Size([B, max_len]), 或者torch.Size([B, max_len, n_classes]) | ||||
:param torch.Tensor target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
:param target: 真实值的tensor, tensor的形状可以是Element's can be: torch.Size([B,]), | |||||
torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | torch.Size([B,]), torch.Size([B, max_len]), 或者torch.Size([B, max_len]) | ||||
:param torch.Tensor seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
:param seq_len: 序列长度标记, 标记的形状可以是None, None, torch.Size([B]), 或者torch.Size([B]). | |||||
如果mask也被传进来的话seq_len会被忽略. | 如果mask也被传进来的话seq_len会被忽略. | ||||
""" | """ | ||||
# 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | # 为了兼容不同框架,我们将输入变量全部转为numpy类型来进行计算。 | ||||
@@ -79,3 +80,20 @@ class Accuracy(Metric): | |||||
else: | else: | ||||
self.total += np.prod(list(pred.shape)).item() | self.total += np.prod(list(pred.shape)).item() | ||||
self.correct += (target == pred).sum().item() | self.correct += (target == pred).sum().item() | ||||
class TransformersAccuracy(Accuracy): | |||||
""" | |||||
适配 transformers 中相关模型的 Accuracy metric 。 | |||||
""" | |||||
def update(self, logits, labels, attention_mask=None): | |||||
r""" | |||||
update 函数将针对一个批次的预测结果做评价指标的累计 | |||||
:param logits: 形状为 ``[B, n_classes]`` 或 ``[B, max_len, n_classes]`` 。 | |||||
:param labels: 形状为 ``[B, ]`` 或 ``[B, max_len]`` | |||||
:param attention_mask: 序列长度标记。 | |||||
""" | |||||
seq_len = attention_mask.sum(dim=-1) | |||||
super().update(pred=logits, target=labels, seq_len=seq_len) |
@@ -22,9 +22,9 @@ from .utils import apply_to_collection | |||||
def _convert_data_device(device: Union[str, int]) -> str: | def _convert_data_device(device: Union[str, int]) -> str: | ||||
""" | """ | ||||
用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 ``fastNLP`` 会将 | |||||
用于转换 ``driver`` 的 ``data_device`` 的函数。如果用户设置了 ``FASTNLP_BACKEND=paddle``,那么 **fastNLP** 会将 | |||||
可见的设备保存在 ``USER_CUDA_VISIBLE_DEVICES`` 中,并且将 ``CUDA_VISIBLE_DEVICES`` 设置为可见的第一张显卡;这是为 | 可见的设备保存在 ``USER_CUDA_VISIBLE_DEVICES`` 中,并且将 ``CUDA_VISIBLE_DEVICES`` 设置为可见的第一张显卡;这是为 | ||||
了顺利执行 ``paddle`` 的分布式训练而设置的。 | |||||
了顺利执行 **paddle** 的分布式训练而设置的。 | |||||
在这种情况下,单纯使用 ``driver.data_device`` 是无效的。比如在分布式训练中将设备设置为 ``[0,2,3]`` ,且用户设置了 | 在这种情况下,单纯使用 ``driver.data_device`` 是无效的。比如在分布式训练中将设备设置为 ``[0,2,3]`` ,且用户设置了 | ||||
``CUDA_VISIBLE_DEVICES=3,4,5,6`` ,那么在 ``rank1``的进程中有:: | ``CUDA_VISIBLE_DEVICES=3,4,5,6`` ,那么在 ``rank1``的进程中有:: | ||||
@@ -127,7 +127,7 @@ def get_paddle_device_id(device: Union[str, int]) -> int: | |||||
def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> Any: | def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> Any: | ||||
r""" | r""" | ||||
将 ``paddle`` 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。 | |||||
将 **paddle** 的数据集合传输到给定设备。只有 :class:`paddle.Tensor` 对象会被传输到设备中,其余保持不变。 | |||||
:param batch: 需要进行迁移的数据集合; | :param batch: 需要进行迁移的数据集合; | ||||
:param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串;当这个参数 | :param device: 目标设备。可以是显卡设备的编号,或是``cpu``, ``gpu`` 或 ``gpu:x`` 格式的字符串;当这个参数 | ||||
@@ -145,20 +145,20 @@ def paddle_move_data_to_device(batch: Any, device: Optional[Union[str, int]]) -> | |||||
def is_in_paddle_dist() -> bool: | def is_in_paddle_dist() -> bool: | ||||
""" | """ | ||||
判断是否处于 ``paddle`` 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。 | |||||
判断是否处于 **paddle** 分布式的进程下,使用 ``PADDLE_RANK_IN_NODE`` 和 ``FLAGS_selected_gpus`` 判断。 | |||||
""" | """ | ||||
return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | return ('PADDLE_RANK_IN_NODE' in os.environ and 'FLAGS_selected_gpus' in os.environ) | ||||
def is_in_fnlp_paddle_dist() -> bool: | def is_in_fnlp_paddle_dist() -> bool: | ||||
""" | """ | ||||
判断是否处于 ``fastNLP`` 拉起的 ``paddle`` 分布式进程中 | |||||
判断是否处于 **fastNLP** 拉起的 **paddle** 分布式进程中 | |||||
""" | """ | ||||
return FASTNLP_DISTRIBUTED_CHECK in os.environ | return FASTNLP_DISTRIBUTED_CHECK in os.environ | ||||
def is_in_paddle_launch_dist() -> bool: | def is_in_paddle_launch_dist() -> bool: | ||||
""" | """ | ||||
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 ``paddle`` 分布式进程中 | |||||
判断是否处于 ``python -m paddle.distributed.launch`` 方法启动的 **paddle** 分布式进程中 | |||||
""" | """ | ||||
return FASTNLP_BACKEND_LAUNCH in os.environ | return FASTNLP_BACKEND_LAUNCH in os.environ |
@@ -1,5 +1,5 @@ | |||||
""" | """ | ||||
该文件用于为 ``fastNLP`` 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中 | |||||
该文件用于为 **fastNLP** 提供一个统一的 ``progress bar`` 管理,通过共用一个``Task`` 对象, :class:`~fastNLP.core.Trainer` 中 | |||||
的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突 | 的 ``progress bar`` 和 :class:`~fastNLP.core.Evaluator` 中的 ``progress bar`` 才能不冲突 | ||||
""" | """ | ||||
import sys | import sys | ||||
@@ -44,11 +44,11 @@ class TorchTransferableDataType(ABC): | |||||
def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, | def torch_move_data_to_device(batch: Any, device: Optional[Union[str, "torch.device"]] = None, | ||||
non_blocking: Optional[bool] = True) -> Any: | non_blocking: Optional[bool] = True) -> Any: | ||||
r""" | r""" | ||||
在 ``pytorch`` 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||||
在 **pytorch** 中将数据集合 ``batch`` 传输到给定设备。任何定义方法 ``to(device)`` 的对象都将被移动并且集合中的所有其他对象将保持不变; | |||||
:param batch: 需要迁移的数据; | :param batch: 需要迁移的数据; | ||||
:param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作; | :param device: 数据应当迁移到的设备;当该参数的值为 ``None`` 时则不执行任何操作; | ||||
:param non_blocking: ``pytorch`` 的数据迁移方法 ``to`` 的参数; | |||||
:param non_blocking: **pytorch** 的数据迁移方法 ``to`` 的参数; | |||||
:return: 迁移到新设备上的数据集合; | :return: 迁移到新设备上的数据集合; | ||||
""" | """ | ||||
if device is None: | if device is None: | ||||
@@ -55,7 +55,7 @@ def get_fn_arg_names(fn: Callable) -> List[str]: | |||||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | ||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | ||||
r""" | r""" | ||||
该函数会根据输入函数的形参名从 ``*args`` (均为 ``dict`` 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过 | |||||
该函数会根据输入函数的形参名从 ``*args`` (均为 **dict** 类型)中找到匹配的值进行调用,如果传入的数据与 ``fn`` 的形参不匹配,可以通过 | |||||
``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为 | ``mapping`` 参数进行转换。``mapping`` 参数中的一对 ``(key, value)`` 表示在 ``*args`` 中找到 ``key`` 对应的值,并将这个值传递给形参中名为 | ||||
``value`` 的参数。 | ``value`` 的参数。 | ||||
@@ -259,21 +259,21 @@ def dataclass_to_dict(data: "dataclasses.dataclass") -> Dict: | |||||
def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, data: Optional[Any] = None) -> Any: | ||||
r""" | r""" | ||||
用来实现将输入的 ``batch`` 或者输出的 ``outputs`` 通过 ``mapping`` 将键值进行更换的功能; | |||||
用来实现将输入的 **batch** 或者输出的 **outputs** 通过 ``mapping`` 将键值进行更换的功能; | |||||
该函数应用于 ``input_mapping`` 和 ``output_mapping``; | 该函数应用于 ``input_mapping`` 和 ``output_mapping``; | ||||
* 对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用; | * 对于 ``input_mapping``,该函数会在 :class:`~fastNLP.core.controllers.TrainBatchLoop` 中取完数据后立刻被调用; | ||||
* 对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step` | * 对于 ``output_mapping``,该函数会在 :class:`~fastNLP.core.Trainer` 的 :meth:`~fastNLP.core.Trainer.train_step` | ||||
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用; | |||||
以及 :class:`~fastNLP.core.Evaluator` 的 :meth:`~fastNLP.core.Evaluator.train_step` 中得到结果后立刻被调用; | |||||
转换的逻辑按优先级依次为: | 转换的逻辑按优先级依次为: | ||||
1. 如果 ``mapping`` 是一个函数,那么会直接返回 ``mapping(data)``; | |||||
2. 如果 ``mapping`` 是一个 ``Dict``,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``; | |||||
1. 如果 ``mapping`` 是一个函数,那么会直接返回 **mapping(data)**; | |||||
2. 如果 ``mapping`` 是一个 **Dict**,那么 ``data`` 的类型只能为以下三种: ``[Dict, dataclass, Sequence]``; | |||||
* 如果 ``data`` 是 ``Dict``,那么该函数会将 ``data`` 的 ``key`` 替换为 ``mapping[key]``; | |||||
* 如果 ``data`` 是 ``dataclass``,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 ``Dict``,然后进行转换; | |||||
* 如果 ``data`` 是 ``Sequence``,那么该函数会先将其转换成一个对应的字典:: | |||||
* 如果 ``data`` 是 **Dict**,那么该函数会将 ``data`` 的 ``key`` 替换为 **mapping[key]**; | |||||
* 如果 ``data`` 是 **dataclass**,那么该函数会先使用 :func:`dataclasses.asdict` 函数将其转换为 **Dict**,然后进行转换; | |||||
* 如果 ``data`` 是 **Sequence**,那么该函数会先将其转换成一个对应的字典:: | |||||
{ | { | ||||
"_0": list[0], | "_0": list[0], | ||||
@@ -281,7 +281,7 @@ def match_and_substitute_params(mapping: Optional[Union[Callable, Dict]] = None, | |||||
... | ... | ||||
} | } | ||||
然后使用 ``mapping`` 对这个 ``Dict`` 进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``\'\_number\'`` 这个形式。 | |||||
然后使用 ``mapping`` 对这个字典进行转换,如果没有匹配上 ``mapping`` 中的 ``key`` 则保持 ``'_number'`` 这个形式。 | |||||
:param mapping: 用于转换的字典或者函数;当 ``mapping`` 是函数时,返回值必须为字典类型; | :param mapping: 用于转换的字典或者函数;当 ``mapping`` 是函数时,返回值必须为字典类型; | ||||
:param data: 需要被转换的对象; | :param data: 需要被转换的对象; | ||||
@@ -459,7 +459,7 @@ def _is_iterable(value): | |||||
def pretty_table_printer(dataset_or_ins) -> PrettyTable: | def pretty_table_printer(dataset_or_ins) -> PrettyTable: | ||||
r""" | r""" | ||||
用于在 ``fastNLP`` 中展示数据的函数:: | |||||
用于在 **fastNLP** 中展示数据的函数:: | |||||
>>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | >>> ins = Instance(field_1=[1, 1, 1], field_2=[2, 2, 2], field_3=["a", "b", "c"]) | ||||
+-----------+-----------+-----------------+ | +-----------+-----------+-----------------+ | ||||
@@ -249,7 +249,7 @@ class DataBundle: | |||||
return self | return self | ||||
def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | def apply_field_more(self, func: Callable, field_name: str, num_proc: int = 0, modify_fields=True, | ||||
ignore_miss_dataset=True, progress_desc: str = '', show_progress_bar: bool = True): | |||||
ignore_miss_dataset=True, show_progress_bar: bool = True, progress_desc: str = ''): | |||||
r""" | r""" | ||||
对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | 对 :class:`~fastNLP.io.DataBundle` 中所有的 dataset 使用 :meth:`~fastNLP.DataSet.apply_field_more` 方法 | ||||
@@ -263,8 +263,8 @@ class DataBundle: | |||||
:param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | :param num_proc: 进程的数量。请注意,由于python语言的特性,多少进程就会导致多少倍内存的增长。 | ||||
:param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | :param bool ignore_miss_dataset: 当某个field名称在某个dataset不存在时,如果为True,则直接忽略该DataSet; | ||||
如果为False,则报错 | 如果为False,则报错 | ||||
:param show_progress_bar: 是否显示tqdm进度条 | |||||
:param progress_desc: 当show_progress_barm为True时,可以显示当前tqdm正在处理的名称 | |||||
:param show_progress_bar: 是否显示进度条 | |||||
:param progress_desc: 当 ``show_progress_bar`` 为 ``True`` 时,可以显示 ``progress`` 的名称。 | |||||
:return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | :return Dict[str:Dict[str:Field]]: 返回一个字典套字典,第一层的 key 是 dataset 的名字,第二层的 key 是 field 的名字 | ||||
@@ -0,0 +1,242 @@ | |||||
import warnings | |||||
from typing import Any, Optional, Union | |||||
import numpy as np | |||||
from fastNLP.core.utils import paddle_to, apply_to_collection | |||||
from fastNLP.core.log import logger | |||||
from fastNLP.envs.imports import _NEED_IMPORT_JITTOR, _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE | |||||
if _NEED_IMPORT_PADDLE: | |||||
import paddle | |||||
if _NEED_IMPORT_JITTOR: | |||||
import jittor | |||||
if _NEED_IMPORT_TORCH: | |||||
import torch | |||||
__all__ = [ | |||||
"paddle2torch", | |||||
"torch2paddle", | |||||
"jittor2torch", | |||||
"torch2jittor", | |||||
] | |||||
def _paddle2torch(paddle_tensor: 'paddle.Tensor', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor': | |||||
""" | |||||
将 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` ,并且能够保留梯度进行反向传播 | |||||
:param paddle_tensor: 要转换的 **paddle** 张量; | |||||
:param device: 是否将转换后的张量迁移到特定设备上,为 ``None``时,和输入的张量相同; | |||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; | |||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; | |||||
:return: 转换后的 **torch** 张量; | |||||
""" | |||||
no_gradient = paddle_tensor.stop_gradient if no_gradient is None else no_gradient | |||||
paddle_numpy = paddle_tensor.numpy() | |||||
if not np.issubdtype(paddle_numpy.dtype, np.inexact): | |||||
no_gradient = True | |||||
if device is None: | |||||
if paddle_tensor.place.is_gpu_place(): | |||||
# paddlepaddle有两种Place,对应不同的device id获取方式 | |||||
if hasattr(paddle_tensor.place, "gpu_device_id"): | |||||
# paddle.fluid.core_avx.Place | |||||
# 在gpu环境下创建张量的话,张量的place是这一类型 | |||||
device = f"cuda:{paddle_tensor.place.gpu_device_id()}" | |||||
else: | |||||
# paddle.CUDAPlace | |||||
device = f"cuda:{paddle_tensor.place.get_device_id()}" | |||||
else: | |||||
# TODO: 可能需要支持xpu等设备 | |||||
device = "cpu" | |||||
if not no_gradient: | |||||
# 保持梯度,并保持反向传播 | |||||
# torch.tensor会保留numpy数组的类型 | |||||
torch_tensor = torch.tensor(paddle_numpy, requires_grad=True, device=device) | |||||
hook = torch_tensor.register_hook( | |||||
lambda grad: paddle.autograd.backward(paddle_tensor, paddle.to_tensor(grad.cpu().numpy())) | |||||
) | |||||
else: | |||||
# 不保留梯度 | |||||
torch_tensor = torch.tensor(paddle_numpy, requires_grad=False, device=device) | |||||
return torch_tensor | |||||
def _torch2paddle(torch_tensor: 'torch.Tensor', device: str = None, no_gradient: bool = None) -> 'paddle.Tensor': | |||||
""" | |||||
将 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor`,并且能够保留梯度进行反向传播。 | |||||
:param torch_tensor: 要转换的 **torch** 张量; | |||||
:param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,和输入的张量相同; | |||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; | |||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; | |||||
:return: 转换后的 **paddle** 张量; | |||||
""" | |||||
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient | |||||
if device is None: | |||||
if torch_tensor.is_cuda: | |||||
device = f"gpu:{torch_tensor.device.index}" | |||||
else: | |||||
device = "cpu" | |||||
if not no_gradient: | |||||
# 保持梯度并保持反向传播 | |||||
# paddle的stop_gradient和torch的requires_grad表现是相反的 | |||||
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=False) | |||||
hook = paddle_tensor.register_hook( | |||||
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) | |||||
) | |||||
else: | |||||
paddle_tensor = paddle.to_tensor(torch_tensor.detach().numpy(), stop_gradient=True) | |||||
paddle_tensor = paddle_to(paddle_tensor, device) | |||||
return paddle_tensor | |||||
def _jittor2torch(jittor_var: 'jittor.Var', device: Optional[Union[str, int]] = None, no_gradient: bool = None) -> 'torch.Tensor': | |||||
""" | |||||
将 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。 | |||||
:param jittor_var: 要转换的 **jittor** 变量; | |||||
:param device: 是否将转换后的张量迁移到特定设备上,输入为 ``None`` 时,根据 ``jittor.flags.use_cuda`` 决定; | |||||
:param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致; | |||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; | |||||
:return: 转换后的 **torch** 张量; | |||||
""" | |||||
# TODO: warning:无法保留梯度 | |||||
# jittor的grad可以通过callback进行传递 | |||||
# 如果outputs有_grad键,可以实现求导 | |||||
no_gradient = not jittor_var.requires_grad if no_gradient is None else no_gradient | |||||
if no_gradient == False: | |||||
warnings.warn("The result tensor will not keep gradients due to differences between jittor and pytorch.") | |||||
jittor_numpy = jittor_var.numpy() | |||||
if not np.issubdtype(jittor_numpy.dtype, np.inexact): | |||||
no_gradient = True | |||||
if device is None: | |||||
# jittor的设备分配是自动的 | |||||
# 根据use_cuda判断 | |||||
if jittor.flags.use_cuda: | |||||
device = "cuda:0" | |||||
else: | |||||
device = "cpu" | |||||
torch_tensor = torch.tensor(jittor_numpy, requires_grad=not no_gradient, device=device) | |||||
return torch_tensor | |||||
def _torch2jittor(torch_tensor: 'torch.Tensor', no_gradient: bool = None) -> 'jittor.Var': | |||||
""" | |||||
将 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。 | |||||
:param torch_tensor: 要转换的 **torch** 张量; | |||||
:param no_gradient: 是否保留原张量的梯度。为``None``时,新的张量与输入张量保持一致; | |||||
为 ``True`` 时,全部不保留梯度;为 ``False`` 时,全部保留梯度; | |||||
:return: 转换后的 **jittor** 变量; | |||||
""" | |||||
no_gradient = not torch_tensor.requires_grad if no_gradient is None else no_gradient | |||||
if not no_gradient: | |||||
# 保持梯度并保持反向传播 | |||||
jittor_var = jittor.Var(torch_tensor.detach().numpy()) | |||||
jittor_var.requires_grad = True | |||||
hook = jittor_var.register_hook( | |||||
lambda grad: torch.autograd.backward(torch_tensor, torch.tensor(grad.numpy())) | |||||
) | |||||
else: | |||||
jittor_var = jittor.Var(torch_tensor.detach().numpy()) | |||||
jittor_var.requires_grad = False | |||||
return jittor_var | |||||
def torch2paddle(batch: Any, device: str = None, no_gradient: bool = None) -> Any: | |||||
""" | |||||
递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`paddle.Tensor` 。 | |||||
:param batch: 包含 :class:`torch.Tensor` 类型的数据集合 | |||||
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None`` 时,和输入保持一致; | |||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; | |||||
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度; | |||||
:return: 转换后的数据; | |||||
""" | |||||
return apply_to_collection( | |||||
batch, | |||||
dtype=torch.Tensor, | |||||
function=_torch2paddle, | |||||
device=device, | |||||
no_gradient=no_gradient, | |||||
) | |||||
def paddle2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any: | |||||
""" | |||||
递归地将输入中包含的 :class:`paddle.Tensor` 转换为 :class:`torch.Tensor` 。 | |||||
:param batch: 包含 :class:`paddle.Tensor` 类型的数据集合; | |||||
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致; | |||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; | |||||
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度; | |||||
:return: 转换后的数据; | |||||
""" | |||||
return apply_to_collection( | |||||
batch, | |||||
dtype=paddle.Tensor, | |||||
function=_paddle2torch, | |||||
device=device, | |||||
no_gradient=no_gradient, | |||||
) | |||||
def jittor2torch(batch: Any, device: str = None, no_gradient: bool = None) -> Any: | |||||
""" | |||||
递归地将输入中包含的 :class:`jittor.Var` 转换为 :class:`torch.Tensor` 。 | |||||
.. note:: | |||||
注意,由于 **pytorch** 和 **jittor** 之间的差异,从 :class:`jittor.Var` 转换 | |||||
至 :class:`torch.Tensor` 的过程中无法保留原张量的梯度。 | |||||
:param batch: 包含 :class:`jittor.Var` 类型的数据集合; | |||||
:param device: 是否将转换后的张量迁移到特定设备上。为 ``None``时,和输入保持一致; | |||||
:param no_gradient: 是否保留原张量的梯度,在这个函数中该参数无效。 | |||||
:return: 转换后的数据; | |||||
""" | |||||
return apply_to_collection( | |||||
batch, | |||||
dtype=jittor.Var, | |||||
function=_jittor2torch, | |||||
device=device, | |||||
no_gradient=no_gradient, | |||||
) | |||||
def torch2jittor(batch: Any, no_gradient: bool = None) -> Any: | |||||
""" | |||||
递归地将输入中包含的 :class:`torch.Tensor` 转换为 :class:`jittor.Var` 。 | |||||
.. note:: | |||||
**jittor** 会自动为创建的变量分配设备。 | |||||
:param batch: 包含 :class:`torch.Tensor` 类型的数据集合; | |||||
:param no_gradient: 是否保留原张量的梯度。为 ``None`` 时,新的张量与输入张量保持一致; | |||||
为 ``True`` 时,不保留梯度;为 ``False`` 时,保留梯度; | |||||
:return: 转换后的数据; | |||||
""" | |||||
return apply_to_collection( | |||||
batch, | |||||
dtype=torch.Tensor, | |||||
function=_torch2jittor, | |||||
no_gradient=no_gradient, | |||||
) |
@@ -314,7 +314,7 @@ class PretrainedConfig: | |||||
# TPU arguments | # TPU arguments | ||||
if kwargs.pop("xla_device", None) is not None: | if kwargs.pop("xla_device", None) is not None: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
"The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " | "The `xla_device` argument has been deprecated in v4.4.0 of Transformers. It is ignored and you can " | ||||
"safely remove it from your `config.json` file." | "safely remove it from your `config.json` file." | ||||
) | ) | ||||
@@ -474,7 +474,7 @@ class PretrainedConfig: | |||||
""" | """ | ||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) | ||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " | ||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." | ||||
) | ) | ||||
@@ -564,9 +564,9 @@ class PretrainedConfig: | |||||
raise EnvironmentError(msg) | raise EnvironmentError(msg) | ||||
if resolved_config_file == config_file: | if resolved_config_file == config_file: | ||||
logger.info(f"loading configuration file {config_file}") | |||||
logger.debug(f"loading configuration file {config_file}") | |||||
else: | else: | ||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") | |||||
logger.debug(f"loading configuration file {config_file} from cache at {resolved_config_file}") | |||||
return config_dict, kwargs | return config_dict, kwargs | ||||
@@ -603,7 +603,7 @@ class PretrainedConfig: | |||||
for key in to_remove: | for key in to_remove: | ||||
kwargs.pop(key, None) | kwargs.pop(key, None) | ||||
logger.info(f"Model config {config}") | |||||
logger.debug(f"Model config {config}") | |||||
if return_unused_kwargs: | if return_unused_kwargs: | ||||
return config, kwargs | return config, kwargs | ||||
else: | else: | ||||
@@ -17,7 +17,7 @@ from enum import Enum | |||||
from functools import partial | from functools import partial | ||||
from hashlib import sha256 | from hashlib import sha256 | ||||
from pathlib import Path | from pathlib import Path | ||||
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union | |||||
from typing import Any, BinaryIO, Dict, Optional, Tuple, Union, List | |||||
from urllib.parse import urlparse | from urllib.parse import urlparse | ||||
from uuid import uuid4 | from uuid import uuid4 | ||||
from zipfile import ZipFile, is_zipfile | from zipfile import ZipFile, is_zipfile | ||||
@@ -750,6 +750,78 @@ def get_from_cache( | |||||
return cache_path | return cache_path | ||||
def get_list_of_files( | |||||
path_or_repo: Union[str, os.PathLike], | |||||
revision: Optional[str] = None, | |||||
use_auth_token: Optional[Union[bool, str]] = None, | |||||
local_files_only: bool = False, | |||||
) -> List[str]: | |||||
""" | |||||
Gets the list of files inside :obj:`path_or_repo`. | |||||
Args: | |||||
path_or_repo (:obj:`str` or :obj:`os.PathLike`): | |||||
Can be either the id of a repo on huggingface.co or a path to a `directory`. | |||||
revision (:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
identifier allowed by git. | |||||
use_auth_token (:obj:`str` or `bool`, `optional`): | |||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token | |||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). | |||||
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
Whether or not to only rely on local files and not to attempt to download any files. | |||||
Returns: | |||||
:obj:`List[str]`: The list of files available in :obj:`path_or_repo`. | |||||
""" | |||||
path_or_repo = str(path_or_repo) | |||||
# If path_or_repo is a folder, we just return what is inside (subdirectories included). | |||||
if os.path.isdir(path_or_repo): | |||||
list_of_files = [] | |||||
for path, dir_names, file_names in os.walk(path_or_repo): | |||||
list_of_files.extend([os.path.join(path, f) for f in file_names]) | |||||
return list_of_files | |||||
# Can't grab the files if we are on offline mode. | |||||
if is_offline_mode() or local_files_only: | |||||
return [] | |||||
# Otherwise we grab the token and use the model_info method. | |||||
if isinstance(use_auth_token, str): | |||||
token = use_auth_token | |||||
elif use_auth_token is True: | |||||
# token = HfFolder.get_token() | |||||
path_token = os.path.expanduser("~/.huggingface/token") | |||||
try: | |||||
with open(path_token, "r") as f: | |||||
token = f.read() | |||||
except FileNotFoundError: | |||||
token = None | |||||
else: | |||||
token = None | |||||
# model_info = HfApi(endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT).model_info( | |||||
# path_or_repo, revision=revision, token=token | |||||
# ) | |||||
endpoint=HUGGINGFACE_CO_RESOLVE_ENDPOINT | |||||
path = ( | |||||
f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}" | |||||
if revision is None | |||||
else f"{HUGGINGFACE_CO_RESOLVE_ENDPOINT}/api/models/{path_or_repo}/revision/{revision}" | |||||
) | |||||
headers = {"authorization": f"Bearer {token}"} if token is not None else None | |||||
status_query_param = None | |||||
r = requests.get( | |||||
path, headers=headers, timeout=None, params=status_query_param | |||||
) | |||||
r.raise_for_status() | |||||
d = r.json() | |||||
siblings = d.get("siblings", None) | |||||
rfilenames = ( | |||||
[x["rfilename"] for x in siblings] if siblings is not None else None | |||||
) | |||||
return rfilenames | |||||
def is_torch_fx_available(): | def is_torch_fx_available(): | ||||
return _TORCH_GREATER_EQUAL_1_8 and _compare_version("torch", operator.lt, "1.9.0") | return _TORCH_GREATER_EQUAL_1_8 and _compare_version("torch", operator.lt, "1.9.0") | ||||
@@ -122,7 +122,7 @@ def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_leng | |||||
stopping_max_length = stopping_criteria.max_length | stopping_max_length = stopping_criteria.max_length | ||||
new_stopping_criteria = deepcopy(stopping_criteria) | new_stopping_criteria = deepcopy(stopping_criteria) | ||||
if stopping_max_length is not None and stopping_max_length != max_length: | if stopping_max_length is not None and stopping_max_length != max_length: | ||||
logger.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |||||
logger.rank_zero_warning("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) | |||||
elif stopping_max_length is None: | elif stopping_max_length is None: | ||||
new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) | ||||
return new_stopping_criteria | return new_stopping_criteria |
@@ -429,7 +429,7 @@ class GenerationMixin: | |||||
def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: | def _get_pad_token_id(self, pad_token_id: int = None, eos_token_id: int = None) -> int: | ||||
if pad_token_id is None and eos_token_id is not None: | if pad_token_id is None and eos_token_id is not None: | ||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
pad_token_id = eos_token_id | pad_token_id = eos_token_id | ||||
return pad_token_id | return pad_token_id | ||||
@@ -912,7 +912,7 @@ class GenerationMixin: | |||||
# special case if pad_token_id is not defined | # special case if pad_token_id is not defined | ||||
if pad_token_id is None and eos_token_id is not None: | if pad_token_id is None and eos_token_id is not None: | ||||
logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
logger.rank_zero_warning(f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation.") | |||||
pad_token_id = eos_token_id | pad_token_id = eos_token_id | ||||
# Storing encoder_input_ids for logits_processor that could use them | # Storing encoder_input_ids for logits_processor that could use them | ||||
@@ -352,7 +352,7 @@ class ModuleUtilsMixin: | |||||
if token_inputs: | if token_inputs: | ||||
return sum([token_input.numel() for token_input in token_inputs]) | return sum([token_input.numel() for token_input in token_inputs]) | ||||
else: | else: | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
"Could not estimate the number of tokens of the input, floating-point operations will not be computed" | "Could not estimate the number of tokens of the input, floating-point operations will not be computed" | ||||
) | ) | ||||
return 0 | return 0 | ||||
@@ -646,7 +646,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
# tie weights recursively | # tie weights recursively | ||||
tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) | ||||
if len(uninitialized_encoder_weights) > 0: | if len(uninitialized_encoder_weights) > 0: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" | ||||
) | ) | ||||
@@ -1260,9 +1260,9 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
raise EnvironmentError(msg) | raise EnvironmentError(msg) | ||||
if resolved_archive_file == archive_file: | if resolved_archive_file == archive_file: | ||||
logger.info(f"loading weights file {archive_file}") | |||||
logger.debug(f"loading weights file {archive_file}") | |||||
else: | else: | ||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") | |||||
logger.debug(f"loading weights file {archive_file} from cache at {resolved_archive_file}") | |||||
else: | else: | ||||
resolved_archive_file = None | resolved_archive_file = None | ||||
@@ -1486,7 +1486,7 @@ class PreTrainedModel(Module, ModuleUtilsMixin, GenerationMixin): | |||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") | ||||
if len(unexpected_keys) > 0: | if len(unexpected_keys) > 0: | ||||
logger.warning( | |||||
logger.rank_zero_warning( | |||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when " | ||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | f"initializing {model.__class__.__name__}: {unexpected_keys}\n" | ||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task " | ||||
@@ -171,7 +171,7 @@ class BartConfig(PretrainedConfig): | |||||
# ensure backward compatibility for BART CNN models | # ensure backward compatibility for BART CNN models | ||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): | if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False): | ||||
self.forced_bos_token_id = self.bos_token_id | self.forced_bos_token_id = self.bos_token_id | ||||
logger.warn( | |||||
logger.rank_zero_warning( | |||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." | f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions." | ||||
"The config can simply be saved and uploaded again to be fixed." | "The config can simply be saved and uploaded again to be fixed." | ||||
) | ) |
@@ -44,6 +44,8 @@ from .file_utils import ( | |||||
cached_path, | cached_path, | ||||
is_offline_mode, | is_offline_mode, | ||||
is_remote_url, | is_remote_url, | ||||
get_list_of_files, | |||||
hf_bucket_url, | |||||
is_tokenizers_available, | is_tokenizers_available, | ||||
to_py_obj, | to_py_obj, | ||||
) | ) | ||||
@@ -100,7 +102,7 @@ TOKENIZER_CONFIG_FILE = "tokenizer_config.json" | |||||
# Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file | # Fast tokenizers (provided by HuggingFace tokenizer's library) can be saved in a single file | ||||
FULL_TOKENIZER_FILE = "tokenizer.json" | FULL_TOKENIZER_FILE = "tokenizer.json" | ||||
_re_tokenizer_file = re.compile(r"tokenizer\.(.*)\.json") | |||||
class TruncationStrategy(ExplicitEnum): | class TruncationStrategy(ExplicitEnum): | ||||
""" | """ | ||||
@@ -1607,8 +1609,41 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): | |||||
file_id = list(cls.vocab_files_names.keys())[0] | file_id = list(cls.vocab_files_names.keys())[0] | ||||
vocab_files[file_id] = pretrained_model_name_or_path | vocab_files[file_id] = pretrained_model_name_or_path | ||||
else: | else: | ||||
raise RuntimeError("At this point pretrained_model_name_or_path is either a directory or a model identifier name, ", | |||||
"which is not supported in fastNLP now.") | |||||
# raise RuntimeError("At this point pretrained_model_name_or_path is either a directory or a model identifier name, ", | |||||
# "which is not supported in fastNLP now.") | |||||
# At this point pretrained_model_name_or_path is either a directory or a model identifier name | |||||
fast_tokenizer_file = get_fast_tokenizer_file( | |||||
pretrained_model_name_or_path, | |||||
revision=revision, | |||||
use_auth_token=use_auth_token, | |||||
local_files_only=local_files_only, | |||||
) | |||||
additional_files_names = { | |||||
"added_tokens_file": ADDED_TOKENS_FILE, | |||||
"special_tokens_map_file": SPECIAL_TOKENS_MAP_FILE, | |||||
"tokenizer_config_file": TOKENIZER_CONFIG_FILE, | |||||
"tokenizer_file": fast_tokenizer_file, | |||||
} | |||||
# Look for the tokenizer files | |||||
for file_id, file_name in {**cls.vocab_files_names, **additional_files_names}.items(): | |||||
if os.path.isdir(pretrained_model_name_or_path): | |||||
if subfolder is not None: | |||||
full_file_name = os.path.join(pretrained_model_name_or_path, subfolder, file_name) | |||||
else: | |||||
full_file_name = os.path.join(pretrained_model_name_or_path, file_name) | |||||
if not os.path.exists(full_file_name): | |||||
logger.info(f"Didn't find file {full_file_name}. We won't load it.") | |||||
full_file_name = None | |||||
else: | |||||
full_file_name = hf_bucket_url( | |||||
pretrained_model_name_or_path, | |||||
filename=file_name, | |||||
subfolder=subfolder, | |||||
revision=revision, | |||||
mirror=None, | |||||
) | |||||
vocab_files[file_id] = full_file_name | |||||
# Get files from url, cache, or disk depending on the case | # Get files from url, cache, or disk depending on the case | ||||
resolved_vocab_files = {} | resolved_vocab_files = {} | ||||
@@ -1665,9 +1700,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): | |||||
continue | continue | ||||
if file_path == resolved_vocab_files[file_id]: | if file_path == resolved_vocab_files[file_id]: | ||||
logger.info(f"loading file {file_path}") | |||||
logger.debug(f"loading file {file_path}") | |||||
else: | else: | ||||
logger.info(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") | |||||
logger.debug(f"loading file {file_path} from cache at {resolved_vocab_files[file_id]}") | |||||
return cls._from_pretrained( | return cls._from_pretrained( | ||||
resolved_vocab_files, | resolved_vocab_files, | ||||
@@ -3349,3 +3384,52 @@ For a more complete example, see the implementation of `prepare_seq2seq_batch`. | |||||
) | ) | ||||
model_inputs["labels"] = labels["input_ids"] | model_inputs["labels"] = labels["input_ids"] | ||||
return model_inputs | return model_inputs | ||||
def get_fast_tokenizer_file( | |||||
path_or_repo: Union[str, os.PathLike], | |||||
revision: Optional[str] = None, | |||||
use_auth_token: Optional[Union[bool, str]] = None, | |||||
local_files_only: bool = False, | |||||
) -> str: | |||||
""" | |||||
Get the tokenizer file to use for this version of transformers. | |||||
Args: | |||||
path_or_repo (:obj:`str` or :obj:`os.PathLike`): | |||||
Can be either the id of a repo on huggingface.co or a path to a `directory`. | |||||
revision(:obj:`str`, `optional`, defaults to :obj:`"main"`): | |||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a | |||||
git-based system for storing models and other artifacts on huggingface.co, so ``revision`` can be any | |||||
identifier allowed by git. | |||||
use_auth_token (:obj:`str` or `bool`, `optional`): | |||||
The token to use as HTTP bearer authorization for remote files. If :obj:`True`, will use the token | |||||
generated when running :obj:`transformers-cli login` (stored in :obj:`~/.huggingface`). | |||||
local_files_only (:obj:`bool`, `optional`, defaults to :obj:`False`): | |||||
Whether or not to only rely on local files and not to attempt to download any files. | |||||
Returns: | |||||
:obj:`str`: The tokenizer file to use. | |||||
""" | |||||
# Inspect all files from the repo/folder. | |||||
all_files = get_list_of_files( | |||||
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only | |||||
) | |||||
tokenizer_files_map = {} | |||||
for file_name in all_files: | |||||
search = _re_tokenizer_file.search(file_name) | |||||
if search is not None: | |||||
v = search.groups()[0] | |||||
tokenizer_files_map[v] = file_name | |||||
available_versions = sorted(tokenizer_files_map.keys()) | |||||
# Defaults to FULL_TOKENIZER_FILE and then try to look at some newer versions. | |||||
tokenizer_file = FULL_TOKENIZER_FILE | |||||
transformers_version = version.parse(__version__) | |||||
for v in available_versions: | |||||
if version.parse(v) <= transformers_version: | |||||
tokenizer_file = tokenizer_files_map[v] | |||||
else: | |||||
# No point going further since the versions are sorted. | |||||
break | |||||
return tokenizer_file |
@@ -0,0 +1,8 @@ | |||||
from fastNLP import print | |||||
def test_print(): | |||||
print("a") | |||||
print([1, 2, 3]) | |||||
print([1,2,3], [4,5,6], 'a') | |||||
print(print) |