From 3791ee7ad2a1e4cc8f5586c7de138ef58a2db3db Mon Sep 17 00:00:00 2001 From: "mulin.lyh" Date: Sat, 29 Oct 2022 13:44:47 +0800 Subject: [PATCH] [to #45821936]fix: fix block user specify revision after release_datetime Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10572162 --- modelscope/hub/api.py | 11 ++- tests/hub/test_hub_revision_release_mode.py | 84 ++++++++++++++++++++- 2 files changed, 90 insertions(+), 5 deletions(-) diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 5923319d..dca6d099 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -382,10 +382,11 @@ class HubApi: logger.info('Model revision not specified, use default: %s in development mode' % revision) if revision not in branches and revision not in tags: raise NotExistError('The model: %s has no branch or tag : %s .' % revision) + logger.info('Development mode use revision: %s' % revision) else: - revisions = self.list_model_revisions( - model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) - if revision is None: + if revision is None: # user not specified revision, use latest revision before release time + revisions = self.list_model_revisions( + model_id, cutoff_timestamp=release_timestamp, use_cookies=False if cookies is None else cookies) if len(revisions) == 0: raise NoValidRevisionError('The model: %s has no valid revision!' % model_id) # tags (revisions) returned from backend are guaranteed to be ordered by create-time @@ -393,9 +394,13 @@ class HubApi: revision = revisions[0] logger.info('Model revision not specified, use the latest revision: %s' % revision) else: + # use user-specified revision + revisions = self.list_model_revisions( + model_id, cutoff_timestamp=current_timestamp, use_cookies=False if cookies is None else cookies) if revision not in revisions: raise NotExistError( 'The model: %s has no revision: %s !' % (model_id, revision)) + logger.info('Use user-specified model revision: %s' % revision) return revision def get_model_branches_and_tags( diff --git a/tests/hub/test_hub_revision_release_mode.py b/tests/hub/test_hub_revision_release_mode.py index 729a1861..73a0625e 100644 --- a/tests/hub/test_hub_revision_release_mode.py +++ b/tests/hub/test_hub_revision_release_mode.py @@ -115,7 +115,7 @@ class HubRevisionTest(unittest.TestCase): time.sleep(10) self.add_new_file_and_tag_to_repo() t2 = datetime.now().isoformat(sep=' ', timespec='seconds') - logger.info('Secnod time: %s' % t2) + logger.info('Second time: %s' % t2) # set release_datetime_backup = version.__release_datetime__ logger.info('Origin __release_datetime__: %s' @@ -142,6 +142,43 @@ class HubRevisionTest(unittest.TestCase): finally: version.__release_datetime__ = release_datetime_backup + def test_snapshot_download_revision_user_set_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Secnod time: %s' % t2) + # set + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + logger.info('Setting __release_datetime__ to: %s' % t1) + version.__release_datetime__ = t1 + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert not os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + with tempfile.TemporaryDirectory() as temp_cache_dir: + snapshot_path = snapshot_download( + self.model_id, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name)) + assert os.path.exists( + os.path.join(snapshot_path, download_model_file_name2)) + finally: + version.__release_datetime__ = release_datetime_backup + def test_file_download_revision(self): with mock.patch.dict(os.environ, self.modified_environ, clear=True): self.prepare_repo_data_and_tag() @@ -175,7 +212,6 @@ class HubRevisionTest(unittest.TestCase): self.model_id, download_model_file_name, cache_dir=temp_cache_dir) - print('Downloaded file path: %s' % file_path) assert os.path.exists(file_path) file_path = model_file_download( self.model_id, @@ -185,6 +221,50 @@ class HubRevisionTest(unittest.TestCase): finally: version.__release_datetime__ = release_datetime_backup + def test_file_download_revision_user_set_revision(self): + with mock.patch.dict(os.environ, self.modified_environ, clear=True): + self.prepare_repo_data_and_tag() + t1 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('First time stamp: %s' % t1) + time.sleep(10) + self.add_new_file_and_tag_to_repo() + t2 = datetime.now().isoformat(sep=' ', timespec='seconds') + logger.info('Second time: %s' % t2) + release_datetime_backup = version.__release_datetime__ + logger.info('Origin __release_datetime__: %s' + % version.__release_datetime__) + try: + version.__release_datetime__ = t1 + logger.info('Setting __release_datetime__ to: %s' % t1) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + with self.assertRaises(NotExistError): + model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision, + cache_dir=temp_cache_dir) + with tempfile.TemporaryDirectory() as temp_cache_dir: + file_path = model_file_download( + self.model_id, + download_model_file_name, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + file_path = model_file_download( + self.model_id, + download_model_file_name2, + revision=self.revision2, + cache_dir=temp_cache_dir) + assert os.path.exists(file_path) + finally: + version.__release_datetime__ = release_datetime_backup + if __name__ == '__main__': unittest.main()