/** * Copyright 2020 Huawei Technologies Co., Ltd * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H_ #define REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H_ #include "external/register/scope/scope_fusion_pass_register.h" namespace ge { class ScopeAttrValue::ScopeAttrValueImpl { public: ScopeAttrValueImpl() : int_value_(0), float_value_(0.0), string_value_(""), bool_value_(false) {} ~ScopeAttrValueImpl() {} void SetIntValue(const int64_t &value) { int_value_ = value; } void SetFloatValue(const float &value) { float_value_ = value; } void SetStringValue(const std::string &value) { string_value_ = value; } void SetBoolValue(const bool &value) { bool_value_ = value; } const int64_t &GetIntValue() const { return int_value_; } const float &GetFloatValue() const { return float_value_; } const std::string &GetStrValue() const { return string_value_; } const bool &GetBoolValue() const { return bool_value_; } private: int64_t int_value_; float float_value_; std::string string_value_; bool bool_value_; }; class NodeOpTypeFeature::NodeOpTypeFeatureImpl : ScopeBaseFeature { public: NodeOpTypeFeatureImpl(std::string nodeType, int num, int step = 0) : node_type_(nodeType), num_(num), step_(step) {} ~NodeOpTypeFeatureImpl() {} bool Match(const Scope *scope) override; public: std::string node_type_; // Node type int num_; // Node number int step_; // step }; class NodeAttrFeature::NodeAttrFeatureImpl : ScopeBaseFeature { public: NodeAttrFeatureImpl(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value) : node_type_(nodeType), attr_name_(attr_name), datatype_(datatype), attr_value_(attr_value) {} ~NodeAttrFeatureImpl() {} bool Match(const Scope *scope) override; public: std::string node_type_; // Node type std::string attr_name_; // attribute name ge::DataType datatype_; // datatype ScopeAttrValue attr_value_; // AttrValue }; class ScopeFeature::ScopeFeatureImpl : ScopeBaseFeature { public: ScopeFeatureImpl(std::string sub_type, int32_t num, std::string suffix = "", std::string sub_scope_mask = "", int step = 0) : sub_type_(sub_type), num_(num), suffix_(suffix), sub_scope_mask_(sub_scope_mask), step_(step) {} ~ScopeFeatureImpl() {} bool Match(const Scope *scope) override; bool SubScopesMatch(const std::vector &scopes); public: std::string sub_type_; int32_t num_; std::string suffix_; std::string sub_scope_mask_; int step_; }; class ScopePattern::ScopePatternImpl { public: ScopePatternImpl() {} ~ScopePatternImpl() {} bool Match(const Scope *scope) const; void SetSubType(const std::string &sub_type); const std::string &SubType() const { return sub_type_; } void AddNodeOpTypeFeature(NodeOpTypeFeature &feature); void AddNodeAttrFeature(NodeAttrFeature &feature); void AddScopeFeature(ScopeFeature &feature); private: std::string sub_type_; // get Scope sub type std::vector node_optype_features_; std::vector node_attr_features_; std::vector scopes_features_; }; } // namespace ge #endif // REGISTER_SCOPE_SCOPE_PATTERN_IMPL_H_