Browse Source

!1609 SetTaskContext for known

From: @zhangxiaokun9
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
28ee6d5a8e
6 changed files with 45 additions and 20 deletions
  1. +23
    -8
      ge/graph/passes/next_iteration_pass.cc
  2. +7
    -0
      ge/graph/passes/next_iteration_pass.h
  3. +1
    -0
      ge/hybrid/executor/subgraph_executor.cc
  4. +8
    -7
      ge/hybrid/node_executor/rts/rts_node_executor.cc
  5. +1
    -1
      ge/hybrid/node_executor/rts/rts_task_factory.cc
  6. +5
    -4
      ge/hybrid/node_executor/rts/rts_task_factory.h

+ 23
- 8
ge/graph/passes/next_iteration_pass.cc View File

@@ -266,19 +266,34 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) {
MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape);
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape);
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape);
for (const auto &switch_node : loop_group.switch_nodes) {
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape);
for (const auto &exit_node : switch_node->GetOutDataNodes()) {
if (exit_node->GetType() == EXIT || exit_node->GetType() == REFEXIT) {
MarkForceUnknownShape(exit_node, loop_group.is_unknown_shape);
}
}
}
HandleSwitchExitNodes(loop_group);
}

return SUCCESS;
}

///
/// @brief Mark force unknown for Exit node
/// @param [in] group of LoopCond
/// @return void
///
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) {
if (!loop_group.is_unknown_shape) {
return;
}

for (const auto &switch_node : loop_group.switch_nodes) {
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape);
for (const auto &node : switch_node->GetOutDataNodes()) {
std::string node_type;
(void)GetOriginalType(node, node_type);
if (node_type == EXIT || node_type == REFEXIT) {
MarkForceUnknownShape(node, loop_group.is_unknown_shape);
}
}
}
}

///
/// @brief Create Active Node
/// @param [in] graph


+ 7
- 0
ge/graph/passes/next_iteration_pass.h View File

@@ -93,6 +93,13 @@ class NextIterationPass : public GraphPass {
///
Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node);

///
/// @brief Mark force unknown for Exit node
/// @param [in] group of LoopCond
/// @return void
///
void HandleSwitchExitNodes(const LoopCondGroup &loop_group);

// map<frame_name, LoopCondGroup>
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_;
};


+ 1
- 0
ge/hybrid/executor/subgraph_executor.cc View File

@@ -177,6 +177,7 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue

known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get());
GE_CHECK_NOTNULL(known_shape_task_context_);
node_state->SetTaskContext(known_shape_task_context_);

std::function<void()> callback;
GE_CHK_STATUS_RET_NOLOG(InitCallback(node_state.get(), callback));


+ 8
- 7
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -20,6 +20,7 @@
#include "common/debug/log.h"
#include "common/ge/ge_util.h"
#include "common/types.h"
#include "graph/common/omg_util.h"
#include "graph/utils/tensor_utils.h"
#include "hybrid/model/hybrid_model.h"
#include "runtime/rt.h"
@@ -132,15 +133,15 @@ Status ProfilingTraceNodeTask::ExecuteAsync(TaskContext &context, std::function<
Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr<NodeTask> &task) const {
GE_CHECK_NOTNULL(node);
GELOGD("[%s] Load for local task.", node->GetName().c_str());
auto op_type = node->GetType();
task = RtsTaskFactory::GetInstance().Create(op_type);
if (task == nullptr) {
GELOGE(INTERNAL_ERROR, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), op_type.c_str());
return INTERNAL_ERROR;
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed.");
RtsNodeTaskPtr rts_task = RtsTaskFactory::GetInstance().Create(node_type);
if (rts_task == nullptr) {
GELOGE(UNSUPPORTED, "[%s] Unsupported RTS op type: %s", node->GetName().c_str(), node_type.c_str());
return UNSUPPORTED;
}

RtsNodeTask *rts_task = dynamic_cast<RtsNodeTask *>(task.get());
GE_CHECK_NOTNULL(rts_task);
task = rts_task;
return rts_task->Init(model, node);
}
} // namespace hybrid

+ 1
- 1
ge/hybrid/node_executor/rts/rts_task_factory.cc View File

@@ -18,7 +18,7 @@

namespace ge {
namespace hybrid {
NodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const {
RtsNodeTaskPtr RtsTaskFactory::Create(const std::string &task_type) const {
auto it = creators_.find(task_type);
if (it == creators_.end()) {
GELOGW("Cannot find task type %s in inner map.", task_type.c_str());


+ 5
- 4
ge/hybrid/node_executor/rts/rts_task_factory.h View File

@@ -17,11 +17,12 @@
#ifndef GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_
#define GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_

#include "hybrid/node_executor/node_executor.h"
#include "hybrid/node_executor/rts/rts_node_task.h"

namespace ge {
namespace hybrid {
using RtsTaskCreatorFun = std::function<NodeTaskPtr()>;
using RtsNodeTaskPtr = std::shared_ptr<RtsNodeTask>;
using RtsTaskCreatorFun = std::function<RtsNodeTaskPtr()>;

class RtsTaskFactory {
public:
@@ -30,7 +31,7 @@ class RtsTaskFactory {
return instance;
}

NodeTaskPtr Create(const std::string &task_type) const;
RtsNodeTaskPtr Create(const std::string &task_type) const;

class RtsTaskRegistrar {
public:
@@ -60,6 +61,6 @@ class RtsTaskFactory {
REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(__COUNTER__, task_type, task_clazz)

#define REGISTER_RTS_TASK_CREATOR_UNIQ_HELPER(ctr, type, clazz) \
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> NodeTaskPtr { return MakeShared<clazz>(); })
RtsTaskFactory::RtsTaskRegistrar g_##type##_Creator##ctr(type, []()-> RtsNodeTaskPtr { return MakeShared<clazz>(); })

#endif // GE_HYBRID_NODE_EXECUTOR_RTS_TASK_FACTORY_H_

Loading…
Cancel
Save