diff --git a/ge/graph/passes/next_iteration_pass.cc b/ge/graph/passes/next_iteration_pass.cc index bac6b3eb..5f4fc4d0 100644 --- a/ge/graph/passes/next_iteration_pass.cc +++ b/ge/graph/passes/next_iteration_pass.cc @@ -266,19 +266,34 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); - for (const auto &switch_node : loop_group.switch_nodes) { - MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); - for (const auto &exit_node : switch_node->GetOutDataNodes()) { - if (exit_node->GetType() == EXIT || exit_node->GetType() == REFEXIT) { - MarkForceUnknownShape(exit_node, loop_group.is_unknown_shape); - } - } - } + HandleSwitchExitNodes(loop_group); } return SUCCESS; } +/// +/// @brief Mark force unknown for Exit node +/// @param [in] group of LoopCond +/// @return void +/// +void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) { + if (!loop_group.is_unknown_shape) { + return; + } + + for (const auto &switch_node : loop_group.switch_nodes) { + MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); + for (const auto &node : switch_node->GetOutDataNodes()) { + std::string node_type; + (void)GetOriginalType(node, node_type); + if (node_type == EXIT || node_type == REFEXIT) { + MarkForceUnknownShape(node, loop_group.is_unknown_shape); + } + } + } +} + /// /// @brief Create Active Node /// @param [in] graph diff --git a/ge/graph/passes/next_iteration_pass.h b/ge/graph/passes/next_iteration_pass.h index dea088ee..e8786516 100755 --- a/ge/graph/passes/next_iteration_pass.h +++ b/ge/graph/passes/next_iteration_pass.h @@ -93,6 +93,13 @@ class NextIterationPass : public GraphPass { /// Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); + /// + /// @brief Mark force unknown for Exit node + /// @param [in] group of LoopCond + /// @return void + /// + void HandleSwitchExitNodes(const LoopCondGroup &loop_group); + // map std::unordered_map loop_group_map_; }; diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index c9c6f768..60895c7e 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -177,6 +177,7 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vectorSetTaskContext(known_shape_task_context_); std::function callback; GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index 438045db..3ad791b6 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -20,6 +20,7 @@ #include "common/debug/log.h" #include "common/ge/ge_util.h" #include "common/types.h" +#include "graph/common/omg_util.h" #include "graph/utils/tensor_utils.h" #include "hybrid/model/hybrid_model.h" #include "runtime/rt.h" @@ -132,15 +133,15 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); GELOGD("[%s] Load for local task.", node->GetName().c_str()); - auto op_type = node->GetType(); - task = RtsTaskFactory::GetInstance().Create(op_type); - if (task == nullptr) { - GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str()); - return INTERNAL_ERROR; + std::string node_type; + GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); + RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); + if (rts_task == nullptr) { + GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); + return UNSUPPORTED; } - RtsNodeTask *rts_task = dynamic_cast(task.get()); - GE_CHECK_NOTNULL(rts_task); + task = rts_task; return rts_task->Init(model, node); } } // namespace hybrid diff --git a/ge/hybrid/node_executor/rts/rts_task_factory.cc b/ge/hybrid/node_executor/rts/rts_task_factory.cc index 73fd065d..0072fdf6 100644 --- a/ge/hybrid/node_executor/rts/rts_task_factory.cc +++ b/ge/hybrid/node_executor/rts/rts_task_factory.cc @@ -18,7 +18,7 @@ namespace ge { namespace hybrid { -NodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const { +RtsNodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const { auto it = creators_.find(task_type); if (it == creators_.end()) { GELOGW("Cannot find task type %s in inner map.", task_type.c_str()); diff --git a/ge/hybrid/node_executor/rts/rts_task_factory.h b/ge/hybrid/node_executor/rts/rts_task_factory.h index 42a07d4e..a2d2bf56 100644 --- a/ge/hybrid/node_executor/rts/rts_task_factory.h +++ b/ge/hybrid/node_executor/rts/rts_task_factory.h @@ -17,11 +17,12 @@ #ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ #define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ -#include "hybrid/node_executor/node_executor.h" +#include "hybrid/node_executor/rts/rts_node_task.h" namespace ge { namespace hybrid { -using RtsTaskCreatorFun = std::function; +using RtsNodeTaskPtr = std::shared_ptr; +using RtsTaskCreatorFun = std::function; class RtsTaskFactory { public: @@ -30,7 +31,7 @@ class RtsTaskFactory { return instance; } - NodeTaskPtr Create(const std::string &task_type) const; + RtsNodeTaskPtr Create(const std::string &task_type) const; class RtsTaskRegistrar { public: @@ -60,6 +61,6 @@ class RtsTaskFactory { REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) #define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ - RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> NodeTaskPtr { return MakeShared(); }) + RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> RtsNodeTaskPtr { return MakeShared(); }) #endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_