You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_ms_dataset.py 4.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132
  1. import unittest
  2. from modelscope.models import Model
  3. from modelscope.msdatasets import MsDataset
  4. from modelscope.preprocessors import SequenceClassificationPreprocessor
  5. from modelscope.preprocessors.base import Preprocessor
  6. from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DownloadMode
  7. from modelscope.utils.test_utils import require_tf, require_torch, test_level
  8. class ImgPreprocessor(Preprocessor):
  9. def __init__(self, *args, **kwargs):
  10. super().__init__(*args, **kwargs)
  11. self.path_field = kwargs.pop('image_path', 'image_path')
  12. self.width = kwargs.pop('width', 'width')
  13. self.height = kwargs.pop('height', 'width')
  14. def __call__(self, data):
  15. import cv2
  16. image_path = data.get(self.path_field)
  17. if not image_path:
  18. return None
  19. img = cv2.imread(image_path)
  20. return {
  21. 'image':
  22. cv2.resize(img,
  23. (data.get(self.height, 128), data.get(self.width, 128)))
  24. }
  25. class MsDatasetTest(unittest.TestCase):
  26. @unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
  27. def test_coco(self):
  28. ms_ds_train = MsDataset.load(
  29. 'pets_small',
  30. namespace=DEFAULT_DATASET_NAMESPACE,
  31. download_mode=DownloadMode.FORCE_REDOWNLOAD,
  32. split='train')
  33. print(ms_ds_train.config_kwargs)
  34. assert next(iter(ms_ds_train.config_kwargs['split_config'].values()))
  35. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  36. def test_ms_csv_basic(self):
  37. ms_ds_train = MsDataset.load(
  38. 'afqmc_small', namespace='userxiaoming', split='train')
  39. print(next(iter(ms_ds_train)))
  40. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  41. def test_ds_basic(self):
  42. ms_ds_full = MsDataset.load(
  43. 'xcopa', subset_name='translation-et', namespace='damotest')
  44. ms_ds = MsDataset.load(
  45. 'xcopa',
  46. subset_name='translation-et',
  47. namespace='damotest',
  48. split='test')
  49. print(next(iter(ms_ds_full['test'])))
  50. print(next(iter(ms_ds)))
  51. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  52. @require_torch
  53. def test_to_torch_dataset_text(self):
  54. model_id = 'damo/bert-base-sst2'
  55. nlp_model = Model.from_pretrained(model_id)
  56. preprocessor = SequenceClassificationPreprocessor(
  57. nlp_model.model_dir,
  58. first_sequence='premise',
  59. second_sequence=None)
  60. ms_ds_train = MsDataset.load(
  61. 'xcopa',
  62. subset_name='translation-et',
  63. namespace='damotest',
  64. split='test')
  65. pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
  66. import torch
  67. dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
  68. print(next(iter(dataloader)))
  69. @unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
  70. @require_tf
  71. def test_to_tf_dataset_text(self):
  72. import tensorflow as tf
  73. tf.compat.v1.enable_eager_execution()
  74. model_id = 'damo/bert-base-sst2'
  75. nlp_model = Model.from_pretrained(model_id)
  76. preprocessor = SequenceClassificationPreprocessor(
  77. nlp_model.model_dir,
  78. first_sequence='premise',
  79. second_sequence=None)
  80. ms_ds_train = MsDataset.load(
  81. 'xcopa',
  82. subset_name='translation-et',
  83. namespace='damotest',
  84. split='test')
  85. tf_dataset = ms_ds_train.to_tf_dataset(
  86. batch_size=5,
  87. shuffle=True,
  88. preprocessors=preprocessor,
  89. drop_remainder=True)
  90. print(next(iter(tf_dataset)))
  91. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  92. @require_torch
  93. def test_to_torch_dataset_img(self):
  94. ms_image_train = MsDataset.load(
  95. 'fixtures_image_utils', namespace='damotest', split='test')
  96. pt_dataset = ms_image_train.to_torch_dataset(
  97. preprocessors=ImgPreprocessor(image_path='file'))
  98. import torch
  99. dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
  100. print(next(iter(dataloader)))
  101. @unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
  102. @require_tf
  103. def test_to_tf_dataset_img(self):
  104. import tensorflow as tf
  105. tf.compat.v1.enable_eager_execution()
  106. ms_image_train = MsDataset.load(
  107. 'fixtures_image_utils', namespace='damotest', split='test')
  108. tf_dataset = ms_image_train.to_tf_dataset(
  109. batch_size=5,
  110. shuffle=True,
  111. preprocessors=ImgPreprocessor(image_path='file'),
  112. drop_remainder=True,
  113. )
  114. print(next(iter(tf_dataset)))
  115. if __name__ == '__main__':
  116. unittest.main()