| @@ -14,6 +14,8 @@ __all__ = [ | |||||
| 'MoreEvaluateCallback', | 'MoreEvaluateCallback', | ||||
| "TorchWarmupCallback", | "TorchWarmupCallback", | ||||
| "TorchGradClipCallback", | "TorchGradClipCallback", | ||||
| "MonitorUtility", | |||||
| 'HasMonitorCallback', | |||||
| # collators | # collators | ||||
| 'Collator', | 'Collator', | ||||
| @@ -40,6 +42,12 @@ __all__ = [ | |||||
| 'Trainer', | 'Trainer', | ||||
| # dataloaders TODO 需要把 mix_dataloader 的搞定 | # dataloaders TODO 需要把 mix_dataloader 的搞定 | ||||
| 'TorchDataLoader', | |||||
| 'PaddleDataLoader', | |||||
| 'JittorDataLoader', | |||||
| 'prepare_jittor_dataloader', | |||||
| 'prepare_paddle_dataloader', | |||||
| 'prepare_torch_dataloader', | |||||
| # dataset | # dataset | ||||
| 'DataSet', | 'DataSet', | ||||
| @@ -15,6 +15,9 @@ __all__ = [ | |||||
| "TorchWarmupCallback", | "TorchWarmupCallback", | ||||
| "TorchGradClipCallback", | "TorchGradClipCallback", | ||||
| "MonitorUtility", | |||||
| 'HasMonitorCallback' | |||||
| ] | ] | ||||
| @@ -28,4 +31,5 @@ from .load_best_model_callback import LoadBestModelCallback | |||||
| from .early_stop_callback import EarlyStopCallback | from .early_stop_callback import EarlyStopCallback | ||||
| from .torch_callbacks import * | from .torch_callbacks import * | ||||
| from .more_evaluate_callback import MoreEvaluateCallback | from .more_evaluate_callback import MoreEvaluateCallback | ||||
| from .has_monitor_callback import MonitorUtility, HasMonitorCallback | |||||
| @@ -66,7 +66,6 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be None at the same time.") | ||||
| if watch_monitor is not None and evaluate_every is not None: | if watch_monitor is not None and evaluate_every is not None: | ||||
| raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | raise RuntimeError("`evaluate_every` and `watch_monitor` cannot be set at the same time.") | ||||
| self.watch_monitor = watch_monitor | |||||
| if topk_monitor is not None and topk == 0: | if topk_monitor is not None and topk == 0: | ||||
| raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | raise RuntimeError("`topk_monitor` is set, but `topk` is 0.") | ||||
| @@ -93,8 +92,8 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| def on_after_trainer_initialized(self, trainer, driver): | def on_after_trainer_initialized(self, trainer, driver): | ||||
| # 如果是需要 watch 的,不能没有 evaluator | # 如果是需要 watch 的,不能没有 evaluator | ||||
| if self.watch_monitor is not None: | |||||
| assert trainer.evaluator is not None, f"You set `watch_monitor={self.watch_monitor}`, but no " \ | |||||
| if self.monitor is not None: | |||||
| assert trainer.evaluator is not None, f"You set `watch_monitor={self.monitor}`, but no " \ | |||||
| f"evaluate_dataloaders is provided in Trainer." | f"evaluate_dataloaders is provided in Trainer." | ||||
| if trainer.evaluate_fn is self.evaluate_fn: | if trainer.evaluate_fn is self.evaluate_fn: | ||||
| @@ -134,7 +133,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| self.topk_saver.save_topk(trainer, results) | self.topk_saver.save_topk(trainer, results) | ||||
| def on_train_epoch_end(self, trainer): | def on_train_epoch_end(self, trainer): | ||||
| if self.watch_monitor is not None: | |||||
| if self.monitor is not None: | |||||
| return | return | ||||
| if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: | ||||
| evaluate_every = -self.evaluate_every | evaluate_every = -self.evaluate_every | ||||
| @@ -143,7 +142,7 @@ class MoreEvaluateCallback(HasMonitorCallback): | |||||
| self.topk_saver.save_topk(trainer, results) | self.topk_saver.save_topk(trainer, results) | ||||
| def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
| if self.watch_monitor is not None: | |||||
| if self.monitor is not None: | |||||
| return | return | ||||
| if callable(self.evaluate_every): | if callable(self.evaluate_every): | ||||
| if self.evaluate_every(trainer): | if self.evaluate_every(trainer): | ||||
| @@ -117,6 +117,7 @@ class Trainer(TrainerEventTrigger): | |||||
| :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 | ||||
| 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 | ||||
| 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 | ||||
| 如果 evaluate_dataloaders 与 metrics 没有提供,该参数无意义。 | |||||
| :param larger_better: monitor 的值是否是越大越好。 | :param larger_better: monitor 的值是否是越大越好。 | ||||
| :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
| :param kwargs: 一些其它的可能需要的参数; | :param kwargs: 一些其它的可能需要的参数; | ||||
| @@ -231,7 +232,6 @@ class Trainer(TrainerEventTrigger): | |||||
| total_batches=None | total_batches=None | ||||
| ) | ) | ||||
| """ 设置内部的 Evaluator """ | |||||
| if metrics is None and evaluate_dataloaders is not None: | if metrics is None and evaluate_dataloaders is not None: | ||||
| raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") | ||||
| @@ -760,8 +760,6 @@ class Trainer(TrainerEventTrigger): | |||||
| self.on_before_backward(outputs) | self.on_before_backward(outputs) | ||||
| loss = self.extract_loss_from_outputs(outputs) | loss = self.extract_loss_from_outputs(outputs) | ||||
| loss = loss / self.accumulation_steps | loss = loss / self.accumulation_steps | ||||
| # with self.get_no_sync_context(): | |||||
| # self.driver.backward(loss) | |||||
| self.driver.backward(loss) | self.driver.backward(loss) | ||||
| self.on_after_backward() | self.on_after_backward() | ||||
| @@ -165,8 +165,8 @@ class TorchDataLoader(DataLoader): | |||||
| def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], | ||||
| batch_size: int = 1, | |||||
| shuffle: bool = False, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
| batch_size: int = 16, | |||||
| shuffle: bool = True, sampler: Union["Sampler[int]", ReproducibleSampler, UnrepeatedSampler] = None, | |||||
| batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | batch_sampler: Union["Sampler[Sequence[int]]", ReproducibleBatchSampler] = None, | ||||
| num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | num_workers: int = 0, collate_fn: Union[str, Callable, None] = None, | ||||
| pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
| @@ -3,6 +3,7 @@ import hashlib | |||||
| import _pickle | import _pickle | ||||
| import functools | import functools | ||||
| import os | import os | ||||
| import re | |||||
| from typing import Callable, List, Any, Optional | from typing import Callable, List, Any, Optional | ||||
| import inspect | import inspect | ||||
| import ast | import ast | ||||
| @@ -126,7 +127,10 @@ def _get_func_and_its_called_func_source_code(func) -> List[str]: | |||||
| # some failure | # some failure | ||||
| pass | pass | ||||
| del last_frame # | del last_frame # | ||||
| sources.append(inspect.getsource(func)) | |||||
| func_source_code = inspect.getsource(func) # 将这个函数中的 cache_results 装饰删除掉。 | |||||
| for match in list(re.finditer('@cache_results\(.*\)\\n', func_source_code))[::-1]: | |||||
| func_source_code = func_source_code[:match.start()] + func_source_code[match.end():] | |||||
| sources.append(func_source_code) | |||||
| return sources | return sources | ||||
| @@ -163,11 +167,12 @@ def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = | |||||
| if fn_kwargs is None: | if fn_kwargs is None: | ||||
| fn_kwargs = {} | fn_kwargs = {} | ||||
| hasher = Hasher() | hasher = Hasher() | ||||
| try: | |||||
| sources = _get_func_and_its_called_func_source_code(fn) | |||||
| hasher.update(sources) | |||||
| except: | |||||
| return "can't be hashed" | |||||
| if fn is not None: | |||||
| try: | |||||
| sources = _get_func_and_its_called_func_source_code(fn) | |||||
| hasher.update(sources) | |||||
| except: | |||||
| return "can't be hashed" | |||||
| for key in sorted(fn_kwargs): | for key in sorted(fn_kwargs): | ||||
| hasher.update(key) | hasher.update(key) | ||||
| try: | try: | ||||
| @@ -177,7 +182,7 @@ def cal_fn_hash_code(fn: Optional[Callable] = None, fn_kwargs: Optional[dict] = | |||||
| return hasher.hexdigest() | return hasher.hexdigest() | ||||
| def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||||
| def cache_results(_cache_fp, _hash_param=True, _refresh=False, _verbose=1, _check_hash=True): | |||||
| r""" | r""" | ||||
| cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | cache_results是fastNLP中用于cache数据的装饰器。通过下面的例子看一下如何使用:: | ||||
| @@ -186,9 +191,9 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||||
| from fastNLP import cache_results | from fastNLP import cache_results | ||||
| @cache_results('cache.pkl') | @cache_results('cache.pkl') | ||||
| def process_data(): | |||||
| def process_data(second=1): | |||||
| # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | # 一些比较耗时的工作,比如读取数据,预处理数据等,这里用time.sleep()代替耗时 | ||||
| time.sleep(1) | |||||
| time.sleep(second) | |||||
| return np.random.randint(10, size=(5,)) | return np.random.randint(10, size=(5,)) | ||||
| start_time = time.time() | start_time = time.time() | ||||
| @@ -199,49 +204,49 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||||
| print("res =",process_data()) | print("res =",process_data()) | ||||
| print(time.time() - start_time) | print(time.time() - start_time) | ||||
| # 输出内容如下,可以看到两次结果相同,且第二次几乎没有花费时间 | |||||
| # Save cache to cache.pkl. | |||||
| start_time = time.time() | |||||
| print("res =",process_data(second=2)) | |||||
| print(time.time() - start_time) | |||||
| # 输出内容如下,可以看到前两次结果相同,且第二次几乎没有花费时间。第三次由于参数变化了,所以cache的结果也就自然变化了。 | |||||
| # Save cache to 2d145aeb_cache.pkl. | |||||
| # res = [5 4 9 1 8] | # res = [5 4 9 1 8] | ||||
| # 1.0042750835418701 | |||||
| # Read cache from cache.pkl. | |||||
| # 1.0134737491607666 | |||||
| # Read cache from 2d145aeb_cache.pkl (Saved on xxxx). | |||||
| # res = [5 4 9 1 8] | # res = [5 4 9 1 8] | ||||
| # 0.0040721893310546875 | # 0.0040721893310546875 | ||||
| # Save cache to 0ead3093_cache.pkl. | |||||
| # res = [1 8 2 5 1] | |||||
| # 2.0086121559143066 | |||||
| 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理:: | |||||
| # 还是以上面的例子为例,如果需要重新生成另一个cache,比如另一个数据集的内容,通过如下的方式调用即可 | |||||
| process_data(_cache_fp='cache2.pkl') # 完全不影响之前的‘cache.pkl' | |||||
| 上面的_cache_fp是cache_results会识别的参数,它将从'cache2.pkl'这里缓存/读取数据,即这里的'cache2.pkl'覆盖默认的 | |||||
| 'cache.pkl'。如果在你的函数前面加上了@cache_results()则你的函数会增加三个参数[_cache_fp, _refresh, _verbose]。 | |||||
| 上面的例子即为使用_cache_fp的情况,这三个参数不会传入到你的函数中,当然你写的函数参数名也不可能包含这三个名称:: | |||||
| process_data(_cache_fp='cache2.pkl', _refresh=True) # 这里强制重新生成一份对预处理的cache。 | |||||
| # _verbose是用于控制输出信息的,如果为0,则不输出任何内容;如果为1,则会提醒当前步骤是读取的cache还是生成了新的cache | |||||
| 可以看到第二次运行的时候,只用了0.0001s左右,是由于第二次运行将直接从cache.pkl这个文件读取数据,而不会经过再次预处理。 | |||||
| 如果在函数加上了装饰器@cache_results(),则函数会增加五个参数[_cache_fp, _hash_param, _refresh, _verbose, | |||||
| _check_hash]。上面的例子即为使用_cache_fp的情况,这五个参数不会传入到被装饰函数中,当然被装饰函数参数名也不能包含这五个名称:: | |||||
| :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | :param str _cache_fp: 将返回结果缓存到什么位置;或从什么位置读取缓存。如果为None,cache_results没有任何效用,除非在 | ||||
| 函数调用的时候传入_cache_fp这个参数。 | |||||
| :param bool _refresh: 是否重新生成cache。 | |||||
| 函数调用的时候传入 _cache_fp 这个参数。保存文件的名称会受到 | |||||
| :param bool _hash_param: 是否将传入给被装饰函数的 parameter 进行 str 之后的 hash 结果加入到 _cache_fp 中,这样每次函数的 | |||||
| parameter 改变的时候,cache 文件就自动改变了。 | |||||
| :param bool _refresh: 强制重新生成新的 cache 。 | |||||
| :param int _verbose: 是否打印cache的信息。 | :param int _verbose: 是否打印cache的信息。 | ||||
| :param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 | :param bool _check_hash: 如果为 True 将尝试对比修饰的函数的源码以及该函数内部调用的函数的源码的hash值。如果发现保存时的hash值 | ||||
| 与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 | 与当前的hash值有差异,会报warning。但该warning可能出现实质上并不影响结果的误报(例如增删空白行);且在修改不涉及源码时,虽然 | ||||
| 该修改对结果有影响,但无法做出warning。 | 该修改对结果有影响,但无法做出warning。 | ||||
| :return: | :return: | ||||
| """ | """ | ||||
| def wrapper_(func): | def wrapper_(func): | ||||
| signature = inspect.signature(func) | signature = inspect.signature(func) | ||||
| for key, _ in signature.parameters.items(): | for key, _ in signature.parameters.items(): | ||||
| if key in ('_cache_fp', '_refresh', '_verbose', '_check_hash'): | |||||
| if key in ('_cache_fp', "_hash_param", '_refresh', '_verbose', '_check_hash'): | |||||
| raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | raise RuntimeError("The function decorated by cache_results cannot have keyword `{}`.".format(key)) | ||||
| @functools.wraps(func) | @functools.wraps(func) | ||||
| def wrapper(*args, **kwargs): | def wrapper(*args, **kwargs): | ||||
| fn_param = kwargs.copy() | |||||
| if args: | |||||
| params = [p.name for p in inspect.signature(func).parameters.values()] | |||||
| fn_param.update(zip(params, args)) | |||||
| # fn_param = kwargs.copy() | |||||
| # if args: | |||||
| # params = [p.name for p in inspect.signature(func).parameters.values()] | |||||
| # fn_param.update(zip(params, args)) | |||||
| if '_cache_fp' in kwargs: | if '_cache_fp' in kwargs: | ||||
| cache_filepath = kwargs.pop('_cache_fp') | cache_filepath = kwargs.pop('_cache_fp') | ||||
| assert isinstance(cache_filepath, str), "_cache_fp can only be str." | assert isinstance(cache_filepath, str), "_cache_fp can only be str." | ||||
| @@ -263,10 +268,31 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||||
| else: | else: | ||||
| check_hash = _check_hash | check_hash = _check_hash | ||||
| if '_hash_param' in kwargs: | |||||
| hash_param = kwargs.pop('_hash_param') | |||||
| assert isinstance(hash_param, bool), "_hash_param can only be bool." | |||||
| else: | |||||
| hash_param = _hash_param | |||||
| if hash_param and cache_filepath is not None: # 尝试将parameter给hash一下 | |||||
| try: | |||||
| params = dict(inspect.getcallargs(func, *args, **kwargs)) | |||||
| if inspect.ismethod(func): # 如果是 method 的话第一个参数(一般就是 self )就不考虑了 | |||||
| first_key = next(iter(params.items())) | |||||
| params.pop(first_key) | |||||
| if len(params): | |||||
| # sort 一下防止顺序改变 | |||||
| params = {k: str(v) for k, v in sorted(params.items(), key=lambda item: item[0])} | |||||
| param_hash = cal_fn_hash_code(None, params)[:8] | |||||
| head, tail = os.path.split(cache_filepath) | |||||
| cache_filepath = os.path.join(head, param_hash + '_' + tail) | |||||
| except BaseException as e: | |||||
| logger.debug(f"Fail to add parameter hash to cache path, because of Exception:{e}") | |||||
| refresh_flag = True | refresh_flag = True | ||||
| new_hash_code = None | new_hash_code = None | ||||
| if check_hash: | if check_hash: | ||||
| new_hash_code = cal_fn_hash_code(func, fn_param) | |||||
| new_hash_code = cal_fn_hash_code(func, None) | |||||
| if cache_filepath is not None and refresh is False: | if cache_filepath is not None and refresh is False: | ||||
| # load data | # load data | ||||
| @@ -281,13 +307,13 @@ def cache_results(_cache_fp, _refresh=False, _verbose=1, _check_hash=True): | |||||
| logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | logger.info("Read cache from {} (Saved on {}).".format(cache_filepath, save_time)) | ||||
| if check_hash and old_hash_code != new_hash_code: | if check_hash and old_hash_code != new_hash_code: | ||||
| logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | logger.warning(f"The function `{func.__name__}` is different from its last cache (Save on {save_time}). The " | ||||
| f"difference may caused by the sourcecode change of the functions by this function.", | |||||
| f"difference may caused by the sourcecode change.", | |||||
| extra={'highlighter': ColorHighlighter('red')}) | extra={'highlighter': ColorHighlighter('red')}) | ||||
| refresh_flag = False | refresh_flag = False | ||||
| if refresh_flag: | if refresh_flag: | ||||
| if new_hash_code is None: | if new_hash_code is None: | ||||
| new_hash_code = cal_fn_hash_code(func, fn_param) | |||||
| new_hash_code = cal_fn_hash_code(func, None) | |||||
| results = func(*args, **kwargs) | results = func(*args, **kwargs) | ||||
| if cache_filepath is not None: | if cache_filepath is not None: | ||||
| if results is None: | if results is None: | ||||
| @@ -246,6 +246,106 @@ class TestCacheResults: | |||||
| rank_zero_rm('demo.pkl') | rank_zero_rm('demo.pkl') | ||||
| def remove_postfix(folder='.', post_fix='.pkl'): | |||||
| import os | |||||
| for f in os.listdir(folder): | |||||
| if os.path.isfile(f) and f.endswith(post_fix): | |||||
| os.remove(os.path.join(folder, f)) | |||||
| class TestCacheResultsWithParam: | |||||
| @pytest.mark.parametrize('_refresh', [True, False]) | |||||
| @pytest.mark.parametrize('_hash_param', [True, False]) | |||||
| @pytest.mark.parametrize('_verbose', [0, 1]) | |||||
| @pytest.mark.parametrize('_check_hash', [True, False]) | |||||
| def test_cache_save(self, _refresh, _hash_param, _verbose, _check_hash): | |||||
| cache_fp = 'demo.pkl' | |||||
| try: | |||||
| @cache_results(cache_fp, _refresh=_refresh, _hash_param=_hash_param, _verbose=_verbose, | |||||
| _check_hash=_check_hash) | |||||
| def demo(a=1): | |||||
| print("¥") | |||||
| return 1 | |||||
| res = demo() | |||||
| with Capturing() as output: | |||||
| res = demo(a=1) | |||||
| if _refresh is False: | |||||
| assert '¥' not in output[0] | |||||
| if _verbose is 0: | |||||
| assert 'read' not in output[0] | |||||
| with Capturing() as output: | |||||
| res = demo(1) | |||||
| if _refresh is False: | |||||
| assert '¥' not in output[0] | |||||
| with Capturing() as output: | |||||
| res = demo(a=2) | |||||
| if _hash_param is True: # 一定对不上,需要重新生成 | |||||
| assert '¥' in output[0] | |||||
| finally: | |||||
| remove_postfix('.') | |||||
| def test_cache_complex_param(self): | |||||
| cache_fp = 'demo.pkl' | |||||
| try: | |||||
| @cache_results(cache_fp, _refresh=False) | |||||
| def demo(*args, s=1, **kwargs): | |||||
| print("¥") | |||||
| return 1 | |||||
| res = demo(1,2,3, s=4, d=4) | |||||
| with Capturing() as output: | |||||
| res = demo(1,2,3,d=4, s=4) | |||||
| assert '¥' not in output[0] | |||||
| finally: | |||||
| remove_postfix('.') | |||||
| def test_wrapper_change(self): | |||||
| cache_fp = 'demo.pkl' | |||||
| test_type = 'wrapper_change' | |||||
| try: | |||||
| cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
| res = get_subprocess_results(cmd) | |||||
| assert "¥" in res | |||||
| cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
| res = get_subprocess_results(cmd) | |||||
| assert "¥" not in res | |||||
| assert 'Read' in res | |||||
| assert 'different' not in res | |||||
| finally: | |||||
| remove_postfix('.') | |||||
| def test_param_change(self): | |||||
| cache_fp = 'demo.pkl' | |||||
| test_type = 'param_change' | |||||
| try: | |||||
| cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0' | |||||
| res = get_subprocess_results(cmd) | |||||
| assert "¥" in res | |||||
| cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1' | |||||
| res = get_subprocess_results(cmd) | |||||
| assert "¥" in res | |||||
| assert 'Read' not in res | |||||
| finally: | |||||
| remove_postfix('.') | |||||
| def test_create_cache_dir(self): | |||||
| @cache_results('demo/demo.pkl') | |||||
| def cache(s): | |||||
| return 1, 2 | |||||
| try: | |||||
| results = cache(s=1) | |||||
| assert (1, 2) == results | |||||
| finally: | |||||
| import shutil | |||||
| shutil.rmtree('demo/') | |||||
| if __name__ == '__main__': | if __name__ == '__main__': | ||||
| import argparse | import argparse | ||||
| parser = argparse.ArgumentParser() | parser = argparse.ArgumentParser() | ||||
| @@ -294,3 +394,31 @@ if __name__ == '__main__': | |||||
| res = demo_func() | res = demo_func() | ||||
| if test_type == 'wrapper_change': | |||||
| if turn == 0: | |||||
| @cache_results(cache_fp, _refresh=True) | |||||
| def demo_wrapper_change(): | |||||
| print("¥") | |||||
| return 1 | |||||
| else: | |||||
| @cache_results(cache_fp, _refresh=False) | |||||
| def demo_wrapper_change(): | |||||
| print("¥") | |||||
| return 1 | |||||
| res = demo_wrapper_change() | |||||
| if test_type == 'param_change': | |||||
| if turn == 0: | |||||
| @cache_results(cache_fp, _refresh=False) | |||||
| def demo_param_change(): | |||||
| print("¥") | |||||
| return 1 | |||||
| else: | |||||
| @cache_results(cache_fp, _refresh=False) | |||||
| def demo_param_change(a=1): | |||||
| print("¥") | |||||
| return 1 | |||||
| res = demo_param_change() | |||||