Browse Source

Fix BuildPartitionFrame failed

tags/v1.5.1
zhangxiaokun 3 years ago
parent
commit
fe77ec974a
2 changed files with 16 additions and 13 deletions
  1. +13
    -12
      ge/graph/partition/dynamic_shape_partition.cc
  2. +3
    -1
      ge/graph/partition/dynamic_shape_partition.h

+ 13
- 12
ge/graph/partition/dynamic_shape_partition.cc View File

@@ -284,9 +284,6 @@ Status DynamicShapePartitioner::InitClusters() {
auto cluster = MakeShared<Cluster>(rank++, type, node, this);
REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed.");
node_2_cluster_[node] = cluster;
if (cluster->IsUnknownShape()) {
ordered_cluster_.push_back(cluster);
}

int64_t group_index = -1;
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
@@ -306,7 +303,7 @@ Status DynamicShapePartitioner::InitClusters() {
return SUCCESS;
}

Status DynamicShapePartitioner::TopologicalSortClusters() {
Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) {
ordered_cluster_.clear();
// BFS topological sort clusters for known shape cluster
std::queue<ClusterPtr> ready_clusters;
@@ -331,7 +328,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() {
auto cluster = ready_clusters.front();
ready_clusters.pop();
cluster->UpdateRank(rank++);
if (cluster->IsKnownShape() || cluster->IsInputNode()) {
if (ordered_filter == nullptr || ordered_filter(cluster)) {
ordered_cluster_.push_back(cluster);
}
for (const auto &out_cluster : cluster->Outputs()) {
@@ -378,7 +375,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
continue;
}

bool is_unknown_cluster = cluster->IsUnknownShape();
for (++rit; rit != control_cluster.rend(); ++rit) {
const auto &cluster_from = *rit;
if (all_merged_clusters.count(cluster_from) > 0) {
@@ -395,11 +391,6 @@ void DynamicShapePartitioner::MergeClustersControlFlow() {
}
}
}

if (!is_unknown_cluster && cluster->IsUnknownShape()) {
GELOGD("Add to ordered cluster: %s", cluster->DebugString().c_str());
ordered_cluster_.push_back(cluster);
}
}
}

@@ -475,9 +466,19 @@ void DynamicShapePartitioner::MergeClustersInputData() {
}

Status DynamicShapePartitioner::MergeClusters() {
const auto filter_known = [](const ClusterPtr &cluster) {
return cluster->IsKnownShape() || cluster->IsInputNode();
};
const auto filter_unknown = [](const ClusterPtr &cluster) {
return cluster->IsUnknownShape();
};

MergeClustersControlFlow();
REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown),
"[TopologicalSort][Clusters] after merge control flow clusters failed.");
MergeClustersUnknownShape();
REQUIRE_SUCCESS(TopologicalSortClusters(), "[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
REQUIRE_SUCCESS(TopologicalSortClusters(filter_known),
"[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
MergeClustersKnownShape();
MergeClustersInputData();
return SUCCESS;


+ 3
- 1
ge/graph/partition/dynamic_shape_partition.h View File

@@ -111,6 +111,8 @@ class DynamicShapePartitioner {

Status Partition();

using OrderedFilter = std::function<bool(const std::shared_ptr<Cluster> &cluster)>;

private:
Status PartitionImpl();
// Collect nodes that satisfy the unknowshape rules:
@@ -138,7 +140,7 @@ class DynamicShapePartitioner {
// Merge clusters step3
void MergeClustersInputData();
// Topological sort clusters after merge unknown shape clusters.
Status TopologicalSortClusters();
Status TopologicalSortClusters(const OrderedFilter &ordered_filter);
// Deduplicate merged clusters
void PruneUniqueClusters();
// Establish the input-output anchors for each partition of the cluster and record links to other clusters


Loading…
Cancel
Save