Browse Source

del finish in FileReader

tags/v1.0.0
liyong 5 years ago
parent
commit
ac39c20f41
9 changed files with 43 additions and 78 deletions
  1. +0
    -1
      mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc
  2. +0
    -4
      mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h
  3. +13
    -17
      mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc
  4. +0
    -9
      mindspore/mindrecord/filereader.py
  5. +1
    -18
      mindspore/mindrecord/shardreader.py
  6. +17
    -17
      tests/ut/cpp/mindrecord/ut_shard_operator_test.cc
  7. +6
    -6
      tests/ut/cpp/mindrecord/ut_shard_reader_test.cc
  8. +5
    -5
      tests/ut/cpp/mindrecord/ut_shard_writer_test.cc
  9. +1
    -1
      tests/ut/python/mindrecord/test_mindrecord_base.py

+ 0
- 1
mindspore/ccsrc/minddata/mindrecord/common/shard_pybind.cc View File

@@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) {
.def("get_blob_fields", &ShardReader::GetBlobFields)
.def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) &
ShardReader::GetNextPy)
.def("finish", &ShardReader::Finish)
.def("close", &ShardReader::Close);
}



+ 0
- 4
mindspore/ccsrc/minddata/mindrecord/include/shard_reader.h View File

@@ -174,10 +174,6 @@ class ShardReader {
ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria,
const std::vector<std::string> &columns = std::vector<std::string>());

/// \brief join all created threads
/// \return MSRStatus the status of MSRStatus
MSRStatus Finish();

/// \brief return a batch, given that one is ready
/// \return a batch of images and image data
std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext();


+ 13
- 17
mindspore/ccsrc/minddata/mindrecord/io/shard_reader.cc View File

@@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() {
ShardReader::~ShardReader() { Close(); }

void ShardReader::Close() {
(void)Finish(); // interrupt reading and stop threads
{
std::lock_guard<std::mutex> lck(mtx_delivery_);
interrupt_ = true; // interrupt reading and stop threads
}
cv_delivery_.notify_all();

// Wait for all threads to finish
for (auto &i_thread : thread_set_) {
if (i_thread.joinable()) {
i_thread.join();
}
}

FileStreamsOperator();
}

@@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int,
return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b));
}

MSRStatus ShardReader::Finish() {
{
std::lock_guard<std::mutex> lck(mtx_delivery_);
interrupt_ = true;
}
cv_delivery_.notify_all();

// Wait for all threads to finish
for (auto &i_thread : thread_set_) {
if (i_thread.joinable()) {
i_thread.join();
}
}
return SUCCESS;
}

int64_t ShardReader::GetNumClasses(const std::string &category_field) {
auto shard_count = file_paths_.size();
auto index_fields = shard_header_->GetFields();


+ 0
- 9
mindspore/mindrecord/filereader.py View File

@@ -83,15 +83,6 @@ class FileReader:
yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema)
iterator = self._reader.get_next()

def finish(self):
"""
Stop reader worker.

Raises:
MRMFinishError: If failed to finish worker threads.
"""
return self._reader.finish()

def close(self):
"""Stop reader worker and close File."""
return self._reader.close()

+ 1
- 18
mindspore/mindrecord/shardreader.py View File

@@ -17,8 +17,7 @@ This module is to read data from mindrecord.
"""
import mindspore._c_mindrecord as ms
from mindspore import log as logger
from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError

from .common.exceptions import MRMOpenError, MRMLaunchError
__all__ = ['ShardReader']

class ShardReader:
@@ -102,22 +101,6 @@ class ShardReader:
"""
return self._reader.get_header()

def finish(self):
"""
stop the worker threads.

Returns:
MSRStatus, SUCCESS or FAILED.

Raises:
MRMFinishError: If failed to finish worker threads.
"""
ret = self._reader.finish()
if ret != ms.MSRStatus.SUCCESS:
logger.error("Failed to finish worker threads.")
raise MRMFinishError
return ret

def close(self):
"""close MindRecord File."""
self._reader.close()

+ 17
- 17
tests/ut/cpp/mindrecord/ut_shard_operator_test.cc View File

@@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= kSampleCount);
}

@@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 5);
}

@@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 10);
}

@@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) {
MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]);
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= 10);
}

@@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i == 20);
} // namespace mindrecord

@@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl;
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i == 6);
}

@@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardOperator, TestShardShuffle) {
@@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardOperator, TestShardSampleShuffle) {
@@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_LE(i, 35);
}

@@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_TRUE(i <= kSampleSize);
}

@@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) {
<< ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump());
i++;
}
dataset.Finish();
dataset.Close();
ASSERT_LE(i, 35);
}

@@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) {
auto y = compare_dataset.GetNext();
if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true;
}
dataset.Finish();
compare_dataset.Finish();
dataset.Close();
compare_dataset.Close();
ASSERT_TRUE(different);
}

@@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
@@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardOperator, TestShardCategorySample) {
@@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
ASSERT_EQ(category_no, 0);
ASSERT_TRUE(i <= kSampleSize);
}
@@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) {
category_no++;
category_no %= static_cast<int>(categories.size());
}
dataset.Finish();
dataset.Close();
ASSERT_EQ(category_no, 0);
ASSERT_TRUE(i <= kSampleSize);
}


+ 6
- 6
tests/ut/cpp/mindrecord/ut_shard_reader_test.cc View File

@@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) {
}
}
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardReader, TestShardReaderSample) {
@@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) {
}
}
}
dataset.Finish();
dataset.Close();
dataset.Close();
}

@@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) {
}
}
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
@@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) {
}
}
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) {
@@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) {
}
}
}
dataset.Finish();
dataset.Close();
}

TEST_F(TestShardReader, TestShardReaderDir) {
@@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) {
}
}
}
dataset.Finish();
dataset.Close();
}
} // namespace mindrecord
} // namespace mindspore

+ 5
- 5
tests/ut/cpp/mindrecord/ut_shard_writer_test.cc View File

@@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) {
}
}
}
dataset.Finish();
dataset.Close();
for (int i = 1; i <= 4; i++) {
string filename = std::string("./OneSample.shard0") + std::to_string(i);
string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db";
@@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) {
}
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();

for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
@@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) {
}
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db));
@@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) {
}
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db));
@@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) {
count++;
}
ASSERT_TRUE(count == 10);
dataset.Finish();
dataset.Close();
for (const auto &filename : file_names) {
auto filename_db = filename + ".db";
remove(common::SafeCStr(filename_db));


+ 1
- 1
tests/ut/python/mindrecord/test_mindrecord_base.py View File

@@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial():
count = count + 1
logger.info("#item{}: {}".format(index, x))
if count == 5:
reader.finish()
reader.close()
assert count == 5




Loading…
Cancel
Save