| @@ -93,7 +93,6 @@ void BindShardReader(const py::module *m) { | |||
| .def("get_blob_fields", &ShardReader::GetBlobFields) | |||
| .def("get_next", (std::vector<std::tuple<std::vector<std::vector<uint8_t>>, pybind11::object>>(ShardReader::*)()) & | |||
| ShardReader::GetNextPy) | |||
| .def("finish", &ShardReader::Finish) | |||
| .def("close", &ShardReader::Close); | |||
| } | |||
| @@ -174,10 +174,6 @@ class ShardReader { | |||
| ROW_GROUP_BRIEF ReadRowGroupCriteria(int group_id, int shard_id, const std::pair<std::string, std::string> &criteria, | |||
| const std::vector<std::string> &columns = std::vector<std::string>()); | |||
| /// \brief join all created threads | |||
| /// \return MSRStatus the status of MSRStatus | |||
| MSRStatus Finish(); | |||
| /// \brief return a batch, given that one is ready | |||
| /// \return a batch of images and image data | |||
| std::vector<std::tuple<std::vector<uint8_t>, json>> GetNext(); | |||
| @@ -239,7 +239,19 @@ void ShardReader::FileStreamsOperator() { | |||
| ShardReader::~ShardReader() { Close(); } | |||
| void ShardReader::Close() { | |||
| (void)Finish(); // interrupt reading and stop threads | |||
| { | |||
| std::lock_guard<std::mutex> lck(mtx_delivery_); | |||
| interrupt_ = true; // interrupt reading and stop threads | |||
| } | |||
| cv_delivery_.notify_all(); | |||
| // Wait for all threads to finish | |||
| for (auto &i_thread : thread_set_) { | |||
| if (i_thread.joinable()) { | |||
| i_thread.join(); | |||
| } | |||
| } | |||
| FileStreamsOperator(); | |||
| } | |||
| @@ -759,22 +771,6 @@ bool ResortRowGroups(std::tuple<int, int, int, int> a, std::tuple<int, int, int, | |||
| return std::get<1>(a) < std::get<1>(b) || (std::get<1>(a) == std::get<1>(b) && std::get<0>(a) < std::get<0>(b)); | |||
| } | |||
| MSRStatus ShardReader::Finish() { | |||
| { | |||
| std::lock_guard<std::mutex> lck(mtx_delivery_); | |||
| interrupt_ = true; | |||
| } | |||
| cv_delivery_.notify_all(); | |||
| // Wait for all threads to finish | |||
| for (auto &i_thread : thread_set_) { | |||
| if (i_thread.joinable()) { | |||
| i_thread.join(); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| int64_t ShardReader::GetNumClasses(const std::string &category_field) { | |||
| auto shard_count = file_paths_.size(); | |||
| auto index_fields = shard_header_->GetFields(); | |||
| @@ -83,15 +83,6 @@ class FileReader: | |||
| yield populate_data(raw, blob, self._columns, self._header.blob_fields, self._header.schema) | |||
| iterator = self._reader.get_next() | |||
| def finish(self): | |||
| """ | |||
| Stop reader worker. | |||
| Raises: | |||
| MRMFinishError: If failed to finish worker threads. | |||
| """ | |||
| return self._reader.finish() | |||
| def close(self): | |||
| """Stop reader worker and close File.""" | |||
| return self._reader.close() | |||
| @@ -17,8 +17,7 @@ This module is to read data from mindrecord. | |||
| """ | |||
| import mindspore._c_mindrecord as ms | |||
| from mindspore import log as logger | |||
| from .common.exceptions import MRMOpenError, MRMLaunchError, MRMFinishError | |||
| from .common.exceptions import MRMOpenError, MRMLaunchError | |||
| __all__ = ['ShardReader'] | |||
| class ShardReader: | |||
| @@ -102,22 +101,6 @@ class ShardReader: | |||
| """ | |||
| return self._reader.get_header() | |||
| def finish(self): | |||
| """ | |||
| stop the worker threads. | |||
| Returns: | |||
| MSRStatus, SUCCESS or FAILED. | |||
| Raises: | |||
| MRMFinishError: If failed to finish worker threads. | |||
| """ | |||
| ret = self._reader.finish() | |||
| if ret != ms.MSRStatus.SUCCESS: | |||
| logger.error("Failed to finish worker threads.") | |||
| raise MRMFinishError | |||
| return ret | |||
| def close(self): | |||
| """close MindRecord File.""" | |||
| self._reader.close() | |||
| @@ -73,7 +73,7 @@ TEST_F(TestShardOperator, TestShardSampleBasic) { | |||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_TRUE(i <= kSampleCount); | |||
| } | |||
| @@ -99,7 +99,7 @@ TEST_F(TestShardOperator, TestShardSampleWrongNumber) { | |||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_TRUE(i <= 5); | |||
| } | |||
| @@ -125,7 +125,7 @@ TEST_F(TestShardOperator, TestShardSampleRatio) { | |||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_TRUE(i <= 10); | |||
| } | |||
| @@ -151,7 +151,7 @@ TEST_F(TestShardOperator, TestShardSamplePartition) { | |||
| MS_LOG(INFO) << "index: " << i << ", filename: " << common::SafeCStr((std::get<1>(x[0]))["file_name"]); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_TRUE(i <= 10); | |||
| } | |||
| @@ -176,7 +176,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerBasic) { | |||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_TRUE(i == 20); | |||
| } // namespace mindrecord | |||
| @@ -202,7 +202,7 @@ TEST_F(TestShardOperator, TestShardPkSamplerNumClass) { | |||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()) << std::endl; | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_TRUE(i == 6); | |||
| } | |||
| @@ -238,7 +238,7 @@ TEST_F(TestShardOperator, TestShardCategory) { | |||
| category_no++; | |||
| category_no %= static_cast<int>(categories.size()); | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardOperator, TestShardShuffle) { | |||
| @@ -262,7 +262,7 @@ TEST_F(TestShardOperator, TestShardShuffle) { | |||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardOperator, TestShardSampleShuffle) { | |||
| @@ -287,7 +287,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffle) { | |||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_LE(i, 35); | |||
| } | |||
| @@ -314,7 +314,7 @@ TEST_F(TestShardOperator, TestShardShuffleSample) { | |||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_TRUE(i <= kSampleSize); | |||
| } | |||
| @@ -341,7 +341,7 @@ TEST_F(TestShardOperator, TestShardSampleShuffleSample) { | |||
| << ", label: " << common::SafeCStr((std::get<1>(x[0]))["label"].dump()); | |||
| i++; | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_LE(i, 35); | |||
| } | |||
| @@ -373,8 +373,8 @@ TEST_F(TestShardOperator, TestShardShuffleCompare) { | |||
| auto y = compare_dataset.GetNext(); | |||
| if ((std::get<1>(x[0]))["file_name"] != (std::get<1>(y[0]))["file_name"]) different = true; | |||
| } | |||
| dataset.Finish(); | |||
| compare_dataset.Finish(); | |||
| dataset.Close(); | |||
| compare_dataset.Close(); | |||
| ASSERT_TRUE(different); | |||
| } | |||
| @@ -409,7 +409,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle1) { | |||
| category_no++; | |||
| category_no %= static_cast<int>(categories.size()); | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardOperator, TestShardCategoryShuffle2) { | |||
| @@ -442,7 +442,7 @@ TEST_F(TestShardOperator, TestShardCategoryShuffle2) { | |||
| category_no++; | |||
| category_no %= static_cast<int>(categories.size()); | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardOperator, TestShardCategorySample) { | |||
| @@ -477,7 +477,7 @@ TEST_F(TestShardOperator, TestShardCategorySample) { | |||
| category_no++; | |||
| category_no %= static_cast<int>(categories.size()); | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_EQ(category_no, 0); | |||
| ASSERT_TRUE(i <= kSampleSize); | |||
| } | |||
| @@ -515,7 +515,7 @@ TEST_F(TestShardOperator, TestShardCategorySampleShuffle) { | |||
| category_no++; | |||
| category_no %= static_cast<int>(categories.size()); | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| ASSERT_EQ(category_no, 0); | |||
| ASSERT_TRUE(i <= kSampleSize); | |||
| } | |||
| @@ -67,7 +67,7 @@ TEST_F(TestShardReader, TestShardReaderGeneral) { | |||
| } | |||
| } | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardReader, TestShardReaderSample) { | |||
| @@ -90,7 +90,7 @@ TEST_F(TestShardReader, TestShardReaderSample) { | |||
| } | |||
| } | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| dataset.Close(); | |||
| } | |||
| @@ -110,7 +110,7 @@ TEST_F(TestShardReader, TestShardReaderEasy) { | |||
| } | |||
| } | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | |||
| @@ -131,7 +131,7 @@ TEST_F(TestShardReader, TestShardReaderColumnNotInIndex) { | |||
| } | |||
| } | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardReader, TestShardReaderColumnNotInSchema) { | |||
| @@ -161,7 +161,7 @@ TEST_F(TestShardReader, TestShardVersion) { | |||
| } | |||
| } | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| TEST_F(TestShardReader, TestShardReaderDir) { | |||
| @@ -192,7 +192,7 @@ TEST_F(TestShardReader, TestShardReaderConsumer) { | |||
| } | |||
| } | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| } | |||
| } // namespace mindrecord | |||
| } // namespace mindspore | |||
| @@ -74,7 +74,7 @@ TEST_F(TestShardWriter, TestShardWriterOneSample) { | |||
| } | |||
| } | |||
| } | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| for (int i = 1; i <= 4; i++) { | |||
| string filename = std::string("./OneSample.shard0") + std::to_string(i); | |||
| string db_name = std::string("./OneSample.shard0") + std::to_string(i) + ".db"; | |||
| @@ -775,7 +775,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberColumnInIndex) { | |||
| } | |||
| } | |||
| ASSERT_TRUE(count == 10); | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| for (const auto &filename : file_names) { | |||
| auto filename_db = filename + ".db"; | |||
| @@ -858,7 +858,7 @@ TEST_F(TestShardWriter, TestShardNoBlob) { | |||
| } | |||
| } | |||
| ASSERT_TRUE(count == 10); | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| for (const auto &filename : file_names) { | |||
| auto filename_db = filename + ".db"; | |||
| remove(common::SafeCStr(filename_db)); | |||
| @@ -952,7 +952,7 @@ TEST_F(TestShardWriter, TestShardReaderStringAndNumberNotColumnInIndex) { | |||
| } | |||
| } | |||
| ASSERT_TRUE(count == 10); | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| for (const auto &filename : file_names) { | |||
| auto filename_db = filename + ".db"; | |||
| remove(common::SafeCStr(filename_db)); | |||
| @@ -1060,7 +1060,7 @@ TEST_F(TestShardWriter, TestShardWriter10Sample40Shard) { | |||
| count++; | |||
| } | |||
| ASSERT_TRUE(count == 10); | |||
| dataset.Finish(); | |||
| dataset.Close(); | |||
| for (const auto &filename : file_names) { | |||
| auto filename_db = filename + ".db"; | |||
| remove(common::SafeCStr(filename_db)); | |||
| @@ -260,7 +260,7 @@ def test_cv_file_reader_partial_tutorial(): | |||
| count = count + 1 | |||
| logger.info("#item{}: {}".format(index, x)) | |||
| if count == 5: | |||
| reader.finish() | |||
| reader.close() | |||
| assert count == 5 | |||