|
|
@@ -33,7 +33,7 @@ class ImgPreprocessor(Preprocessor): |
|
|
|
|
|
|
|
class PyDatasetTest(unittest.TestCase): |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
def test_ds_basic(self): |
|
|
|
ms_ds_full = PyDataset.load('squad') |
|
|
|
ms_ds_full_hf = hfdata.load_dataset('squad') |
|
|
@@ -49,7 +49,7 @@ class PyDatasetTest(unittest.TestCase): |
|
|
|
print(next(iter(ms_ds_train))) |
|
|
|
print(next(iter(ms_image_train))) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
@require_torch |
|
|
|
def test_to_torch_dataset_text(self): |
|
|
|
model_id = 'damo/bert-base-sst2' |
|
|
@@ -64,7 +64,7 @@ class PyDatasetTest(unittest.TestCase): |
|
|
|
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5) |
|
|
|
print(next(iter(dataloader))) |
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level') |
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level') |
|
|
|
@require_tf |
|
|
|
def test_to_tf_dataset_text(self): |
|
|
|
import tensorflow as tf |
|
|
|