Browse Source

!475 GE添加reduce算子定义

* update ge/graph/manager/util/hcom_util.cc.
* !1 update ge/graph/load/new_model_manager/davinci_model.cc.
* update ge/graph/build/stream_allocator.cc.
* update ge/graph/load/new_model_manager/davinci_model.cc.
* update ge/graph/manager/util/hcom_util.cc.
* update inc/framework/common/types.h.
* update ge/hybrid/node_executor/hccl/hccl_node_executor.cc.
* update ge/graph/optimize/mem_rw_conflict_optimize.cc.
* update ge/graph/manager/util/hcom_util.cc.
* update ge/graph/load/new_model_manager/task_info/hccl_task_info.cc.
* update ge/graph/build/stream_allocator.cc.
* update ge/graph/build/stream_allocator.cc.
* update ge/common/types.cc.
tags/v1.1.0
cclworkaccount 王笑天 3 years ago
parent
commit
35f7bb4c06
8 changed files with 10 additions and 7 deletions
  1. +1
    -0
      ge/common/types.cc
  2. +1
    -1
      ge/graph/build/stream_allocator.cc
  3. +2
    -1
      ge/graph/load/new_model_manager/davinci_model.cc
  4. +1
    -1
      ge/graph/load/new_model_manager/task_info/hccl_task_info.cc
  5. +2
    -2
      ge/graph/manager/util/hcom_util.cc
  6. +1
    -1
      ge/graph/optimize/mem_rw_conflict_optimize.cc
  7. +1
    -1
      ge/hybrid/node_executor/hccl/hccl_node_executor.cc
  8. +1
    -0
      inc/framework/common/types.h

+ 1
- 0
ge/common/types.cc View File

@@ -382,6 +382,7 @@ REGISTER_OPTYPE_DEFINE(HCOMBROADCAST, "HcomBroadcast");
REGISTER_OPTYPE_DEFINE(HCOMALLGATHER, "HcomAllGather");
REGISTER_OPTYPE_DEFINE(HCOMALLREDUCE, "HcomAllReduce");
REGISTER_OPTYPE_DEFINE(HCOMREDUCESCATTER, "HcomReduceScatter");
REGISTER_OPTYPE_DEFINE(HCOMREDUCE, "HcomReduce");
REGISTER_OPTYPE_DEFINE(HCOMSEND, "HcomSend");
REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive");
REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead");


+ 1
- 1
ge/graph/build/stream_allocator.cc View File

@@ -49,7 +49,7 @@ inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string &
}

bool IsHcclOp(const string &op_type) {
const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER});
const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE});
return hccl_op_types.find(op_type) != hccl_op_types.end();
}
} // namespace


+ 2
- 1
ge/graph/load/new_model_manager/davinci_model.cc View File

@@ -411,7 +411,8 @@ void DavinciModel::CheckHasHcomOp() {
(op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) ||
(op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER) ||
(op_desc->GetType() == HVDCALLBACKALLREDUCE) || (op_desc->GetType() == HVDCALLBACKALLGATHER) ||
(op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT)),
(op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT) ||
(op_desc->GetType() == HCOMREDUCE)),
uint32_t stream_id = static_cast<uint32_t>(op_desc->GetStreamId());
(void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue);
}


+ 1
- 1
ge/graph/load/new_model_manager/task_info/hccl_task_info.cc View File

@@ -279,7 +279,7 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc,
output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i];
}
kernel_hccl_infos[i].inputDataAddr = input_data_addr;
if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) {
if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER || hccl_type == HCOMREDUCE) {
kernel_hccl_infos[i].outputDataAddr = output_data_addr;
} else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) {
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type),


+ 2
- 2
ge/graph/manager/util/hcom_util.cc View File

@@ -263,7 +263,7 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro
Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc,
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) {
GE_CHECK_NOTNULL(op_desc);
if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST) {
if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) {
GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str());
int64_t root_id = 0;
Status dmrt = GetHcclRootId(op_desc, root_id);
@@ -281,7 +281,7 @@ Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc,

bool HcomOmeUtil::IsHCOMOp(const string &op_type) {
return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) ||
(op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER);
(op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER) || (op_type == HCOMREDUCE);
}

bool HcomOmeUtil::IsHorovodOp(const string &op_type) {


+ 1
- 1
ge/graph/optimize/mem_rw_conflict_optimize.cc View File

@@ -234,7 +234,7 @@ InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) {
return InputRWType::kInvalidRWType;
}
if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER
|| op_desc->GetType() == HCOMREDUCESCATTER) {
|| op_desc->GetType() == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE) {
return InputRWType::kScopeWriteable;
}
// check if it is ref input


+ 1
- 1
ge/hybrid/node_executor/hccl/hccl_node_executor.cc View File

@@ -85,7 +85,7 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
op_info.dataType = iter->second;
HcclReduceOp op_type = HCCL_REDUCE_SUM;
if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER ||
op_desc->GetType() == HVDCALLBACKALLREDUCE) {
op_desc->GetType() == HVDCALLBACKALLREDUCE || op_desc->GetType() == HCOMREDUCE) {
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), "GetHcclOperationType failed");
op_info.opType = op_type;
}


+ 1
- 0
inc/framework/common/types.h View File

@@ -431,6 +431,7 @@ REGISTER_OPTYPE_DECLARE(HCOMBROADCAST, "HcomBroadcast");
REGISTER_OPTYPE_DECLARE(HCOMALLGATHER, "HcomAllGather");
REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce");
REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter");
REGISTER_OPTYPE_DECLARE(HCOMREDUCE, "HcomReduce");
REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend");
REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive");
REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead");


Loading…
Cancel
Save