Browse Source

del finish in FileReader

tags/v1.0.0
liyong 5 years ago
parent
commit
ac39c20f41
9 changed files with 43 additions and 78 deletions
  1. +0
    -1
      mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc
  2. +0
    -4
      mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
  3. +13
    -17
      mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
  4. +0
    -9
      mindspore/mindrecord/filereader.py
  5. +1
    -18
      mindspore/mindrecord/shardreader.py
  6. +17
    -17
      tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
  7. +6
    -6
      tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
  8. +5
    -5
      tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
  9. +1
    -1
      tests/ut/python/mindrecord/test_mindrecord_base.py

+ 0
- 1
mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc View File

@@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) {
.def("get_blob_fields", &ShardReader::GetBlobFields) .def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
ShardReader::GetNextPy) ShardReader::GetNextPy)
.def("finish", &ShardReader::Finish)
.def("close", &ShardReader::Close); .def("close", &ShardReader::Close);
} }




+ 0
- 4
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h View File

@@ -174,10 +174,6 @@ class ShardReader {
ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns = std::vector<std::string>()); const std::vector<std::string> &columns = std::vector<std::string>());


/// \brief join all created threads
/// \return MSRStatus the status of MSRStatus
MSRStatus Finish();

/// \brief return a batch, given that one is ready /// \brief return a batch, given that one is ready
/// \return a batch of images and image data /// \return a batch of images and image data
std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext(); std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext();


+ 13
- 17
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc View File

@@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() {
ShardReader::~ShardReader() { Close(); } ShardReader::~ShardReader() { Close(); }


void ShardReader::Close() { void ShardReader::Close() {
(void)Finish(); // interrupt reading and stop threads
{
std::lock_guard<std::mutex> lck(mtx_delivery_);
interrupt_ = true; // interrupt reading and stop threads
}
cv_delivery_.notify_all();

// Wait for all threads to finish
for (auto &i_thread : thread_set_) {
if (i_thread.joinable()) {
i_thread.join();
}
}

FileStreamsOperator(); FileStreamsOperator();
} }


@@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int,
return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
} }


MSRStatus ShardReader::Finish() {
{
std::lock_guard<std::mutex> lck(mtx_delivery_);
interrupt_ = true;
}
cv_delivery_.notify_all();

// Wait for all threads to finish
for (auto &i_thread : thread_set_) {
if (i_thread.joinable()) {
i_thread.join();
}
}
return SUCCESS;
}

int64_t ShardReader::GetNumClasses(const std::string &category_field) { int64_t ShardReader::GetNumClasses(const std::string &category_field) {
auto shard_count = file_paths_.size(); auto shard_count = file_paths_.size();
auto index_fields = shard_header_->GetFields(); auto index_fields = shard_header_->GetFields();


+ 0
- 9
mindspore/mindrecord/filereader.py View File

@@ -83,15 +83,6 @@ class FileReader:
yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema)
iterator = self._reader.get_next() iterator = self._reader.get_next()


def finish(self):
"""
Stop reader worker.

Raises:
MRMFinishError: If failed to finish worker threads.
"""
return self._reader.finish()

def close(self): def close(self):
"""Stop reader worker and close File.""" """Stop reader worker and close File."""
return self._reader.close() return self._reader.close()

+ 1
- 18
mindspore/mindrecord/shardreader.py View File

@@ -17,8 +17,7 @@ This module is to read data from mindrecord.
""" """
import mindspore._c_mindrecord as ms import mindspore._c_mindrecord as ms
from mindspore import log as logger from mindspore import log as logger
from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError

from .common.exceptions import MRMOpenError, MRMLaunchError
__all__ = ['ShardReader'] __all__ = ['ShardReader']


class ShardReader: class ShardReader:
@@ -102,22 +101,6 @@ class ShardReader:
""" """
return self._reader.get_header() return self._reader.get_header()


def finish(self):
"""
stop the worker threads.

Returns:
MSRStatus, SUCCESS or FAILED.

Raises:
MRMFinishError: If failed to finish worker threads.
"""
ret = self._reader.finish()
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to finish worker threads.")
raise MRMFinishError
return ret

def close(self): def close(self):
"""close MindRecord File.""" """close MindRecord File."""
self._reader.close() self._reader.close()

+ 17
- 17
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc View File

@@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= kSampleCount); ASSERT_TRUE(i <= kSampleCount);
} }


@@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 5); ASSERT_TRUE(i <= 5);
} }


@@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 10); ASSERT_TRUE(i <= 10);
} }


@@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 10); ASSERT_TRUE(i <= 10);
} }


@@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i == 20); ASSERT_TRUE(i == 20);
} // namespace mindrecord } // namespace mindrecord


@@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i == 6); ASSERT_TRUE(i == 6);
} }


@@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
category_no++; category_no++;
category_no %= static_cast<int>(categories.size()); category_no %= static_cast<int>(categories.size());
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardOperator, TestShardShuffle) { TEST_F(TestShardOperator, TestShardShuffle) {
@@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardOperator, TestShardSampleShuffle) { TEST_F(TestShardOperator, TestShardSampleShuffle) {
@@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_LE(i, 35); ASSERT_LE(i, 35);
} }


@@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= kSampleSize); ASSERT_TRUE(i <= kSampleSize);
} }


@@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++; i++;
} }
dataset.Finish();
dataset.Close();
ASSERT_LE(i, 35); ASSERT_LE(i, 35);
} }


@@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
auto y = compare_dataset.GetNext(); auto y = compare_dataset.GetNext();
if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true; if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true;
} }
dataset.Finish();
compare_dataset.Finish();
dataset.Close();
compare_dataset.Close();
ASSERT_TRUE(different); ASSERT_TRUE(different);
} }


@@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
category_no++; category_no++;
category_no %= static_cast<int>(categories.size()); category_no %= static_cast<int>(categories.size());
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardOperator, TestShardCategoryShuffle2) { TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
@@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
category_no++; category_no++;
category_no %= static_cast<int>(categories.size()); category_no %= static_cast<int>(categories.size());
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardOperator, TestShardCategorySample) { TEST_F(TestShardOperator, TestShardCategorySample) {
@@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
category_no++; category_no++;
category_no %= static_cast<int>(categories.size()); category_no %= static_cast<int>(categories.size());
} }
dataset.Finish();
dataset.Close();
ASSERT_EQ(category_no, 0); ASSERT_EQ(category_no, 0);
ASSERT_TRUE(i <= kSampleSize); ASSERT_TRUE(i <= kSampleSize);
} }
@@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
category_no++; category_no++;
category_no %= static_cast<int>(categories.size()); category_no %= static_cast<int>(categories.size());
} }
dataset.Finish();
dataset.Close();
ASSERT_EQ(category_no, 0); ASSERT_EQ(category_no, 0);
ASSERT_TRUE(i <= kSampleSize); ASSERT_TRUE(i <= kSampleSize);
} }


+ 6
- 6
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc View File

@@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
} }
} }
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardReader, TestShardReaderSample) { TEST_F(TestShardReader, TestShardReaderSample) {
@@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
} }
} }
} }
dataset.Finish();
dataset.Close();
dataset.Close(); dataset.Close();
} }


@@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
} }
} }
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
@@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
} }
} }
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
@@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) {
} }
} }
} }
dataset.Finish();
dataset.Close();
} }


TEST_F(TestShardReader, TestShardReaderDir) { TEST_F(TestShardReader, TestShardReaderDir) {
@@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
} }
} }
} }
dataset.Finish();
dataset.Close();
} }
} // namespace mindrecord } // namespace mindrecord
} // namespace mindspore } // namespace mindspore

+ 5
- 5
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc View File

@@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
} }
} }
} }
dataset.Finish();
dataset.Close();
for (int i = 1; i <= 4; i++) { for (int i = 1; i <= 4; i++) {
string filename = std::string("./OneSample.shard0") + std::to_string(i); string filename = std::string("./OneSample.shard0") + std::to_string(i);
string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db"; string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db";
@@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
} }
} }
ASSERT_TRUE(count == 10); ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();


for (const auto &filename : file_names) { for (const auto &filename : file_names) {
auto filename_db = filename + ".db"; auto filename_db = filename + ".db";
@@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
} }
} }
ASSERT_TRUE(count == 10); ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) { for (const auto &filename : file_names) {
auto filename_db = filename + ".db"; auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db)); remove(common::SafeCStr(filename_db));
@@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
} }
} }
ASSERT_TRUE(count == 10); ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) { for (const auto &filename : file_names) {
auto filename_db = filename + ".db"; auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db)); remove(common::SafeCStr(filename_db));
@@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
count++; count++;
} }
ASSERT_TRUE(count == 10); ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) { for (const auto &filename : file_names) {
auto filename_db = filename + ".db"; auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db)); remove(common::SafeCStr(filename_db));


+ 1
- 1
tests/ut/python/mindrecord/test_mindrecord_base.py View File

@@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial():
count = count + 1 count = count + 1
logger.info("#item{}: {}".format(index, x)) logger.info("#item{}: {}".format(index, x))
if count == 5: if count == 5:
reader.finish()
reader.close()
assert count == 5 assert count == 5






Loading…
Cancel
Save