| @@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) { | |||||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | .def("get_blob_fields", &ShardReader::GetBlobFields) | ||||
| .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & | .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & | ||||
| ShardReader::GetNextPy) | ShardReader::GetNextPy) | ||||
| .def("finish", &ShardReader::Finish) | |||||
| .def("close", &ShardReader::Close); | .def("close", &ShardReader::Close); | ||||
| } | } | ||||
| @@ -174,10 +174,6 @@ class ShardReader { | |||||
| ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, | ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, | ||||
| const std::vector<std::string> &columns = std::vector<std::string>()); | const std::vector<std::string> &columns = std::vector<std::string>()); | ||||
| /// \brief join all created threads | |||||
| /// \return MSRStatus the status of MSRStatus | |||||
| MSRStatus Finish(); | |||||
| /// \brief return a batch, given that one is ready | /// \brief return a batch, given that one is ready | ||||
| /// \return a batch of images and image data | /// \return a batch of images and image data | ||||
| std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext(); | std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext(); | ||||
| @@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() { | |||||
| ShardReader::~ShardReader() { Close(); } | ShardReader::~ShardReader() { Close(); } | ||||
| void ShardReader::Close() { | void ShardReader::Close() { | ||||
| (void)Finish(); // interrupt reading and stop threads | |||||
| { | |||||
| std::lock_guard<std::mutex> lck(mtx_delivery_); | |||||
| interrupt_ = true; // interrupt reading and stop threads | |||||
| } | |||||
| cv_delivery_.notify_all(); | |||||
| // Wait for all threads to finish | |||||
| for (auto &i_thread : thread_set_) { | |||||
| if (i_thread.joinable()) { | |||||
| i_thread.join(); | |||||
| } | |||||
| } | |||||
| FileStreamsOperator(); | FileStreamsOperator(); | ||||
| } | } | ||||
| @@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int, | |||||
| return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); | return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); | ||||
| } | } | ||||
| MSRStatus ShardReader::Finish() { | |||||
| { | |||||
| std::lock_guard<std::mutex> lck(mtx_delivery_); | |||||
| interrupt_ = true; | |||||
| } | |||||
| cv_delivery_.notify_all(); | |||||
| // Wait for all threads to finish | |||||
| for (auto &i_thread : thread_set_) { | |||||
| if (i_thread.joinable()) { | |||||
| i_thread.join(); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| int64_t ShardReader::GetNumClasses(const std::string &category_field) { | int64_t ShardReader::GetNumClasses(const std::string &category_field) { | ||||
| auto shard_count = file_paths_.size(); | auto shard_count = file_paths_.size(); | ||||
| auto index_fields = shard_header_->GetFields(); | auto index_fields = shard_header_->GetFields(); | ||||
| @@ -83,15 +83,6 @@ class FileReader: | |||||
| yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) | yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) | ||||
| iterator = self._reader.get_next() | iterator = self._reader.get_next() | ||||
| def finish(self): | |||||
| """ | |||||
| Stop reader worker. | |||||
| Raises: | |||||
| MRMFinishError: If failed to finish worker threads. | |||||
| """ | |||||
| return self._reader.finish() | |||||
| def close(self): | def close(self): | ||||
| """Stop reader worker and close File.""" | """Stop reader worker and close File.""" | ||||
| return self._reader.close() | return self._reader.close() | ||||
| @@ -17,8 +17,7 @@ This module is to read data from mindrecord. | |||||
| """ | """ | ||||
| import mindspore._c_mindrecord as ms | import mindspore._c_mindrecord as ms | ||||
| from mindspore import log as logger | from mindspore import log as logger | ||||
| from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError | |||||
| from .common.exceptions import MRMOpenError, MRMLaunchError | |||||
| __all__ = ['ShardReader'] | __all__ = ['ShardReader'] | ||||
| class ShardReader: | class ShardReader: | ||||
| @@ -102,22 +101,6 @@ class ShardReader: | |||||
| """ | """ | ||||
| return self._reader.get_header() | return self._reader.get_header() | ||||
| def finish(self): | |||||
| """ | |||||
| stop the worker threads. | |||||
| Returns: | |||||
| MSRStatus, SUCCESS or FAILED. | |||||
| Raises: | |||||
| MRMFinishError: If failed to finish worker threads. | |||||
| """ | |||||
| ret = self._reader.finish() | |||||
| if ret != ms.MSRStatus.SUCCESS: | |||||
| logger.error("Failed to finish worker threads.") | |||||
| raise MRMFinishError | |||||
| return ret | |||||
| def close(self): | def close(self): | ||||
| """close MindRecord File.""" | """close MindRecord File.""" | ||||
| self._reader.close() | self._reader.close() | ||||
| @@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_TRUE(i <= kSampleCount); | ASSERT_TRUE(i <= kSampleCount); | ||||
| } | } | ||||
| @@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) { | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_TRUE(i <= 5); | ASSERT_TRUE(i <= 5); | ||||
| } | } | ||||
| @@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_TRUE(i <= 10); | ASSERT_TRUE(i <= 10); | ||||
| } | } | ||||
| @@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { | |||||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_TRUE(i <= 10); | ASSERT_TRUE(i <= 10); | ||||
| } | } | ||||
| @@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_TRUE(i == 20); | ASSERT_TRUE(i == 20); | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| @@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_TRUE(i == 6); | ASSERT_TRUE(i == 6); | ||||
| } | } | ||||
| @@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) { | |||||
| category_no++; | category_no++; | ||||
| category_no %= static_cast<int>(categories.size()); | category_no %= static_cast<int>(categories.size()); | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardOperator, TestShardShuffle) { | TEST_F(TestShardOperator, TestShardShuffle) { | ||||
| @@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) { | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardOperator, TestShardSampleShuffle) { | TEST_F(TestShardOperator, TestShardSampleShuffle) { | ||||
| @@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_LE(i, 35); | ASSERT_LE(i, 35); | ||||
| } | } | ||||
| @@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_TRUE(i <= kSampleSize); | ASSERT_TRUE(i <= kSampleSize); | ||||
| } | } | ||||
| @@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { | |||||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | ||||
| i++; | i++; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_LE(i, 35); | ASSERT_LE(i, 35); | ||||
| } | } | ||||
| @@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { | |||||
| auto y = compare_dataset.GetNext(); | auto y = compare_dataset.GetNext(); | ||||
| if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true; | if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true; | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| compare_dataset.Finish(); | |||||
| dataset.Close(); | |||||
| compare_dataset.Close(); | |||||
| ASSERT_TRUE(different); | ASSERT_TRUE(different); | ||||
| } | } | ||||
| @@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { | |||||
| category_no++; | category_no++; | ||||
| category_no %= static_cast<int>(categories.size()); | category_no %= static_cast<int>(categories.size()); | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardOperator, TestShardCategoryShuffle2) { | TEST_F(TestShardOperator, TestShardCategoryShuffle2) { | ||||
| @@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { | |||||
| category_no++; | category_no++; | ||||
| category_no %= static_cast<int>(categories.size()); | category_no %= static_cast<int>(categories.size()); | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardOperator, TestShardCategorySample) { | TEST_F(TestShardOperator, TestShardCategorySample) { | ||||
| @@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) { | |||||
| category_no++; | category_no++; | ||||
| category_no %= static_cast<int>(categories.size()); | category_no %= static_cast<int>(categories.size()); | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_EQ(category_no, 0); | ASSERT_EQ(category_no, 0); | ||||
| ASSERT_TRUE(i <= kSampleSize); | ASSERT_TRUE(i <= kSampleSize); | ||||
| } | } | ||||
| @@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { | |||||
| category_no++; | category_no++; | ||||
| category_no %= static_cast<int>(categories.size()); | category_no %= static_cast<int>(categories.size()); | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| ASSERT_EQ(category_no, 0); | ASSERT_EQ(category_no, 0); | ||||
| ASSERT_TRUE(i <= kSampleSize); | ASSERT_TRUE(i <= kSampleSize); | ||||
| } | } | ||||
| @@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardReader, TestShardReaderSample) { | TEST_F(TestShardReader, TestShardReaderSample) { | ||||
| @@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| dataset.Close(); | dataset.Close(); | ||||
| } | } | ||||
| @@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | ||||
| @@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { | TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { | ||||
| @@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| TEST_F(TestShardReader, TestShardReaderDir) { | TEST_F(TestShardReader, TestShardReaderDir) { | ||||
| @@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| } | } | ||||
| } // namespace mindrecord | } // namespace mindrecord | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| for (int i = 1; i <= 4; i++) { | for (int i = 1; i <= 4; i++) { | ||||
| string filename = std::string("./OneSample.shard0") + std::to_string(i); | string filename = std::string("./OneSample.shard0") + std::to_string(i); | ||||
| string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db"; | string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db"; | ||||
| @@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { | |||||
| } | } | ||||
| } | } | ||||
| ASSERT_TRUE(count == 10); | ASSERT_TRUE(count == 10); | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| for (const auto &filename : file_names) { | for (const auto &filename : file_names) { | ||||
| auto filename_db = filename + ".db"; | auto filename_db = filename + ".db"; | ||||
| @@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) { | |||||
| } | } | ||||
| } | } | ||||
| ASSERT_TRUE(count == 10); | ASSERT_TRUE(count == 10); | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| for (const auto &filename : file_names) { | for (const auto &filename : file_names) { | ||||
| auto filename_db = filename + ".db"; | auto filename_db = filename + ".db"; | ||||
| remove(common::SafeCStr(filename_db)); | remove(common::SafeCStr(filename_db)); | ||||
| @@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { | |||||
| } | } | ||||
| } | } | ||||
| ASSERT_TRUE(count == 10); | ASSERT_TRUE(count == 10); | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| for (const auto &filename : file_names) { | for (const auto &filename : file_names) { | ||||
| auto filename_db = filename + ".db"; | auto filename_db = filename + ".db"; | ||||
| remove(common::SafeCStr(filename_db)); | remove(common::SafeCStr(filename_db)); | ||||
| @@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { | |||||
| count++; | count++; | ||||
| } | } | ||||
| ASSERT_TRUE(count == 10); | ASSERT_TRUE(count == 10); | ||||
| dataset.Finish(); | |||||
| dataset.Close(); | |||||
| for (const auto &filename : file_names) { | for (const auto &filename : file_names) { | ||||
| auto filename_db = filename + ".db"; | auto filename_db = filename + ".db"; | ||||
| remove(common::SafeCStr(filename_db)); | remove(common::SafeCStr(filename_db)); | ||||
| @@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial(): | |||||
| count = count + 1 | count = count + 1 | ||||
| logger.info("#item{}: {}".format(index, x)) | logger.info("#item{}: {}".format(index, x)) | ||||
| if count == 5: | if count == 5: | ||||
| reader.finish() | |||||
| reader.close() | |||||
| assert count == 5 | assert count == 5 | ||||