Browse Source

!3234 [MD] fix bug when save tfrecord data

Merge pull request !3234 from liyong126/fix_save_tfrecod
tags/v0.6.0-beta
mindspore-ci-bot Gitee 5 years ago
parent
commit
fa7fa8a162
3 changed files with 60 additions and 6 deletions
  1. +15
    -5
      mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc
  2. BIN
      tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord
  3. +45
    -1
      tests/ut/python/dataset/test_save_op.py

+ 15
- 5
mindspore/ccsrc/minddata/dataset/api/de_pipeline.cc View File

@@ -386,9 +386,14 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
} }


TensorRow row; TensorRow row;
std::unordered_map<std::string, int32_t> column_name_id_map =
iterator_->GetColumnNameMap(); // map of column name, id
bool first_loop = true; // build schema in first loop
std::unordered_map<std::string, int32_t> column_name_id_map;
for (auto el : iterator_->GetColumnNameMap()) {
std::string column_name = el.first;
std::transform(column_name.begin(), column_name.end(), column_name.begin(),
[](unsigned char c) { return ispunct(c) ? '_' : c; });
column_name_id_map[column_name] = el.second;
}
bool first_loop = true; // build schema in first loop
do { do {
json row_raw_data; json row_raw_data;
std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data; std::map<std::string, std::unique_ptr<std::vector<uint8_t>>> row_bin_data;
@@ -403,7 +408,10 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
std::vector<std::string> index_fields; std::vector<std::string> index_fields;
s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields); s = FetchMetaFromTensorRow(column_name_id_map, row, &mr_json, &index_fields);
RETURN_IF_NOT_OK(s); RETURN_IF_NOT_OK(s);
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id);
if (mindrecord::SUCCESS !=
mindrecord::ShardHeader::initialize(&mr_header, mr_json, index_fields, blob_fields, mr_schema_id)) {
RETURN_STATUS_UNEXPECTED("Error: failed to initialize ShardHeader.");
}
mr_writer->SetShardHeader(mr_header); mr_writer->SetShardHeader(mr_header);
first_loop = false; first_loop = false;
} }
@@ -423,7 +431,9 @@ Status DEPipeline::SaveDataset(const std::vector<std::string> &file_names, const
} }
} while (!row.empty()); } while (!row.empty());
mr_writer->Commit(); mr_writer->Commit();
mindrecord::ShardIndexGenerator::finalize(file_names);
if (mindrecord::SUCCESS != mindrecord::ShardIndexGenerator::finalize(file_names)) {
RETURN_STATUS_UNEXPECTED("Error: failed to finalize ShardIndexGenerator.");
}
return Status::OK(); return Status::OK();
} }




BIN
tests/ut/data/mindrecord/testTFRecordData/dummy.tfrecord View File


+ 45
- 1
tests/ut/python/dataset/test_save_op.py View File

@@ -16,6 +16,7 @@
This is the test module for saveOp. This is the test module for saveOp.
""" """
import os import os
from string import punctuation
import mindspore.dataset as ds import mindspore.dataset as ds
from mindspore import log as logger from mindspore import log as logger
from mindspore.mindrecord import FileWriter from mindspore.mindrecord import FileWriter
@@ -24,7 +25,7 @@ import pytest


CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord" CV_FILE_NAME1 = "../data/mindrecord/testMindDataSet/temp.mindrecord"
CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord" CV_FILE_NAME2 = "../data/mindrecord/testMindDataSet/auto.mindrecord"
TFRECORD_FILES = "../data/mindrecord/testTFRecordData/dummy.tfrecord"
FILES_NUM = 1 FILES_NUM = 1
num_readers = 1 num_readers = 1


@@ -388,3 +389,46 @@ def test_case_06(add_and_remove_cv_file):


with pytest.raises(Exception, match="tfrecord dataset format is not supported."): with pytest.raises(Exception, match="tfrecord dataset format is not supported."):
d1.save(CV_FILE_NAME2, 1, "tfrecord") d1.save(CV_FILE_NAME2, 1, "tfrecord")


def cast_name(key):
"""
Cast schema names which containing special characters to valid names.
"""
special_symbols = set('{}{}'.format(punctuation, ' '))
special_symbols.remove('_')
new_key = ['_' if x in special_symbols else x for x in key]
casted_key = ''.join(new_key)
return casted_key


def test_case_07():
if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))
d1 = ds.TFRecordDataset(TFRECORD_FILES, shuffle=False)
tf_data = []
for x in d1.create_dict_iterator():
tf_data.append(x)
d1.save(CV_FILE_NAME2, FILES_NUM)
d2 = ds.MindDataset(dataset_file=CV_FILE_NAME2,
num_parallel_workers=num_readers,
shuffle=False)
mr_data = []
for x in d2.create_dict_iterator():
mr_data.append(x)
count = 0
for x in tf_data:
for k, v in x.items():
if isinstance(v, np.ndarray):
assert (v == mr_data[count][cast_name(k)]).all()
else:
assert v == mr_data[count][cast_name(k)]
count += 1
assert count == 10

if os.path.exists("{}".format(CV_FILE_NAME2)):
os.remove("{}".format(CV_FILE_NAME2))
if os.path.exists("{}.db".format(CV_FILE_NAME2)):
os.remove("{}.db".format(CV_FILE_NAME2))

Loading…
Cancel
Save