You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_upload.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. from modelscope.msdatasets import MsDataset
  8. from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects
  9. from modelscope.utils import logger as logging
  10. from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, DownloadMode,
  11. ModelFile)
  12. from modelscope.utils.test_utils import test_level
  13. logger = logging.get_logger(__name__)
  14. KEY_EXTRACTED = 'extracted'
  15. class DatasetUploadTest(unittest.TestCase):
  16. def setUp(self):
  17. self.old_dir = os.getcwd()
  18. self.dataset_name = 'small_coco_for_test'
  19. self.dataset_file_name = self.dataset_name
  20. self.prepared_dataset_name = 'pets_small'
  21. self.token = os.getenv('TEST_UPLOAD_MS_TOKEN')
  22. error_msg = 'The modelscope token can not be empty, please set env variable: TEST_UPLOAD_MS_TOKEN'
  23. self.assertIsNotNone(self.token, msg=error_msg)
  24. from modelscope.hub.api import HubApi
  25. from modelscope.hub.api import ModelScopeConfig
  26. self.api = HubApi()
  27. self.api.login(self.token)
  28. # get user info
  29. self.namespace, _ = ModelScopeConfig.get_user_info()
  30. self.temp_dir = tempfile.mkdtemp()
  31. self.test_work_dir = os.path.join(self.temp_dir, self.dataset_name)
  32. self.test_meta_dir = os.path.join(self.test_work_dir, 'meta')
  33. if not os.path.exists(self.test_work_dir):
  34. os.makedirs(self.test_work_dir)
  35. def tearDown(self):
  36. os.chdir(self.old_dir)
  37. shutil.rmtree(self.temp_dir, ignore_errors=True)
  38. logger.info(
  39. f'Temporary directory {self.temp_dir} successfully removed!')
  40. @staticmethod
  41. def get_raw_downloaded_file_path(extracted_path):
  42. raw_downloaded_file_path = ''
  43. raw_data_dir = os.path.abspath(
  44. os.path.join(extracted_path, '../../..'))
  45. for root, dirs, files in os.walk(raw_data_dir):
  46. if KEY_EXTRACTED in dirs:
  47. for file in files:
  48. curr_file_path = os.path.join(root, file)
  49. if zipfile.is_zipfile(curr_file_path):
  50. raw_downloaded_file_path = curr_file_path
  51. return raw_downloaded_file_path
  52. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  53. def test_ds_upload(self):
  54. # Get the prepared data from hub, using default modelscope namespace
  55. ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train')
  56. config_res = ms_ds_train._hf_ds.config_kwargs
  57. extracted_path = config_res.get('split_config').get('train')
  58. raw_zipfile_path = self.get_raw_downloaded_file_path(extracted_path)
  59. MsDataset.upload(
  60. object_name=self.dataset_file_name + '.zip',
  61. local_file_path=raw_zipfile_path,
  62. dataset_name=self.dataset_name,
  63. namespace=self.namespace)
  64. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  65. def test_ds_upload_dir(self):
  66. ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train')
  67. config_train = ms_ds_train._hf_ds.config_kwargs
  68. extracted_path_train = config_train.get('split_config').get('train')
  69. MsDataset.upload(
  70. object_name='train',
  71. local_file_path=os.path.join(extracted_path_train,
  72. 'Pets/images/train'),
  73. dataset_name=self.dataset_name,
  74. namespace=self.namespace)
  75. MsDataset.upload(
  76. object_name='val',
  77. local_file_path=os.path.join(extracted_path_train,
  78. 'Pets/images/val'),
  79. dataset_name=self.dataset_name,
  80. namespace=self.namespace)
  81. objects = list_dataset_objects(
  82. hub_api=self.api,
  83. max_limit=-1,
  84. is_recursive=True,
  85. dataset_name=self.dataset_name,
  86. namespace=self.namespace,
  87. version=DEFAULT_DATASET_REVISION)
  88. logger.info(f'{len(objects)} objects have been uploaded: {objects}')
  89. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  90. def test_ds_download_dir(self):
  91. test_ds = MsDataset.load(
  92. self.dataset_name,
  93. namespace=self.namespace,
  94. download_mode=DownloadMode.FORCE_REDOWNLOAD)
  95. assert test_ds.config_kwargs['split_config'].values()
  96. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  97. def test_ds_clone_meta(self):
  98. MsDataset.clone_meta(
  99. dataset_work_dir=self.test_meta_dir,
  100. dataset_id=os.path.join(self.namespace, self.dataset_name))
  101. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  102. def test_ds_upload_meta(self):
  103. # Clone dataset meta repo first.
  104. MsDataset.clone_meta(
  105. dataset_work_dir=self.test_meta_dir,
  106. dataset_id=os.path.join(self.namespace, self.dataset_name))
  107. with open(os.path.join(self.test_meta_dir, ModelFile.README),
  108. 'a') as f:
  109. f.write('\nThis is a line for unit test.')
  110. MsDataset.upload_meta(
  111. dataset_work_dir=self.test_meta_dir,
  112. commit_message='Update for unit test.')
  113. if __name__ == '__main__':
  114. unittest.main()