* update ge/graph/manager/util/hcom_util.cc. * !1 update ge/graph/load/new_model_manager/davinci_model.cc. * update ge/graph/build/stream_allocator.cc. * update ge/graph/load/new_model_manager/davinci_model.cc. * update ge/graph/manager/util/hcom_util.cc. * update inc/framework/common/types.h. * update ge/hybrid/node_executor/hccl/hccl_node_executor.cc. * update ge/graph/optimize/mem_rw_conflict_optimize.cc. * update ge/graph/manager/util/hcom_util.cc. * update ge/graph/load/new_model_manager/task_info/hccl_task_info.cc. * update ge/graph/build/stream_allocator.cc. * update ge/graph/build/stream_allocator.cc. * update ge/common/types.cc.tags/v1.1.0
| @@ -382,6 +382,7 @@ REGISTER_OPTYPE_DEFINE(HCOMBROADCAST, "HcomBroadcast"); | |||||
| REGISTER_OPTYPE_DEFINE(HCOMALLGATHER, "HcomAllGather"); | REGISTER_OPTYPE_DEFINE(HCOMALLGATHER, "HcomAllGather"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMALLREDUCE, "HcomAllReduce"); | REGISTER_OPTYPE_DEFINE(HCOMALLREDUCE, "HcomAllReduce"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREDUCESCATTER, "HcomReduceScatter"); | REGISTER_OPTYPE_DEFINE(HCOMREDUCESCATTER, "HcomReduceScatter"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREDUCE, "HcomReduce"); | |||||
| REGISTER_OPTYPE_DEFINE(HCOMSEND, "HcomSend"); | REGISTER_OPTYPE_DEFINE(HCOMSEND, "HcomSend"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive"); | REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | ||||
| @@ -49,7 +49,7 @@ inline bool HasContinuousStreamLabel(const ge::OpDescPtr &op_desc, std::string & | |||||
| } | } | ||||
| bool IsHcclOp(const string &op_type) { | bool IsHcclOp(const string &op_type) { | ||||
| const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER}); | |||||
| const set<string> hccl_op_types({ge::HCOMBROADCAST, ge::HCOMALLGATHER, ge::HCOMALLREDUCE, ge::HCOMREDUCESCATTER, ge::HCOMREDUCE}); | |||||
| return hccl_op_types.find(op_type) != hccl_op_types.end(); | return hccl_op_types.find(op_type) != hccl_op_types.end(); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -411,7 +411,8 @@ void DavinciModel::CheckHasHcomOp() { | |||||
| (op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) || | (op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) || | ||||
| (op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER) || | (op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER) || | ||||
| (op_desc->GetType() == HVDCALLBACKALLREDUCE) || (op_desc->GetType() == HVDCALLBACKALLGATHER) || | (op_desc->GetType() == HVDCALLBACKALLREDUCE) || (op_desc->GetType() == HVDCALLBACKALLGATHER) || | ||||
| (op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT)), | |||||
| (op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT) || | |||||
| (op_desc->GetType() == HCOMREDUCE)), | |||||
| uint32_t stream_id = static_cast<uint32_t>(op_desc->GetStreamId()); | uint32_t stream_id = static_cast<uint32_t>(op_desc->GetStreamId()); | ||||
| (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); | (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); | ||||
| } | } | ||||
| @@ -279,7 +279,7 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc, | |||||
| output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; | output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; | ||||
| } | } | ||||
| kernel_hccl_infos[i].inputDataAddr = input_data_addr; | kernel_hccl_infos[i].inputDataAddr = input_data_addr; | ||||
| if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { | |||||
| if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER || hccl_type == HCOMREDUCE) { | |||||
| kernel_hccl_infos[i].outputDataAddr = output_data_addr; | kernel_hccl_infos[i].outputDataAddr = output_data_addr; | ||||
| } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { | } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { | ||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), | ||||
| @@ -263,7 +263,7 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro | |||||
| Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | ||||
| std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST) { | |||||
| if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { | |||||
| GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| int64_t root_id = 0; | int64_t root_id = 0; | ||||
| Status dmrt = GetHcclRootId(op_desc, root_id); | Status dmrt = GetHcclRootId(op_desc, root_id); | ||||
| @@ -281,7 +281,7 @@ Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | |||||
| bool HcomOmeUtil::IsHCOMOp(const string &op_type) { | bool HcomOmeUtil::IsHCOMOp(const string &op_type) { | ||||
| return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) || | return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) || | ||||
| (op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER); | |||||
| (op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER) || (op_type == HCOMREDUCE); | |||||
| } | } | ||||
| bool HcomOmeUtil::IsHorovodOp(const string &op_type) { | bool HcomOmeUtil::IsHorovodOp(const string &op_type) { | ||||
| @@ -234,7 +234,7 @@ InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) { | |||||
| return InputRWType::kInvalidRWType; | return InputRWType::kInvalidRWType; | ||||
| } | } | ||||
| if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER | if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER | ||||
| || op_desc->GetType() == HCOMREDUCESCATTER) { | |||||
| || op_desc->GetType() == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE) { | |||||
| return InputRWType::kScopeWriteable; | return InputRWType::kScopeWriteable; | ||||
| } | } | ||||
| // check if it is ref input | // check if it is ref input | ||||
| @@ -85,7 +85,7 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| op_info.dataType = iter->second; | op_info.dataType = iter->second; | ||||
| HcclReduceOp op_type = HCCL_REDUCE_SUM; | HcclReduceOp op_type = HCCL_REDUCE_SUM; | ||||
| if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER || | if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER || | ||||
| op_desc->GetType() == HVDCALLBACKALLREDUCE) { | |||||
| op_desc->GetType() == HVDCALLBACKALLREDUCE || op_desc->GetType() == HCOMREDUCE) { | |||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), "GetHcclOperationType failed"); | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), "GetHcclOperationType failed"); | ||||
| op_info.opType = op_type; | op_info.opType = op_type; | ||||
| } | } | ||||
| @@ -431,6 +431,7 @@ REGISTER_OPTYPE_DECLARE(HCOMBROADCAST, "HcomBroadcast"); | |||||
| REGISTER_OPTYPE_DECLARE(HCOMALLGATHER, "HcomAllGather"); | REGISTER_OPTYPE_DECLARE(HCOMALLGATHER, "HcomAllGather"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce"); | REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter"); | REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREDUCE, "HcomReduce"); | |||||
| REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); | REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | ||||