/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef REGISTER_SCOPE_SCOPE_PASS_IMPL_H_ #define REGISTER_SCOPE_SCOPE_PASS_IMPL_H_ #include "external/register/scope/scope_fusion_pass_register.h" namespace ge { class ScopesResult::ScopesResultImpl { public: void SetScopes(const std::vector &scopes) { scopes_ = scopes; } const std::vector &GetScopes() const { return scopes_; } void SetNodes(const std::vector &nodes) { nodes_ = nodes; } const std::vector &GetNodes() const { return nodes_; } private: std::vector scopes_; // multiple scopes std::vector nodes_; // op outside of scope }; class ScopeBasePass::ScopeBasePassImpl { public: ScopeBasePassImpl(ScopeBasePass *parent) : parent_(parent) {} virtual ~ScopeBasePassImpl(); Status Run(std::shared_ptr &scope_graph); private: Status AddFusionScopesResultToScopeGraph(std::shared_ptr &scope_graph, std::vector &scope_results); // Match rules one by one, support multiple sets of matching rules, and finally output a single scope // Note: This function does not have to be rewritten. // In order to match the fusion rules designed by you better, // you can implement your specific versions separately. bool MatchAllBatches(const ScopeTree *scope_tree, std::vector &results); bool MatchOneBatch(const ScopeTree *scope_tree, const std::vector &patternlist, std::vector &results); bool MatchOneScope(const ScopePattern *pattern, Scope *scope, std::vector &results); Status PrintFusionScopeInfo(std::shared_ptr &scope_graph); private: std::vector patterns_; ScopeBasePass *parent_; }; } // namespace ge #endif // REGISTER_SCOPE_SCOPE_PASS_IMPL_H_