From: @wan_xuelei Reviewed-by: @tangqunzhang,@wqtshg Signed-off-by:tags/v1.2.0
| @@ -28,10 +28,9 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | |||||
| kBinSizeUnit8 * kMByteSize, | kBinSizeUnit8 * kMByteSize, | ||||
| kBinSizeUnit32 * kMByteSize, | kBinSizeUnit32 * kMByteSize, | ||||
| kBinSizeUnit128 * kMByteSize, | kBinSizeUnit128 * kMByteSize, | ||||
| kGByteSize, | |||||
| kBinSizeUnit4 * kGByteSize, | |||||
| kBinSizeUnit16 * kGByteSize, | |||||
| kBinSizeUnit26 * kGByteSize}; | |||||
| kBinSizeUnit256 * kMByteSize, | |||||
| kBinSizeUnit512 * kMByteSize, | |||||
| kGByteSize}; | |||||
| static bool BlockComparator(const Block *left, const Block *right) { | static bool BlockComparator(const Block *left, const Block *right) { | ||||
| if (left->size != right->size) { | if (left->size != right->size) { | ||||
| @@ -63,7 +62,10 @@ size_t GetBinIndex(size_t size) { | |||||
| size_t GetAllocationSize(size_t size) { | size_t GetAllocationSize(size_t size) { | ||||
| size_t index = GetBinIndex(size); | size_t index = GetBinIndex(size); | ||||
| return bin_ranges[index]; | |||||
| if (bin_ranges[index] >= size) { | |||||
| return bin_ranges[index]; | |||||
| } | |||||
| return kGByteSize * ((size + kGByteSize - 1) / kGByteSize); | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -119,6 +121,7 @@ void CachingAllocator::Finalize(uint32_t device_id) { | |||||
| } | } | ||||
| uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { | uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { | ||||
| GELOGI("Start malloc pool memory, size = %zu, device id = %u", size, device_id); | |||||
| uint8_t *ptr = nullptr; | uint8_t *ptr = nullptr; | ||||
| size = GetBlockSize(size); | size = GetBlockSize(size); | ||||
| Block *block = FindFreeBlock(size, org_ptr, device_id); | Block *block = FindFreeBlock(size, org_ptr, device_id); | ||||
| @@ -253,6 +256,7 @@ Block *CachingAllocator::SplitBlock(Block *block, size_t size, BlockBin &bin, ui | |||||
| } | } | ||||
| Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | ||||
| GELOGI("Try to extend cache. size = %zu, device id = %u", size, device_id); | |||||
| auto memory_size = GetAllocationSize(size); | auto memory_size = GetAllocationSize(size); | ||||
| const std::string purpose = "Memory for caching."; | const std::string purpose = "Memory for caching."; | ||||
| auto memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | auto memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | ||||
| @@ -36,17 +36,17 @@ namespace ge { | |||||
| constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | ||||
| constexpr size_t kBinSizeUnit4 = 4; | constexpr size_t kBinSizeUnit4 = 4; | ||||
| constexpr size_t kBinSizeUnit8 = 8; | constexpr size_t kBinSizeUnit8 = 8; | ||||
| constexpr size_t kBinSizeUnit16 = 16; | |||||
| constexpr size_t kBinSizeUnit26 = 26; | |||||
| constexpr size_t kBinSizeUnit32 = 32; | constexpr size_t kBinSizeUnit32 = 32; | ||||
| constexpr size_t kBinSizeUnit128 = 128; | constexpr size_t kBinSizeUnit128 = 128; | ||||
| constexpr size_t kBinSizeUnit256 = 256; | |||||
| constexpr size_t kBinSizeUnit512 = 512; | |||||
| constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold | |||||
| constexpr double kSplitThreshold = 0.5; // split when malloc size <= small block size * kSpliThreshold | |||||
| constexpr size_t kKByteSize = 1024; | constexpr size_t kKByteSize = 1024; | ||||
| constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | ||||
| constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | ||||
| static const uint32_t kNumBins = 8; | |||||
| static const uint32_t kNumBins = 7; | |||||
| class MemoryAllocator; | class MemoryAllocator; | ||||
| @@ -37,7 +37,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| GELOGI("FlowCtrl pass begin"); | |||||
| GELOGI("FlowCtrl pass begin.graph is [%s]", compute_graph->GetName().c_str()); | |||||
| bool graph_change = false; | bool graph_change = false; | ||||
| // 1. Add FP/BP flow ctrl (big cycle) | // 1. Add FP/BP flow ctrl (big cycle) | ||||
| for (auto &node : compute_graph->GetDirectNode()) { | for (auto &node : compute_graph->GetDirectNode()) { | ||||
| @@ -458,7 +458,7 @@ Status NetOutputPass::Run(ge::ComputeGraphPtr graph) { | |||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); | ||||
| return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
| } | } | ||||
| GELOGI("NetOutputPass Run."); | |||||
| GELOGI("NetOutputPass Run.graph is [%s]", graph->GetName().c_str()); | |||||
| NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT); | NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT); | ||||
| // save user targets node | // save user targets node | ||||
| SaveAndRemoveTargets(graph); | SaveAndRemoveTargets(graph); | ||||
| @@ -27,12 +27,11 @@ | |||||
| namespace ge { | namespace ge { | ||||
| Status PrunePass::Run(ge::ComputeGraphPtr graph) { | Status PrunePass::Run(ge::ComputeGraphPtr graph) { | ||||
| GELOGD("PrunePass Start"); | |||||
| GELOGD("PrunePass Start, graph is [%s]", graph->GetName().c_str()); | |||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| GELOGE(GE_GRAPH_ISNULL, "input compute graph is NULL."); | GELOGE(GE_GRAPH_ISNULL, "input compute graph is NULL."); | ||||
| return GE_GRAPH_ISNULL; | return GE_GRAPH_ISNULL; | ||||
| } | } | ||||
| std::vector<NodePtr> out_nodes; | std::vector<NodePtr> out_nodes; | ||||
| std::unordered_set<NodePtr> nodes; | std::unordered_set<NodePtr> nodes; | ||||
| for (NodePtr &node_ptr : graph->GetDirectNode()) { | for (NodePtr &node_ptr : graph->GetDirectNode()) { | ||||
| @@ -42,7 +41,6 @@ Status PrunePass::Run(ge::ComputeGraphPtr graph) { | |||||
| out_nodes.push_back(node_ptr); | out_nodes.push_back(node_ptr); | ||||
| } | } | ||||
| } | } | ||||
| if (out_nodes.empty()) { | if (out_nodes.empty()) { | ||||
| GELOGW("graph [%s] does not contain NETOUTPUT type node,no return value. Do nothing!", graph->GetName().c_str()); | GELOGW("graph [%s] does not contain NETOUTPUT type node,no return value. Do nothing!", graph->GetName().c_str()); | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| @@ -323,6 +323,8 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| node_item.NodeName().c_str()); | node_item.NodeName().c_str()); | ||||
| } | } | ||||
| // release workspace | |||||
| context_->ReleaseWorkspace(); | |||||
| // release inputs | // release inputs | ||||
| for (int i = 0; i < context_->NumInputs(); ++i) { | for (int i = 0; i < context_->NumInputs(); ++i) { | ||||
| context_->ReleaseInput(i); | context_->ReleaseInput(i); | ||||
| @@ -36,10 +36,6 @@ TaskContext::TaskContext(GraphExecutionContext *execution_context, | |||||
| TaskContext::~TaskContext() { | TaskContext::~TaskContext() { | ||||
| GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | ||||
| for (auto ws_addr : workspaces_) { | |||||
| execution_context_->allocator->Deallocate(ws_addr); | |||||
| } | |||||
| // release output | // release output | ||||
| for (int i = 0; i < NumOutputs(); ++i) { | for (int i = 0; i < NumOutputs(); ++i) { | ||||
| auto output_tensor = MutableOutput(i); | auto output_tensor = MutableOutput(i); | ||||
| @@ -49,6 +45,13 @@ TaskContext::~TaskContext() { | |||||
| } | } | ||||
| } | } | ||||
| void TaskContext::ReleaseWorkspace() { | |||||
| GELOGD("[%s] Start ReleaseWorkspace.", node_item_->NodeName().c_str()); | |||||
| for (auto ws_addr : workspaces_) { | |||||
| execution_context_->allocator->Deallocate(ws_addr); | |||||
| } | |||||
| } | |||||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | ||||
| GraphExecutionContext *execution_context, | GraphExecutionContext *execution_context, | ||||
| SubgraphContext *subgraph_context) { | SubgraphContext *subgraph_context) { | ||||
| @@ -56,6 +56,7 @@ class TaskContext { | |||||
| void ReleaseInputsAndOutputs(); | void ReleaseInputsAndOutputs(); | ||||
| bool NeedCallback(); | bool NeedCallback(); | ||||
| void ReleaseInput(int index); | void ReleaseInput(int index); | ||||
| void ReleaseWorkspace(); | |||||
| const TensorValue *GetInput(int index) const; | const TensorValue *GetInput(int index) const; | ||||
| const TensorValue *GetOutput(int index) const; | const TensorValue *GetOutput(int index) const; | ||||
| TensorValue *MutableOutput(int index); | TensorValue *MutableOutput(int index); | ||||
| @@ -752,6 +752,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "graph/build/mem_assigner_unittest.cc" | "graph/build/mem_assigner_unittest.cc" | ||||
| "graph/preprocess/graph_preprocess_unittest.cc" | "graph/preprocess/graph_preprocess_unittest.cc" | ||||
| "graph/manager/hcom_util_unittest.cc" | "graph/manager/hcom_util_unittest.cc" | ||||
| "graph/manager/graph_caching_allocator_unittest.cc" | |||||
| "session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
| ) | ) | ||||
| @@ -0,0 +1,87 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <memory> | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/attr_value.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #undef protected | |||||
| #undef private | |||||
| using namespace std; | |||||
| using namespace testing; | |||||
| using namespace ge; | |||||
| using domi::GetContext; | |||||
| class UtestGraphCachingAllocatorTest : public testing::Test { | |||||
| protected: | |||||
| void SetUp() {} | |||||
| void TearDown() { GetContext().out_nodes_map.clear(); } | |||||
| }; | |||||
| TEST_F(UtestGraphCachingAllocatorTest, initialize_success) { | |||||
| std::vector<rtMemType_t> mem_type; | |||||
| mem_type.push_back(RT_MEMORY_HBM); | |||||
| EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
| MemManager::Instance().Finalize(); | |||||
| } | |||||
| TEST_F(UtestGraphCachingAllocatorTest, malloc_success) { | |||||
| std::vector<rtMemType_t> mem_type; | |||||
| mem_type.push_back(RT_MEMORY_HBM); | |||||
| EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
| uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| MemManager::Instance().Finalize(); | |||||
| } | |||||
| TEST_F(UtestGraphCachingAllocatorTest, extend_malloc_success) { | |||||
| std::vector<rtMemType_t> mem_type; | |||||
| mem_type.push_back(RT_MEMORY_HBM); | |||||
| EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
| uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kBinSizeUnit32*kMByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| MemManager::Instance().Finalize(); | |||||
| } | |||||
| TEST_F(UtestGraphCachingAllocatorTest, malloc_statics) { | |||||
| std::vector<rtMemType_t> mem_type; | |||||
| mem_type.push_back(RT_MEMORY_HBM); | |||||
| EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
| uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| uint8_t *ptr1 = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kKByteSize); | |||||
| EXPECT_NE(nullptr, ptr); | |||||
| EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr), SUCCESS); | |||||
| EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr1), SUCCESS); | |||||
| MemManager::Instance().CachingInstance(RT_MEMORY_HBM).FreeCachedBlocks(); | |||||
| MemManager::Instance().Finalize(); | |||||
| } | |||||