Browse Source

增强了cache_results, 现在可以将函数的参数hash进入保存的文件名字中

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
8c5250f5b4
7 changed files with 208 additions and 45 deletions
  1. +8
    -0
      fastNLP/core/__init__.py
  2. +4
    -0
      fastNLP/core/callbacks/__init__.py
  3. +4
    -5
      fastNLP/core/callbacks/more_evaluate_callback.py
  4. +1
    -3
      fastNLP/core/controllers/trainer.py
  5. +2
    -2
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  6. +61
    -35
      fastNLP/core/utils/cache_results.py
  7. +128
    -0
      tests/core/utils/test_cache_results.py

+ 8
- 0
fastNLP/core/__init__.py View File

@@ -14,6 +14,8 @@ __all__ = [
'MoreEvaluateCallback',
"TorchWarmupCallback",
"TorchGradClipCallback",
"MonitorUtility",
'HasMonitorCallback',

# collators
'Collator',
@@ -40,6 +42,12 @@ __all__ = [
'Trainer',

# dataloaders TODO 需要把 mix_dataloader 的搞定
'TorchDataLoader',
'PaddleDataLoader',
'JittorDataLoader',
'prepare_jittor_dataloader',
'prepare_paddle_dataloader',
'prepare_torch_dataloader',

# dataset
'DataSet',


+ 4
- 0
fastNLP/core/callbacks/__init__.py View File

@@ -15,6 +15,9 @@ __all__ = [

"TorchWarmupCallback",
"TorchGradClipCallback",

"MonitorUtility",
'HasMonitorCallback'
]


@@ -28,4 +31,5 @@ from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback
from .torch_callbacks import *
from .more_evaluate_callback import MoreEvaluateCallback
from .has_monitor_callback import MonitorUtility, HasMonitorCallback


+ 4
- 5
fastNLP/core/callbacks/more_evaluate_callback.py View File

@@ -66,7 +66,6 @@ class MoreEvaluateCallback(HasMonitorCallback):
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.")
if watch_monitor is not None and evaluate_every is not None:
raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.")
self.watch_monitor = watch_monitor

if topk_monitor is not None and topk == 0:
raise RuntimeError("`topk_monitor` is set, but `topk` is 0.")
@@ -93,8 +92,8 @@ class MoreEvaluateCallback(HasMonitorCallback):

def on_after_trainer_initialized(self, trainer, driver):
# 如果是需要 watch 的,不能没有 evaluator
if self.watch_monitor is not None:
assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \
if self.monitor is not None:
assert trainer.evaluator is not None, f"You set `watch_monitor={self.monitor}`, but no " \
f"evaluate_dataloaders is provided in Trainer."

if trainer.evaluate_fn is self.evaluate_fn:
@@ -134,7 +133,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
self.topk_saver.save_topk(trainer, results)

def on_train_epoch_end(self, trainer):
if self.watch_monitor is not None:
if self.monitor is not None:
return
if isinstance(self.evaluate_every, int) and self.evaluate_every < 0:
evaluate_every = -self.evaluate_every
@@ -143,7 +142,7 @@ class MoreEvaluateCallback(HasMonitorCallback):
self.topk_saver.save_topk(trainer, results)

def on_train_batch_end(self, trainer):
if self.watch_monitor is not None:
if self.monitor is not None:
return
if callable(self.evaluate_every):
if self.evaluate_every(trainer):


+ 1
- 3
fastNLP/core/controllers/trainer.py View File

@@ -117,6 +117,7 @@ class Trainer(TrainerEventTrigger):
:param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有
在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配
的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。
如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。
:param larger_better: monitor 的值是否是越大越好。
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
:param kwargs: 一些其它的可能需要的参数;
@@ -231,7 +232,6 @@ class Trainer(TrainerEventTrigger):
total_batches=None
)

""" 设置内部的 Evaluator """
if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")

@@ -760,8 +760,6 @@ class Trainer(TrainerEventTrigger):
self.on_before_backward(outputs)
loss = self.extract_loss_from_outputs(outputs)
loss = loss / self.accumulation_steps
# with self.get_no_sync_context():
# self.driver.backward(loss)
self.driver.backward(loss)
self.on_after_backward()



+ 2
- 2
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -165,8 +165,8 @@ class TorchDataLoader(DataLoader):


def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]],
batch_size: int = 1,
shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_size: int = 16,
shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None,
batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None,
num_workers: int = 0, collate_fn: Union[str, Callable, None] = None,
pin_memory: bool = False, drop_last: bool = False,


+ 61
- 35
fastNLP/core/utils/cache_results.py View File

@@ -3,6 +3,7 @@ import hashlib
import _pickle
import functools
import os
import re
from typing import Callable, List, Any, Optional
import inspect
import ast
@@ -126,7 +127,10 @@ def _get_func_and_its_called_func_source_code(func) -> List[str]:
# some failure
pass
del last_frame #
sources.append(inspect.getsource(func))
func_source_code = inspect.getsource(func) # 将这个函数中的 cache_results 装饰删除掉。
for match in list(re.finditer('@cache_results\(.*\)\\n', func_source_code))[::-1]:
func_source_code = func_source_code[:match.start()] + func_source_code[match.end():]
sources.append(func_source_code)
return sources


@@ -163,11 +167,12 @@ def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] =
if fn_kwargs is None:
fn_kwargs = {}
hasher = Hasher()
try:
sources = _get_func_and_its_called_func_source_code(fn)
hasher.update(sources)
except:
return "can't be hashed"
if fn is not None:
try:
sources = _get_func_and_its_called_func_source_code(fn)
hasher.update(sources)
except:
return "can't be hashed"
for key in sorted(fn_kwargs):
hasher.update(key)
try:
@@ -177,7 +182,7 @@ def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] =
return hasher.hexdigest()


def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _check_hash=True):
r"""
cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用::

@@ -186,9 +191,9 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
from fastNLP import cache_results

@cache_results('cache.pkl')
def process_data():
def process_data(second=1):
# 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时
time.sleep(1)
time.sleep(second)
return np.random.randint(10, size=(5,))

start_time = time.time()
@@ -199,49 +204,49 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
print("res =",process_data())
print(time.time() - start_time)

# 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间
# Save cache to cache.pkl.
start_time = time.time()
print("res =",process_data(second=2))
print(time.time() - start_time)

# 输出内容如下,可以看到前两次结果相同,且第二次几乎没有花费时间。第三次由于参数变化了,所以cache的结果也就自然变化了。
# Save cache to 2d145aeb_cache.pkl.
# res = [5 4 9 1 8]
# 1.0042750835418701
# Read cache from cache.pkl.
# 1.0134737491607666
# Read cache from 2d145aeb_cache.pkl (Saved on xxxx).
# res = [5 4 9 1 8]
# 0.0040721893310546875
# Save cache to 0ead3093_cache.pkl.
# res = [1 8 2 5 1]
# 2.0086121559143066

可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理::

# 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可
process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl'

上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的
'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。
上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称::

process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。
# _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache
可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。
如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose,
_check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称::

:param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在
函数调用的时候传入_cache_fp这个参数。
:param bool _refresh: 是否重新生成cache。
函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到
:param bool _hash_param: 是否将传入给被装饰函数的 parameter 进行 str 之后的 hash 结果加入到 _cache_fp 中,这样每次函数的
parameter 改变的时候,cache 文件就自动改变了。
:param bool _refresh: 强制重新生成新的 cache 。
:param int _verbose: 是否打印cache的信息。
:param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值
与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然
该修改对结果有影响,但无法做出warning。

:return:
"""

def wrapper_(func):
signature = inspect.signature(func)
for key, _ in signature.parameters.items():
if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'):
if key in ('_cache_fp', "_hash_param", '_refresh', '_verbose', '_check_hash'):
raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key))

@functools.wraps(func)
def wrapper(*args, **kwargs):
fn_param = kwargs.copy()
if args:
params = [p.name for p in inspect.signature(func).parameters.values()]
fn_param.update(zip(params, args))
# fn_param = kwargs.copy()
# if args:
# params = [p.name for p in inspect.signature(func).parameters.values()]
# fn_param.update(zip(params, args))
if '_cache_fp' in kwargs:
cache_filepath = kwargs.pop('_cache_fp')
assert isinstance(cache_filepath, str), "_cache_fp can only be str."
@@ -263,10 +268,31 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
else:
check_hash = _check_hash

if '_hash_param' in kwargs:
hash_param = kwargs.pop('_hash_param')
assert isinstance(hash_param, bool), "_hash_param can only be bool."
else:
hash_param = _hash_param

if hash_param and cache_filepath is not None: # 尝试将parameter给hash一下
try:
params = dict(inspect.getcallargs(func, *args, **kwargs))
if inspect.ismethod(func): # 如果是 method 的话第一个参数(一般就是 self )就不考虑了
first_key = next(iter(params.items()))
params.pop(first_key)
if len(params):
# sort 一下防止顺序改变
params = {k: str(v) for k, v in sorted(params.items(), key=lambda item: item[0])}
param_hash = cal_fn_hash_code(None, params)[:8]
head, tail = os.path.split(cache_filepath)
cache_filepath = os.path.join(head, param_hash + '_' + tail)
except BaseException as e:
logger.debug(f"Fail to add parameter hash to cache path, because of Exception:{e}")

refresh_flag = True
new_hash_code = None
if check_hash:
new_hash_code = cal_fn_hash_code(func, fn_param)
new_hash_code = cal_fn_hash_code(func, None)

if cache_filepath is not None and refresh is False:
# load data
@@ -281,13 +307,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True):
logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time))
if check_hash and old_hash_code != new_hash_code:
logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The "
f"difference may caused by the sourcecode change of the functions by this function.",
f"difference may caused by the sourcecode change.",
extra={'highlighter': ColorHighlighter('red')})
refresh_flag = False

if refresh_flag:
if new_hash_code is None:
new_hash_code = cal_fn_hash_code(func, fn_param)
new_hash_code = cal_fn_hash_code(func, None)
results = func(*args, **kwargs)
if cache_filepath is not None:
if results is None:


+ 128
- 0
tests/core/utils/test_cache_results.py View File

@@ -246,6 +246,106 @@ class TestCacheResults:
rank_zero_rm('demo.pkl')


def remove_postfix(folder='.', post_fix='.pkl'):
import os
for f in os.listdir(folder):
if os.path.isfile(f) and f.endswith(post_fix):
os.remove(os.path.join(folder, f))


class TestCacheResultsWithParam:
@pytest.mark.parametrize('_refresh', [True, False])
@pytest.mark.parametrize('_hash_param', [True, False])
@pytest.mark.parametrize('_verbose', [0, 1])
@pytest.mark.parametrize('_check_hash', [True, False])
def test_cache_save(self, _refresh, _hash_param, _verbose, _check_hash):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp, _refresh=_refresh, _hash_param=_hash_param, _verbose=_verbose,
_check_hash=_check_hash)
def demo(a=1):
print("¥")
return 1
res = demo()

with Capturing() as output:
res = demo(a=1)
if _refresh is False:
assert '¥' not in output[0]
if _verbose is 0:
assert 'read' not in output[0]

with Capturing() as output:
res = demo(1)
if _refresh is False:
assert '¥' not in output[0]

with Capturing() as output:
res = demo(a=2)
if _hash_param is True: # 一定对不上,需要重新生成
assert '¥' in output[0]

finally:
remove_postfix('.')

def test_cache_complex_param(self):
cache_fp = 'demo.pkl'
try:
@cache_results(cache_fp, _refresh=False)
def demo(*args, s=1, **kwargs):
print("¥")
return 1

res = demo(1,2,3, s=4, d=4)
with Capturing() as output:
res = demo(1,2,3,d=4, s=4)
assert '¥' not in output[0]
finally:
remove_postfix('.')

def test_wrapper_change(self):
cache_fp = 'demo.pkl'
test_type = 'wrapper_change'
try:
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert "¥" in res
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert "¥" not in res
assert 'Read' in res
assert 'different' not in res

finally:
remove_postfix('.')

def test_param_change(self):
cache_fp = 'demo.pkl'
test_type = 'param_change'
try:
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert "¥" in res
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert "¥" in res
assert 'Read' not in res
finally:
remove_postfix('.')

def test_create_cache_dir(self):
@cache_results('demo/demo.pkl')
def cache(s):
return 1, 2

try:
results = cache(s=1)
assert (1, 2) == results
finally:
import shutil
shutil.rmtree('demo/')


if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
@@ -294,3 +394,31 @@ if __name__ == '__main__':

res = demo_func()

if test_type == 'wrapper_change':
if turn == 0:
@cache_results(cache_fp, _refresh=True)
def demo_wrapper_change():
print("¥")
return 1
else:
@cache_results(cache_fp, _refresh=False)
def demo_wrapper_change():
print("¥")
return 1

res = demo_wrapper_change()

if test_type == 'param_change':
if turn == 0:
@cache_results(cache_fp, _refresh=False)
def demo_param_change():
print("¥")
return 1
else:
@cache_results(cache_fp, _refresh=False)
def demo_param_change(a=1):
print("¥")
return 1

res = demo_param_change()


Loading…
Cancel
Save