From: @zhangxiaokun9 Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshgtags/v1.3.0
@@ -266,19 +266,34 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | ||||
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | ||||
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | ||||
for (const auto &switch_node : loop_group.switch_nodes) { | |||||
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||||
for (const auto &exit_node : switch_node->GetOutDataNodes()) { | |||||
if (exit_node->GetType() == EXIT || exit_node->GetType() == REFEXIT) { | |||||
MarkForceUnknownShape(exit_node, loop_group.is_unknown_shape); | |||||
} | |||||
} | |||||
} | |||||
HandleSwitchExitNodes(loop_group); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
/// | |||||
/// @brief Mark force unknown for Exit node | |||||
/// @param [in] group of LoopCond | |||||
/// @return void | |||||
/// | |||||
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) { | |||||
if (!loop_group.is_unknown_shape) { | |||||
return; | |||||
} | |||||
for (const auto &switch_node : loop_group.switch_nodes) { | |||||
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||||
for (const auto &node : switch_node->GetOutDataNodes()) { | |||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
if (node_type == EXIT || node_type == REFEXIT) { | |||||
MarkForceUnknownShape(node, loop_group.is_unknown_shape); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
/// | /// | ||||
/// @brief Create Active Node | /// @brief Create Active Node | ||||
/// @param [in] graph | /// @param [in] graph | ||||
@@ -93,6 +93,13 @@ class NextIterationPass : public GraphPass { | |||||
/// | /// | ||||
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); | Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); | ||||
/// | |||||
/// @brief Mark force unknown for Exit node | |||||
/// @param [in] group of LoopCond | |||||
/// @return void | |||||
/// | |||||
void HandleSwitchExitNodes(const LoopCondGroup &loop_group); | |||||
// map<frame_name, LoopCondGroup> | // map<frame_name, LoopCondGroup> | ||||
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | ||||
}; | }; | ||||
@@ -177,6 +177,7 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||||
known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | ||||
GE_CHECK_NOTNULL(known_shape_task_context_); | GE_CHECK_NOTNULL(known_shape_task_context_); | ||||
node_state->SetTaskContext(known_shape_task_context_); | |||||
std::function<void()> callback; | std::function<void()> callback; | ||||
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback)); | ||||
@@ -20,6 +20,7 @@ | |||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "graph/common/omg_util.h" | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
#include "runtime/rt.h" | #include "runtime/rt.h" | ||||
@@ -132,15 +133,15 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function< | |||||
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const { | ||||
GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
GELOGD("[%s] Load for local task.", node->GetName().c_str()); | GELOGD("[%s] Load for local task.", node->GetName().c_str()); | ||||
auto op_type = node->GetType(); | |||||
task = RtsTaskFactory::GetInstance().Create(op_type); | |||||
if (task == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str()); | |||||
return INTERNAL_ERROR; | |||||
std::string node_type; | |||||
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); | |||||
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type); | |||||
if (rts_task == nullptr) { | |||||
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str()); | |||||
return UNSUPPORTED; | |||||
} | } | ||||
RtsNodeTask *rts_task = dynamic_cast<RtsNodeTask *>(task.get()); | |||||
GE_CHECK_NOTNULL(rts_task); | |||||
task = rts_task; | |||||
return rts_task->Init(model, node); | return rts_task->Init(model, node); | ||||
} | } | ||||
} // namespace hybrid | } // namespace hybrid |
@@ -18,7 +18,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
NodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const { | |||||
RtsNodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const { | |||||
auto it = creators_.find(task_type); | auto it = creators_.find(task_type); | ||||
if (it == creators_.end()) { | if (it == creators_.end()) { | ||||
GELOGW("Cannot find task type %s in inner map.", task_type.c_str()); | GELOGW("Cannot find task type %s in inner map.", task_type.c_str()); | ||||
@@ -17,11 +17,12 @@ | |||||
#ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | #ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | ||||
#define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | #define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | ||||
#include "hybrid/node_executor/node_executor.h" | |||||
#include "hybrid/node_executor/rts/rts_node_task.h" | |||||
namespace ge { | namespace ge { | ||||
namespace hybrid { | namespace hybrid { | ||||
using RtsTaskCreatorFun = std::function<NodeTaskPtr()>; | |||||
using RtsNodeTaskPtr = std::shared_ptr<RtsNodeTask>; | |||||
using RtsTaskCreatorFun = std::function<RtsNodeTaskPtr()>; | |||||
class RtsTaskFactory { | class RtsTaskFactory { | ||||
public: | public: | ||||
@@ -30,7 +31,7 @@ class RtsTaskFactory { | |||||
return instance; | return instance; | ||||
} | } | ||||
NodeTaskPtr Create(const std::string &task_type) const; | |||||
RtsNodeTaskPtr Create(const std::string &task_type) const; | |||||
class RtsTaskRegistrar { | class RtsTaskRegistrar { | ||||
public: | public: | ||||
@@ -60,6 +61,6 @@ class RtsTaskFactory { | |||||
REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz) | ||||
#define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | #define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \ | ||||
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> NodeTaskPtr { return MakeShared<clazz>(); }) | |||||
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> RtsNodeTaskPtr { return MakeShared<clazz>(); }) | |||||
#endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ | #endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_ |