|
|
@@ -37,6 +37,7 @@ inline bool IsMergeInLoop(const NodePtr &node) { |
|
|
|
|
|
|
|
Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { |
|
|
|
GELOGD("MarkForceUnknownForCondPass Enter"); |
|
|
|
std::map<NodePtr, std::vector<NodePtr>> switch_groups; |
|
|
|
for (const auto &node : graph->GetDirectNode()) { |
|
|
|
std::string node_type; |
|
|
|
GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original type failed."); |
|
|
@@ -44,20 +45,15 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
const auto op_desc = node->GetOpDesc(); |
|
|
|
if (!op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) && !IsUnknownShapeTensor(op_desc->GetOutputDesc(0))) { |
|
|
|
GELOGI("Merge[%s] has known shape, no need check switch", node->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
const auto &all_in_nodes = node->GetInDataNodes(); |
|
|
|
if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsMergeInLoop)) { |
|
|
|
continue; // LoopCond marked in NextIterationPass. |
|
|
|
} |
|
|
|
|
|
|
|
MarkUnknownForSwitch(node); |
|
|
|
MarkUnknownForSwitch(node, switch_groups[node]); |
|
|
|
} |
|
|
|
|
|
|
|
MarkUnknownForSwitch(switch_groups); |
|
|
|
GELOGD("MarkForceUnknownForCondPass Leave"); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
@@ -65,13 +61,12 @@ Status MarkForceUnknownForCondPass::Run(ComputeGraphPtr graph) { |
|
|
|
/// |
|
|
|
/// @brief Mark force unknown shape for Switch node |
|
|
|
/// @param [in] merge node |
|
|
|
/// @param [out] switch group |
|
|
|
/// @return |
|
|
|
/// |
|
|
|
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { |
|
|
|
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std::vector<NodePtr> &switch_group) { |
|
|
|
// Switch --> {Switch --> Merge} --> Merge |
|
|
|
std::vector<NodePtr> switch_group; |
|
|
|
std::unordered_set<NodePtr> nodes_seen; |
|
|
|
|
|
|
|
std::queue<std::pair<NodePtr, uint32_t>> search_queue({{node, 0}}); |
|
|
|
while (!search_queue.empty()) { |
|
|
|
const auto dst_node = search_queue.front().first; |
|
|
@@ -117,9 +112,30 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node) { |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
for (const auto &n : switch_group) { |
|
|
|
MarkForceUnknownShape(n, true); |
|
|
|
/// |
|
|
|
/// @brief Mark force unknown shape for Switch node |
|
|
|
/// @param [in] switch groups |
|
|
|
/// @return |
|
|
|
/// |
|
|
|
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { |
|
|
|
std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { |
|
|
|
return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); |
|
|
|
}; |
|
|
|
|
|
|
|
for (const auto &group : switch_groups) { |
|
|
|
const auto &node = group.first; |
|
|
|
const auto &switch_group = group.second; |
|
|
|
const auto &op_desc = node->GetOpDesc(); |
|
|
|
if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) || |
|
|
|
std::any_of(switch_group.begin(), switch_group.end(), callback)) { |
|
|
|
GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str()); |
|
|
|
MarkForceUnknownShape(node, true); |
|
|
|
for (const auto &n : switch_group) { |
|
|
|
MarkForceUnknownShape(n, true); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
} // namespace ge |