| @@ -28,10 +28,9 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | |||||
| kBinSizeUnit8 * kMByteSize, | kBinSizeUnit8 * kMByteSize, | ||||
| kBinSizeUnit32 * kMByteSize, | kBinSizeUnit32 * kMByteSize, | ||||
| kBinSizeUnit128 * kMByteSize, | kBinSizeUnit128 * kMByteSize, | ||||
| kGByteSize, | |||||
| kBinSizeUnit4 * kGByteSize, | |||||
| kBinSizeUnit16 * kGByteSize, | |||||
| kBinSizeUnit26 * kGByteSize}; | |||||
| kBinSizeUnit256 * kMByteSize, | |||||
| kBinSizeUnit512 * kMByteSize, | |||||
| kGByteSize}; | |||||
| static bool BlockComparator(const Block *left, const Block *right) { | static bool BlockComparator(const Block *left, const Block *right) { | ||||
| if (left->size != right->size) { | if (left->size != right->size) { | ||||
| @@ -63,7 +62,10 @@ size_t GetBinIndex(size_t size) { | |||||
| size_t GetAllocationSize(size_t size) { | size_t GetAllocationSize(size_t size) { | ||||
| size_t index = GetBinIndex(size); | size_t index = GetBinIndex(size); | ||||
| return bin_ranges[index]; | |||||
| if (bin_ranges[index] >= size) { | |||||
| return bin_ranges[index]; | |||||
| } | |||||
| return kGByteSize * ((size + kGByteSize - 1) / kGByteSize); | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -36,17 +36,17 @@ namespace ge { | |||||
| constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | ||||
| constexpr size_t kBinSizeUnit4 = 4; | constexpr size_t kBinSizeUnit4 = 4; | ||||
| constexpr size_t kBinSizeUnit8 = 8; | constexpr size_t kBinSizeUnit8 = 8; | ||||
| constexpr size_t kBinSizeUnit16 = 16; | |||||
| constexpr size_t kBinSizeUnit26 = 26; | |||||
| constexpr size_t kBinSizeUnit32 = 32; | constexpr size_t kBinSizeUnit32 = 32; | ||||
| constexpr size_t kBinSizeUnit128 = 128; | constexpr size_t kBinSizeUnit128 = 128; | ||||
| constexpr size_t kBinSizeUnit256 = 256; | |||||
| constexpr size_t kBinSizeUnit512 = 512; | |||||
| constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold | |||||
| constexpr double kSplitThreshold = 0.5; // split when malloc size <= small block size * kSpliThreshold | |||||
| constexpr size_t kKByteSize = 1024; | constexpr size_t kKByteSize = 1024; | ||||
| constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | ||||
| constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | ||||
| static const uint32_t kNumBins = 8; | |||||
| static const uint32_t kNumBins = 7; | |||||
| class MemoryAllocator; | class MemoryAllocator; | ||||
| @@ -323,6 +323,8 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| node_item.NodeName().c_str()); | node_item.NodeName().c_str()); | ||||
| } | } | ||||
| // release workspace | |||||
| context_->ReleaseWorkspace(); | |||||
| // release inputs | // release inputs | ||||
| for (int i = 0; i < context_->NumInputs(); ++i) { | for (int i = 0; i < context_->NumInputs(); ++i) { | ||||
| context_->ReleaseInput(i); | context_->ReleaseInput(i); | ||||
| @@ -36,10 +36,6 @@ TaskContext::TaskContext(GraphExecutionContext *execution_context, | |||||
| TaskContext::~TaskContext() { | TaskContext::~TaskContext() { | ||||
| GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | ||||
| for (auto ws_addr : workspaces_) { | |||||
| execution_context_->allocator->Deallocate(ws_addr); | |||||
| } | |||||
| // release output | // release output | ||||
| for (int i = 0; i < NumOutputs(); ++i) { | for (int i = 0; i < NumOutputs(); ++i) { | ||||
| auto output_tensor = MutableOutput(i); | auto output_tensor = MutableOutput(i); | ||||
| @@ -49,6 +45,13 @@ TaskContext::~TaskContext() { | |||||
| } | } | ||||
| } | } | ||||
| void TaskContext::ReleaseWorkspace() { | |||||
| GELOGD("[%s] Start ReleaseWorkspace.", node_item_->NodeName().c_str()); | |||||
| for (auto ws_addr : workspaces_) { | |||||
| execution_context_->allocator->Deallocate(ws_addr); | |||||
| } | |||||
| } | |||||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | ||||
| GraphExecutionContext *execution_context, | GraphExecutionContext *execution_context, | ||||
| SubgraphContext *subgraph_context) { | SubgraphContext *subgraph_context) { | ||||
| @@ -56,6 +56,7 @@ class TaskContext { | |||||
| void ReleaseInputsAndOutputs(); | void ReleaseInputsAndOutputs(); | ||||
| bool NeedCallback(); | bool NeedCallback(); | ||||
| void ReleaseInput(int index); | void ReleaseInput(int index); | ||||
| void ReleaseWorkspace(); | |||||
| const TensorValue *GetInput(int index) const; | const TensorValue *GetInput(int index) const; | ||||
| const TensorValue *GetOutput(int index) const; | const TensorValue *GetOutput(int index) const; | ||||
| TensorValue *MutableOutput(int index); | TensorValue *MutableOutput(int index); | ||||
| @@ -752,6 +752,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "graph/build/mem_assigner_unittest.cc" | "graph/build/mem_assigner_unittest.cc" | ||||
| "graph/preprocess/graph_preprocess_unittest.cc" | "graph/preprocess/graph_preprocess_unittest.cc" | ||||
| "graph/manager/hcom_util_unittest.cc" | "graph/manager/hcom_util_unittest.cc" | ||||
| "graph/manager/graph_caching_allocator_unittest.cc" | |||||
| "session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
| ) | ) | ||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <memory> | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| using namespace ge; | |||||
| using domi::GetContext; | |||||
| class UtestGraphCachingAllocatorTest : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() { GetContext().out_nodes_map.clear(); } | |||||
| }; | |||||
| TEST_F(UtestGraphCachingAllocatorTest, initialize_success) { | |||||
| std::vector<rtMemType_t> mem_type; | |||||
| mem_type.push_back(RT_MEMORY_HBM); | |||||
| EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
| MemManager::Instance().Finalize(); | |||||
| } | |||||
| TEST_F(UtestGraphCachingAllocatorTest, malloc_success) { | |||||
| std::vector<rtMemType_t> mem_type; | |||||
| mem_type.push_back(RT_MEMORY_HBM); | |||||
| EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
| uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| MemManager::Instance().Finalize(); | |||||
| } | |||||
| TEST_F(UtestGraphCachingAllocatorTest, malloc_statics) { | |||||
| std::vector<rtMemType_t> mem_type; | |||||
| mem_type.push_back(RT_MEMORY_HBM); | |||||
| EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
| uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| uint8_t *ptr1 = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kKByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr), SUCCESS); | |||||
| EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr1), SUCCESS); | |||||
| MemManager::Instance().CachingInstance(RT_MEMORY_HBM).FreeCachedBlocks(); | |||||
| MemManager::Instance().Finalize(); | |||||
| } | |||||