| @@ -263,7 +263,7 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro | |||||
| Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | ||||
| std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST) { | |||||
| if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST || op_desc->GetType() == HCOMREDUCE) { | |||||
| GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| int64_t root_id = 0; | int64_t root_id = 0; | ||||
| Status dmrt = GetHcclRootId(op_desc, root_id); | Status dmrt = GetHcclRootId(op_desc, root_id); | ||||
| @@ -281,7 +281,7 @@ Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | |||||
| bool HcomOmeUtil::IsHCOMOp(const string &op_type) { | bool HcomOmeUtil::IsHCOMOp(const string &op_type) { | ||||
| return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) || | return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) || | ||||
| (op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER); | |||||
| (op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER || op_desc->GetType() == HCOMREDUCE); | |||||
| } | } | ||||
| bool HcomOmeUtil::IsHorovodOp(const string &op_type) { | bool HcomOmeUtil::IsHorovodOp(const string &op_type) { | ||||