You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset_upload.py 5.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import shutil
  4. import tempfile
  5. import unittest
  6. import zipfile
  7. from modelscope.msdatasets import MsDataset
  8. from modelscope.msdatasets.utils.dataset_utils import list_dataset_objects
  9. from modelscope.utils import logger as logging
  10. from modelscope.utils.constant import DEFAULT_DATASET_REVISION, ModelFile
  11. from modelscope.utils.test_utils import test_level
  12. logger = logging.get_logger(__name__)
  13. KEY_EXTRACTED = 'extracted'
  14. class DatasetUploadTest(unittest.TestCase):
  15. def setUp(self):
  16. self.old_dir = os.getcwd()
  17. self.dataset_name = 'small_coco_for_test'
  18. self.dataset_file_name = self.dataset_name
  19. self.prepared_dataset_name = 'pets_small'
  20. self.token = os.getenv('TEST_UPLOAD_MS_TOKEN')
  21. error_msg = 'The modelscope token can not be empty, please set env variable: TEST_UPLOAD_MS_TOKEN'
  22. self.assertIsNotNone(self.token, msg=error_msg)
  23. from modelscope.hub.api import HubApi
  24. from modelscope.hub.api import ModelScopeConfig
  25. self.api = HubApi()
  26. self.api.login(self.token)
  27. # get user info
  28. self.namespace, _ = ModelScopeConfig.get_user_info()
  29. self.temp_dir = tempfile.mkdtemp()
  30. self.test_work_dir = os.path.join(self.temp_dir, self.dataset_name)
  31. self.test_meta_dir = os.path.join(self.test_work_dir, 'meta')
  32. if not os.path.exists(self.test_work_dir):
  33. os.makedirs(self.test_work_dir)
  34. def tearDown(self):
  35. os.chdir(self.old_dir)
  36. shutil.rmtree(self.temp_dir, ignore_errors=True)
  37. logger.info(
  38. f'Temporary directory {self.temp_dir} successfully removed!')
  39. @staticmethod
  40. def get_raw_downloaded_file_path(extracted_path):
  41. raw_downloaded_file_path = ''
  42. raw_data_dir = os.path.abspath(
  43. os.path.join(extracted_path, '../../..'))
  44. for root, dirs, files in os.walk(raw_data_dir):
  45. if KEY_EXTRACTED in dirs:
  46. for file in files:
  47. curr_file_path = os.path.join(root, file)
  48. if zipfile.is_zipfile(curr_file_path):
  49. raw_downloaded_file_path = curr_file_path
  50. return raw_downloaded_file_path
  51. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  52. def test_ds_upload(self):
  53. # Get the prepared data from hub, using default modelscope namespace
  54. ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train')
  55. config_res = ms_ds_train._hf_ds.config_kwargs
  56. extracted_path = config_res.get('split_config').get('train')
  57. raw_zipfile_path = self.get_raw_downloaded_file_path(extracted_path)
  58. MsDataset.upload(
  59. object_name=self.dataset_file_name + '.zip',
  60. local_file_path=raw_zipfile_path,
  61. dataset_name=self.dataset_name,
  62. namespace=self.namespace)
  63. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  64. def test_ds_upload_dir(self):
  65. ms_ds_train = MsDataset.load(self.prepared_dataset_name, split='train')
  66. config_train = ms_ds_train._hf_ds.config_kwargs
  67. extracted_path_train = config_train.get('split_config').get('train')
  68. MsDataset.upload(
  69. object_name='train',
  70. local_file_path=os.path.join(extracted_path_train,
  71. 'Pets/images/train'),
  72. dataset_name=self.dataset_name,
  73. namespace=self.namespace)
  74. MsDataset.upload(
  75. object_name='val',
  76. local_file_path=os.path.join(extracted_path_train,
  77. 'Pets/images/val'),
  78. dataset_name=self.dataset_name,
  79. namespace=self.namespace)
  80. objects = list_dataset_objects(
  81. hub_api=self.api,
  82. max_limit=-1,
  83. is_recursive=True,
  84. dataset_name=self.dataset_name,
  85. namespace=self.namespace,
  86. version=DEFAULT_DATASET_REVISION)
  87. logger.info(f'{len(objects)} objects have been uploaded: {objects}')
  88. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  89. def test_ds_download_dir(self):
  90. test_ds = MsDataset.load(self.dataset_name, self.namespace)
  91. assert test_ds.config_kwargs['split_config'].values()
  92. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  93. def test_ds_clone_meta(self):
  94. MsDataset.clone_meta(
  95. dataset_work_dir=self.test_meta_dir,
  96. dataset_id=os.path.join(self.namespace, self.dataset_name))
  97. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  98. def test_ds_upload_meta(self):
  99. # Clone dataset meta repo first.
  100. MsDataset.clone_meta(
  101. dataset_work_dir=self.test_meta_dir,
  102. dataset_id=os.path.join(self.namespace, self.dataset_name))
  103. with open(os.path.join(self.test_meta_dir, ModelFile.README),
  104. 'a') as f:
  105. f.write('\nThis is a line for unit test.')
  106. MsDataset.upload_meta(
  107. dataset_work_dir=self.test_meta_dir,
  108. commit_message='Update for unit test.')
  109. if __name__ == '__main__':
  110. unittest.main()