Browse Source

!949 cache support

From: @chen_yemeng
Reviewed-by: 
Signed-off-by:
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
1ff223abd1
31 changed files with 468 additions and 89 deletions
  1. +1
    -0
      ge/CMakeLists.txt
  2. +1
    -0
      ge/common/types.cc
  3. +1
    -0
      ge/executor/CMakeLists.txt
  4. +1
    -0
      ge/executor/module.mk
  5. +1
    -0
      ge/ge_runner.mk
  6. +6
    -2
      ge/graph/build/memory/var_mem_assign_util.cc
  7. +32
    -12
      ge/graph/load/new_model_manager/model_utils.cc
  8. +9
    -0
      ge/graph/load/new_model_manager/model_utils.h
  9. +58
    -16
      ge/graph/manager/graph_var_manager.cc
  10. +25
    -4
      ge/graph/manager/graph_var_manager.h
  11. +4
    -0
      ge/graph/manager/rdma_pool_allocator.h
  12. +23
    -4
      ge/graph/partition/dynamic_shape_partition.cc
  13. +2
    -1
      ge/graph/partition/dynamic_shape_partition.h
  14. +32
    -6
      ge/graph/partition/stage_partition.cc
  15. +4
    -3
      ge/graph/passes/subgraph_pass.cc
  16. +3
    -0
      ge/host_cpu_engine/ops_kernel_store/op/host_op.cc
  17. +6
    -1
      ge/hybrid/executor/hybrid_model_async_executor.cc
  18. +6
    -1
      ge/hybrid/model/hybrid_model_builder.cc
  19. +87
    -27
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  20. +2
    -0
      ge/hybrid/node_executor/hccl/hccl_node_executor.h
  21. +1
    -3
      ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc
  22. +41
    -0
      ge/hybrid/node_executor/host_cpu/kernel/data_kernel.cc
  23. +42
    -0
      ge/hybrid/node_executor/host_cpu/kernel/data_kernel.h
  24. +1
    -1
      ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc
  25. +1
    -3
      ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc
  26. +1
    -3
      ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc
  27. +1
    -0
      inc/framework/common/types.h
  28. +4
    -2
      inc/framework/omg/parser/parser_types.h
  29. +1
    -0
      tests/ut/ge/CMakeLists.txt
  30. +70
    -0
      tests/ut/ge/graph/load/model_utils_unittest.cc
  31. +1
    -0
      third_party/fwkacllib/inc/runtime/mem.h

+ 1
- 0
ge/CMakeLists.txt View File

@@ -375,6 +375,7 @@ set(TRAIN_SRC_LIST
"hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" "hybrid/node_executor/host_cpu/kernel/variable_kernel.cc"
"hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" "hybrid/node_executor/host_cpu/kernel/assign_kernel.cc"
"hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" "hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc"
"hybrid/node_executor/host_cpu/kernel/data_kernel.cc"
"hybrid/node_executor/controlop/control_op_executor.cc" "hybrid/node_executor/controlop/control_op_executor.cc"
"hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc"
"hybrid/node_executor/hccl/hccl_node_executor.cc" "hybrid/node_executor/hccl/hccl_node_executor.cc"


+ 1
- 0
ge/common/types.cc View File

@@ -388,6 +388,7 @@ REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite");
REGISTER_OPTYPE_DEFINE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite");


REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign");
REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp");


+ 1
- 0
ge/executor/CMakeLists.txt View File

@@ -104,6 +104,7 @@ set(SRC_LIST
"../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc" "../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc"
"../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc" "../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc"
"../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc" "../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc"
"../hybrid/node_executor/host_cpu/kernel/data_kernel.cc"
"../hybrid/node_executor/controlop/control_op_executor.cc" "../hybrid/node_executor/controlop/control_op_executor.cc"
"../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" "../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc"
"../hybrid/node_executor/rts/rts_node_executor.cc" "../hybrid/node_executor/rts/rts_node_executor.cc"


+ 1
- 0
ge/executor/module.mk View File

@@ -95,6 +95,7 @@ local_ge_executor_src_files := \
../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \ ../hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \
../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \ ../hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \
../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \ ../hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \
../hybrid/node_executor/host_cpu/kernel/data_kernel.cc \
../hybrid/node_executor/controlop/control_op_executor.cc \ ../hybrid/node_executor/controlop/control_op_executor.cc \
../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ ../hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \
../hybrid/node_executor/rts/rts_node_executor.cc \ ../hybrid/node_executor/rts/rts_node_executor.cc \


+ 1
- 0
ge/ge_runner.mk View File

@@ -300,6 +300,7 @@ LIBGE_LOCAL_SRC_FILES := \
hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \ hybrid/node_executor/host_cpu/kernel/variable_kernel.cc \
hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \ hybrid/node_executor/host_cpu/kernel/assign_kernel.cc \
hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \ hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc \
hybrid/node_executor/host_cpu/kernel/data_kernel.cc \
hybrid/node_executor/controlop/control_op_executor.cc \ hybrid/node_executor/controlop/control_op_executor.cc \
hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \
hybrid/node_executor/hccl/hccl_node_executor.cc \ hybrid/node_executor/hccl/hccl_node_executor.cc \


+ 6
- 2
ge/graph/build/memory/var_mem_assign_util.cc View File

@@ -60,9 +60,14 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr
return FAILED); return FAILED);
ge::ConstGeTensorDescPtr tensor_desc = n->GetOpDesc()->GetOutputDescPtr(0); ge::ConstGeTensorDescPtr tensor_desc = n->GetOpDesc()->GetOutputDescPtr(0);
GE_CHECK_NOTNULL(tensor_desc); GE_CHECK_NOTNULL(tensor_desc);
rtMemType_t memory_type = RT_MEMORY_HBM;
uint32_t mem_type = 0;
if (AttrUtils::GetInt(n->GetOpDesc(), ATTR_OUTPUT_MEMORY_TYPE, mem_type) && (mem_type == 1)) {
memory_type = RT_MEMORY_RDMA_HBM;
}
if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) { if (!VarManager::Instance(compute_graph->GetSessionID())->IsVarExist(node_name, *tensor_desc)) {
GE_CHK_STATUS_RET( GE_CHK_STATUS_RET(
VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, RT_MEMORY_HBM));
VarManager::Instance(compute_graph->GetSessionID())->AssignVarMem(node_name, *tensor_desc, memory_type));
GE_IF_BOOL_EXEC(n->GetType() == VARIABLE, GE_IF_BOOL_EXEC(n->GetType() == VARIABLE,
GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID()))); GE_CHK_STATUS_RET(AssignData2Fp32Var(n, compute_graph->GetSessionID())));
GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID())
@@ -70,7 +75,6 @@ Status VarMemAssignUtil::AssignStaticMemory2Node(ge::ComputeGraphPtr &compute_gr
} }


uint8_t *dev_ptr = nullptr; uint8_t *dev_ptr = nullptr;
rtMemType_t memory_type = RT_MEMORY_HBM;
GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID()) GE_CHK_STATUS_RET(VarManager::Instance(compute_graph->GetSessionID())
->GetVarAddr(node_name, *tensor_desc, &dev_ptr, memory_type)); ->GetVarAddr(node_name, *tensor_desc, &dev_ptr, memory_type));
vector<int64_t> output_list = n->GetOpDesc()->GetOutputOffset(); vector<int64_t> output_list = n->GetOpDesc()->GetOutputOffset();


+ 32
- 12
ge/graph/load/new_model_manager/model_utils.cc View File

@@ -15,18 +15,10 @@
*/ */


#include "graph/load/new_model_manager/model_utils.h" #include "graph/load/new_model_manager/model_utils.h"

#include <string> #include <string>

#include "common/debug/log.h" #include "common/debug/log.h"
#include "common/op/ge_op_utils.h" #include "common/op/ge_op_utils.h"
#include "graph/debug/ge_attr_define.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "runtime/base.h"
#include "runtime/kernel.h"

#include "framework/common/debug/ge_log.h"
#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"


#define VALIDATE_MEM_RANGE(OP, SIZE, OFFSET) \ #define VALIDATE_MEM_RANGE(OP, SIZE, OFFSET) \
@@ -342,8 +334,8 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co
int64_t input_offset = v_input_offset[non_const_index]; int64_t input_offset = v_input_offset[non_const_index];
non_const_index++; non_const_index++;
GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset),
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, input_offset - model_param.logic_var_base);
uint8_t *variable_addr = model_param.var_base + input_offset - model_param.logic_var_base;
uint8_t *variable_addr = nullptr;
GE_CHK_STATUS_EXEC(GetVarAddr(model_param, op_desc, input_offset, variable_addr), return {});
v_input_data_addr.push_back(variable_addr); v_input_data_addr.push_back(variable_addr);
GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]",
model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr);
@@ -380,6 +372,34 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co
return v_input_data_addr; return v_input_data_addr;
} }


///
/// @ingroup ge
/// @brief Get variable address.
/// @return Status
///
Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset,
uint8_t *&var_addr) {
rtMemType_t mem_type = ge::VarManager::Instance(model_param.session_id)->GetVarMemType(offset);
switch (mem_type) {
case RT_MEMORY_RDMA_HBM:
if (offset < 0) {
GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", reinterpret_cast<uint8_t *>(offset));
return PARAM_INVALID;
}
var_addr = reinterpret_cast<uint8_t *>(offset);
break;
case RT_MEMORY_HBM:
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, offset - model_param.logic_var_base);
var_addr = model_param.var_base + offset - model_param.logic_var_base;
break;
default:
GELOGE(PARAM_INVALID, "unsupported memory type %u", mem_type);
return PARAM_INVALID;
}
GE_CHECK_NOTNULL(var_addr);
return SUCCESS;
}

/// ///
/// @ingroup ge /// @ingroup ge
/// @brief Get output data address. /// @brief Get output data address.
@@ -405,8 +425,8 @@ vector<void *> ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C
} }
for (size_t i = 0; i < outputs_size; ++i) { for (size_t i = 0; i < outputs_size; ++i) {
GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]),
VALIDATE_MEM_RANGE(op_desc, model_param.var_size, v_output_offset[i] - model_param.logic_var_base);
uint8_t *variable_addr = model_param.var_base + v_output_offset[i] - model_param.logic_var_base;
uint8_t *variable_addr = nullptr;
GE_CHK_STATUS_EXEC(GetVarAddr(model_param, op_desc, v_output_offset[i], variable_addr), return {});
v_output_data_addr.push_back(variable_addr); v_output_data_addr.push_back(variable_addr);
GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[V] name[%s] output[%zu] memaddr[%p]", GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[V] name[%s] output[%zu] memaddr[%p]",
model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr);


+ 9
- 0
ge/graph/load/new_model_manager/model_utils.h View File

@@ -107,6 +107,15 @@ class ModelUtils {
/// @return Status /// @return Status
/// ///
static Status GetRtAddress(const RuntimeParam &model_param, uintptr_t logic_addr, uint8_t *&mem_addr); static Status GetRtAddress(const RuntimeParam &model_param, uintptr_t logic_addr, uint8_t *&mem_addr);

private:
///
/// @ingroup ge
/// @brief Get variable address.
/// @return Status
///
static Status GetVarAddr(const RuntimeParam &model_param, const ConstOpDescPtr &op_desc, int64_t offset,
uint8_t *&var_addr);
}; };
} // namespace ge } // namespace ge




+ 58
- 16
ge/graph/manager/graph_var_manager.cc View File

@@ -16,17 +16,10 @@


#include "graph/manager/graph_var_manager.h" #include "graph/manager/graph_var_manager.h"


#include <utility>

#include "common/l2_cache_optimize.h"
#include "common/types.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "ge/ge_api_types.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_mem_allocator.h"
#include "graph/manager/rdma_pool_allocator.h"
#include "graph/manager/trans_var_data_utils.h" #include "graph/manager/trans_var_data_utils.h"
#include "graph/utils/attr_utils.h"
#include "graph/utils/type_utils.h" #include "graph/utils/type_utils.h"


using std::map; using std::map;
@@ -37,7 +30,7 @@ namespace ge {
VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {} VarResource::VarResource(uint64_t session_id) : session_id_(session_id) {}


VarResource::~VarResource() { VarResource::~VarResource() {
var_offset_set_.clear();
var_offset_map_.clear();
var_addr_mgr_map_.clear(); var_addr_mgr_map_.clear();
cur_var_tensor_desc_map_.clear(); cur_var_tensor_desc_map_.clear();
var_broad_cast_info_.clear(); var_broad_cast_info_.clear();
@@ -91,8 +84,10 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen
std::string var_key = VarKey(var_name, tensor_desc); std::string var_key = VarKey(var_name, tensor_desc);
GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str()); GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str());
if (var_addr_mgr_map_.count(var_key) == 0) { if (var_addr_mgr_map_.count(var_key) == 0) {
uint64_t logic_address = VarManager::Instance(session_id_)->GetVarMemLogicBase() +
static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
uint64_t logic_address = static_cast<uint64_t>(reinterpret_cast<std::uintptr_t>(address));
if (memory_type != RT_MEMORY_RDMA_HBM) {
logic_address += VarManager::Instance(session_id_)->GetVarMemLogicBase();
}
GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(), GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(),
TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(),
TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str()); TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str());
@@ -102,7 +97,7 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen
var_addr_mgr.tensor_desc = tensor_desc; var_addr_mgr.tensor_desc = tensor_desc;
var_addr_mgr.memory_type = memory_type; var_addr_mgr.memory_type = memory_type;
var_addr_mgr_map_[var_key] = var_addr_mgr; var_addr_mgr_map_[var_key] = var_addr_mgr;
var_offset_set_.insert(logic_address);
var_offset_map_[logic_address] = memory_type;


return SUCCESS; return SUCCESS;
} }
@@ -211,7 +206,14 @@ ge::Status VarResource::SyncVarData(uint32_t graph_id, const std::string &var_na
return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr); return SyncVarData2BroadCast(graph_id, var_name, var_tensor_desc, base_ptr);
} }


bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_set_.count(offset) > 0; }
bool VarResource::IsVarAddr(const int64_t &offset) { return var_offset_map_.count(offset) > 0; }

rtMemType_t VarResource::GetVarMemType(const int64_t &offset) {
if (var_offset_map_.count(offset) > 0) {
return var_offset_map_[offset];
}
return RT_MEMORY_RESERVED;
}


VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) { VarTransRoad *VarResource::GetTransRoad(const std::string &var_name) {
auto iter = var_to_trans_road_.find(var_name); auto iter = var_to_trans_road_.find(var_name);
@@ -252,7 +254,19 @@ Status VarResource::SetAllocatedGraphId(const std::string &var_name, uint32_t gr


MemResource::MemResource() : total_size_(0), var_mem_size_(0) {} MemResource::MemResource() : total_size_(0), var_mem_size_(0) {}


Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) {
MemResource *MemResource::BuildMemResourceFromType(rtMemType_t mem_type) {
switch (mem_type) {
case RT_MEMORY_HBM:
return new (std::nothrow) HbmMemResource();
case RT_MEMORY_RDMA_HBM:
return new (std::nothrow) RdmaMemResource();
default:
return nullptr;
}
}

Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id,
size_t &mem_offset) {
size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize;
uint64_t real_size = size; uint64_t real_size = size;
total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize(); total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize();
@@ -282,6 +296,19 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin
return SUCCESS; return SUCCESS;
} }


Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) {
uint8_t *buffer = MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Malloc(size);
if (buffer == nullptr) {
GELOGE(MEMALLOC_FAILED, "Failed to malloc rdma memory for node %s, size = %llu", var_name.c_str(), size);
return MEMALLOC_FAILED;
}
address = reinterpret_cast<size_t>(reinterpret_cast<uintptr_t>(buffer));
var_mem_size_ += size;
GELOGI("[IMAS]AssignVarMem Set session_%llu name[%s] output[%d] addr to [%p] size[%llu].",
session_id, var_name.c_str(), 0, buffer, size);
return SUCCESS;
}

uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; } uint64_t MemResource::GetVarMemSize() const { return var_mem_size_; }


void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; }; void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; };
@@ -428,7 +455,7 @@ Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) {
MemResource *mem_resource = nullptr; MemResource *mem_resource = nullptr;
auto iter = mem_resource_map_.find(memory_type); auto iter = mem_resource_map_.find(memory_type);
if (iter == mem_resource_map_.end()) { if (iter == mem_resource_map_.end()) {
mem_resource = new (std::nothrow) MemResource();
mem_resource = MemResource::BuildMemResourceFromType(memory_type);
if (mem_resource == nullptr) { if (mem_resource == nullptr) {
GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
return ge::INTERNAL_ERROR; return ge::INTERNAL_ERROR;
@@ -465,7 +492,7 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen
MemResource *mem_resource = nullptr; MemResource *mem_resource = nullptr;
auto it = mem_resource_map_.find(memory_type); auto it = mem_resource_map_.find(memory_type);
if (it == mem_resource_map_.end()) { if (it == mem_resource_map_.end()) {
mem_resource = new (std::nothrow) MemResource();
mem_resource = MemResource::BuildMemResourceFromType(memory_type);
if (mem_resource == nullptr) { if (mem_resource == nullptr) {
GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type);
return ge::INTERNAL_ERROR; return ge::INTERNAL_ERROR;
@@ -629,6 +656,15 @@ bool VarManager::IsVarAddr(const int64_t &offset) {
return var_resource_->IsVarAddr(offset); return var_resource_->IsVarAddr(offset);
} }


rtMemType_t VarManager::GetVarMemType(const int64_t &offset) {
std::lock_guard<std::recursive_mutex> lock(mutex_);
if (var_resource_ == nullptr) {
GELOGW("VarManager has not been init.");
return RT_MEMORY_RESERVED;
}
return var_resource_->GetVarMemType(offset);
}

ge::Status VarManager::MallocVarMemory(size_t memory_size) { ge::Status VarManager::MallocVarMemory(size_t memory_size) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
uint8_t *var_mem_base = nullptr; uint8_t *var_mem_base = nullptr;
@@ -654,12 +690,18 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) {


uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) { uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (memory_type == RT_MEMORY_RDMA_HBM) {
return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr();
}
string memory_key = std::to_string(session_id_); string memory_key = std::to_string(session_id_);
return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key); return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key);
} }


uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) { uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) {
std::lock_guard<std::recursive_mutex> lock(mutex_); std::lock_guard<std::recursive_mutex> lock(mutex_);
if (memory_type == RT_MEMORY_RDMA_HBM) {
return logic_addr;
}
string mem_key = std::to_string(session_id_); string mem_key = std::to_string(session_id_);
uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key); uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key);
if (mem_base == nullptr) { if (mem_base == nullptr) {


+ 25
- 4
ge/graph/manager/graph_var_manager.h View File

@@ -158,13 +158,15 @@ class VarResource {


bool IsVarAddr(const int64_t &offset); bool IsVarAddr(const int64_t &offset);


rtMemType_t GetVarMemType(const int64_t &offset);

std::unordered_map<std::string, ge::GeTensorDesc> GetAllVarDesc() const { return cur_var_tensor_desc_map_; } std::unordered_map<std::string, ge::GeTensorDesc> GetAllVarDesc() const { return cur_var_tensor_desc_map_; }


private: private:
std::string VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); std::string VarKey(const std::string &var_name, const ge::GeTensorDesc &tensor_desc);


uint64_t session_id_; uint64_t session_id_;
std::unordered_set<uint64_t> var_offset_set_;
std::unordered_map<uint64_t, rtMemType_t> var_offset_map_;
std::unordered_map<std::string, VarAddrMgr> var_addr_mgr_map_; std::unordered_map<std::string, VarAddrMgr> var_addr_mgr_map_;
std::unordered_map<std::string, ge::GeTensorDesc> cur_var_tensor_desc_map_; std::unordered_map<std::string, ge::GeTensorDesc> cur_var_tensor_desc_map_;
std::unordered_map<std::string, std::vector<TransNodeInfo>> var_to_trans_road_; std::unordered_map<std::string, std::vector<TransNodeInfo>> var_to_trans_road_;
@@ -176,19 +178,36 @@ class VarResource {
class MemResource { class MemResource {
public: public:
MemResource(); MemResource();
~MemResource() = default;
virtual ~MemResource() = default;
static MemResource *BuildMemResourceFromType(rtMemType_t mem_type);


Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset);
virtual Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) = 0;


uint64_t GetVarMemSize() const; uint64_t GetVarMemSize() const;


void UpdateVarMemSize(int64_t mem_size); void UpdateVarMemSize(int64_t mem_size);


private:
protected:
uint64_t total_size_; uint64_t total_size_;
uint64_t var_mem_size_; uint64_t var_mem_size_;
}; };


class HbmMemResource : public MemResource {
public:
HbmMemResource() = default;
~HbmMemResource() override = default;

Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) override;
};

class RdmaMemResource : public MemResource {
public:
RdmaMemResource() = default;
~RdmaMemResource() override = default;

Status AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &address) override;
};

class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {
public: public:
static VarManager *Instance(uint64_t session_id); static VarManager *Instance(uint64_t session_id);
@@ -275,6 +294,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager {


bool IsVarAddr(const int64_t &offset); bool IsVarAddr(const int64_t &offset);


rtMemType_t GetVarMemType(const int64_t &offset);

uint8_t *GetVarMemoryBase(rtMemType_t memory_type); uint8_t *GetVarMemoryBase(rtMemType_t memory_type);


uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type); uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type);


+ 4
- 0
ge/graph/manager/rdma_pool_allocator.h View File

@@ -53,6 +53,10 @@ class RdmaPoolAllocator {


Status GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size); Status GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size);


uint8_t *GetRdmaBaseAddr() { return rdma_base_addr_; }

size_t GetRdmaMemSize() { return rdma_mem_size_; }

private: private:
void MergeBlocks(Block *dst, Block *src); void MergeBlocks(Block *dst, Block *src);




+ 23
- 4
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -213,6 +213,7 @@ std::string DynamicShapePartitioner::DebugString() const {
size_t data = 0; size_t data = 0;
size_t netoutput = 0; size_t netoutput = 0;
size_t is_inputnode = 0; size_t is_inputnode = 0;
size_t stage = 0;
std::stringstream ss; std::stringstream ss;
ss << "All unknown shape nodes:" << std::endl; ss << "All unknown shape nodes:" << std::endl;
for (const auto &node : unknown_shape_nodes_) { for (const auto &node : unknown_shape_nodes_) {
@@ -229,10 +230,13 @@ std::string DynamicShapePartitioner::DebugString() const {
netoutput++; netoutput++;
} else if (cluster->IsInputNode()) { } else if (cluster->IsInputNode()) {
is_inputnode++; is_inputnode++;
} else if (cluster->IsIndependent()) {
stage++;
} }
} }
ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known
<< ", unknown:" << unknown << ", netoutput:" << netoutput << ", is_inputnode:" << is_inputnode << std::endl;
<< ", unknown:" << unknown << ", netoutput:" << netoutput << ", is_inputnode:" << is_inputnode
<< ", stage:" << stage << std::endl;
for (const auto &cluster : unique_clusters_) { for (const auto &cluster : unique_clusters_) {
ss << " " << cluster->DebugString() << std::endl; ss << " " << cluster->DebugString() << std::endl;
} }
@@ -272,12 +276,15 @@ Status DynamicShapePartitioner::InitClusters() {
for (const auto &node : graph->GetDirectNode()) { for (const auto &node : graph->GetDirectNode()) {
Cluster::Type type = Cluster::DATA; Cluster::Type type = Cluster::DATA;
bool is_input = ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) && node->GetInNodes().empty(); bool is_input = ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) && node->GetInNodes().empty();
REQUIRE_NOT_NULL(node->GetOpDesc(), "op_desc is null");
if (node->GetType() == DATA) { if (node->GetType() == DATA) {
type = Cluster::DATA; type = Cluster::DATA;
} else if (is_input) { } else if (is_input) {
type = Cluster::INPUT_NODE; type = Cluster::INPUT_NODE;
} else if (node->GetType() == NETOUTPUT) { } else if (node->GetType() == NETOUTPUT) {
type = Cluster::NETOUTPUT; type = Cluster::NETOUTPUT;
} else if ((node->GetType() == PARTITIONEDCALL) && (node->GetOpDesc()->HasAttr(ATTR_STAGE_LEVEL))) {
type = Cluster::STAGE;
} else if (unknown_shape_nodes_.count(node) > 0) { } else if (unknown_shape_nodes_.count(node) > 0) {
type = Cluster::UNKNOWN_SHAPE; type = Cluster::UNKNOWN_SHAPE;
} else { } else {
@@ -360,6 +367,9 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) {
void DynamicShapePartitioner::MergeClustersUnknownShape() { void DynamicShapePartitioner::MergeClustersUnknownShape() {
// Merge unknown shape clusters // Merge unknown shape clusters
for (const auto &cluster : ordered_cluster_) { for (const auto &cluster : ordered_cluster_) {
if (cluster->IsIndependent()) {
continue;
}
for (const auto &in_cluster : cluster->Inputs()) { for (const auto &in_cluster : cluster->Inputs()) {
if (!in_cluster->IsUnknownShape()) { if (!in_cluster->IsUnknownShape()) {
continue; continue;
@@ -379,6 +389,9 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() {
void DynamicShapePartitioner::MergeClustersKnownShape() { void DynamicShapePartitioner::MergeClustersKnownShape() {
// Merge known shape clusters // Merge known shape clusters
for (const auto &cluster : ordered_cluster_) { for (const auto &cluster : ordered_cluster_) {
if (cluster->IsIndependent()) {
continue;
}
if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) {
auto in_cluster = *(cluster->Inputs().begin()); auto in_cluster = *(cluster->Inputs().begin());
in_cluster->Merge(cluster); in_cluster->Merge(cluster);
@@ -606,6 +619,7 @@ void Cluster::UpdateRank(size_t rank) {
bool Cluster::IsData() const { return type_ == DATA; }; bool Cluster::IsData() const { return type_ == DATA; };
bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; }; bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; };
bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; }; bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; };
bool Cluster::IsIndependent() const { return type_ == STAGE; };
bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; }; bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; };
bool Cluster::IsInputNode() const { return type_ == INPUT_NODE; }; bool Cluster::IsInputNode() const { return type_ == INPUT_NODE; };
bool Cluster::IsRefVariable() const { bool Cluster::IsRefVariable() const {
@@ -641,6 +655,9 @@ void Cluster::RemoveOutput(ClusterPtr out) {
out->in_clusters_.end()); out->in_clusters_.end());
}; };
void Cluster::Merge(ClusterPtr other) { void Cluster::Merge(ClusterPtr other) {
if (other->IsIndependent()) {
return;
}
nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end()); nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end());
other->in_clusters_.erase(std::remove(other->in_clusters_.begin(), other->in_clusters_.end(), shared_from_this()), other->in_clusters_.erase(std::remove(other->in_clusters_.begin(), other->in_clusters_.end(), shared_from_this()),
other->in_clusters_.end()); other->in_clusters_.end());
@@ -689,7 +706,9 @@ std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) {
std::unordered_set<ClusterPtr> forward_reached_clusters; std::unordered_set<ClusterPtr> forward_reached_clusters;
std::unordered_set<ClusterPtr> backward_reached_clusters; std::unordered_set<ClusterPtr> backward_reached_clusters;
std::vector<ClusterPtr> path_clusters; std::vector<ClusterPtr> path_clusters;

if (other->IsIndependent()) {
return path_clusters;
}
if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) ==
other->out_clusters_.end()) { other->out_clusters_.end()) {
return path_clusters; return path_clusters;
@@ -772,7 +791,7 @@ Status Cluster::BuildFrame() {
} }
} }
} }
if (IsData()) {
if (IsData() || IsIndependent()) {
for (const auto &anchor : node->GetAllOutDataAnchors()) { for (const auto &anchor : node->GetAllOutDataAnchors()) {
AddFrameOutput(anchor); AddFrameOutput(anchor);
} }
@@ -888,7 +907,7 @@ Status Cluster::CombinePartitionFrame() {
} }


Status Cluster::BuildPartitionSubgraph() { Status Cluster::BuildPartitionSubgraph() {
if (IsData() || IsNetOutput()) {
if (IsData() || IsNetOutput() || IsIndependent()) {
return SUCCESS; return SUCCESS;
} }
int64_t parent_node_index = 0; int64_t parent_node_index = 0;


+ 2
- 1
ge/graph/partition/dynamic_shape_partition.h View File

@@ -32,7 +32,7 @@ class DynamicShapePartitioner {
// DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. // DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT.
class Cluster : public std::enable_shared_from_this<Cluster> { class Cluster : public std::enable_shared_from_this<Cluster> {
public: public:
enum Type { DATA, INPUT_NODE, NETOUTPUT, KNOWN_SHAPE, UNKNOWN_SHAPE };
enum Type { DATA, INPUT_NODE, NETOUTPUT, STAGE, KNOWN_SHAPE, UNKNOWN_SHAPE };
Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner)
: id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { : id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) {
nodes_.push_back(node); nodes_.push_back(node);
@@ -45,6 +45,7 @@ class DynamicShapePartitioner {
bool IsData() const; bool IsData() const;
bool IsKnownShape() const; bool IsKnownShape() const;
bool IsUnknownShape() const; bool IsUnknownShape() const;
bool IsIndependent() const;
bool IsNetOutput() const; bool IsNetOutput() const;
std::vector<std::shared_ptr<Cluster>> Inputs() const; std::vector<std::shared_ptr<Cluster>> Inputs() const;
std::vector<std::shared_ptr<Cluster>> Outputs() const; std::vector<std::shared_ptr<Cluster>> Outputs() const;


+ 32
- 6
ge/graph/partition/stage_partition.cc View File

@@ -25,6 +25,10 @@
#include "common/types.h" #include "common/types.h"


namespace ge { namespace ge {
namespace {
const std::set<std::string> kSrcNodeTypes = { DATA, AIPPDATA, ANN_DATA };
}

Status StagePartitioner::Partition() { Status StagePartitioner::Partition() {
GE_CHECK_NOTNULL(root_graph_); GE_CHECK_NOTNULL(root_graph_);
if (root_graph_->GetParentGraph() != nullptr) { if (root_graph_->GetParentGraph() != nullptr) {
@@ -37,6 +41,10 @@ Status StagePartitioner::Partition() {
if (!AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, level)) { if (!AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, level)) {
continue; continue;
} }
if ((kSrcNodeTypes.count(op_desc->GetType()) != 0) && node->GetInAllNodes().empty()) {
continue;
}
GELOGD("original node %s for stage %u", node->GetName().c_str(), level);
stage_nodes_[level].insert(node); stage_nodes_[level].insert(node);
} }
if (stage_nodes_.empty()) { if (stage_nodes_.empty()) {
@@ -54,6 +62,13 @@ Status StagePartitioner::Partition() {
return FAILED; return FAILED;
} }


root_graph_->TopologicalSorting([](const NodePtr &a, const NodePtr &b) -> bool {
uint32_t a_level = UINT32_MAX;
(void)AttrUtils::GetInt(a->GetOpDesc(), ATTR_STAGE_LEVEL, a_level);
uint32_t b_level = UINT32_MAX;
(void)AttrUtils::GetInt(b->GetOpDesc(), ATTR_STAGE_LEVEL, b_level);
return a_level < b_level;
});
if (root_graph_->TopologicalSorting() != GRAPH_SUCCESS) { if (root_graph_->TopologicalSorting() != GRAPH_SUCCESS) {
GELOGE(FAILED, "Topological sort for graph %s after stage partition failed, " GELOGE(FAILED, "Topological sort for graph %s after stage partition failed, "
"maybe stage_level was not set correctly.", root_graph_->GetName().c_str()); "maybe stage_level was not set correctly.", root_graph_->GetName().c_str());
@@ -76,20 +91,26 @@ Status StagePartitioner::SplitStageLevel() {
auto node = nodes.top(); auto node = nodes.top();
nodes.pop(); nodes.pop();
GE_CHECK_NOTNULL(node->GetOpDesc()); GE_CHECK_NOTNULL(node->GetOpDesc());
if (node->GetOpDesc()->HasAttr(ATTR_STAGE_LEVEL) && (cur_stage_nodes.count(node) == 0)) {
uint32_t tmp_level = cur_stage_level;
(void)AttrUtils::GetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, tmp_level);
if (tmp_level != cur_stage_level) {
continue; continue;
} }
for (const auto &in_node : node->GetInAllNodes()) { for (const auto &in_node : node->GetInAllNodes()) {
if (visited_stage_nodes.count(in_node) != 0) { if (visited_stage_nodes.count(in_node) != 0) {
continue; continue;
} }
if (!AttrUtils::SetInt(in_node->GetOpDesc(), ATTR_STAGE_LEVEL, cur_stage_level)) {
GELOGE(INTERNAL_ERROR, "Set attr ATTR_STAGE_LEVEL on node %s failed.", in_node->GetName().c_str());
return INTERNAL_ERROR;
}
GELOGD("Mark stage_level node %s, stage_level=%u", in_node->GetName().c_str(), cur_stage_level);
if ((kSrcNodeTypes.count(in_node->GetType()) != 0) && in_node->GetInAllNodes().empty()) {
GELOGD("skip data node %s for stage %u", in_node->GetName().c_str(), cur_stage_level);
continue;
}
nodes.push(in_node); nodes.push(in_node);
} }
if (!AttrUtils::SetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, cur_stage_level)) {
GELOGE(INTERNAL_ERROR, "Set attr ATTR_STAGE_LEVEL on node %s failed.", node->GetName().c_str());
return INTERNAL_ERROR;
}
GELOGD("Mark stage_level node %s, stage_level=%u", node->GetName().c_str(), cur_stage_level);
visited_stage_nodes.emplace(node); visited_stage_nodes.emplace(node);
} }
for (const auto &node : visited_stage_nodes) { for (const auto &node : visited_stage_nodes) {
@@ -219,6 +240,11 @@ NodePtr StagePartitioner::BuildSubgraphNode(const std::string &graph_name, const
op_desc->AddSubgraphName("f"); op_desc->AddSubgraphName("f");
op_desc->SetSubgraphInstanceName(0, graph_name); op_desc->SetSubgraphInstanceName(0, graph_name);


if (!AttrUtils::SetInt(op_desc, ATTR_STAGE_LEVEL, stage_info.stage_level)) {
GELOGE(INTERNAL_ERROR, "Set attr ATTR_STAGE_LEVEL on node %s failed", op_desc->GetName().c_str());
return nullptr;
}

NodePtr subgraph_node = root_graph_->AddNode(op_desc); NodePtr subgraph_node = root_graph_->AddNode(op_desc);
if (subgraph_node == nullptr) { if (subgraph_node == nullptr) {
GELOGE(FAILED, "Add node %s failed.", graph_name.c_str()); GELOGE(FAILED, "Add node %s failed.", graph_name.c_str());


+ 4
- 3
ge/graph/passes/subgraph_pass.cc View File

@@ -142,17 +142,18 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node
GE_CHECK_NOTNULL(in_node); GE_CHECK_NOTNULL(in_node);


// Need insert memcpy // Need insert memcpy
// 1. Const->NetOutput in subgraph
// 1. Const->NetOutput in subgraph & parent graph is known
// 2. AtomicOp->NetOutput in subgraph // 2. AtomicOp->NetOutput in subgraph
// 3. OutputContinuesRequiredOp->NetOutput in subgraph // 3. OutputContinuesRequiredOp->NetOutput in subgraph
// 4. Data->NetOutput in subgraph but parent_node is not while // 4. Data->NetOutput in subgraph but parent_node is not while
// 5. While->NetOutput in known subgraph // 5. While->NetOutput in known subgraph
std::string op_type; std::string op_type;
bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) ||
bool insert_flag =
(NodeUtils::GetConstOpType(in_node, op_type) && !graph->GetParentGraph()->GetGraphUnknownFlag()) ||
IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) ||
((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) || ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)) ||
(!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) && (!graph->GetGraphUnknownFlag() && NodeUtils::IsDynamicShape(node) &&
(kWhileOpTypes.count(in_node->GetType()) != 0));
(kWhileOpTypes.count(in_node->GetType()) != 0));
if (insert_flag) { if (insert_flag) {
GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str());
std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy";


+ 3
- 0
ge/host_cpu_engine/ops_kernel_store/op/host_op.cc View File

@@ -32,5 +32,8 @@ REGISTER_OP_CREATOR(Assign, HostOp);
REGISTER_OP_CREATOR(RandomUniform, HostOp); REGISTER_OP_CREATOR(RandomUniform, HostOp);
REGISTER_OP_CREATOR(Add, HostOp); REGISTER_OP_CREATOR(Add, HostOp);
REGISTER_OP_CREATOR(Mul, HostOp); REGISTER_OP_CREATOR(Mul, HostOp);
REGISTER_OP_CREATOR(ConcatV2, HostOp);
REGISTER_OP_CREATOR(Data, HostOp);
REGISTER_OP_CREATOR(Fill, HostOp);
} // namespace host_cpu } // namespace host_cpu
} // namespace ge } // namespace ge

+ 6
- 1
ge/hybrid/executor/hybrid_model_async_executor.cc View File

@@ -59,6 +59,7 @@ Status HybridModelAsyncExecutor::Start(const std::shared_ptr<ModelListener> &lis
run_flag_ = true; run_flag_ = true;
listener_ = listener; listener_ = listener;
future_ = std::async(std::launch::async, [&]() -> Status { future_ = std::async(std::launch::async, [&]() -> Status {
GetThreadLocalContext() = *executor_->GetContext()->ge_context;
GetContext().SetSessionId(executor_->GetContext()->session_id); GetContext().SetSessionId(executor_->GetContext()->session_id);
return RunInternal(); return RunInternal();
}); });
@@ -229,7 +230,11 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData &current_data, Hy
} }


GE_CHECK_GE(tensor_size, 0); GE_CHECK_GE(tensor_size, 0);
auto tensor_buffer = TensorBuffer::Create(allocator, tensor_size);
AllocationAttr attr;
if (GetContext().GetHostExecFlag()) {
attr.SetMemType(HOST_DDR);
}
auto tensor_buffer = TensorBuffer::Create(allocator, tensor_size, &attr);
GE_CHECK_NOTNULL(tensor_buffer); GE_CHECK_NOTNULL(tensor_buffer);
args.inputs.emplace_back(std::shared_ptr<TensorBuffer>(tensor_buffer.release())); args.inputs.emplace_back(std::shared_ptr<TensorBuffer>(tensor_buffer.release()));




+ 6
- 1
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -772,7 +772,12 @@ Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_
var_name.c_str(), var_name.c_str(),
hybrid_model_.GetSessionId()); hybrid_model_.GetSessionId());


uint8_t *dev_mem = var_manager_->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM);
rtMemType_t memory_type = RT_MEMORY_HBM;
uint32_t mem_type = 0;
if (AttrUtils::GetInt(var_node->GetOpDesc(), ATTR_OUTPUT_MEMORY_TYPE, mem_type) && (mem_type == 1)) {
memory_type = RT_MEMORY_RDMA_HBM;
}
uint8_t *dev_mem = var_manager_->GetVarMemoryAddr(var_logic, memory_type);
if (dev_mem == nullptr) { if (dev_mem == nullptr) {
GELOGE(INTERNAL_ERROR, GELOGE(INTERNAL_ERROR,
"Failed to copy var %s from device, cant not get " "Failed to copy var %s from device, cant not get "


+ 87
- 27
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -15,23 +15,25 @@
*/ */


#include "hybrid/node_executor/hccl/hccl_node_executor.h" #include "hybrid/node_executor/hccl/hccl_node_executor.h"
#include "common/ge/ge_util.h"
#include "common/ge/plugin_manager.h" #include "common/ge/plugin_manager.h"
#include "common/math/math_util.h" #include "common/math/math_util.h"
#include "framework/common/debug/ge_log.h"
#include "graph/attr_value.h" #include "graph/attr_value.h"
#include "graph/debug/ge_attr_define.h" #include "graph/debug/ge_attr_define.h"
#include "graph/manager/util/hcom_util.h" #include "graph/manager/util/hcom_util.h"
#include "graph/runtime_inference_context.h" #include "graph/runtime_inference_context.h"
#include "hccl/hcom.h"
#include "graph/utils/type_utils.h"
#include "hybrid/executor/hybrid_execution_context.h"


namespace ge {
namespace { namespace {
const size_t kVarTableDims = 2;
const size_t kVarTableRowCnt = 3;
const size_t kVarTableIdxAddr = 1;
const size_t kVarTableIdxLen = 2;
constexpr size_t kVarTableDims = 2;
constexpr size_t kVarTableRowCnt = 3;
constexpr size_t kVarTableIdxAddr = 1;
constexpr size_t kVarTableIdxLen = 2;
const std::set<std::string> kRdmaReadTypes = { HCOMREMOTEREAD, HCOMREMOTEREFREAD };
const std::set<std::string> kRdmaWriteTypes = { HCOMREMOTEWRITE, HCOMREMOTESCATTERWRITE };
const std::set<std::string> kRdmaScatterTypes = { HCOMREMOTEREFREAD, HCOMREMOTESCATTERWRITE };
} // namespace } // namespace
namespace ge {
namespace hybrid { namespace hybrid {


REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, HcclNodeExecutor); REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, HcclNodeExecutor);
@@ -142,11 +144,22 @@ Status RdmaNodeTask::Init(TaskContext &context) {
GE_CHECK_NOTNULL(peer_node->GetOpDesc()); GE_CHECK_NOTNULL(peer_node->GetOpDesc());


remote_index_ = {peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx()}; remote_index_ = {peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx()};
if (node_item.node->GetType() == HCOMREMOTEREAD) {
if (kRdmaReadTypes.count(node_item.node->GetType()) > 0) {
local_index_ = 0; local_index_ = 0;
} else { } else {
local_index_ = op_desc->GetInputIndexByName("local"); local_index_ = op_desc->GetInputIndexByName("local");
} }
int32_t offset_idx = node_item.op_desc->GetInputIndexByName("local_offset");
if ((offset_idx != -1) && (node_item.op_desc->GetInputDescPtr(offset_idx) != nullptr)) {
skip_flag_ = true;
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx));
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor());
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode());
GE_CHECK_NOTNULL(node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc());
offset_index_ = {
node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetOwnerNode()->GetOpDesc()->GetId(),
node_item.node->GetInDataAnchor(offset_idx)->GetPeerOutAnchor()->GetIdx() };
}
return SUCCESS; return SUCCESS;
} }


@@ -158,8 +171,13 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor));
auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData()); auto data = reinterpret_cast<uint64_t *>(remote_tensor.GetData());
if (data == nullptr) { if (data == nullptr) {
GELOGE(FAILED, "Tensor data is nullptr.");
return FAILED;
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) {
GELOGD("data is null, no need to do rdma read/write, node=%s", context.GetNodeName());
return SUCCESS;
} else {
GELOGE(FAILED, "Tensor data is nullptr.");
return FAILED;
}
} }
auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims(); auto dims = remote_tensor.GetTensorDesc().GetShape().GetDims();
if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) { if (dims.size() != kVarTableDims && dims.back() != kVarTableRowCnt) {
@@ -183,30 +201,63 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess
auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr); auto tensor_buffer = TensorBuffer::Create(allocator, remote_size, &attr);
GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release())))); GE_CHK_STATUS_RET(context.SetOutput(i, TensorValue(std::shared_ptr<TensorBuffer>(tensor_buffer.release()))));
} }
} else if (context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD) {
AllocationAttr attr;
attr.SetMemType(RDMA_HBM);
GE_CHK_STATUS_RET(context.AllocateOutputs(&attr))
} }


TensorValue *tv; TensorValue *tv;
if (context.GetNodeItem().NodeType() == HCOMREMOTEREAD) {
tv = context.MutableOutput(0);
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) > 0) {
tv = context.MutableOutput(local_index_);
} else { } else {
tv = context.MutableInput(local_index_); tv = context.MutableInput(local_index_);
} }
GE_CHECK_NOTNULL(tv); GE_CHECK_NOTNULL(tv);
auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData()));
auto row_num = dims.front(); auto row_num = dims.front();
addr_infos.resize(row_num); addr_infos.resize(row_num);
auto device_len = tv->GetSize() / row_num;
if (device_len <= 0 || device_len > data[kVarTableIdxLen]) {
GELOGE(FAILED, "Local embedding length is out of range.");
return FAILED;
}
if (skip_flag_) {
int32_t offset_idx = context.GetNodeItem().op_desc->GetInputIndexByName("local_offset");
GE_CHECK_NOTNULL(context.GetNodeItem().op_desc->GetInputDescPtr(offset_idx));
auto data_type = context.GetNodeItem().op_desc->GetInputDesc(offset_idx).GetDataType();

Tensor offset_tensor;
GE_CHK_STATUS_RET(ctx->GetTensor(offset_index_.first, offset_index_.second, offset_tensor))
if (static_cast<int64_t>(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) {
GELOGE(PARAM_INVALID, "num of offset and remote addr mismatch, offset size=%zu, remote_addr size=%lld, dtype=%s",
offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str());
return PARAM_INVALID;
}


for (auto idx = 0; idx < row_num; ++idx) {
FMK_INT64_MULCHECK(idx, kVarTableRowCnt);
auto line_idx = idx * kVarTableRowCnt;
addr_infos[idx] = {static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr,
device_len};
local_addr += device_len;
auto addr_offset = reinterpret_cast<uint64_t *>(offset_tensor.GetData());
GE_CHECK_NOTNULL(addr_offset);
auto base_addr = reinterpret_cast<float *>(tv->MutableData());
GE_CHECK_NOTNULL(base_addr);

for (auto idx = 0; idx < row_num; idx++) {
FMK_INT64_MULCHECK(idx, kVarTableRowCnt)
auto line_idx = idx * kVarTableRowCnt;
addr_infos[idx] = { static_cast<uint32_t>(data[line_idx]),
data[line_idx + kVarTableIdxAddr],
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(base_addr + addr_offset[idx])),
data[line_idx + kVarTableIdxLen] };
}
} else {
auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData()));
auto device_len = tv->GetSize() / row_num;
if (device_len <= 0 || device_len > data[kVarTableIdxLen]) {
GELOGE(FAILED, "Local embedding length is out of range, expect %lld, but %lld exactly.",
data[kVarTableIdxLen], device_len);
return FAILED;
}

for (auto idx = 0; idx < row_num; ++idx) {
FMK_INT64_MULCHECK(idx, kVarTableRowCnt)
auto line_idx = idx * kVarTableRowCnt;
addr_infos[idx] = { static_cast<uint32_t>(data[line_idx]), data[line_idx + kVarTableIdxAddr], local_addr,
device_len };
local_addr += device_len;
}
} }


return SUCCESS; return SUCCESS;
@@ -226,6 +277,10 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
} }
vector<HcomRemoteAccessAddrInfo> addr_infos; vector<HcomRemoteAccessAddrInfo> addr_infos;
GE_CHK_STATUS_RET(ExtractTensor(context, addr_infos)); GE_CHK_STATUS_RET(ExtractTensor(context, addr_infos));
if (addr_infos.empty()) {
done_callback();
return SUCCESS;
}


auto callback = [this](HcclResult status) { auto callback = [this](HcclResult status) {
if (status != HCCL_SUCCESS) { if (status != HCCL_SUCCESS) {
@@ -235,6 +290,11 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
this->cond_.notify_all(); this->cond_.notify_all();
GELOGI("rdma callback success."); GELOGI("rdma callback success.");
}; };

std::string executor_type = context.GetNodeItem().NodeType();
if (kRdmaScatterTypes.count(context.GetNodeItem().NodeType()) > 0) {
executor_type = context.GetNodeItem().NodeType() == HCOMREMOTEREFREAD ? HCOMREMOTEREAD : HCOMREMOTEWRITE;
}
HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback);
if (hccl_ret != HCCL_SUCCESS) { if (hccl_ret != HCCL_SUCCESS) {
GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", hccl_ret); GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", hccl_ret);
@@ -262,7 +322,7 @@ Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const


GE_CHK_STATUS_RET(task.Init(context), "hccl node load hccl so failed."); GE_CHK_STATUS_RET(task.Init(context), "hccl node load hccl so failed.");
// allocate output mem, output mem or remote read will be calculated when node execute. // allocate output mem, output mem or remote read will be calculated when node execute.
if (context.GetNodeItem().NodeType() != HCOMREMOTEREAD) {
if (kRdmaReadTypes.count(context.GetNodeItem().NodeType()) == 0) {
GE_CHK_STATUS_RET(context.AllocateOutputs(), "hccl node task allocate output failed."); GE_CHK_STATUS_RET(context.AllocateOutputs(), "hccl node task allocate output failed.");
} }


@@ -274,7 +334,7 @@ Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const
Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const {
GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str()); GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str());
GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node);
if (node->GetType() == HCOMREMOTEREAD || node->GetType() == HCOMREMOTEWRITE) {
if ((kRdmaReadTypes.count(node->GetType()) > 0) || (kRdmaWriteTypes.count(node->GetType()) > 0)) {
task = MakeShared<RdmaNodeTask>(); task = MakeShared<RdmaNodeTask>();
} else { } else {
task = MakeShared<HcclNodeTask>(); task = MakeShared<HcclNodeTask>();


+ 2
- 0
ge/hybrid/node_executor/hccl/hccl_node_executor.h View File

@@ -55,9 +55,11 @@ class RdmaNodeTask : public NodeTask {
private: private:
Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos); Status ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos);
std::pair<int64_t, int64_t> remote_index_; std::pair<int64_t, int64_t> remote_index_;
std::pair<int64_t, int64_t> offset_index_;
int32_t local_index_ = 0; int32_t local_index_ = 0;
std::mutex hccl_mutex_; std::mutex hccl_mutex_;
std::condition_variable cond_; std::condition_variable cond_;
bool skip_flag_;
}; };


class HcclNodeExecutor : public NodeExecutor { class HcclNodeExecutor : public NodeExecutor {


+ 1
- 3
ge/hybrid/node_executor/host_cpu/kernel/assign_kernel.cc View File

@@ -29,8 +29,6 @@ namespace ge {
namespace hybrid { namespace hybrid {
namespace host_cpu { namespace host_cpu {
Status AssignKernel::Compute(TaskContext& context) { Status AssignKernel::Compute(TaskContext& context) {
GELOGI("[%s] compute begin.", node_->GetName().c_str());

auto ref_tensor = context.MutableInput(kAssignRefInputIndex); auto ref_tensor = context.MutableInput(kAssignRefInputIndex);
GE_CHECK_NOTNULL(ref_tensor); GE_CHECK_NOTNULL(ref_tensor);
const auto value_tensor = context.GetInput(kAssignValueInputIndex); const auto value_tensor = context.GetInput(kAssignValueInputIndex);
@@ -50,7 +48,7 @@ Status AssignKernel::Compute(TaskContext& context) {
GE_CHK_STATUS_RET(context.SetOutput(kAssignRefOutputIndex, *ref_tensor), GE_CHK_STATUS_RET(context.SetOutput(kAssignRefOutputIndex, *ref_tensor),
"[%s] Failed to set output.", context.GetNodeName()); "[%s] Failed to set output.", context.GetNodeName());


GELOGI("[%s] compute success.", node_->GetName().c_str());
GELOGD("[%s] compute success.", node_->GetName().c_str());
return SUCCESS; return SUCCESS;
} }




+ 41
- 0
ge/hybrid/node_executor/host_cpu/kernel/data_kernel.cc View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "hybrid/node_executor/host_cpu/kernel/data_kernel.h"
#include "framework/common/debug/ge_log.h"
#include "framework/common/util.h"
#include "hybrid/node_executor/host_cpu/kernel_factory.h"

namespace {
constexpr size_t kDataInputIndex = 0;
constexpr size_t kDataOutputIndex = 0;
}

namespace ge {
namespace hybrid {
namespace host_cpu {
Status DataKernel::Compute(TaskContext& context) {
auto input = context.MutableInput(kDataInputIndex);
GE_CHECK_NOTNULL(input);
GE_CHK_STATUS_RET(context.SetOutput(kDataOutputIndex, *input), "[%s] Failed to set output.", context.GetNodeName())
GELOGD("[%s] compute success.", node_->GetName().c_str());
return SUCCESS;
}

REGISTER_KERNEL_CREATOR(Data, DataKernel);
} // namespace host_cpu
} // namespace hybrid
} // namespace ge

+ 42
- 0
ge/hybrid/node_executor/host_cpu/kernel/data_kernel.h View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef GE_HYBRID_HOST_CPU_KERNEL_DATA_KERNEL_H_
#define GE_HYBRID_HOST_CPU_KERNEL_DATA_KERNEL_H_

#include "hybrid/node_executor/host_cpu/kernel/kernel.h"

namespace ge {
namespace hybrid {
namespace host_cpu {
class DataKernel : public Kernel {
public:
DataKernel(const NodePtr &node) : Kernel(node) {}
~DataKernel() override = default;
DataKernel &operator=(const DataKernel &op) = delete;
DataKernel(const DataKernel &op) = delete;

/**
* @brief compute for node_task.
* @return result
*/
Status Compute(TaskContext& context) override;
};
} // namespace host_cpu
} // namespace hybrid
} // namespace ge

#endif // GE_HYBRID_HOST_CPU_KERNEL_DATA_KERNEL_H_

+ 1
- 1
ge/hybrid/node_executor/host_cpu/kernel/no_op_kernel.cc View File

@@ -23,7 +23,7 @@ namespace ge {
namespace hybrid { namespace hybrid {
namespace host_cpu { namespace host_cpu {
Status NoOpKernel::Compute(TaskContext& context) { Status NoOpKernel::Compute(TaskContext& context) {
GELOGI("[%s] no need to compute.", node_->GetName().c_str());
GELOGD("[%s] no need to compute.", node_->GetName().c_str());
return SUCCESS; return SUCCESS;
} }




+ 1
- 3
ge/hybrid/node_executor/host_cpu/kernel/random_uniform_kernel.cc View File

@@ -30,8 +30,6 @@ namespace ge {
namespace hybrid { namespace hybrid {
namespace host_cpu { namespace host_cpu {
Status RandomUniformKernel::Compute(TaskContext& context) { Status RandomUniformKernel::Compute(TaskContext& context) {
GELOGI("[%s] compute begin.", node_->GetName().c_str());

int64_t seed = 0; int64_t seed = 0;
int64_t seed2 = 0; int64_t seed2 = 0;
(void)AttrUtils::GetInt(node_->GetOpDesc(), "seed", seed); (void)AttrUtils::GetInt(node_->GetOpDesc(), "seed", seed);
@@ -66,7 +64,7 @@ Status RandomUniformKernel::Compute(TaskContext& context) {
return UNSUPPORTED; return UNSUPPORTED;
} }


GELOGI("[%s] compute success.", node_->GetName().c_str());
GELOGD("[%s] compute success.", node_->GetName().c_str());
return SUCCESS; return SUCCESS;
} }




+ 1
- 3
ge/hybrid/node_executor/host_cpu/kernel/variable_kernel.cc View File

@@ -23,8 +23,6 @@ namespace ge {
namespace hybrid { namespace hybrid {
namespace host_cpu { namespace host_cpu {
Status VariableKernel::Compute(TaskContext& context) { Status VariableKernel::Compute(TaskContext& context) {
GELOGI("[%s] compute begin.", node_->GetName().c_str());

auto tensor = context.GetVariable(node_->GetName()); auto tensor = context.GetVariable(node_->GetName());
if (tensor == nullptr) { if (tensor == nullptr) {
GELOGE(PARAM_INVALID, "tensor is NULL."); GELOGE(PARAM_INVALID, "tensor is NULL.");
@@ -32,7 +30,7 @@ Status VariableKernel::Compute(TaskContext& context) {
} }
// Constant & Variable Op has and only has one output // Constant & Variable Op has and only has one output
GE_CHK_STATUS_RET(context.SetOutput(0, *tensor), "[%s] Failed to set output.", context.GetNodeName()); GE_CHK_STATUS_RET(context.SetOutput(0, *tensor), "[%s] Failed to set output.", context.GetNodeName());
GELOGI("[%s] compute success.", node_->GetName().c_str());
GELOGD("[%s] compute success.", node_->GetName().c_str());
return SUCCESS; return SUCCESS;
} }




+ 1
- 0
inc/framework/common/types.h View File

@@ -437,6 +437,7 @@ REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite");
REGISTER_OPTYPE_DECLARE(HCOMREMOTESCATTERWRITE, "HcomRemoteScatterWrite");


REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign");
REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp");


+ 4
- 2
inc/framework/omg/parser/parser_types.h View File

@@ -238,8 +238,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SOFTSIGN;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *COSH; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *COSH;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SINH; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SINH;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SQUAREDDIFFERENCE; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SQUAREDDIFFERENCE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char
*REQUIREDSPACETOBATCHPADDINGS; // for retinanet scope fusion
// for retinanet scope fusion
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *REQUIREDSPACETOBATCHPADDINGS;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDPOSTPROCESSOR; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *SSDPOSTPROCESSOR;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINANETBOXES; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINANETBOXES;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINAMULTIANCHORS; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *RETINAMULTIANCHORS;
@@ -370,7 +370,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREDUCESC
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMSEND; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMSEND;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMRECEIVE; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMRECEIVE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEREAD; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEREAD;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEREFREAD;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEWRITE; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTEWRITE;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *HCOMREMOTESCATTERWRITE;


FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARASSIGN; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARASSIGN;
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARISINITIALIZEDOP; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const char *VARISINITIALIZEDOP;


+ 1
- 0
tests/ut/ge/CMakeLists.txt View File

@@ -589,6 +589,7 @@ set(DISTINCT_GRAPH_LOAD_TEST_FILES
#"graph/graph_load_unittest.cc" #"graph/graph_load_unittest.cc"
"graph/ge_executor_unittest.cc" "graph/ge_executor_unittest.cc"
"graph/load/model_helper_unittest.cc" "graph/load/model_helper_unittest.cc"
"graph/load/model_utils_unittest.cc"
) )


set(PASS_TEST_FILES set(PASS_TEST_FILES


+ 70
- 0
tests/ut/ge/graph/load/model_utils_unittest.cc View File

@@ -0,0 +1,70 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#define protected public
#define private public
#include "graph/load/new_model_manager/model_utils.h"
#include "graph/manager/graph_var_manager.h"

using namespace std;

namespace ge {
class UtestModelUtils : public testing::Test {
protected:
void TearDown() {}
};

// test ModelUtils::GetVarAddr
TEST_F(UtestModelUtils, get_var_addr_hbm) {
uint8_t test = 2;
uint8_t *pf = &test;
RuntimeParam runtime_param;
runtime_param.session_id = 0;
runtime_param.logic_var_base = 0;
runtime_param.var_base = pf;
runtime_param.var_size = 16;

int64_t offset = 8;
EXPECT_EQ(VarManager::Instance(runtime_param.session_id)->Init(0, 0, 0, 0), SUCCESS);
EXPECT_NE(VarManager::Instance(runtime_param.session_id)->var_resource_, nullptr);
VarManager::Instance(runtime_param.session_id)->var_resource_->var_offset_map_[offset] = RT_MEMORY_HBM;
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>("test", "test");
uint8_t *var_addr = nullptr;
EXPECT_EQ(ModelUtils::GetVarAddr(runtime_param, op_desc, offset, var_addr), SUCCESS);
EXPECT_EQ(runtime_param.var_base + offset - runtime_param.logic_var_base, var_addr);
VarManager::Instance(runtime_param.session_id)->Destory();
}

TEST_F(UtestModelUtils, get_var_addr_rdma_hbm) {
uint8_t test = 2;
uint8_t *pf = &test;
RuntimeParam runtime_param;
runtime_param.session_id = 0;
runtime_param.logic_var_base = 0;
runtime_param.var_base = pf;

int64_t offset = 8;
EXPECT_EQ(VarManager::Instance(runtime_param.session_id)->Init(0, 0, 0, 0), SUCCESS);
EXPECT_NE(VarManager::Instance(runtime_param.session_id)->var_resource_, nullptr);
VarManager::Instance(runtime_param.session_id)->var_resource_->var_offset_map_[offset] = RT_MEMORY_RDMA_HBM;
std::shared_ptr<OpDesc> op_desc = std::make_shared<OpDesc>("test", "test");
uint8_t *var_addr = nullptr;
EXPECT_EQ(ModelUtils::GetVarAddr(runtime_param, op_desc, offset, var_addr), SUCCESS);
EXPECT_EQ(reinterpret_cast<uint8_t *>(offset), var_addr);
VarManager::Instance(runtime_param.session_id)->Destory();
}
} // namespace ge

+ 1
- 0
third_party/fwkacllib/inc/runtime/mem.h View File

@@ -34,6 +34,7 @@ extern "C" {
*/ */
#define RT_MEMORY_DEFAULT ((uint32_t)0x0) // default memory on device #define RT_MEMORY_DEFAULT ((uint32_t)0x0) // default memory on device
#define RT_MEMORY_HBM ((uint32_t)0x2) // HBM memory on device #define RT_MEMORY_HBM ((uint32_t)0x2) // HBM memory on device
#define RT_MEMORY_RDMA_HBM ((uint32_t)0x3) // RDMA-HBM memory on device
#define RT_MEMORY_DDR ((uint32_t)0x4) // DDR memory on device #define RT_MEMORY_DDR ((uint32_t)0x4) // DDR memory on device
#define RT_MEMORY_SPM ((uint32_t)0x8) // shared physical memory on device #define RT_MEMORY_SPM ((uint32_t)0x8) // shared physical memory on device
#define RT_MEMORY_P2P_HBM ((uint32_t)0x10) // HBM memory on other 4P device #define RT_MEMORY_P2P_HBM ((uint32_t)0x10) // HBM memory on other 4P device


Loading…
Cancel
Save