/** * Copyright 2019-2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ #define INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_ #include #include #include #include #include using std::initializer_list; using std::map; using std::string; using std::vector; using namespace std; namespace fe { /** Fusion pattern * @ingroup FUSION_PASS_GROUP * Describe Pattern of Ops waiting for fusion(Op type, etc) */ class FusionPattern { public: struct OpDesc; using OpDescPtr = std::shared_ptr; /** * @ingroup fe * @brief description of Ops */ struct OpDesc { string id; // Identifier std::vector types; // the Op types of Ops std::vector inputs; // all input Ops bool repeatable; // flag to show if match multiple Ops or not bool is_output; // flag to show if the op is output node }; public: explicit FusionPattern(string name = ""); ~FusionPattern(); /** set pattern name * * @param name pattern name * @return FusionPattern */ FusionPattern &SetName(const string &name); /** add Op description with unknown number of args * * @param id pattern id * @param types op type list * @return FusionPattern */ FusionPattern &AddOpDesc(const string &id, const initializer_list &types = {}); /** add Op description with vector * * @param id pattern id * @param types op type list * * @return FusionPattern */ FusionPattern &AddOpDesc(const string &id, const vector &types); /** set input Ops with unknown number of args * * @param id pattern id * * @param input_ids inputs to id op * * @return FusionPattern */ FusionPattern &SetInputs(const string &id, const initializer_list &input_ids); /** set input Ops with unknown number of args * * @param id pattern id * * @param input_ids inputs to id op * * @return FusionPattern */ FusionPattern &SetInputs(const string &id, const vector &input_ids); /** set output Op * * @param id pattern id * * @return FusionPattern */ FusionPattern &SetOutput(const string &id); /** build pattern and check if error exists * * @return True or False */ bool Build(); /** get pattern name * * @param id pattern id * * @return fusion pattern name */ const string &GetName() const; /** get the OpDesc of input Ops (const) * * @param op_desc op_desc for getting inputs * * @return op_desc's iniput opdesc list */ static const vector> *GetInputs(std::shared_ptr op_desc); /** get the OpDesc of output Op * * @return pattern's output opdesc list */ const std::shared_ptr GetOutput() const; /** print pattern * */ void Dump() const; void GetOpDescList(vector> &op_desc_list); /** get OpDesc based on ID, return nullptr if failed * * @param id pattern id * * @return pattern's output opdesc list */ std::shared_ptr GetOpDesc(const string &id) const; private: FusionPattern(const FusionPattern &) = default; FusionPattern &operator=(const FusionPattern &) = default; void SetError(); private: string name_; vector> ops_; map> op_map_; std::shared_ptr output_; bool has_error_; }; } // namespace fe #endif // INC_REGISTER_GRAPH_OPTIMIZER_FUSION_PATTERN_H_