Browse Source

!3234 [MD] fix bug when save tfrecord data

Merge pull request !3234 from liyong126/fix_save_tfrecod
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
fa7fa8a162
3 changed files with 60 additions and 6 deletions
  1. +15
    -5
      mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
  2. BIN
      tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord
  3. +45
    -1
      tests/ut/python/dataset/test_save_op.py

+ 15
- 5
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc View File

@@ -386,9 +386,14 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
}

TensorRow row;
std::unordered_map<std::string, int32_t> column_name_id_map =
iterator_->GetColumnNameMap(); // map of column name, id
bool first_loop = true; // build schema in first loop
std::unordered_map<std::string, int32_t> column_name_id_map;
for (auto el : iterator_->GetColumnNameMap()) {
std::string column_name = el.first;
std::transform(column_name.begin(), column_name.end(), column_name.begin(),
[](unsigned char c) { return ispunct(c) ? '_' : c; });
column_name_id_map[column_name] = el.second;
}
bool first_loop = true; // build schema in first loop
do {
json row_raw_data;
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
@@ -403,7 +408,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
std::vector<std::string> index_fields;
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
RETURN_IF_NOT_OK(s);
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
if (mindrecord::SUCCESS !=
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
}
mr_writer->SetShardHeader(mr_header);
first_loop = false;
}
@@ -423,7 +431,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
}
} while (!row.empty());
mr_writer->Commit();
mindrecord::ShardIndexGenerator::finalize(file_names);
if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
}
return Status::OK();
}



BIN
tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord View File


+ 45
- 1
tests/ut/python/dataset/test_save_op.py View File

@@ -16,6 +16,7 @@
This is the test module for saveOp.
"""
import os
from string import punctuation
import mindspore.dataset as ds
from mindspore import log as logger
from mindspore.mindrecord import FileWriter
@@ -24,7 +25,7 @@ import pytest

CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
FILES_NUM = 1
num_readers = 1

@@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):

with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
d1.save(CV_FILE_NAME2, 1, "tfrecord")


def cast_name(key):
"""
Cast schema names which containing special characters to valid names.
"""
special_symbols = set('{}{}'.format(punctuation, ' '))
special_symbols.remove('_')
new_key = ['_' if x in special_symbols else x for x in key]
casted_key = ''.join(new_key)
return casted_key


def test_case_07():
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
tf_data = []
for x in d1.create_dict_iterator():
tf_data.append(x)
d1.save(CV_FILE_NAME2, FILES_NUM)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
mr_data = []
for x in d2.create_dict_iterator():
mr_data.append(x)
count = 0
for x in tf_data:
for k, v in x.items():
if isinstance(v, np.ndarray):
assert (v == mr_data[count][cast_name(k)]).all()
else:
assert v == mr_data[count][cast_name(k)]
count += 1
assert count == 10

if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))

Loading…
Cancel
Save