Browse Source

增强了cache_results, 现在可以将函数的参数hash进入保存的文件名字中

tags/v1.0.0alpha
yh_cc 3 years ago
parent
commit
8c5250f5b4
7 changed files with 208 additions and 45 deletions
  1. +8
    -0
      fastNLP/core/__init__.py
  2. +4
    -0
      fastNLP/core/callbacks/__init__.py
  3. +4
    -5
      fastNLP/core/callbacks/more_evaluate_callback.py
  4. +1
    -3
      fastNLP/core/controllers/trainer.py
  5. +2
    -2
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  6. +61
    -35
      fastNLP/core/utils/cache_results.py
  7. +128
    -0
      tests/core/utils/test_cache_results.py

+ 8
- 0
fastNLP/core/__init__.py View File

@@ -14,6 +14,8 @@ __all__ = [
'MoreEvaluateCallback', 'MoreEvaluateCallback',
"TorchWarmupCallback", "TorchWarmupCallback",
"TorchGradClipCallback", "TorchGradClipCallback",
"MonitorUtility",
'HasMonitorCallback',


# collators # collators
'Collator', 'Collator',
@@ -40,6 +42,12 @@ __all__ = [
'Trainer', 'Trainer',


# dataloaders TODO 需要把 mix_dataloader 的搞定 # dataloaders TODO 需要把 mix_dataloader 的搞定
'TorchDataLoader',
'PaddleDataLoader',
'JittorDataLoader',
'prepare_jittor_dataloader',
'prepare_paddle_dataloader',
'prepare_torch_dataloader',


# dataset # dataset
'DataSet', 'DataSet',


+ 4
- 0
fastNLP/core/callbacks/__init__.py View File

@@ -15,6 +15,9 @@ __all__ = [


"TorchWarmupCallback", "TorchWarmupCallback",
"TorchGradClipCallback", "TorchGradClipCallback",

"MonitorUtility",
'HasMonitorCallback'
] ]




@@ -28,4 +31,5 @@ from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback from .early_stop_callback import EarlyStopCallback
from .torch_callbacks import * from .torch_callbacks import *
from .more_evaluate_callback import MoreEvaluateCallback from .more_evaluate_callback import MoreEvaluateCallback
from .has_monitor_callback import MonitorUtility, HasMonitorCallback



+ 4
- 5
fastNLP/core/callbacks/more_evaluate_callback.py View File

@@ -66,7 +66,6 @@ class MoreEvaluateCallback(HasMonitorCallback):
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.")
if watch_monitor is not None and evaluate_every is not None: if watch_monitor is not None and evaluate_every is not None:
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.")
self.watch_monitor = watch_monitor


if topk_monitor is not None and topk == 0: if topk_monitor is not None and topk == 0:
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") raise RuntimeError("`topk_monitor` is set, but `topk` is 0.")
@@ -93,8 +92,8 @@ class MoreEvaluateCallback(HasMonitorCallback):


def on_after_trainer_initialized(self, trainer, driver): def on_after_trainer_initialized(self, trainer, driver):
# 如果是需要 watch 的,不能没有 evaluator # 如果是需要 watch 的,不能没有 evaluator
if self.watch_monitor is not None:
assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \
if self.monitor is not None:
assert trainer.evaluator is not None, f"You set `watch_monitor={self.monitor}`, but no " \
f"evaluate_dataloaders is provided in Trainer." f"evaluate_dataloaders is provided in Trainer."


if trainer.evaluate_fn is self.evaluate_fn: if trainer.evaluate_fn is self.evaluate_fn:
@@ -134,7 +133,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
self.topk_saver.save_topk(trainer, results) self.topk_saver.save_topk(trainer, results)


def on_train_epoch_end(self, trainer): def on_train_epoch_end(self, trainer):
if self.watch_monitor is not None:
if self.monitor is not None:
return return
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
evaluate_every = -self.evaluate_every evaluate_every = -self.evaluate_every
@@ -143,7 +142,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
self.topk_saver.save_topk(trainer, results) self.topk_saver.save_topk(trainer, results)


def on_train_batch_end(self, trainer): def on_train_batch_end(self, trainer):
if self.watch_monitor is not None:
if self.monitor is not None:
return return
if callable(self.evaluate_every): if callable(self.evaluate_every):
if self.evaluate_every(trainer): if self.evaluate_every(trainer):


+ 1
- 3
fastNLP/core/controllers/trainer.py View File

@@ -117,6 +117,7 @@ class Trainer(TrainerEventTrigger):
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。
:param larger_better: monitor 的值是否是越大越好。 :param larger_better: monitor 的值是否是越大越好。
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
:param kwargs: 一些其它的可能需要的参数; :param kwargs: 一些其它的可能需要的参数;
@@ -231,7 +232,6 @@ class Trainer(TrainerEventTrigger):
total_batches=None total_batches=None
) )


""" 设置内部的 Evaluator """
if metrics is None and evaluate_dataloaders is not None: if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")


@@ -760,8 +760,6 @@ class Trainer(TrainerEventTrigger):
self.on_before_backward(outputs) self.on_before_backward(outputs)
loss = self.extract_loss_from_outputs(outputs) loss = self.extract_loss_from_outputs(outputs)
loss = loss / self.accumulation_steps loss = loss / self.accumulation_steps
# with self.get_no_sync_context():
# self.driver.backward(loss)
self.driver.backward(loss) self.driver.backward(loss)
self.on_after_backward() self.on_after_backward()




+ 2
- 2
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -165,8 +165,8 @@ class TorchDataLoader(DataLoader):




def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_size: int = 16,
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
pin_memory: bool = False, drop_last: bool = False, pin_memory: bool = False, drop_last: bool = False,


+ 61
- 35
fastNLP/core/utils/cache_results.py View File

@@ -3,6 +3,7 @@ import hashlib
import _pickle import _pickle
import functools import functools
import os import os
import re
from typing import Callable, List, Any, Optional from typing import Callable, List, Any, Optional
import inspect import inspect
import ast import ast
@@ -126,7 +127,10 @@ def _get_func_and_its_called_func_source_code(func) -> List[str]:
# some failure # some failure
pass pass
del last_frame # del last_frame #
sources.append(inspect.getsource(func))
func_source_code = inspect.getsource(func) # 将这个函数中的 cache_results 装饰删除掉。
for match in list(re.finditer('@cache_results\(.*\)\\n', func_source_code))[::-1]:
func_source_code = func_source_code[:match.start()] + func_source_code[match.end():]
sources.append(func_source_code)
return sources return sources




@@ -163,11 +167,12 @@ def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] =
if fn_kwargs is None: if fn_kwargs is None:
fn_kwargs = {} fn_kwargs = {}
hasher = Hasher() hasher = Hasher()
try:
sources = _get_func_and_its_called_func_source_code(fn)
hasher.update(sources)
except:
return "can't be hashed"
if fn is not None:
try:
sources = _get_func_and_its_called_func_source_code(fn)
hasher.update(sources)
except:
return "can't be hashed"
for key in sorted(fn_kwargs): for key in sorted(fn_kwargs):
hasher.update(key) hasher.update(key)
try: try:
@@ -177,7 +182,7 @@ def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] =
return hasher.hexdigest() return hasher.hexdigest()




def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _check_hash=True):
r""" r"""
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::


@@ -186,9 +191,9 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
from fastNLP import cache_results from fastNLP import cache_results


@cache_results('cache.pkl') @cache_results('cache.pkl')
def process_data():
def process_data(second=1):
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
time.sleep(1)
time.sleep(second)
return np.random.randint(10, size=(5,)) return np.random.randint(10, size=(5,))


start_time = time.time() start_time = time.time()
@@ -199,49 +204,49 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
print("res =",process_data()) print("res =",process_data())
print(time.time() - start_time) print(time.time() - start_time)


# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
# Save cache to cache.pkl.
start_time = time.time()
print("res =",process_data(second=2))
print(time.time() - start_time)

# 输出内容如下,可以看到前两次结果相同,且第二次几乎没有花费时间。第三次由于参数变化了,所以cache的结果也就自然变化了。
# Save cache to 2d145aeb_cache.pkl.
# res = [5 4 9 1 8] # res = [5 4 9 1 8]
# 1.0042750835418701
# Read cache from cache.pkl.
# 1.0134737491607666
# Read cache from 2d145aeb_cache.pkl (Saved on xxxx).
# res = [5 4 9 1 8] # res = [5 4 9 1 8]
# 0.0040721893310546875 # 0.0040721893310546875
# Save cache to 0ead3093_cache.pkl.
# res = [1 8 2 5 1]
# 2.0086121559143066


可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::

# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'

上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::

process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。
如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose,
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称::


:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
函数调用的时候传入_cache_fp这个参数。
:param bool _refresh: 是否重新生成cache。
函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到
:param bool _hash_param: 是否将传入给被装饰函数的 parameter 进行 str 之后的 hash 结果加入到 _cache_fp 中,这样每次函数的
parameter 改变的时候,cache 文件就自动改变了。
:param bool _refresh: 强制重新生成新的 cache 。
:param int _verbose: 是否打印cache的信息。 :param int _verbose: 是否打印cache的信息。
:param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 :param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值
与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然
该修改对结果有影响,但无法做出warning。 该修改对结果有影响,但无法做出warning。

:return: :return:
""" """


def wrapper_(func): def wrapper_(func):
signature = inspect.signature(func) signature = inspect.signature(func)
for key, _ in signature.parameters.items(): for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'):
if key in ('_cache_fp', "_hash_param", '_refresh', '_verbose', '_check_hash'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))


@functools.wraps(func) @functools.wraps(func)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
fn_param = kwargs.copy()
if args:
params = [p.name for p in inspect.signature(func).parameters.values()]
fn_param.update(zip(params, args))
# fn_param = kwargs.copy()
# if args:
# params = [p.name for p in inspect.signature(func).parameters.values()]
# fn_param.update(zip(params, args))
if '_cache_fp' in kwargs: if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp') cache_filepath = kwargs.pop('_cache_fp')
assert isinstance(cache_filepath, str), "_cache_fp can only be str." assert isinstance(cache_filepath, str), "_cache_fp can only be str."
@@ -263,10 +268,31 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
else: else:
check_hash = _check_hash check_hash = _check_hash


if '_hash_param' in kwargs:
hash_param = kwargs.pop('_hash_param')
assert isinstance(hash_param, bool), "_hash_param can only be bool."
else:
hash_param = _hash_param

if hash_param and cache_filepath is not None: # 尝试将parameter给hash一下
try:
params = dict(inspect.getcallargs(func, *args, **kwargs))
if inspect.ismethod(func): # 如果是 method 的话第一个参数(一般就是 self )就不考虑了
first_key = next(iter(params.items()))
params.pop(first_key)
if len(params):
# sort 一下防止顺序改变
params = {k: str(v) for k, v in sorted(params.items(), key=lambda item: item[0])}
param_hash = cal_fn_hash_code(None, params)[:8]
head, tail = os.path.split(cache_filepath)
cache_filepath = os.path.join(head, param_hash + '_' + tail)
except BaseException as e:
logger.debug(f"Fail to add parameter hash to cache path, because of Exception:{e}")

refresh_flag = True refresh_flag = True
new_hash_code = None new_hash_code = None
if check_hash: if check_hash:
new_hash_code = cal_fn_hash_code(func, fn_param)
new_hash_code = cal_fn_hash_code(func, None)


if cache_filepath is not None and refresh is False: if cache_filepath is not None and refresh is False:
# load data # load data
@@ -281,13 +307,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time))
if check_hash and old_hash_code != new_hash_code: if check_hash and old_hash_code != new_hash_code:
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The "
f"difference may caused by the sourcecode change of the functions by this function.",
f"difference may caused by the sourcecode change.",
extra={'highlighter': ColorHighlighter('red')}) extra={'highlighter': ColorHighlighter('red')})
refresh_flag = False refresh_flag = False


if refresh_flag: if refresh_flag:
if new_hash_code is None: if new_hash_code is None:
new_hash_code = cal_fn_hash_code(func, fn_param)
new_hash_code = cal_fn_hash_code(func, None)
results = func(*args, **kwargs) results = func(*args, **kwargs)
if cache_filepath is not None: if cache_filepath is not None:
if results is None: if results is None:


+ 128
- 0
tests/core/utils/test_cache_results.py View File

@@ -246,6 +246,106 @@ class TestCacheResults:
rank_zero_rm('demo.pkl') rank_zero_rm('demo.pkl')




def remove_postfix(folder='.', post_fix='.pkl'):
import os
for f in os.listdir(folder):
if os.path.isfile(f) and f.endswith(post_fix):
os.remove(os.path.join(folder, f))


class TestCacheResultsWithParam:
@pytest.mark.parametrize('_refresh', [True, False])
@pytest.mark.parametrize('_hash_param', [True, False])
@pytest.mark.parametrize('_verbose', [0, 1])
@pytest.mark.parametrize('_check_hash', [True, False])
def test_cache_save(self, _refresh, _hash_param, _verbose, _check_hash):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp, _refresh=_refresh, _hash_param=_hash_param, _verbose=_verbose,
_check_hash=_check_hash)
def demo(a=1):
print("¥")
return 1
res = demo()

with Capturing() as output:
res = demo(a=1)
if _refresh is False:
assert '¥' not in output[0]
if _verbose is 0:
assert 'read' not in output[0]

with Capturing() as output:
res = demo(1)
if _refresh is False:
assert '¥' not in output[0]

with Capturing() as output:
res = demo(a=2)
if _hash_param is True: # 一定对不上,需要重新生成
assert '¥' in output[0]

finally:
remove_postfix('.')

def test_cache_complex_param(self):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp, _refresh=False)
def demo(*args, s=1, **kwargs):
print("¥")
return 1

res = demo(1,2,3, s=4, d=4)
with Capturing() as output:
res = demo(1,2,3,d=4, s=4)
assert '¥' not in output[0]
finally:
remove_postfix('.')

def test_wrapper_change(self):
cache_fp = 'demo.pkl'
test_type = 'wrapper_change'
try:
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert "¥" in res
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert "¥" not in res
assert 'Read' in res
assert 'different' not in res

finally:
remove_postfix('.')

def test_param_change(self):
cache_fp = 'demo.pkl'
test_type = 'param_change'
try:
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert "¥" in res
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert "¥" in res
assert 'Read' not in res
finally:
remove_postfix('.')

def test_create_cache_dir(self):
@cache_results('demo/demo.pkl')
def cache(s):
return 1, 2

try:
results = cache(s=1)
assert (1, 2) == results
finally:
import shutil
shutil.rmtree('demo/')


if __name__ == '__main__': if __name__ == '__main__':
import argparse import argparse
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -294,3 +394,31 @@ if __name__ == '__main__':


res = demo_func() res = demo_func()


if test_type == 'wrapper_change':
if turn == 0:
@cache_results(cache_fp, _refresh=True)
def demo_wrapper_change():
print("¥")
return 1
else:
@cache_results(cache_fp, _refresh=False)
def demo_wrapper_change():
print("¥")
return 1

res = demo_wrapper_change()

if test_type == 'param_change':
if turn == 0:
@cache_results(cache_fp, _refresh=False)
def demo_param_change():
print("¥")
return 1
else:
@cache_results(cache_fp, _refresh=False)
def demo_param_change(a=1):
print("¥")
return 1

res = demo_param_change()


Loading…
Cancel
Save