|
- /**
- * Copyright 2021 Huawei Technologies Co., Ltd
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_
- #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_
-
- #include <vector>
-
- #include "common/op/ge_op_utils.h"
- #include "graph/compute_graph.h"
- #include "graph/graph.h"
- #include "graph/op_desc.h"
- #include "graph/debug/ge_op_types.h"
- #include "framework/common/types.h"
- #include "graph/utils/op_desc_utils.h"
- #include "graph/utils/graph_utils.h"
- #include "graph/shape_refiner.h"
- #include "../pass_utils.h"
- #include "./mds_utils.h"
- #include "./mds_kernel_factory.h"
-
- namespace ge {
- class DeploySchedulerKernel {
- public:
- static shared_ptr<DeploySchedulerKernel> Instance();
-
- /// CutN imply
- /// @param [in] node_ptr
- virtual Status CutN(const ge::NodePtr &node_ptr);
-
- /// CutH imply
- /// @param [in] node_ptr
- virtual Status CutH(const ge::NodePtr &node_ptr);
-
- /// DynamicCutN imply
- /// @param [in] node_ptr
- virtual Status DynamicCutN(const ge::NodePtr &node_ptr);
-
- /// DynamicCutH imply
- /// @param [in] node_ptr
- virtual Status DynamicCutH(const ge::NodePtr &node_ptr);
-
- // halo exchange process
- Status HaloExchangeProcess(NodePtr node, int64_t index, bool local_slice = false);
-
- NodePtr GetInputNode() {
- return input_node_;
- }
- DeploySchedulerKernel &operator=(const DeploySchedulerKernel &kernel) = delete;
- DeploySchedulerKernel(const DeploySchedulerKernel &kernel) = delete;
-
- protected:
- DeploySchedulerKernel() = default;
- virtual ~DeploySchedulerKernel() = default;
-
- private:
- NodePtr input_node_ = nullptr;
- };
-
- namespace mds_cut_pass {
- shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node);
- }
- } // namespace ge
- #endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_
|