Browse Source

更新cache_results的测试

tags/v1.0.0alpha
yh_cc 2 years ago
parent
commit
4ed012bf1a
2 changed files with 75 additions and 88 deletions
  1. +68
    -77
      tests/core/utils/test_cache_results.py
  2. +7
    -11
      tests/helpers/common/utils.py

+ 68
- 77
tests/core/utils/test_cache_results.py View File

@@ -1,29 +1,16 @@
import time
import os import os
import pytest import pytest
from subprocess import Popen, PIPE
import subprocess
from io import StringIO from io import StringIO
import sys import sys


from fastNLP.core.utils.cache_results import cache_results from fastNLP.core.utils.cache_results import cache_results
from tests.helpers.common.utils import check_time_elapse

from fastNLP.core import rank_zero_rm from fastNLP.core import rank_zero_rm




def get_subprocess_results(cmd): def get_subprocess_results(cmd):
pipe = Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE)
output, err = pipe.communicate()
if output:
output = output.decode('utf8')
else:
output = ''
if err:
err = err.decode('utf8')
else:
err = ''
res = output + err
return res
output = subprocess.check_output(cmd, shell=True)
return output.decode('utf8')




class Capturing(list): class Capturing(list):
@@ -48,12 +35,12 @@ class TestCacheResults:
try: try:
@cache_results(cache_fp) @cache_results(cache_fp)
def demo(): def demo():
time.sleep(1)
print("¥")
return 1 return 1

res = demo() res = demo()
with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo() res = demo()
assert '¥' not in output[0]


finally: finally:
rank_zero_rm(cache_fp) rank_zero_rm(cache_fp)
@@ -63,12 +50,13 @@ class TestCacheResults:
try: try:
@cache_results(cache_fp, _refresh=True) @cache_results(cache_fp, _refresh=True)
def demo(): def demo():
time.sleep(1.5)
print("¥")
return 1 return 1


res = demo() res = demo()
with check_time_elapse(1, op='ge'):
with Capturing() as output:
res = demo() res = demo()
assert '¥' in output[0]
finally: finally:
rank_zero_rm(cache_fp) rank_zero_rm(cache_fp)


@@ -77,19 +65,21 @@ class TestCacheResults:
try: try:
@cache_results(cache_fp) @cache_results(cache_fp)
def demo(): def demo():
time.sleep(2)
print('¥')
return 1 return 1


with check_time_elapse(1, op='gt'):
with Capturing() as output:
res = demo() res = demo()
assert '¥' in output[0]


@cache_results(cache_fp) @cache_results(cache_fp)
def demo(): def demo():
time.sleep(2)
print('¥')
return 1 return 1


with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo() res = demo()
assert '¥' not in output[0]
finally: finally:
rank_zero_rm('demo.pkl') rank_zero_rm('demo.pkl')


@@ -98,27 +88,28 @@ class TestCacheResults:
try: try:
@cache_results(cache_fp) @cache_results(cache_fp)
def demo(): def demo():
time.sleep(2)
print('¥')
return 1 return 1


with check_time_elapse(1, op='gt'):
with Capturing() as output:
res = demo() res = demo()
assert '¥' in output[0]


@cache_results(cache_fp) @cache_results(cache_fp)
def demo(): def demo():
time.sleep(1)
print('¥¥')
return 1 return 1


with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo()
assert 'is different from its last cache' in output[0]
with Capturing() as output:
res = demo()
assert 'different' in output[0]
assert '¥' not in output[0]


# 关闭check_hash应该不warning的 # 关闭check_hash应该不warning的
with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo(_check_hash=0)
assert 'is different from its last cache' not in output[0]
with Capturing() as output:
res = demo(_check_hash=0)
assert 'different' not in output[0]
assert '¥' not in output[0]


finally: finally:
rank_zero_rm('demo.pkl') rank_zero_rm('demo.pkl')
@@ -128,28 +119,29 @@ class TestCacheResults:
try: try:
@cache_results(cache_fp, _check_hash=False) @cache_results(cache_fp, _check_hash=False)
def demo(): def demo():
time.sleep(2)
print('¥')
return 1 return 1


with check_time_elapse(1, op='gt'):
res = demo()
with Capturing() as output:
res = demo(_check_hash=0)
assert '¥' in output[0]


@cache_results(cache_fp, _check_hash=False) @cache_results(cache_fp, _check_hash=False)
def demo(): def demo():
time.sleep(1)
print('¥¥')
return 1 return 1


# 默认不会check # 默认不会check
with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo()
assert 'is different from its last cache' not in output[0]
with Capturing() as output:
res = demo()
assert 'different' not in output[0]
assert '¥' not in output[0]


# check也可以 # check也可以
with check_time_elapse(1, op='lt'):
with Capturing() as output:
res = demo(_check_hash=True)
assert 'is different from its last cache' in output[0]
with Capturing() as output:
res = demo(_check_hash=True)
assert 'different' in output[0]
assert '¥' not in output[0]


finally: finally:
rank_zero_rm('demo.pkl') rank_zero_rm('demo.pkl')
@@ -159,22 +151,22 @@ class TestCacheResults:
cache_fp = 'demo.pkl' cache_fp = 'demo.pkl'
test_type = 'func_refer_fun_change' test_type = 'func_refer_fun_change'
try: try:
with check_time_elapse(3, op='gt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)

cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert "¥" in res
# 引用的function没有变化 # 引用的function没有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert 'Read cache from' in res
assert 'is different from its last cache' not in res
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert "¥" not in res

assert 'Read' in res
assert 'different' not in res


# 引用的function有变化 # 引用的function有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert 'is different from its last cache' in res
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert "¥" not in res
assert 'different' in res


finally: finally:
rank_zero_rm(cache_fp) rank_zero_rm(cache_fp)
@@ -184,22 +176,21 @@ class TestCacheResults:
cache_fp = 'demo.pkl' cache_fp = 'demo.pkl'
test_type = 'refer_class_method_change' test_type = 'refer_class_method_change'
try: try:
with check_time_elapse(3, op='gt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert "¥" in res


# 引用的class没有变化 # 引用的class没有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert 'Read cache from' in res
assert 'is different from its last cache' not in res

# 引用的class有变化
with check_time_elapse(2, op='lt'):
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert 'is different from its last cache' in res
cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 0'
res = get_subprocess_results(cmd)
assert 'Read' in res
assert 'different' not in res
assert "¥" not in res

cmd = f'python {__file__} --cache_fp {cache_fp} --test_type {test_type} --turn 1'
res = get_subprocess_results(cmd)
assert 'different' in res
assert "¥" not in res


finally: finally:
rank_zero_rm(cache_fp) rank_zero_rm(cache_fp)
@@ -278,8 +269,8 @@ if __name__ == '__main__':


@cache_results(cache_fp) @cache_results(cache_fp)
def demo_refer_other_func(): def demo_refer_other_func():
time.sleep(3)
b = demo() b = demo()
print("¥")
return b return b


res = demo_refer_other_func() res = demo_refer_other_func()
@@ -296,7 +287,7 @@ if __name__ == '__main__':
# pdb.set_trace() # pdb.set_trace()
@cache_results(cache_fp) @cache_results(cache_fp)
def demo_func(): def demo_func():
time.sleep(3)
print("¥")
b = demo.demo() b = demo.demo()
return b return b




+ 7
- 11
tests/helpers/common/utils.py View File

@@ -3,11 +3,11 @@ from contextlib import contextmanager




@contextmanager @contextmanager
def check_time_elapse(seconds, op='lt'):
def check_time_elapse(seconds:float, op='lt'):
""" """
检测某一段程序所花费的时间,是否 op 给定的seconds 检测某一段程序所花费的时间,是否 op 给定的seconds


:param int seconds:
:param seconds:
:param str op: :param str op:
:return: :return:
""" """
@@ -15,19 +15,15 @@ def check_time_elapse(seconds, op='lt'):
yield yield
end = time.time() end = time.time()
if op == 'lt': if op == 'lt':
assert end-start < seconds
assert end-start < seconds, (end-start, seconds)
elif op == 'gt': elif op == 'gt':
assert end-start > seconds
assert end-start > seconds, (end-start, seconds)
elif op == 'eq': elif op == 'eq':
assert end - start == seconds
assert end - start == seconds, (end-start, seconds)
elif op == 'le': elif op == 'le':
assert end - start <= seconds
assert end - start <= seconds, (end-start, seconds)
elif op == 'ge': elif op == 'ge':
assert end - start >= seconds
assert end - start >= seconds, (end-start, seconds)
else: else:
raise ValueError("Only supports lt,gt,eq,le,ge.") raise ValueError("Only supports lt,gt,eq,le,ge.")







Loading…
Cancel
Save