From: @zhangxiaokun9 Reviewed-by: @wqtshg,@ji_chen Signed-off-by: @ji_chentags/v1.3.0
@@ -272,20 +272,32 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||||
/// @brief Set Op _force_unknown_shape flag | /// @brief Set Op _force_unknown_shape flag | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] force_unknown, set attribute if true | /// @param [in] force_unknown, set attribute if true | ||||
/// @param [in] group_index, condition group index of node. | |||||
/// @return | /// @return | ||||
/// | /// | ||||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) { | |||||
GE_RT_VOID_CHECK_NOTNULL(node); | |||||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||||
if (!force_unknown) { | if (!force_unknown) { | ||||
return; | return; | ||||
} | } | ||||
GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str()); | |||||
if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { | |||||
GE_RT_VOID_CHECK_NOTNULL(node); | |||||
const auto &op_desc = node->GetOpDesc(); | |||||
GE_RT_VOID_CHECK_NOTNULL(op_desc); | |||||
// op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. | |||||
GELOGD("Mark [%s] as force unknown shape node, group index: %ld", node->GetName().c_str(), group_index); | |||||
if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { | |||||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | ||||
node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | ||||
node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
} | } | ||||
if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||||
node->GetName().c_str(), node->GetType().c_str()); | |||||
} | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -129,9 +129,10 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||||
/// @brief Set Op _force_unknown_shape flag | /// @brief Set Op _force_unknown_shape flag | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] force_unknown, set attribute if true | /// @param [in] force_unknown, set attribute if true | ||||
/// @param [in] group_index, condition group index of node. | |||||
/// @return | /// @return | ||||
/// | /// | ||||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown); | |||||
void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ |
@@ -65,7 +65,6 @@ | |||||
#include "graph/passes/merge_pass.h" | #include "graph/passes/merge_pass.h" | ||||
#include "graph/passes/merge_input_memcpy_pass.h" | #include "graph/passes/merge_input_memcpy_pass.h" | ||||
#include "graph/passes/merge_to_stream_merge_pass.h" | #include "graph/passes/merge_to_stream_merge_pass.h" | ||||
#include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||||
#include "graph/passes/multi_batch_pass.h" | #include "graph/passes/multi_batch_pass.h" | ||||
#include "graph/passes/next_iteration_pass.h" | #include "graph/passes/next_iteration_pass.h" | ||||
#include "graph/passes/permute_pass.h" | #include "graph/passes/permute_pass.h" | ||||
@@ -2582,8 +2581,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::Migration", new (std::nothrow) SubgraphConstMigrationPass)); | ||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ArgsClean", new (std::nothrow) UnusedArgsCleanPass)); | ||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)); | ||||
auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::MarkForceUnknownForCondPass", mark_force_unknown_pass)); | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) | ||||
GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) | GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) | ||||
GE_CHK_STATUS_RET( | GE_CHK_STATUS_RET( | ||||
@@ -46,11 +46,6 @@ | |||||
#define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | ||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
const std::set<std::string> kControlFlowOps{ | |||||
STREAMACTIVE, STREAMSWITCH, STREAMMERGE, ENTER, REFENTER, LOOPCOND, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT | |||||
}; | |||||
} | |||||
using Cluster = DynamicShapePartitioner::Cluster; | using Cluster = DynamicShapePartitioner::Cluster; | ||||
using ClusterPtr = std::shared_ptr<Cluster>; | using ClusterPtr = std::shared_ptr<Cluster>; | ||||
@@ -279,9 +274,17 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
auto cluster = MakeShared<Cluster>(rank++, type, node, this); | auto cluster = MakeShared<Cluster>(rank++, type, node, this); | ||||
REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | ||||
node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
if (cluster->IsUnknownShape() && !cluster->IsControlFlow()) { | |||||
if (cluster->IsUnknownShape()) { | |||||
ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
} | } | ||||
int64_t group_index = -1; | |||||
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
GELOGD("[%s] is rts control flow Op, group index: %ld", node->GetName().c_str(), group_index); | |||||
auto &control_cluster = control_clusters_[group_index]; | |||||
control_cluster.emplace_back(cluster); | |||||
} | |||||
// Already sorted topologically, so access to the parent cluster is safe | // Already sorted topologically, so access to the parent cluster is safe | ||||
for (const auto &parent : node->GetInAllNodes()) { | for (const auto &parent : node->GetInAllNodes()) { | ||||
cluster->AddInput(node_2_cluster_[parent]); | cluster->AddInput(node_2_cluster_[parent]); | ||||
@@ -350,14 +353,38 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||||
} | } | ||||
} | } | ||||
void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
for (const auto &item : control_clusters_) { | |||||
const auto &control_cluster = item.second; | |||||
auto rit = control_cluster.rbegin(); | |||||
if (rit == control_cluster.rend()) { | |||||
GELOGW("Invalid empty control flow cluster."); | |||||
continue; | |||||
} | |||||
const auto &cluster = *rit; | |||||
for (++rit; rit != control_cluster.rend(); ++rit) { | |||||
const auto &cluster_from = *rit; | |||||
auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | |||||
GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | |||||
ToString(merged_clusters).c_str()); | |||||
for (const auto &merged_cluster : merged_clusters) { | |||||
for (const auto &node : merged_cluster->Nodes()) { | |||||
node_2_cluster_[node] = cluster; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
void DynamicShapePartitioner::MergeClustersUnknownShape() { | void DynamicShapePartitioner::MergeClustersUnknownShape() { | ||||
// Merge unknown shape clusters | // Merge unknown shape clusters | ||||
for (const auto &cluster : ordered_cluster_) { | for (const auto &cluster : ordered_cluster_) { | ||||
if (cluster->IsIndependent() || cluster->IsControlFlow()) { | |||||
if (cluster->IsIndependent()) { | |||||
continue; | continue; | ||||
} | } | ||||
for (const auto &in_cluster : cluster->Inputs()) { | for (const auto &in_cluster : cluster->Inputs()) { | ||||
if (!in_cluster->IsUnknownShape() || in_cluster->IsControlFlow()) { | |||||
if (!in_cluster->IsUnknownShape()) { | |||||
continue; | continue; | ||||
} | } | ||||
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | ||||
@@ -419,6 +446,7 @@ void DynamicShapePartitioner::MergeClustersInputData() { | |||||
} | } | ||||
Status DynamicShapePartitioner::MergeClusters() { | Status DynamicShapePartitioner::MergeClusters() { | ||||
MergeClustersControlFlow(); | |||||
MergeClustersUnknownShape(); | MergeClustersUnknownShape(); | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); | REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); | ||||
MergeClustersKnownShape(); | MergeClustersKnownShape(); | ||||
@@ -608,13 +636,6 @@ bool Cluster::IsRefVariable() const { | |||||
return false; | return false; | ||||
} | } | ||||
bool Cluster::IsControlFlow() const { | |||||
const auto &op_desc = nodes_[0]->GetOpDesc(); | |||||
bool is_ctrl_flow = kControlFlowOps.count(op_desc->GetType()) > 0 && op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
GELOGD("[%s] %s rts control flow Op ", op_desc->GetName().c_str(), is_ctrl_flow ? "Is" : "Not"); | |||||
return is_ctrl_flow; | |||||
} | |||||
void Cluster::AddInput(ClusterPtr in) { | void Cluster::AddInput(ClusterPtr in) { | ||||
if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; | if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; | ||||
in_clusters_.insert(in_clusters_.end(), in); | in_clusters_.insert(in_clusters_.end(), in); | ||||
@@ -694,10 +715,7 @@ std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) { | |||||
if (other->IsIndependent()) { | if (other->IsIndependent()) { | ||||
return path_clusters; | return path_clusters; | ||||
} | } | ||||
if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == | |||||
other->out_clusters_.end()) { | |||||
return path_clusters; | |||||
} | |||||
path_clusters.push_back(other); | path_clusters.push_back(other); | ||||
forward_reached_queue.push(other); | forward_reached_queue.push(other); | ||||
backward_reached_queue.push(shared_from_this()); | backward_reached_queue.push(shared_from_this()); | ||||
@@ -761,7 +779,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_-> | |||||
OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | ||||
Status Cluster::BuildFrame() { | Status Cluster::BuildFrame() { | ||||
if ((IsUnknownShape() || IsKnownShape() || IsInputNode()) && !IsControlFlow()) { | |||||
if (IsUnknownShape() || IsKnownShape() || IsInputNode()) { | |||||
return BuildPartitionFrame(); | return BuildPartitionFrame(); | ||||
} else { | } else { | ||||
auto node = nodes_.front(); | auto node = nodes_.front(); | ||||
@@ -896,7 +914,7 @@ Status Cluster::CombinePartitionFrame() { | |||||
} | } | ||||
Status Cluster::BuildPartitionSubgraph() { | Status Cluster::BuildPartitionSubgraph() { | ||||
if (IsData() || IsNetOutput() || IsIndependent() || IsControlFlow()) { | |||||
if (IsData() || IsNetOutput() || IsIndependent()) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
int64_t parent_node_index = 0; | int64_t parent_node_index = 0; | ||||
@@ -47,7 +47,6 @@ class DynamicShapePartitioner { | |||||
bool IsUnknownShape() const; | bool IsUnknownShape() const; | ||||
bool IsIndependent() const; | bool IsIndependent() const; | ||||
bool IsNetOutput() const; | bool IsNetOutput() const; | ||||
bool IsControlFlow() const; | |||||
std::vector<std::shared_ptr<Cluster>> Inputs() const; | std::vector<std::shared_ptr<Cluster>> Inputs() const; | ||||
std::vector<std::shared_ptr<Cluster>> Outputs() const; | std::vector<std::shared_ptr<Cluster>> Outputs() const; | ||||
bool IsInputNode() const; | bool IsInputNode() const; | ||||
@@ -126,13 +125,15 @@ class DynamicShapePartitioner { | |||||
// and there's only one path between the two clusters , merge the two clusters | // and there's only one path between the two clusters , merge the two clusters | ||||
// 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA | // 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA | ||||
Status MergeClusters(); | Status MergeClusters(); | ||||
// Merge clusters step0 | |||||
void MergeClustersControlFlow(); | |||||
// Merge clusters step1 | // Merge clusters step1 | ||||
void MergeClustersUnknownShape(); | void MergeClustersUnknownShape(); | ||||
// Merge clusters step2 | // Merge clusters step2 | ||||
void MergeClustersKnownShape(); | void MergeClustersKnownShape(); | ||||
// Merge clusters step3 | // Merge clusters step3 | ||||
void MergeClustersInputData(); | void MergeClustersInputData(); | ||||
// Topological sort clusters after merge unknow shape clusters. | |||||
// Topological sort clusters after merge unknown shape clusters. | |||||
Status TopologicalSortClusters(); | Status TopologicalSortClusters(); | ||||
// Deduplicate merged clusters | // Deduplicate merged clusters | ||||
void PruneUniqueClusters(); | void PruneUniqueClusters(); | ||||
@@ -140,7 +141,7 @@ class DynamicShapePartitioner { | |||||
Status BuildPartitionFrame(); | Status BuildPartitionFrame(); | ||||
// Establish connection between corresponding partitioned of clusters | // Establish connection between corresponding partitioned of clusters | ||||
Status CombinePartitionFrame(); | Status CombinePartitionFrame(); | ||||
// Convert the nodes in cluster into a complete ComputeGraoh | |||||
// Convert the nodes in cluster into a complete ComputeGraph | |||||
Status BuildPartitionSubgraph(); | Status BuildPartitionSubgraph(); | ||||
// Clear resource and break circular dependency | // Clear resource and break circular dependency | ||||
void ClearResource(); | void ClearResource(); | ||||
@@ -155,6 +156,8 @@ class DynamicShapePartitioner { | |||||
Status CtrlEdgeTransfer(); | Status CtrlEdgeTransfer(); | ||||
ge::ComputeGraphPtr root_graph_; // The original graph to partition | ge::ComputeGraphPtr root_graph_; // The original graph to partition | ||||
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | ||||
// V1 control flow cluster, need merge to one Graph. | |||||
std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||||
// topological sorted clusters, this field will change with the splitting. | // topological sorted clusters, this field will change with the splitting. | ||||
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | ||||
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | ||||
@@ -18,20 +18,25 @@ | |||||
#include <queue> | #include <queue> | ||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const std::set<std::string> kMergeOpTypes{ MERGE, REFMERGE }; | |||||
inline bool IsMergeInLoop(const NodePtr &node) { | |||||
const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||||
const std::set<std::string> kSwitchOpTypes{ SWITCH, REFSWITCH }; | |||||
std::string node_type; | |||||
(void)GetOriginalType(node, node_type); | |||||
return kLoopMergeInputs.count(node_type) > 0; | |||||
} | |||||
const std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||||
inline bool IsSwitchInLoop(const NodePtr &node) { | |||||
const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||||
inline bool IsMergeInLoop(const NodePtr &node) { | |||||
std::string node_type; | std::string node_type; | ||||
(void)GetOriginalType(node, node_type); | (void)GetOriginalType(node, node_type); | ||||
return kLoopMergeInputs.count(node_type) > 0; | |||||
return kLoopSwitchInputs.count(node_type) > 0; | |||||
} | } | ||||
} | } | ||||
@@ -103,7 +108,13 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
if (dst_span > 0) { | if (dst_span > 0) { | ||||
search_queue.push({in_node, dst_span - 1}); | search_queue.push({in_node, dst_span - 1}); | ||||
} else { | } else { | ||||
switch_group.emplace_back(in_node); | |||||
const auto &all_in_nodes = in_node->GetInDataNodes(); | |||||
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||||
GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||||
in_node->GetName().c_str()); | |||||
} else { | |||||
switch_group.emplace_back(in_node); | |||||
} | |||||
} | } | ||||
} else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | ||||
search_queue.push({in_node, dst_span + 1}); | search_queue.push({in_node, dst_span + 1}); | ||||
@@ -121,19 +132,37 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
/// | /// | ||||
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | ||||
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | ||||
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||||
}; | }; | ||||
for (const auto &group : switch_groups) { | |||||
const auto &node = group.first; | |||||
const auto &switch_group = group.second; | |||||
const auto &op_desc = node->GetOpDesc(); | |||||
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) || | |||||
std::any_of(switch_group.begin(), switch_group.end(), callback)) { | |||||
GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str()); | |||||
MarkForceUnknownShape(node, true); | |||||
for (const auto &n : switch_group) { | |||||
MarkForceUnknownShape(n, true); | |||||
for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||||
const auto &op_node1 = it1->first; | |||||
const auto &op_desc1 = op_node1->GetOpDesc(); | |||||
if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
continue; | |||||
} | |||||
if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { | |||||
int64_t group_index = op_desc1->GetId(); | |||||
GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||||
MarkForceUnknownShape(op_node1, true, group_index); | |||||
for (const auto &n : it1->second) { | |||||
MarkForceUnknownShape(n, true, group_index); | |||||
} | |||||
for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||||
const auto &op_node2 = it2->first; | |||||
const auto &op_desc2 = op_node2->GetOpDesc(); | |||||
if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
continue; | |||||
} | |||||
if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||||
MarkForceUnknownShape(op_node2, true, group_index); | |||||
for (const auto &n : it2->second) { | |||||
MarkForceUnknownShape(n, true, group_index); | |||||
} | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -84,8 +84,9 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
GE_CHK_BOOL_EXEC(node != nullptr, | GE_CHK_BOOL_EXEC(node != nullptr, | ||||
REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | ||||
return FAILED, "Param of pre node is null."); | return FAILED, "Param of pre node is null."); | ||||
bool force_unknown = node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
MarkForceUnknownShape(node, force_unknown); | |||||
int64_t group_index = -1; | |||||
bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
MarkForceUnknownShape(node, force_unknown, group_index); | |||||
for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
@@ -102,7 +103,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); | GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
MarkForceUnknownShape(active_node, force_unknown); | |||||
MarkForceUnknownShape(active_node, force_unknown, group_index); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -18,6 +18,7 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "graph/utils/node_utils.h" | |||||
using std::string; | using std::string; | ||||
@@ -203,6 +204,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
for (const auto &loop_cond_iter : loop_group_map_) { | for (const auto &loop_cond_iter : loop_group_map_) { | ||||
const LoopCondGroup &loop_group = *loop_cond_iter.second; | const LoopCondGroup &loop_group = *loop_cond_iter.second; | ||||
const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | ||||
const int64_t group_index = loop_group.loop_cond->GetOpDesc()->GetId(); | |||||
GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | ||||
// Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | ||||
@@ -223,7 +225,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
enter_active->GetName().c_str()); | enter_active->GetName().c_str()); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape); | |||||
MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); | |||||
} | } | ||||
for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | ||||
@@ -253,8 +255,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
MarkForceUnknownShape(next_node, loop_group.is_unknown_shape); | |||||
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape); | |||||
MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); | |||||
MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); | |||||
} | } | ||||
if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | ||||
@@ -263,10 +265,10 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | |||||
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | |||||
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | |||||
HandleSwitchExitNodes(loop_group); | |||||
MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); | |||||
MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); | |||||
MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); | |||||
HandleSwitchExitNodes(loop_group, group_index); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -275,20 +277,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
/// | /// | ||||
/// @brief Mark force unknown for Exit node | /// @brief Mark force unknown for Exit node | ||||
/// @param [in] group of LoopCond | /// @param [in] group of LoopCond | ||||
/// @param [in] index of LoopCond Node | |||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) { | |||||
void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | |||||
if (!loop_group.is_unknown_shape) { | if (!loop_group.is_unknown_shape) { | ||||
return; | return; | ||||
} | } | ||||
for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||||
MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); | |||||
for (const auto &node : switch_node->GetOutDataNodes()) { | for (const auto &node : switch_node->GetOutDataNodes()) { | ||||
std::string node_type; | std::string node_type; | ||||
(void)GetOriginalType(node, node_type); | (void)GetOriginalType(node, node_type); | ||||
if (node_type == EXIT || node_type == REFEXIT) { | |||||
MarkForceUnknownShape(node, loop_group.is_unknown_shape); | |||||
if (kExitOpTypes.count(node_type) > 0) { | |||||
MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -96,9 +96,10 @@ class NextIterationPass : public GraphPass { | |||||
/// | /// | ||||
/// @brief Mark force unknown for Exit node | /// @brief Mark force unknown for Exit node | ||||
/// @param [in] group of LoopCond | /// @param [in] group of LoopCond | ||||
/// @param [in] index of LoopCond Node | |||||
/// @return void | /// @return void | ||||
/// | /// | ||||
void HandleSwitchExitNodes(const LoopCondGroup &loop_group); | |||||
void HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index); | |||||
// map<frame_name, LoopCondGroup> | // map<frame_name, LoopCondGroup> | ||||
std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | ||||
@@ -369,7 +369,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), | ||||
"StreamSwitch node add cond edge failed."); | "StreamSwitch node add cond edge failed."); | ||||
MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)); | |||||
int64_t group_index = -1; | |||||
bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||||
return stream_switch; | return stream_switch; | ||||
} | } | ||||
@@ -488,11 +490,12 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||||
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
int64_t group_index = -1; | |||||
std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | |||||
return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
}; | }; | ||||
bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | ||||
MarkForceUnknownShape(active_node, is_unknown_shape); | |||||
MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||||
const std::string &cond_group = cond_node->GetName(); | const std::string &cond_group = cond_node->GetName(); | ||||
for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | ||||
@@ -522,7 +525,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), | ||||
"Cast add data edge failed."); | "Cast add data edge failed."); | ||||
MarkForceUnknownShape(stream_switch, is_unknown_shape); | |||||
MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||||
for (const NodePtr &node : switch_list) { | for (const NodePtr &node : switch_list) { | ||||
GE_IF_BOOL_EXEC(node != stream_switch, { | GE_IF_BOOL_EXEC(node != stream_switch, { | ||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | ||||
@@ -74,6 +74,7 @@ | |||||
#include "graph/passes/unused_const_pass.h" | #include "graph/passes/unused_const_pass.h" | ||||
#include "graph/passes/var_is_initialized_op_pass.h" | #include "graph/passes/var_is_initialized_op_pass.h" | ||||
#include "graph/passes/variable_prepare_op_pass.h" | #include "graph/passes/variable_prepare_op_pass.h" | ||||
#include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||||
#include "graph/preprocess/insert_op/util_insert_aipp_op.h" | #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
@@ -1675,6 +1676,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||||
PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | ||||
PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | ||||
PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); | PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); | ||||
PP_RUN_AND_DUMP("CtrlFlowPreProcess", CtrlFlowPreProcess); | |||||
PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_); | PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_); | ||||
PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); | PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); | ||||
PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); | PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); | ||||
@@ -1683,6 +1685,17 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphPrepare::CtrlFlowPreProcess() { | |||||
PassManager graph_pass; | |||||
// After InferShape Mark v1 control flow for unknown shape. | |||||
auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; | |||||
GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", mark_force_unknown_pass)); | |||||
GE_CHK_STATUS_RET(graph_pass.Run(compute_graph_)); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) { | Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) { | ||||
PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_); | PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -79,6 +79,7 @@ class GraphPrepare { | |||||
Status ProcessNetOutput(); | Status ProcessNetOutput(); | ||||
Status ProcessBeforeInfershape(); | Status ProcessBeforeInfershape(); | ||||
Status UpdateInputOutputByOptions(); | Status UpdateInputOutputByOptions(); | ||||
Status CtrlFlowPreProcess(); | |||||
bool IsTansDataOpData(const ge::NodePtr &var_node); | bool IsTansDataOpData(const ge::NodePtr &var_node); | ||||
@@ -104,11 +104,47 @@ void ShapeInferenceState::UpdateInputShapeFuture(int idx, ShapeFuture &&future) | |||||
} | } | ||||
} | } | ||||
Status ShapeInferenceState::UpdateInputForMerge(const GraphExecutionContext &context) { | |||||
int merge_index = -1; | |||||
const auto &guard = node_item.MutexGuard("UpdateInputForMerge"); | |||||
if (!AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_MERGE_INPUT_INDEX, merge_index)) { | |||||
GELOGE(FAILED, "[%s] Get attr %s failed", node_item.NodeName().c_str(), ATTR_NAME_MERGE_INPUT_INDEX.c_str()); | |||||
return FAILED; | |||||
} | |||||
if (merge_index < 0 || static_cast<size_t>(merge_index) >= input_tensor_desc.size()) { | |||||
GELOGE(FAILED, "[%s] merge index: %d invalid, should in range[0, %zu)", | |||||
node_item.NodeName().c_str(), merge_index, input_tensor_desc.size()); | |||||
return FAILED; | |||||
} | |||||
auto dst_tensor_desc = node_item.MutableInputDesc(merge_index); | |||||
GE_CHECK_NOTNULL(dst_tensor_desc); | |||||
int64_t tensor_size = -1; | |||||
auto &tensor_desc = input_tensor_desc[merge_index]; | |||||
(void)TensorUtils::GetSize(tensor_desc, tensor_size); | |||||
dst_tensor_desc->SetShape(tensor_desc.MutableShape()); | |||||
dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); | |||||
(void)TensorUtils::SetSize(*dst_tensor_desc, tensor_size); | |||||
(void)guard; | |||||
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", | |||||
node_item.NodeName().c_str(), merge_index, dst_tensor_desc->GetShape().ToString().c_str(), | |||||
dst_tensor_desc->GetOriginShape().ToString().c_str(), tensor_size); | |||||
return SUCCESS; | |||||
} | |||||
Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { | Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { | ||||
if (!node_item.is_dynamic) { | if (!node_item.is_dynamic) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
std::unique_lock<std::mutex> lk(mu_); | std::unique_lock<std::mutex> lk(mu_); | ||||
if (node_item.IsMergeOp()) { | |||||
return UpdateInputForMerge(context); | |||||
} | |||||
if (num_pending_shapes_ > 0) { | if (num_pending_shapes_ > 0) { | ||||
GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); | GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); | ||||
int try_count = 0; | int try_count = 0; | ||||
@@ -169,7 +205,7 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
int64_t tensor_size = -1; | int64_t tensor_size = -1; | ||||
(void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | ||||
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | |||||
GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", | |||||
node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
idx, | idx, | ||||
src_tensor_desc->GetShape().ToString().c_str(), | src_tensor_desc->GetShape().ToString().c_str(), | ||||
@@ -283,11 +319,8 @@ void NodeState::ResetContext(int group) { | |||||
} | } | ||||
switch_index_ = -1; | switch_index_ = -1; | ||||
const auto &guard = node_item_->MutexGuard("ResetContext"); | |||||
shape_inference_state_.InitShapeState(); | |||||
subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | ||||
(void)guard; | |||||
} | } | ||||
void NodeState::ResetSchedule() { | void NodeState::ResetSchedule() { | ||||
@@ -67,6 +67,8 @@ struct ShapeInferenceState { | |||||
const NodeItem &node_item; | const NodeItem &node_item; | ||||
private: | private: | ||||
Status UpdateInputForMerge(const GraphExecutionContext &context); | |||||
friend struct NodeState; | friend struct NodeState; | ||||
std::vector<std::pair<int, ShapeFuture>> shape_futures; | std::vector<std::pair<int, ShapeFuture>> shape_futures; | ||||
// do not directly update op_desc, in case race condition across pipelines | // do not directly update op_desc, in case race condition across pipelines | ||||
@@ -823,7 +823,7 @@ set(PROFILING_MNG_TEST_FILES | |||||
set(HYBRID_TEST_FILES | set(HYBRID_TEST_FILES | ||||
"hybrid/ge_hybrid_unittest.cc" | "hybrid/ge_hybrid_unittest.cc" | ||||
"hybrid/known_node_executor_unittest.cc" | "hybrid/known_node_executor_unittest.cc" | ||||
"hybrid/executor/worker/execution_engine_unittest.cc" | |||||
"hybrid/executor/node_state_unittest.cc" | |||||
"hybrid/executor/subgraph_executor_unittest.cc" | "hybrid/executor/subgraph_executor_unittest.cc" | ||||
"hybrid/executor/worker/execution_engine_unittest.cc" | "hybrid/executor/worker/execution_engine_unittest.cc" | ||||
"hybrid/model/hybrid_model_builder_unittest.cc" | "hybrid/model/hybrid_model_builder_unittest.cc" | ||||
@@ -15,20 +15,17 @@ | |||||
*/ | */ | ||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#define private public | |||||
#define protected public | |||||
#include "graph/partition/dynamic_shape_partition.h" | #include "graph/partition/dynamic_shape_partition.h" | ||||
#include "compute_graph.h" | #include "compute_graph.h" | ||||
#include "inc/framework/common/types.h" | #include "inc/framework/common/types.h" | ||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#define private public | |||||
#define protected public | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW, | GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW, | ||||
DataType data_type = DT_FLOAT) { | DataType data_type = DT_FLOAT) { | ||||
GeShape ge_shape{vector<int64_t>(shape)}; | GeShape ge_shape{vector<int64_t>(shape)}; | ||||
@@ -94,4 +91,29 @@ TEST_F(UtestDynamicShapePartition, single_op_scene_success) { | |||||
DynamicShapePartitioner partitioner(graph); | DynamicShapePartitioner partitioner(graph); | ||||
EXPECT_EQ(partitioner.Partition(), SUCCESS); | EXPECT_EQ(partitioner.Partition(), SUCCESS); | ||||
} | } | ||||
TEST_F(UtestDynamicShapePartition, merge_control_flow_group) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("default"); | |||||
AttrUtils::SetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, "session_graph_id"); | |||||
NodePtr data1 = NodeBuilder("data1", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||||
NodePtr data2 = NodeBuilder("data2", DATA).AddInputDesc({1}).AddOutputDesc({1}).Build(graph); | |||||
NodePtr merge = NodeBuilder("node2", MERGE).AddInputDesc({1}).AddInputDesc({1}) | |||||
.AddOutputDesc({1}).AddOutputDesc({}).Build(graph); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data2->GetOutDataAnchor(0), merge->GetInDataAnchor(1)); | |||||
(void)AttrUtils::SetBool(data1->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
(void)AttrUtils::SetInt(data1->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||||
(void)AttrUtils::SetBool(data2->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
(void)AttrUtils::SetInt(data2->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||||
(void)AttrUtils::SetBool(merge->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, true); | |||||
(void)AttrUtils::SetInt(merge->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, 3); | |||||
EXPECT_EQ(graph->sub_graph_.size(), 0); | |||||
DynamicShapePartitioner partitioner(graph); | |||||
EXPECT_EQ(partitioner.Partition(), SUCCESS); | |||||
EXPECT_EQ(graph->sub_graph_.size(), 1); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -223,4 +223,17 @@ TEST_F(UtestGraphPreproces, test_update_dtype_mbatch_case) { | |||||
auto data1_output = data1_desc->MutableOutputDesc(0); | auto data1_output = data1_desc->MutableOutputDesc(0); | ||||
EXPECT_EQ(data1_output->GetDataType(), 1); | EXPECT_EQ(data1_output->GetDataType(), 1); | ||||
} | } | ||||
TEST_F(UtestGraphPreproces, test_prepare_dyn_shape) { | |||||
ComputeGraphPtr compute_graph = BuildGraph5(); | |||||
GraphPtr graph_ptr = std::make_shared<Graph>(GraphUtils::CreateGraphFromComputeGraph(compute_graph)); | |||||
GraphNodePtr graph_node = make_shared<GraphNode>(0); | |||||
graph_node->SetComputeGraph(compute_graph); | |||||
graph_node->SetGraph(graph_ptr); | |||||
std::vector<GeTensor> user_input; | |||||
GraphPrepare graph_prepare; | |||||
EXPECT_EQ(graph_prepare.PrepareDynShape(graph_node, user_input, compute_graph, 0), SUCCESS); | |||||
} | |||||
} | } |
@@ -0,0 +1,106 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include <gmock/gmock.h> | |||||
#include <vector> | |||||
#define private public | |||||
#define protected public | |||||
#include "hybrid/executor/node_state.h" | |||||
#include "hybrid/executor/subgraph_context.h" | |||||
#include "hybrid/model/graph_item.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
using namespace std; | |||||
using namespace testing; | |||||
namespace ge { | |||||
using namespace hybrid; | |||||
class UtestNodeState : public testing::Test { | |||||
protected: | |||||
void SetUp() { | |||||
} | |||||
void TearDown() { | |||||
} | |||||
}; | |||||
static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) { | |||||
OpDescPtr op_desc = std::make_shared<OpDesc>(name, type); | |||||
op_desc->SetStreamId(0); | |||||
static int32_t index = 0; | |||||
op_desc->SetId(index++); | |||||
GeTensorDesc tensor(GeShape(), FORMAT_ND, DT_INT64); | |||||
TensorUtils::SetSize(tensor, 64); | |||||
vector<int64_t> input_offset; | |||||
for (int i = 0; i < in_num; i++) { | |||||
op_desc->AddInputDesc(tensor); | |||||
input_offset.emplace_back(index * 64 + i * 64); | |||||
} | |||||
op_desc->SetInputOffset(input_offset); | |||||
vector<int64_t> output_offset; | |||||
for (int i = 0; i < out_num; i++) { | |||||
op_desc->AddOutputDesc(tensor); | |||||
output_offset.emplace_back(index * 64 + in_num * 64 + i * 64); | |||||
} | |||||
op_desc->SetOutputOffset(output_offset); | |||||
op_desc->SetWorkspace({}); | |||||
op_desc->SetWorkspaceBytes({}); | |||||
op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE"); | |||||
return graph.AddNode(op_desc); | |||||
} | |||||
TEST_F(UtestNodeState, merge_await_shapes_ready) { | |||||
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
const auto data0 = CreateNode(*graph, "data", DATA, 1, 1); | |||||
const auto data1 = CreateNode(*graph, "data1", DATA, 1, 1); | |||||
const auto merge1 = CreateNode(*graph, "merge", STREAMMERGE, 2, 2); | |||||
const auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1); | |||||
GraphUtils::AddEdge(data0->GetOutDataAnchor(0), merge1->GetInDataAnchor(0)); | |||||
GraphUtils::AddEdge(data1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1)); | |||||
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); | |||||
GraphItem graph_item; | |||||
GraphExecutionContext graph_context; | |||||
SubgraphContext subgraph_context(&graph_item, &graph_context); | |||||
std::unique_ptr<NodeItem> node_item; | |||||
NodeItem::Create(merge1, node_item); | |||||
NodeState node_state(*node_item, &subgraph_context); | |||||
// Not dynamic. | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), SUCCESS); | |||||
// Not set merge index. | |||||
node_item->is_dynamic = true; | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), FAILED); | |||||
// merge index out of bound. | |||||
AttrUtils::SetInt(merge1->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, 3); | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), FAILED); | |||||
AttrUtils::SetInt(merge1->GetOpDesc(), ATTR_NAME_MERGE_INPUT_INDEX, 1); | |||||
ASSERT_EQ(node_state.shape_inference_state_.AwaitShapesReady(graph_context), SUCCESS); | |||||
} | |||||
} // namespace ge |