Browse Source

Support session scope memory

tags/v1.3.0
TangQunzhang 3 years ago
parent
commit
653b914521
4 changed files with 23 additions and 3 deletions
  1. +2
    -0
      ge/graph/manager/graph_mem_manager.cc
  2. +1
    -3
      ge/graph/manager/session_scope_mem_allocator.cc
  3. +1
    -0
      ge/graph/manager/session_scope_mem_allocator.h
  4. +19
    -0
      tests/ut/ge/graph/manager/session_scope_mem_allocator_unittest.cc

+ 2
- 0
ge/graph/manager/graph_mem_manager.cc View File

@@ -65,6 +65,7 @@ Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) {
return ret;
}
init_ = true;
memory_type_ = memory_type;
return SUCCESS;
}

@@ -90,6 +91,7 @@ void MemManager::Finalize() noexcept {
FinalizeAllocatorMap(host_allocator_map_);
FinalizeAllocatorMap(memory_allocator_map_);
init_ = false;
memory_type_.clear();
}

MemoryAllocator &MemManager::MemInstance(rtMemType_t memory_type) {


+ 1
- 3
ge/graph/manager/session_scope_mem_allocator.cc View File

@@ -65,9 +65,7 @@ Status SessionScopeMemAllocator::Free(uint64_t session_id, uint32_t device_id) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
auto it = allocated_memory_.find(session_id);
if (it == allocated_memory_.end()) {
REPORT_INNER_ERROR("E19999", "Param memory not allocated before, session_id:%lu device_id:%u, check invalid",
session_id, device_id);
GELOGE(PARAM_INVALID, "Invalid session_id");
GELOGW("Invalid session_id");
return ge::PARAM_INVALID;
}
allocated_memory_.erase(it);


+ 1
- 0
ge/graph/manager/session_scope_mem_allocator.h View File

@@ -53,6 +53,7 @@ class SessionScopeMemoryInfo {
}
size = other.size;
ptr = other.ptr;
return *this;
};

private:


+ 19
- 0
tests/ut/ge/graph/manager/session_scope_mem_allocator_unittest.cc View File

@@ -73,3 +73,22 @@ TEST_F(UtestSessionScopeMemAllocator, free_success) {
EXPECT_NE(SUCCESS, MemManager::Instance().SessionScopeMemInstance(RT_MEMORY_HBM).Free(0));
MemManager::Instance().Finalize();
}

TEST_F(UtestSessionScopeMemAllocator, free_success_session) {
std::vector<rtMemType_t> mem_type;
mem_type.push_back(RT_MEMORY_HBM);
mem_type.push_back(RT_MEMORY_P2P_DDR);
EXPECT_EQ(MemManager::Instance().Initialize(mem_type), SUCCESS);
uint8_t *ptr = MemManager::Instance().SessionScopeMemInstance(RT_MEMORY_HBM).Malloc(100, 0);
EXPECT_NE(nullptr, ptr);
ptr = MemManager::Instance().SessionScopeMemInstance(RT_MEMORY_HBM).Malloc(100, 0);
EXPECT_NE(nullptr, ptr);
for (auto memory_type : MemManager::Instance().GetAllMemoryType()) {
if (RT_MEMORY_P2P_DDR == memory_type) {
EXPECT_NE(MemManager::Instance().SessionScopeMemInstance(memory_type).Free(0), SUCCESS);
} else {
EXPECT_EQ(MemManager::Instance().SessionScopeMemInstance(memory_type).Free(0), SUCCESS);
}
}
MemManager::Instance().Finalize();
}

Loading…
Cancel
Save