|
- /**
- * Copyright 2020 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_
- #define GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_
- #include <map>
- #include <queue>
- #include <vector>
-
- #include "external/ge/ge_api_error_codes.h"
-
- #include "graph/compute_graph.h"
-
- namespace ge {
- namespace multibatch {
- Status ProcessMultiBatch(ComputeGraphPtr &graph);
-
- Status GetDynamicOutputShape(ComputeGraphPtr &graph);
-
- enum NodeStatus {
- kNodeInBatchBranch,
- kNodeOutBatchBranch,
- kNodeStartNode,
- kNodeNotSupportNode,
- };
-
- enum DynamicType {
- kDynamicBatch,
- kDynamicImageSize,
- kDynamicDims,
- kDynamicUnknown,
- };
-
- class MultiBatchGraphCopyer {
- public:
- explicit MultiBatchGraphCopyer(ComputeGraphPtr &graph) : graph_(graph) {}
- ~MultiBatchGraphCopyer() = default;
-
- void AddShape(const std::vector<int64_t> &shape) { shapes_.emplace_back(shape); }
- void SetUserDesignateShape(const vector<pair<string, vector<int64_t>>> &designate_shape) {
- user_designate_shape_ = designate_shape;
- for (auto &item : designate_shape) {
- data_name_order_.push_back(item.first);
- }
- }
-
- void SetDynamicType(const DynamicType dynamic_type) {
- dynamic_type_ = dynamic_type;
- }
- Status CopyGraph();
-
- private:
- Status Init();
- Status CheckArguments();
-
- // label status for origin_all_nodes_
- Status LabelStatus();
- Status LabelInBatchBranchStatus();
- void LabelStatusForData(const NodePtr &data);
- void LabelStatusForGetNextSink(const NodePtr &data);
- // add nodes functions
- Status CreateNewNodes();
-
- NodePtr InsertShapeDataNode();
- NodePtr InsertGetDynamicDimsNode();
-
- Status InsertSwitchNAndUpdateMaxShape(const NodePtr &node);
-
- Status InsertSwitchNForData(const NodePtr &node, const size_t &out_anchor_index, const size_t &peer_in_anchor_index,
- std::vector<std::pair<Node *, NodePtr>> &dynamic_out_to_switchn);
-
- Status InsertIdentityAfterSwitchN();
- Status UpdateMaxShapeToData(const NodePtr &node, size_t out_anchor_index);
- Status UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index);
-
- Status InsertMergeForEdgeNode(const NodePtr &node);
- Status LinkGetDynamicDimsToNetOutput(const NodePtr &node);
-
- /// Insert a merge node for src node `node` on output index `index`. The merge node will be used to merge all nodes
- /// in batch-branch to one output to the node out of the batch-branch.
- /// Cond 1: If the `index` is -1, then the src node link a data edge(at output 0) to the merge node,
- /// Cond 2: In condition 1, if the src node does not have any data output, we create a const node after it,
- /// the result like this:
- /// src_node ---------> const_for_src_node --------> merge
- /// control data
- /// Cond 3: If the src node is a data-like node, the SwitchN after it will be link to the merge node.
- /// @param node
- /// @param index
- /// @return
- NodePtr InsertMergeNode(const NodePtr &node, int index);
- Status CopyNodeInBatchBranch(const NodePtr &node);
-
- // link edges functions
- Status LinkEdges();
- Status AddAttrForGetDynamicDims(const NodePtr &node);
- Status AddLinkForGetDynamicDims(const NodePtr &node);
-
- Status LinkDataToSwitchN(const NodePtr &data, const NodePtr &switchn, const int &out_index);
- Status LinkToMerge(const NodePtr &node);
- Status LinkToNodeInBranch(const NodePtr &node);
- Status LinkToNodeOutBranch(const NodePtr &node);
- Status LinkDataToMerge(const NodePtr &data, const NodePtr &merge, const NodePtr &switchn);
- Status LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge);
- NodePtr FindSwitchnNodeForDataEdge(const OutDataAnchorPtr &data_out_anchor, const NodePtr &origin_node);
- Status CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr ©ed_node);
- Status CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr ©ed_node);
- Status CheckAndParseDynamicData();
- bool IsInBatchBranch(const NodePtr &node);
- NodeStatus GetNodeStatus(const NodePtr &node) { return origin_nodes_status_[node.get()]; };
- Status CheckCopyResult(const std::vector<NodePtr> &start_nodes);
-
- // arguments
- ComputeGraphPtr graph_;
- std::vector<std::vector<int64_t>> shapes_;
-
- // the shape data node created
- NodePtr shape_data_;
-
- // all nodes in the origin graph
- std::vector<NodePtr> origin_all_nodes_;
-
- // all data nodes in the origin graph
- std::vector<NodePtr> origin_data_nodes_;
-
- // the nodes in-batch-branch, and the nodes copyed by shapes
- std::map<Node *, std::vector<NodePtr>> nodes_to_batch_nodes_;
-
- // the data nodes, and the SwitchN nodes inserted after it
- std::map<Node *, NodePtr> data_nodes_to_switchn_;
-
- // the getnext_sink nodes, and the SwitchN nodes inserted after it
- std::vector<std::vector<std::pair<Node *, NodePtr>>> getnext_nodes_to_switchn_;
- std::vector<std::vector<std::pair<int, int>>> outidx_inidx_mappings_;
- std::vector<std::pair<int, int>> outidx_inidx_mapping_;
-
- // the nodes on the in/out-batch-branch edge, and the merge nodes inserted after it
- std::map<Node *, std::vector<NodePtr>> nodes_to_merge_nodes_;
-
- // all nodes and their status
- std::map<Node *, NodeStatus> origin_nodes_status_;
-
- // user designate shape, decord the order of each input data
- std::vector<std::pair<std::string, std::vector<int64_t>>> user_designate_shape_;
- std::vector<std::string> data_name_order_;
-
- // each data's own dynamic info
- map<string, vector<vector<int64_t>>> data_to_dynamic_info_;
-
- // dynamic type : dynamic batch,, dynamic image size, dynamic dims.
- DynamicType dynamic_type_ = DynamicType::kDynamicUnknown;
-
- std::vector<std::pair<size_t, size_t>> getnext_sink_dynamic_out_mapping_;
- bool getnext_sink_dynamic_dims_ = false;
- };
- } // namespace multibatch
- } // namespace ge
- #endif // GE_GRAPH_PREPROCESS_MULTI_BATCH_COPY_GRAPH_H_
|