From: @zhangxiaokun9 Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshgtags/v1.3.0
| @@ -266,19 +266,34 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | ||||
| MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | ||||
| MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | ||||
| for (const auto &switch_node : loop_group.switch_nodes) { | |||||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||||
| for (const auto &exit_node : switch_node->GetOutDataNodes()) { | |||||
| if (exit_node->GetType() == EXIT || exit_node->GetType() == REFEXIT) { | |||||
| MarkForceUnknownShape(exit_node, loop_group.is_unknown_shape); | |||||
| } | |||||
| } | |||||
| } | |||||
| HandleSwitchExitNodes(loop_group); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| /// | |||||
| /// @brief Mark force unknown for Exit node | |||||
| /// @param [in] group of LoopCond | |||||
| /// @return void | |||||
| /// | |||||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) { | |||||
| if (!loop_group.is_unknown_shape) { | |||||
| return; | |||||
| } | |||||
| for (const auto &switch_node : loop_group.switch_nodes) { | |||||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||||
| for (const auto &node : switch_node->GetOutDataNodes()) { | |||||
| std::string node_type; | |||||
| (void)GetOriginalType(node, node_type); | |||||
| if (node_type == EXIT || node_type == REFEXIT) { | |||||
| MarkForceUnknownShape(node, loop_group.is_unknown_shape); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| /// | /// | ||||
| /// @brief Create Active Node | /// @brief Create Active Node | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| @@ -93,6 +93,13 @@ class NextIterationPass : public GraphPass { | |||||
| /// | /// | ||||
| Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); | Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); | ||||
| /// | |||||
| /// @brief Mark force unknown for Exit node | |||||
| /// @param [in] group of LoopCond | |||||
| /// @return void | |||||
| /// | |||||
| void HandleSwitchExitNodes(const LoopCondGroup &loop_group); | |||||
| // map<frame_name, LoopCondGroup> | // map<frame_name, LoopCondGroup> | ||||
| std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | ||||
| }; | }; | ||||
| @@ -177,6 +177,7 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||||
| known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | ||||
| GE_CHECK_NOTNULL(known_shape_task_context_); | GE_CHECK_NOTNULL(known_shape_task_context_); | ||||
| node_state->SetTaskContext(known_shape_task_context_); | |||||
| std::function<void()> callback; | std::function<void()> callback; | ||||
| GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "common/types.h" | #include "common/types.h" | ||||
| #include "graph/common/omg_util.h" | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
| #include "runtime/rt.h" | #include "runtime/rt.h" | ||||
| @@ -132,15 +133,15 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||||
| Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GELOGD("[%s] Load for local task.", node->GetName().c_str()); | GELOGD("[%s] Load for local task.", node->GetName().c_str()); | ||||
| auto op_type = node->GetType(); | |||||
| task = RtsTaskFactory::GetInstance().Create(op_type); | |||||
| if (task == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | |||||
| RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | |||||
| if (rts_task == nullptr) { | |||||
| GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | |||||
| return UNSUPPORTED; | |||||
| } | } | ||||
| RtsNodeTask *rts_task = dynamic_cast<RtsNodeTask *>(task.get()); | |||||
| GE_CHECK_NOTNULL(rts_task); | |||||
| task = rts_task; | |||||
| return rts_task->Init(model, node); | return rts_task->Init(model, node); | ||||
| } | } | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| @@ -18,7 +18,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| NodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const { | |||||
| RtsNodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const { | |||||
| auto it = creators_.find(task_type); | auto it = creators_.find(task_type); | ||||
| if (it == creators_.end()) { | if (it == creators_.end()) { | ||||
| GELOGW("Cannot find task type %s in inner map.", task_type.c_str()); | GELOGW("Cannot find task type %s in inner map.", task_type.c_str()); | ||||
| @@ -17,11 +17,12 @@ | |||||
| #ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | #ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | ||||
| #define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | #define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | ||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| #include "hybrid/node_executor/rts/rts_node_task.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| using RtsTaskCreatorFun = std::function<NodeTaskPtr()>; | |||||
| using RtsNodeTaskPtr = std::shared_ptr<RtsNodeTask>; | |||||
| using RtsTaskCreatorFun = std::function<RtsNodeTaskPtr()>; | |||||
| class RtsTaskFactory { | class RtsTaskFactory { | ||||
| public: | public: | ||||
| @@ -30,7 +31,7 @@ class RtsTaskFactory { | |||||
| return instance; | return instance; | ||||
| } | } | ||||
| NodeTaskPtr Create(const std::string &task_type) const; | |||||
| RtsNodeTaskPtr Create(const std::string &task_type) const; | |||||
| class RtsTaskRegistrar { | class RtsTaskRegistrar { | ||||
| public: | public: | ||||
| @@ -60,6 +61,6 @@ class RtsTaskFactory { | |||||
| REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | ||||
| #define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | #define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | ||||
| RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> NodeTaskPtr { return MakeShared<clazz>(); }) | |||||
| RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> RtsNodeTaskPtr { return MakeShared<clazz>(); }) | |||||
| #endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | #endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | ||||