Browse Source

[to #45821936]fix: fix block user specify revision after release_datetime

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10572162
master
mulin.lyh yingda.chen 2 years ago
parent
commit
3791ee7ad2
2 changed files with 90 additions and 5 deletions
  1. +8
    -3
      modelscope/hub/api.py
  2. +82
    -2
      tests/hub/test_hub_revision_release_mode.py

+ 8
- 3
modelscope/hub/api.py View File

@@ -382,10 +382,11 @@ class HubApi:
logger.info('Model revision not specified, use default: %s in development mode' % revision)
if revision not in branches and revision not in tags:
raise NotExistError('The model: %s has no branch or tag : %s .' % revision)
logger.info('Development mode use revision: %s' % revision)
else:
revisions = self.list_model_revisions(
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies)
if revision is None:
if revision is None: # user not specified revision, use latest revision before release time
revisions = self.list_model_revisions(
model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies)
if len(revisions) == 0:
raise NoValidRevisionError('The model: %s has no valid revision!' % model_id)
# tags (revisions) returned from backend are guaranteed to be ordered by create-time
@@ -393,9 +394,13 @@ class HubApi:
revision = revisions[0]
logger.info('Model revision not specified, use the latest revision: %s' % revision)
else:
# use user-specified revision
revisions = self.list_model_revisions(
model_id, cutoff_timestamp=current_timestamp, use_cookies=False if cookies is None else cookies)
if revision not in revisions:
raise NotExistError(
'The model: %s has no revision: %s !' % (model_id, revision))
logger.info('Use user-specified model revision: %s' % revision)
return revision

def get_model_branches_and_tags(


+ 82
- 2
tests/hub/test_hub_revision_release_mode.py View File

@@ -115,7 +115,7 @@ class HubRevisionTest(unittest.TestCase):
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Secnod time: %s' % t2)
logger.info('Second time: %s' % t2)
# set
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
@@ -142,6 +142,43 @@ class HubRevisionTest(unittest.TestCase):
finally:
version.__release_datetime__ = release_datetime_backup

def test_snapshot_download_revision_user_set_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time: %s' % t1)
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Secnod time: %s' % t2)
# set
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
% version.__release_datetime__)
try:
logger.info('Setting __release_datetime__ to: %s' % t1)
version.__release_datetime__ = t1
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,
revision=self.revision,
cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert not os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))
with tempfile.TemporaryDirectory() as temp_cache_dir:
snapshot_path = snapshot_download(
self.model_id,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name))
assert os.path.exists(
os.path.join(snapshot_path, download_model_file_name2))
finally:
version.__release_datetime__ = release_datetime_backup

def test_file_download_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
@@ -175,7 +212,6 @@ class HubRevisionTest(unittest.TestCase):
self.model_id,
download_model_file_name,
cache_dir=temp_cache_dir)
print('Downloaded file path: %s' % file_path)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id,
@@ -185,6 +221,50 @@ class HubRevisionTest(unittest.TestCase):
finally:
version.__release_datetime__ = release_datetime_backup

def test_file_download_revision_user_set_revision(self):
with mock.patch.dict(os.environ, self.modified_environ, clear=True):
self.prepare_repo_data_and_tag()
t1 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('First time stamp: %s' % t1)
time.sleep(10)
self.add_new_file_and_tag_to_repo()
t2 = datetime.now().isoformat(sep=' ', timespec='seconds')
logger.info('Second time: %s' % t2)
release_datetime_backup = version.__release_datetime__
logger.info('Origin __release_datetime__: %s'
% version.__release_datetime__)
try:
version.__release_datetime__ = t1
logger.info('Setting __release_datetime__ to: %s' % t1)
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
revision=self.revision,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
with self.assertRaises(NotExistError):
model_file_download(
self.model_id,
download_model_file_name2,
revision=self.revision,
cache_dir=temp_cache_dir)
with tempfile.TemporaryDirectory() as temp_cache_dir:
file_path = model_file_download(
self.model_id,
download_model_file_name,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
file_path = model_file_download(
self.model_id,
download_model_file_name2,
revision=self.revision2,
cache_dir=temp_cache_dir)
assert os.path.exists(file_path)
finally:
version.__release_datetime__ = release_datetime_backup


if __name__ == '__main__':
unittest.main()

Loading…
Cancel
Save