|
- /**
- * Copyright 2019-2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
- #ifndef INC_GRAPH_UTILS_GRAPH_UTILS_H_
- #define INC_GRAPH_UTILS_GRAPH_UTILS_H_
-
- #include <fstream>
- #include <iostream>
- #include <map>
- #include <string>
- #include <vector>
- #include <list>
- #include "graph/anchor.h"
- #include "graph/node.h"
- #include "graph/compute_graph.h"
- #include "graph/utils/anchor_utils.h"
- #include "graph/graph.h"
- #include "graph/model.h"
-
- #define GE_DUMP(compute_graph, name) \
- do { \
- GraphUtils::DumpGEGraph(compute_graph, name); \
- GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \
- for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \
- static int8_t i = 0; \
- auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \
- GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \
- GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \
- } \
- } while (0)
-
- #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
- do { \
- DataType ret; \
- attr.GetValue<DataType>(ret); \
- } while (0)
-
- #define PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
- do { \
- if (value_type == VT_ENUM) { \
- REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
- stream << ret; \
- } \
- } while (0)
-
- #define PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream) \
- do { \
- if (value_type == VT_ENUM) { \
- REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \
- stream << "["; \
- for (int i = 0; i < ret.size(); i++) { \
- stream << ret[i]; \
- if (i + 1 != ret.size()) stream << ", "; \
- } \
- stream << "]"; \
- } \
- } while (0)
-
- #define PRINT_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
- else PRINT_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
-
- #define PRINT_LIST_ATTR_VALUE_ELIF(value_type, VT_ENUM, DataType, attr, stream) \
- else PRINT_LIST_ATTR_VALUE_IF(value_type, VT_ENUM, DataType, attr, stream)
-
- #define PRINT_SHAPE(i_o, n, idx, stream) \
- do { \
- auto op = n->GetOpDesc(); \
- GeTensorDesc td = i_o == "input" ? op->GetInputDesc(idx) : op->GetOutputDesc(idx); \
- auto shape = td.GetShape().GetDims(); \
- stream << "["; \
- for (int i = 0; i < shape.size(); i++) { \
- stream << shape[i]; \
- if (i + 1 < shape.size()) stream << ", "; \
- } \
- stream << "]"; \
- } while (0)
-
- #define PRINT_ATTR_FUNC(stream) \
- [&](GeAttrValue attr) { \
- auto type = attr.GetValueType(); \
- PRINT_ATTR_VALUE_IF(type, GeAttrValue::ValueType::VT_STRING, GeAttrValue::STR, attr, stream) \
- PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_FLOAT, GeAttrValue::FLOAT, attr, stream) \
- PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_BOOL, GeAttrValue::BOOL, attr, stream) \
- PRINT_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_INT, GeAttrValue::INT, attr, stream) \
- PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_STRING, GeAttrValue::LIST_STR, attr, stream) \
- PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_FLOAT, GeAttrValue::LIST_FLOAT, attr, stream) \
- PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_BOOL, GeAttrValue::LIST_BOOL, attr, stream) \
- PRINT_LIST_ATTR_VALUE_ELIF(type, GeAttrValue::ValueType::VT_LIST_INT, GeAttrValue::LIST_INT, attr, stream) \
- else if (type == GeAttrValue::ValueType::VT_TENSOR_DESC) stream << "TENSOR_DESC"; \
- else if (type == GeAttrValue::ValueType::VT_TENSOR) stream << "TENSOR"; \
- else if (type == GeAttrValue::ValueType::VT_BYTES) stream << "BYTES"; \
- else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR_DESC) stream << "LIST_TENSOR_DESC"; \
- else if (type == GeAttrValue::ValueType::VT_LIST_TENSOR) stream << "LIST_TENSOR"; \
- else if (type == GeAttrValue::ValueType::VT_LIST_BYTES) stream << "LIST_BYTES"; \
- };
-
- namespace ge {
- enum IOType { kIn, kOut };
-
- struct NodeIndexIO {
- NodeIndexIO(ge::NodePtr node, uint32_t index, IOType io_type)
- : node_(std::move(node)), index_(index), io_type_(io_type) {
- if (node_ != nullptr) {
- value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_);
- }
- }
- NodeIndexIO(ge::NodePtr node, int index, IOType io_type)
- : node_(std::move(node)), index_(static_cast<uint32_t>(index)), io_type_(io_type) {
- if (node_ != nullptr) {
- value_ = node_->GetName() + (io_type_ == kOut ? "_out_" : "_in_") + std::to_string(index_);
- }
- }
- ~NodeIndexIO() {}
-
- NodePtr node_ = nullptr;
- uint32_t index_ = 0;
- IOType io_type_ = kOut;
- std::string value_;
-
- const std::string &ToString() const { return value_; }
- };
-
- class GraphUtils {
- public:
- static ComputeGraphPtr GetComputeGraph(const Graph &graph);
-
- static Graph CreateGraphFromComputeGraph(const ComputeGraphPtr compute_graph);
-
- static graphStatus RecoverGraphOperators(const Graph &graph);
-
- static ComputeGraphPtr CreateGraphFromOperator(const string &name, const std::vector<Operator> &inputs);
-
- static graphStatus AddEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
-
- static graphStatus AddEdge(const OutDataAnchorPtr &src, const Format &src_format, const InDataAnchorPtr &dst,
- const Format &dst_format);
-
- static graphStatus AddEdge(const AnchorPtr &src, const AnchorPtr &dst);
-
- static graphStatus AddEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
-
- static graphStatus AddEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
-
- // check whether src is link to dst and then remove
- static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst);
-
- static graphStatus RemoveEdge(const AnchorPtr &src, const AnchorPtr &dst);
-
- static graphStatus RemoveEdge(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst);
-
- static graphStatus RemoveEdge(const OutDataAnchorPtr &src, const InControlAnchorPtr &dst);
-
- static graphStatus ReplaceEdgeDst(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
- const InDataAnchorPtr &new_dst);
-
- static graphStatus ReplaceEdgeDst(const OutControlAnchorPtr &src, const InControlAnchorPtr &dst,
- const InControlAnchorPtr &new_dst);
-
- static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst,
- const NodePtr &new_node);
-
- static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node);
-
- static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node);
-
- static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor,
- const std::vector<OpDescPtr> &vec_op_desc);
-
- ///
- /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst
- /// @param [in] src
- /// @param [in] dsts
- /// @param [in] insert_node
- /// @param [in] input_index
- /// @param [in] output_index
- /// @return graphStatus
- ///
- static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDataAnchorPtr> &dsts,
- const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0);
-
- static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node);
-
- static graphStatus RemoveJustNode(ComputeGraph &compute_graph, const NodePtr &node);
-
- static void RecordOriginalNames(std::vector<ge::NodePtr> original_nodes, const ge::NodePtr &node);
-
- static void RecordOriginalNames(std::vector<std::string> names_tmp, const ge::NodePtr &node);
-
- static bool MatchDumpStr(const std::string &suffix);
-
- static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false);
-
- static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph);
-
- static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos);
-
- static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix);
-
- static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph);
-
- static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message);
-
- static void WriteProtoToTextFile(const google::protobuf::Message &proto, const char *real_path);
-
- static graphStatus AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node);
-
- ///
- /// Isolating `node`, relinking data links from the in-anchor peer nodes to
- /// the out-anchor peer nodes according to `io_map`, relinking control links
- /// to ensure that input nodes of `node` are before out nodes
- ///
- /// Link the `io_map[i]` input anchor peer node to `i` output anchor peer
- /// nodes, then unlink all links connecting with `node`. If `io_map[i]` < 0,
- /// unlink all links from `i` output anchor without any relinking.
- ///
- /// @param node
- /// @param io_map
- /// @return
- ///
- static graphStatus IsolateNode(const NodePtr &node, const std::initializer_list<int> &io_map);
- static graphStatus IsolateNode(const NodePtr &node, const std::vector<int> &io_map);
-
- ///
- /// Isolate `node` which must be one input one output, equivalent to
- /// `IsolateNode(node, {0})`
- /// @param node
- /// @return
- ///
- static graphStatus IsolateNodeOneIO(const NodePtr &node);
-
- ///
- /// The data anchors replacing behavior is the same with
- /// `ReplaceNodeDataAnchors`. In addition, replace all `old_node` control
- /// anchors with `new_node`'s.
- /// @param new_node
- /// @param old_node
- /// @param inputs_map
- /// @param outputs_map
- /// @return
- ///
- static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
- std::initializer_list<int> inputs_map, std::initializer_list<int> outputs_map);
-
- static graphStatus ReplaceNodeAnchors(const NodePtr &new_node, const NodePtr &old_node,
- const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
-
- ///
- /// Replace `old_node` data anchors with `new_node`'s according to `inputs_map` and `outputs_map`.
- /// Replace the `i` in/out data anchor on `old_node` with
- /// `inputs_map[i]`/`outputs_map[i]` data anchor on `new_node`.
- /// If `inputs_map[i]`/`outputs_map[i]` < 0 or the index not contained in
- /// `inputs_map[i]`/`outputs_map[i]`, the `i` data anchor will remain
- /// on `old_node`.
- /// @param new_node
- /// @param old_node
- /// @param inputs_map
- /// @param outputs_map
- /// @return
- ///
- static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
- std::initializer_list<int> inputs_map,
- std::initializer_list<int> outputs_map);
-
- static graphStatus ReplaceNodeDataAnchors(const NodePtr &new_node, const NodePtr &old_node,
- const std::vector<int> &inputs_map, const std::vector<int> &outputs_map);
-
- ///
- /// Copy all in-control edges from `src_node` to `dst_node`
- /// @param src_node
- /// @param dst_node
- /// @return
- ///
- static graphStatus CopyInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
-
- static graphStatus MoveInCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
-
- ///
- /// Copy all out-control edges from `src_node` to `dst_node`
- /// @param src_node
- /// @param dst_node
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus CopyOutCtrlEdges(const NodePtr &src_node, NodePtr &dst_node);
-
- ///
- /// Move all out-control edges from `src_node` to `dst_node`
- /// @param src_node
- /// @param dst_node
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node);
-
- ///
- /// Copy all in-data edges from `src_node` to `dst_node`
- /// @param src_node
- /// @param dst_node
- /// @return
- ///
- static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node);
-
- static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph);
-
- static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec);
-
- ///
- /// Get reference-mapping of all data_anchors in graph
- /// @param [in] graph
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus GetRefMapping(const ComputeGraphPtr &graph,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol);
-
- ///
- /// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs
- /// of the graph have UNKNOWN_SHAPE operators or not.
- /// Note: This function will only look 'down' from the graph, not 'up'. For example, the following
- /// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE
- /// ROOT graph: A -----> B -----> C
- /// K subgraph U
- /// |
- /// V
- /// SUB graph: D --> E --> F
- /// K K K
- /// @param [in] graph
- /// @return bool
- ///
- static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph);
-
- static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name);
-
- private:
- ///
- /// Get reference-mapping for in_data_anchors of node
- /// @param [in] node
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus HandleInAnchorMapping(const NodePtr &node,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol);
-
- ///
- /// Get reference-mapping for out_data_anchors of node
- /// @param [in] node
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus HandleOutAnchorMapping(const NodePtr &node,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol);
-
- ///
- /// Handle input of subgraph
- /// @param [in] node
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus HandleSubgraphInput(const NodePtr &node,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol);
-
- ///
- /// Handle input of Merge op
- /// @param [in] node
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus HandleMergeInput(const NodePtr &node,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol);
-
- ///
- /// Handle output of subgraph
- /// @param [in] node
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus HandleSubgraphOutput(const NodePtr &node,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol);
-
- ///
- /// Union ref-mapping
- /// @param [in] exist_node_info1
- /// @param [in] exist_node_info2
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @param [out] symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol, std::string &symbol);
-
- ///
- /// Update symbol mapping with a new reference pair
- /// @param [in] cur_node_info
- /// @param [in] exist_node_info
- /// @param [out] symbol_to_anchors
- /// @param [out] anchor_to_symbol
- /// @return success: GRAPH_SUCESS
- ///
- static graphStatus UpdateRefMapping(const NodeIndexIO &cur_node_info, const NodeIndexIO &exist_node_info,
- std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors,
- std::map<std::string, std::string> &anchor_to_symbol);
-
- ///
- /// Check if out_data_anchor is reference of input
- /// @param [in] out_data_anchor
- /// @param [out] reuse_in_index
- /// @return bool
- ///
- static bool IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t &reuse_in_index);
- };
-
- class ComputeGraphBuilder {
- public:
- ComputeGraphBuilder() : owner_graph_(nullptr) {}
- ComputeGraphBuilder(const ComputeGraphBuilder &) = delete;
- ComputeGraphBuilder &operator=(const ComputeGraphBuilder &) = delete;
- ComputeGraphBuilder(const ComputeGraphBuilder &&) = delete;
- ComputeGraphBuilder &operator=(const ComputeGraphBuilder &&) = delete;
- ~ComputeGraphBuilder() = default;
-
- ///
- /// @brief Add node to graph
- /// @param [in] op_desc
- /// @return ComputeGraphBuilder
- ///
- virtual ComputeGraphBuilder &AddNode(const OpDescPtr &op_desc);
-
- ///
- /// @brief Add data-link among nodes in graph
- /// @param [in] src_name
- /// @param [in] out_anchor_ind
- /// @param [in] dst_name
- /// @param [in] in_anchor_ind
- /// @return ComputeGraphBuilder
- ///
- virtual ComputeGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind,
- const std::string &dst_name, uint32_t in_anchor_ind);
-
- ///
- /// @brief Add ctrl-link among nodes in graph
- /// @param [in] src_name
- /// @param [in] dst_name
- /// @return ComputeGraphBuilder
- ///
- virtual ComputeGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name);
-
- ///
- /// @brief Build graph
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return ComputeGraphPtr
- ///
- virtual ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) = 0;
-
- /// @brief Get node with name
- /// @param [in] name
- /// @return NodePtr
- ///
- NodePtr GetNode(const std::string &name);
-
- /// @brief Get all nodes
- /// @return std::vector<NodePtr>
- ///
- std::vector<NodePtr> GetAllNodes();
-
- protected:
- ///
- /// @brief Build nodes
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- void BuildNodes(graphStatus &error_code, std::string &error_msg);
-
- ///
- /// @brief Build data-links
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- void BuildDataLinks(graphStatus &error_code, std::string &error_msg);
-
- ///
- /// @brief Build ctrl-links
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- void BuildCtrlLinks(graphStatus &error_code, std::string &error_msg);
-
- ComputeGraphPtr owner_graph_;
-
- // node_name -> node
- std::map<std::string, NodePtr> node_names_;
- std::vector<OpDescPtr> nodes_;
-
- // <src_node_name, out_anchor_ind> -> <dst_node_name, in_anchor_ind>
- std::vector<std::pair<std::pair<std::string, uint32_t>, std::pair<std::string, uint32_t>>> data_links_;
- // src_node_name -> dst_node_name
- std::vector<std::pair<std::string, std::string>> ctrl_links_;
- };
-
- class CompleteGraphBuilder : public ComputeGraphBuilder {
- public:
- explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {}
- CompleteGraphBuilder(const CompleteGraphBuilder &) = delete;
- CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete;
- CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete;
- CompleteGraphBuilder &operator=(const CompleteGraphBuilder &&) = delete;
- ~CompleteGraphBuilder() = default;
-
- ///
- /// @brief Add node to graph
- /// @param [in] op_desc
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
-
- ///
- /// @brief Add data-link among nodes in graph
- /// @param [in] src_name
- /// @param [in] out_anchor_ind
- /// @param [in] dst_name
- /// @param [in] in_anchor_ind
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
- uint32_t in_anchor_ind) override;
-
- ///
- /// @brief Add ctrl-link among nodes in graph
- /// @param [in] src_name
- /// @param [in] dst_name
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
-
- ///
- /// @brief Set index_th input anchor for graph
- /// @param [in] index
- /// @param [in] node_names
- /// @param [in] anchor_inds
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &SetInput(uint32_t index, const std::vector<std::string> &node_names,
- const std::vector<uint32_t> &anchor_inds);
-
- ///
- /// @brief Set index_th input of graph as useless
- /// @param [in] index
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &SetUselessInput(uint32_t index);
-
- ///
- /// @brief Add output anchor for graph
- /// @param [in] owner_node_name
- /// @param [in] anchor_ind
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind);
-
- ///
- /// @brief Add target for graph
- /// @param [in] target_name
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &AddTarget(const std::string &target_name);
-
- ///
- /// @brief Set parent-node of graph
- /// @param [in] parent_node
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &SetParentNode(const NodePtr &parent_node);
-
- ///
- /// @brief Set mapping-relation of parent-node in_anchor_ind & Data-node
- /// @param [in] input_mapping: index_of_graph_input -> in_anchor_index_of_parent_node
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &SetInputMapping(const std::map<uint32_t, uint32_t> &input_mapping);
-
- ///
- /// @brief Set mapping-relation of parent-node out_anchor_ind & NetOutput-node out_anchor_ind
- /// @param [in] output_mapping: index_of_graph_output -> out_anchor_index_of_parent_node
- /// @return CompleteGraphBuilder
- ///
- CompleteGraphBuilder &SetOutputMapping(const std::map<uint32_t, uint32_t> &output_mapping);
-
- ///
- /// @brief Build graph
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return ComputeGraphPtr
- ///
- ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
-
- private:
- ///
- /// @brief Add data nodes
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- void AddDataNodes(graphStatus &error_code, std::string &error_msg);
-
- ///
- /// @brief Add data node
- /// @param [in] index
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- NodePtr AddDataNode(uint32_t index, graphStatus &error_code, std::string &error_msg);
-
- ///
- /// @brief Add RetVal nodes
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- void AddRetValNodes(graphStatus &error_code, std::string &error_msg);
-
- ///
- /// @brief Build target-nodes for graph
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- void BuildGraphTargets(graphStatus &error_code, std::string &error_msg);
-
- std::string name_;
- NodePtr parent_node_;
- std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_;
- std::vector<std::pair<std::string, uint32_t>> graph_outputs_;
- std::vector<std::string> graph_targets_;
-
- // index_of_graph_input -> in_anchor_index_of_parent_node
- std::map<uint32_t, uint32_t> input_mapping_;
- // index_of_graph_output -> out_anchor_index_of_parent_node
- std::map<uint32_t, uint32_t> output_mapping_;
- };
-
- class PartialGraphBuilder : public ComputeGraphBuilder {
- public:
- PartialGraphBuilder() = default;
- PartialGraphBuilder(const PartialGraphBuilder &) = delete;
- PartialGraphBuilder &operator=(const PartialGraphBuilder &) = delete;
- PartialGraphBuilder(const PartialGraphBuilder &&) = delete;
- PartialGraphBuilder &operator=(const PartialGraphBuilder &&) = delete;
- ~PartialGraphBuilder() = default;
-
- ///
- /// @brief Add node to graph
- /// @param [in] op_desc
- /// @return PartialGraphBuilder
- ///
- PartialGraphBuilder &AddNode(const OpDescPtr &op_desc) override;
-
- ///
- /// @brief Add data-link among nodes in graph
- /// @param [in] src_name
- /// @param [in] out_anchor_ind
- /// @param [in] dst_name
- /// @param [in] in_anchor_ind
- /// @return PartialGraphBuilder
- ///
- PartialGraphBuilder &AddDataLink(const std::string &src_name, uint32_t out_anchor_ind, const std::string &dst_name,
- uint32_t in_anchor_ind) override;
-
- ///
- /// @brief Add ctrl-link among nodes in graph
- /// @param [in] src_name
- /// @param [in] dst_name
- /// @return PartialGraphBuilder
- ///
- PartialGraphBuilder &AddControlLink(const std::string &src_name, const std::string &dst_name) override;
-
- ///
- /// @brief Set owner graph
- /// @param [in] graph
- /// @return PartialGraphBuilder
- ///
- PartialGraphBuilder &SetOwnerGraph(const ComputeGraphPtr &graph);
-
- ///
- /// @brief Add exist node
- /// @param [in] node
- /// @return PartialGraphBuilder
- ///
- PartialGraphBuilder &AddExistNode(const NodePtr &node);
-
- ///
- /// @brief Build multi nodes with links
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return ComputeGraphPtr
- ///
- ComputeGraphPtr Build(graphStatus &error_code, std::string &error_msg) override;
-
- private:
- ///
- /// @brief Build exist nodes
- /// @param [out] error_code
- /// @param [out] error_msg
- /// @return void
- ///
- void BuildExistNodes(graphStatus &error_code, std::string &error_msg);
-
- std::vector<NodePtr> exist_nodes_;
- };
- } // namespace ge
- #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_
|