Browse Source

!1620 Fix DSP and MemcpyAsyncNodeTask

From: @zhangxiaokun9
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
a2b3fa2371
4 changed files with 52 additions and 23 deletions
  1. +1
    -1
      ge/graph/partition/dynamic_shape_partition.cc
  2. +28
    -12
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc
  3. +9
    -1
      ge/graph/passes/mark_force_unknown_for_cond_pass.h
  4. +14
    -9
      ge/hybrid/node_executor/rts/rts_node_task.cc

+ 1
- 1
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -357,7 +357,7 @@ void DynamicShapePartitioner::MergeClustersUnknownShape() {
continue;
}
for (const auto &in_cluster : cluster->Inputs()) {
if (!in_cluster->IsUnknownShape()) {
if (!in_cluster->IsUnknownShape() || in_cluster->IsControlFlow()) {
continue;
}
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster);


+ 28
- 12
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -37,6 +37,7 @@ inline bool IsMergeInLoop(const NodePtr &node) {

Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
GELOGD("MarkForceUnknownForCondPass Enter");
std::map<NodePtr, std::vector<NodePtr>> switch_groups;
for (const auto &node : graph->GetDirectNode()) {
std::string node_type;
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed.");
@@ -44,20 +45,15 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
continue;
}

const auto op_desc = node->GetOpDesc();
if (!op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) && !IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) {
GELOGI("Merge[%s] has known shape, no need check switch", node->GetName().c_str());
continue;
}

const auto &all_in_nodes = node->GetInDataNodes();
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsMergeInLoop)) {
continue; // LoopCond marked in NextIterationPass.
}

MarkUnknownForSwitch(node);
MarkUnknownForSwitch(node, switch_groups[node]);
}

MarkUnknownForSwitch(switch_groups);
GELOGD("MarkForceUnknownForCondPass Leave");
return SUCCESS;
}
@@ -65,13 +61,12 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) {
///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch group
/// @return
///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) {
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) {
// Switch --> {Switch --> Merge} --> Merge
std::vector<NodePtr> switch_group;
std::unordered_set<NodePtr> nodes_seen;

std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}});
while (!search_queue.empty()) {
const auto dst_node = search_queue.front().first;
@@ -117,9 +112,30 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) {
}
}
}
}

for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
///
/// @brief Mark force unknown shape for Switch node
/// @param [in] switch groups
/// @return
///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
};

for (const auto &group : switch_groups) {
const auto &node = group.first;
const auto &switch_group = group.second;
const auto &op_desc = node->GetOpDesc();
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) ||
std::any_of(switch_group.begin(), switch_group.end(), callback)) {
GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str());
MarkForceUnknownShape(node, true);
for (const auto &n : switch_group) {
MarkForceUnknownShape(n, true);
}
}
}
}
} // namespace ge

+ 9
- 1
ge/graph/passes/mark_force_unknown_for_cond_pass.h View File

@@ -28,9 +28,17 @@ class MarkForceUnknownForCondPass : public GraphPass {
///
/// @brief Mark force unknown shape for Switch node
/// @param [in] merge node
/// @param [out] switch group
/// @return
///
void MarkUnknownForSwitch(const NodePtr &node);
void MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group);

///
/// @brief Mark force unknown shape for Switch node
/// @param [in] switch groups
/// @return
///
void MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups);
};
} // namespace ge
#endif // GE_GRAPH_PASSES_MARK_FORCE_UNKNOWN_FOR_COND_PASS_H_

+ 14
- 9
ge/hybrid/node_executor/rts/rts_node_task.cc View File

@@ -169,17 +169,22 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio

Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
const auto in_x = task_context.GetInput(0); // x
GE_CHECK_NOTNULL(in_x);
const auto out_y = task_context.MutableOutput(0); // value_index
GE_CHECK_NOTNULL(out_y);

GELOGD("[%s] input size: %zu, output size: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize());
if (in_x->GetSize() > 0 && out_y->GetSize() > 0) {
GE_CHK_RT_RET(rtMemcpyAsync(out_y->MutableData(), out_y->GetSize(), in_x->GetData(), in_x->GetSize(),
auto input_desc = task_context.MutableInputDesc(0);
GE_CHECK_NOTNULL(input_desc);
int64_t copy_size = 0;
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size));
// copy_size would not be negative since GetTensorSizeInBytes returned successfully.
if (copy_size > 0) {
const auto in_v = task_context.MutableInput(0);
const auto out_v = task_context.MutableOutput(0);
GE_CHECK_NOTNULL(in_v);
GE_CHECK_NOTNULL(out_v);
GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(),
in_v->GetSize(), out_v->GetSize(), copy_size);
GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size,
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream()));
} else {
GELOGW("[%s] invalid copy size, src: %zu, dst: %zu", task_context.GetNodeName(), in_x->GetSize(), out_y->GetSize());
GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size);
}

if (done_callback) {


Loading…
Cancel
Save