You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ms_dataset.py 5.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import unittest
  3. from modelscope.models import Model
  4. from modelscope.msdatasets import MsDataset
  5. from modelscope.preprocessors import SequenceClassificationPreprocessor
  6. from modelscope.preprocessors.base import Preprocessor
  7. from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DownloadMode
  8. from modelscope.utils.test_utils import require_tf, require_torch, test_level
  9. class ImgPreprocessor(Preprocessor):
  10. def __init__(self, *args, **kwargs):
  11. super().__init__(*args, **kwargs)
  12. self.path_field = kwargs.pop('image_path', 'image_path')
  13. self.width = kwargs.pop('width', 'width')
  14. self.height = kwargs.pop('height', 'width')
  15. def __call__(self, data):
  16. import cv2
  17. image_path = data.get(self.path_field)
  18. if not image_path:
  19. return None
  20. img = cv2.imread(image_path)
  21. return {
  22. 'image':
  23. cv2.resize(img,
  24. (data.get(self.height, 128), data.get(self.width, 128)))
  25. }
  26. class MsDatasetTest(unittest.TestCase):
  27. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  28. def test_movie_scene_seg_toydata(self):
  29. ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train')
  30. print(ms_ds_train._hf_ds.config_kwargs)
  31. assert next(iter(ms_ds_train.config_kwargs['split_config'].values()))
  32. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  33. def test_coco(self):
  34. ms_ds_train = MsDataset.load(
  35. 'pets_small',
  36. namespace=DEFAULT_DATASET_NAMESPACE,
  37. download_mode=DownloadMode.FORCE_REDOWNLOAD,
  38. split='train')
  39. print(ms_ds_train.config_kwargs)
  40. assert next(iter(ms_ds_train.config_kwargs['split_config'].values()))
  41. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  42. def test_ms_csv_basic(self):
  43. ms_ds_train = MsDataset.load(
  44. 'clue', subset_name='afqmc',
  45. split='train').to_hf_dataset().select(range(5))
  46. print(next(iter(ms_ds_train)))
  47. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  48. def test_ds_basic(self):
  49. ms_ds_full = MsDataset.load(
  50. 'xcopa', subset_name='translation-et', namespace='damotest')
  51. ms_ds = MsDataset.load(
  52. 'xcopa',
  53. subset_name='translation-et',
  54. namespace='damotest',
  55. split='test')
  56. print(next(iter(ms_ds_full['test'])))
  57. print(next(iter(ms_ds)))
  58. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  59. @require_torch
  60. def test_to_torch_dataset_text(self):
  61. model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
  62. nlp_model = Model.from_pretrained(model_id)
  63. preprocessor = SequenceClassificationPreprocessor(
  64. nlp_model.model_dir,
  65. first_sequence='premise',
  66. second_sequence=None,
  67. padding='max_length')
  68. ms_ds_train = MsDataset.load(
  69. 'xcopa',
  70. subset_name='translation-et',
  71. namespace='damotest',
  72. split='test')
  73. pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
  74. import torch
  75. dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
  76. print(next(iter(dataloader)))
  77. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  78. @require_tf
  79. def test_to_tf_dataset_text(self):
  80. import tensorflow as tf
  81. tf.compat.v1.enable_eager_execution()
  82. model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
  83. nlp_model = Model.from_pretrained(model_id)
  84. preprocessor = SequenceClassificationPreprocessor(
  85. nlp_model.model_dir,
  86. first_sequence='premise',
  87. second_sequence=None)
  88. ms_ds_train = MsDataset.load(
  89. 'xcopa',
  90. subset_name='translation-et',
  91. namespace='damotest',
  92. split='test')
  93. tf_dataset = ms_ds_train.to_tf_dataset(
  94. batch_size=5,
  95. shuffle=True,
  96. preprocessors=preprocessor,
  97. drop_remainder=True)
  98. print(next(iter(tf_dataset)))
  99. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  100. @require_torch
  101. def test_to_torch_dataset_img(self):
  102. ms_image_train = MsDataset.load(
  103. 'fixtures_image_utils', namespace='damotest', split='test')
  104. pt_dataset = ms_image_train.to_torch_dataset(
  105. preprocessors=ImgPreprocessor(image_path='file'))
  106. import torch
  107. dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
  108. print(next(iter(dataloader)))
  109. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  110. @require_tf
  111. def test_to_tf_dataset_img(self):
  112. import tensorflow as tf
  113. tf.compat.v1.enable_eager_execution()
  114. ms_image_train = MsDataset.load(
  115. 'fixtures_image_utils', namespace='damotest', split='test')
  116. tf_dataset = ms_image_train.to_tf_dataset(
  117. batch_size=5,
  118. shuffle=True,
  119. preprocessors=ImgPreprocessor(image_path='file'),
  120. drop_remainder=True,
  121. )
  122. print(next(iter(tf_dataset)))
  123. if __name__ == '__main__':
  124. unittest.main()