From: @wan_xuelei Reviewed-by: @tangqunzhang,@wqtshg Signed-off-by:tags/v1.2.0
@@ -28,10 +28,9 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | |||||
kBinSizeUnit8 * kMByteSize, | kBinSizeUnit8 * kMByteSize, | ||||
kBinSizeUnit32 * kMByteSize, | kBinSizeUnit32 * kMByteSize, | ||||
kBinSizeUnit128 * kMByteSize, | kBinSizeUnit128 * kMByteSize, | ||||
kGByteSize, | |||||
kBinSizeUnit4 * kGByteSize, | |||||
kBinSizeUnit16 * kGByteSize, | |||||
kBinSizeUnit26 * kGByteSize}; | |||||
kBinSizeUnit256 * kMByteSize, | |||||
kBinSizeUnit512 * kMByteSize, | |||||
kGByteSize}; | |||||
static bool BlockComparator(const Block *left, const Block *right) { | static bool BlockComparator(const Block *left, const Block *right) { | ||||
if (left->size != right->size) { | if (left->size != right->size) { | ||||
@@ -63,7 +62,10 @@ size_t GetBinIndex(size_t size) { | |||||
size_t GetAllocationSize(size_t size) { | size_t GetAllocationSize(size_t size) { | ||||
size_t index = GetBinIndex(size); | size_t index = GetBinIndex(size); | ||||
return bin_ranges[index]; | |||||
if (bin_ranges[index] >= size) { | |||||
return bin_ranges[index]; | |||||
} | |||||
return kGByteSize * ((size + kGByteSize - 1) / kGByteSize); | |||||
} | } | ||||
/// | /// | ||||
@@ -119,6 +121,7 @@ void CachingAllocator::Finalize(uint32_t device_id) { | |||||
} | } | ||||
uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { | uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { | ||||
GELOGI("Start malloc pool memory, size = %zu, device id = %u", size, device_id); | |||||
uint8_t *ptr = nullptr; | uint8_t *ptr = nullptr; | ||||
size = GetBlockSize(size); | size = GetBlockSize(size); | ||||
Block *block = FindFreeBlock(size, org_ptr, device_id); | Block *block = FindFreeBlock(size, org_ptr, device_id); | ||||
@@ -253,6 +256,7 @@ Block *CachingAllocator::SplitBlock(Block *block, size_t size, BlockBin &bin, ui | |||||
} | } | ||||
Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | ||||
GELOGI("Try to extend cache. size = %zu, device id = %u", size, device_id); | |||||
auto memory_size = GetAllocationSize(size); | auto memory_size = GetAllocationSize(size); | ||||
const std::string purpose = "Memory for caching."; | const std::string purpose = "Memory for caching."; | ||||
auto memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | auto memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | ||||
@@ -36,17 +36,17 @@ namespace ge { | |||||
constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | ||||
constexpr size_t kBinSizeUnit4 = 4; | constexpr size_t kBinSizeUnit4 = 4; | ||||
constexpr size_t kBinSizeUnit8 = 8; | constexpr size_t kBinSizeUnit8 = 8; | ||||
constexpr size_t kBinSizeUnit16 = 16; | |||||
constexpr size_t kBinSizeUnit26 = 26; | |||||
constexpr size_t kBinSizeUnit32 = 32; | constexpr size_t kBinSizeUnit32 = 32; | ||||
constexpr size_t kBinSizeUnit128 = 128; | constexpr size_t kBinSizeUnit128 = 128; | ||||
constexpr size_t kBinSizeUnit256 = 256; | |||||
constexpr size_t kBinSizeUnit512 = 512; | |||||
constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold | |||||
constexpr double kSplitThreshold = 0.5; // split when malloc size <= small block size * kSpliThreshold | |||||
constexpr size_t kKByteSize = 1024; | constexpr size_t kKByteSize = 1024; | ||||
constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | constexpr size_t kMByteSize = 1048576; // 1024 * 1024 | ||||
constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | constexpr size_t kGByteSize = 1073741824; // 1024 * 1024 * 1024 | ||||
static const uint32_t kNumBins = 8; | |||||
static const uint32_t kNumBins = 7; | |||||
class MemoryAllocator; | class MemoryAllocator; | ||||
@@ -37,7 +37,7 @@ Status FlowCtrlPass::Run(ComputeGraphPtr compute_graph) { | |||||
return NOT_CHANGED; | return NOT_CHANGED; | ||||
} | } | ||||
GELOGI("FlowCtrl pass begin"); | |||||
GELOGI("FlowCtrl pass begin.graph is [%s]", compute_graph->GetName().c_str()); | |||||
bool graph_change = false; | bool graph_change = false; | ||||
// 1. Add FP/BP flow ctrl (big cycle) | // 1. Add FP/BP flow ctrl (big cycle) | ||||
for (auto &node : compute_graph->GetDirectNode()) { | for (auto &node : compute_graph->GetDirectNode()) { | ||||
@@ -458,7 +458,7 @@ Status NetOutputPass::Run(ge::ComputeGraphPtr graph) { | |||||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "Compute graph is null."); | ||||
return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
} | } | ||||
GELOGI("NetOutputPass Run."); | |||||
GELOGI("NetOutputPass Run.graph is [%s]", graph->GetName().c_str()); | |||||
NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT); | NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT); | ||||
// save user targets node | // save user targets node | ||||
SaveAndRemoveTargets(graph); | SaveAndRemoveTargets(graph); | ||||
@@ -27,12 +27,11 @@ | |||||
namespace ge { | namespace ge { | ||||
Status PrunePass::Run(ge::ComputeGraphPtr graph) { | Status PrunePass::Run(ge::ComputeGraphPtr graph) { | ||||
GELOGD("PrunePass Start"); | |||||
GELOGD("PrunePass Start, graph is [%s]", graph->GetName().c_str()); | |||||
if (graph == nullptr) { | if (graph == nullptr) { | ||||
GELOGE(GE_GRAPH_ISNULL, "input compute graph is NULL."); | GELOGE(GE_GRAPH_ISNULL, "input compute graph is NULL."); | ||||
return GE_GRAPH_ISNULL; | return GE_GRAPH_ISNULL; | ||||
} | } | ||||
std::vector<NodePtr> out_nodes; | std::vector<NodePtr> out_nodes; | ||||
std::unordered_set<NodePtr> nodes; | std::unordered_set<NodePtr> nodes; | ||||
for (NodePtr &node_ptr : graph->GetDirectNode()) { | for (NodePtr &node_ptr : graph->GetDirectNode()) { | ||||
@@ -42,7 +41,6 @@ Status PrunePass::Run(ge::ComputeGraphPtr graph) { | |||||
out_nodes.push_back(node_ptr); | out_nodes.push_back(node_ptr); | ||||
} | } | ||||
} | } | ||||
if (out_nodes.empty()) { | if (out_nodes.empty()) { | ||||
GELOGW("graph [%s] does not contain NETOUTPUT type node,no return value. Do nothing!", graph->GetName().c_str()); | GELOGW("graph [%s] does not contain NETOUTPUT type node,no return value. Do nothing!", graph->GetName().c_str()); | ||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
@@ -323,6 +323,8 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
node_item.NodeName().c_str()); | node_item.NodeName().c_str()); | ||||
} | } | ||||
// release workspace | |||||
context_->ReleaseWorkspace(); | |||||
// release inputs | // release inputs | ||||
for (int i = 0; i < context_->NumInputs(); ++i) { | for (int i = 0; i < context_->NumInputs(); ++i) { | ||||
context_->ReleaseInput(i); | context_->ReleaseInput(i); | ||||
@@ -36,10 +36,6 @@ TaskContext::TaskContext(GraphExecutionContext *execution_context, | |||||
TaskContext::~TaskContext() { | TaskContext::~TaskContext() { | ||||
GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | ||||
for (auto ws_addr : workspaces_) { | |||||
execution_context_->allocator->Deallocate(ws_addr); | |||||
} | |||||
// release output | // release output | ||||
for (int i = 0; i < NumOutputs(); ++i) { | for (int i = 0; i < NumOutputs(); ++i) { | ||||
auto output_tensor = MutableOutput(i); | auto output_tensor = MutableOutput(i); | ||||
@@ -49,6 +45,13 @@ TaskContext::~TaskContext() { | |||||
} | } | ||||
} | } | ||||
void TaskContext::ReleaseWorkspace() { | |||||
GELOGD("[%s] Start ReleaseWorkspace.", node_item_->NodeName().c_str()); | |||||
for (auto ws_addr : workspaces_) { | |||||
execution_context_->allocator->Deallocate(ws_addr); | |||||
} | |||||
} | |||||
std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | ||||
GraphExecutionContext *execution_context, | GraphExecutionContext *execution_context, | ||||
SubgraphContext *subgraph_context) { | SubgraphContext *subgraph_context) { | ||||
@@ -56,6 +56,7 @@ class TaskContext { | |||||
void ReleaseInputsAndOutputs(); | void ReleaseInputsAndOutputs(); | ||||
bool NeedCallback(); | bool NeedCallback(); | ||||
void ReleaseInput(int index); | void ReleaseInput(int index); | ||||
void ReleaseWorkspace(); | |||||
const TensorValue *GetInput(int index) const; | const TensorValue *GetInput(int index) const; | ||||
const TensorValue *GetOutput(int index) const; | const TensorValue *GetOutput(int index) const; | ||||
TensorValue *MutableOutput(int index); | TensorValue *MutableOutput(int index); | ||||
@@ -752,6 +752,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
"graph/build/mem_assigner_unittest.cc" | "graph/build/mem_assigner_unittest.cc" | ||||
"graph/preprocess/graph_preprocess_unittest.cc" | "graph/preprocess/graph_preprocess_unittest.cc" | ||||
"graph/manager/hcom_util_unittest.cc" | "graph/manager/hcom_util_unittest.cc" | ||||
"graph/manager/graph_caching_allocator_unittest.cc" | |||||
"session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
) | ) | ||||
@@ -0,0 +1,87 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <memory> | |||||
#include "graph/anchor.h" | |||||
#include "graph/attr_value.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | |||||
#include "omg/omg_inner_types.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/manager/graph_caching_allocator.h" | |||||
#include "graph/manager/graph_mem_allocator.h" | |||||
#undef protected | |||||
#undef private | |||||
using namespace std; | |||||
using namespace testing; | |||||
using namespace ge; | |||||
using domi::GetContext; | |||||
class UtestGraphCachingAllocatorTest : public testing::Test { | |||||
protected: | |||||
void SetUp() {} | |||||
void TearDown() { GetContext().out_nodes_map.clear(); } | |||||
}; | |||||
TEST_F(UtestGraphCachingAllocatorTest, initialize_success) { | |||||
std::vector<rtMemType_t> mem_type; | |||||
mem_type.push_back(RT_MEMORY_HBM); | |||||
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
MemManager::Instance().Finalize(); | |||||
} | |||||
TEST_F(UtestGraphCachingAllocatorTest, malloc_success) { | |||||
std::vector<rtMemType_t> mem_type; | |||||
mem_type.push_back(RT_MEMORY_HBM); | |||||
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
EXPECT_NE(nullptr, ptr); | |||||
MemManager::Instance().Finalize(); | |||||
} | |||||
TEST_F(UtestGraphCachingAllocatorTest, extend_malloc_success) { | |||||
std::vector<rtMemType_t> mem_type; | |||||
mem_type.push_back(RT_MEMORY_HBM); | |||||
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
EXPECT_NE(nullptr, ptr); | |||||
ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kBinSizeUnit32*kMByteSize); | |||||
EXPECT_NE(nullptr, ptr); | |||||
MemManager::Instance().Finalize(); | |||||
} | |||||
TEST_F(UtestGraphCachingAllocatorTest, malloc_statics) { | |||||
std::vector<rtMemType_t> mem_type; | |||||
mem_type.push_back(RT_MEMORY_HBM); | |||||
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS); | |||||
uint8_t *ptr = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kMByteSize); | |||||
EXPECT_NE(nullptr, ptr); | |||||
uint8_t *ptr1 = MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Malloc(kKByteSize); | |||||
EXPECT_NE(nullptr, ptr); | |||||
EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr), SUCCESS); | |||||
EXPECT_EQ(MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(ptr1), SUCCESS); | |||||
MemManager::Instance().CachingInstance(RT_MEMORY_HBM).FreeCachedBlocks(); | |||||
MemManager::Instance().Finalize(); | |||||
} |