|
|
|
@@ -16,6 +16,7 @@ |
|
|
|
This is the test module for saveOp. |
|
|
|
""" |
|
|
|
import os |
|
|
|
from string import punctuation |
|
|
|
import mindspore.dataset as ds |
|
|
|
from mindspore import log as logger |
|
|
|
from mindspore.mindrecord import FileWriter |
|
|
|
@@ -24,7 +25,7 @@ import pytest |
|
|
|
|
|
|
|
CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" |
|
|
|
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" |
|
|
|
|
|
|
|
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord" |
|
|
|
FILES_NUM = 1 |
|
|
|
num_readers = 1 |
|
|
|
|
|
|
|
@@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file): |
|
|
|
|
|
|
|
with pytest.raises(Exception, match="tfrecord dataset format is not supported."): |
|
|
|
d1.save(CV_FILE_NAME2, 1, "tfrecord") |
|
|
|
|
|
|
|
|
|
|
|
def cast_name(key): |
|
|
|
""" |
|
|
|
Cast schema names which containing special characters to valid names. |
|
|
|
""" |
|
|
|
special_symbols = set('{}{}'.format(punctuation, ' ')) |
|
|
|
special_symbols.remove('_') |
|
|
|
new_key = ['_' if x in special_symbols else x for x in key] |
|
|
|
casted_key = ''.join(new_key) |
|
|
|
return casted_key |
|
|
|
|
|
|
|
|
|
|
|
def test_case_07(): |
|
|
|
if os.path.exists("{}".format(CV_FILE_NAME2)): |
|
|
|
os.remove("{}".format(CV_FILE_NAME2)) |
|
|
|
if os.path.exists("{}.db".format(CV_FILE_NAME2)): |
|
|
|
os.remove("{}.db".format(CV_FILE_NAME2)) |
|
|
|
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False) |
|
|
|
tf_data = [] |
|
|
|
for x in d1.create_dict_iterator(): |
|
|
|
tf_data.append(x) |
|
|
|
d1.save(CV_FILE_NAME2, FILES_NUM) |
|
|
|
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2, |
|
|
|
num_parallel_workers=num_readers, |
|
|
|
shuffle=False) |
|
|
|
mr_data = [] |
|
|
|
for x in d2.create_dict_iterator(): |
|
|
|
mr_data.append(x) |
|
|
|
count = 0 |
|
|
|
for x in tf_data: |
|
|
|
for k, v in x.items(): |
|
|
|
if isinstance(v, np.ndarray): |
|
|
|
assert (v == mr_data[count][cast_name(k)]).all() |
|
|
|
else: |
|
|
|
assert v == mr_data[count][cast_name(k)] |
|
|
|
count += 1 |
|
|
|
assert count == 10 |
|
|
|
|
|
|
|
if os.path.exists("{}".format(CV_FILE_NAME2)): |
|
|
|
os.remove("{}".format(CV_FILE_NAME2)) |
|
|
|
if os.path.exists("{}.db".format(CV_FILE_NAME2)): |
|
|
|
os.remove("{}.db".format(CV_FILE_NAME2)) |