| @@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true | |||||
| ConstructorInitializerIndentWidth: 4 | ConstructorInitializerIndentWidth: 4 | ||||
| ContinuationIndentWidth: 4 | ContinuationIndentWidth: 4 | ||||
| Cpp11BracedListStyle: true | Cpp11BracedListStyle: true | ||||
| DerivePointerAlignment: true | |||||
| DisableFormat: false | DisableFormat: false | ||||
| ExperimentalAutoDetectBinPacking: false | ExperimentalAutoDetectBinPacking: false | ||||
| FixNamespaceComments: true | FixNamespaceComments: true | ||||
| @@ -94,7 +93,7 @@ PenaltyBreakString: 1000 | |||||
| PenaltyBreakTemplateDeclaration: 10 | PenaltyBreakTemplateDeclaration: 10 | ||||
| PenaltyExcessCharacter: 1000000 | PenaltyExcessCharacter: 1000000 | ||||
| PenaltyReturnTypeOnItsOwnLine: 200 | PenaltyReturnTypeOnItsOwnLine: 200 | ||||
| PointerAlignment: Left | |||||
| PointerAlignment: Right | |||||
| RawStringFormats: | RawStringFormats: | ||||
| - Language: Cpp | - Language: Cpp | ||||
| Delimiters: | Delimiters: | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit f3f137de034885f0c7394d7f04b41b08d450d2d2 | |||||
| Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc | |||||
| @@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l | |||||
| ENV PROJECT_HOME=/code/Turing/graphEngine | ENV PROJECT_HOME=/code/Turing/graphEngine | ||||
| RUN mkdir /var/run/sshd | |||||
| RUN echo "root:root" | chpasswd | |||||
| RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config | |||||
| RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd | |||||
| ENV NOTVISIBLE "in users profile" | |||||
| RUN echo "export VISIBLE=now" >> /etc/profile | |||||
| EXPOSE 22 7777 | |||||
| RUN useradd -ms /bin/bash debugger | |||||
| RUN echo "debugger:ge123" | chpasswd | |||||
| CMD ["/usr/sbin/sshd" "-D" "&"] | |||||
| RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | ||||
| @@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
| DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | ||||
| DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | ||||
| DOCKER_IMAGE_TAG=ge_build_env.1.0.6 | |||||
| DOCKER_IMAGE_TAG=ge_build_env.1.0.9 | |||||
| DOCKER_IAMGE_NAME=joycode2art/turing | DOCKER_IAMGE_NAME=joycode2art/turing | ||||
| DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | ||||
| @@ -61,7 +61,7 @@ function enter_docker_env(){ | |||||
| if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | ||||
| echo "please run 'ge env --pull' to download images first!" | echo "please run 'ge env --pull' to download images first!" | ||||
| elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | ||||
| $docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||||
| $docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||||
| elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | ||||
| $docker_cmd start ${DOCKER_BUILD_ENV_NAME} | $docker_cmd start ${DOCKER_BUILD_ENV_NAME} | ||||
| $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | ||||
| @@ -60,6 +60,7 @@ set(SRCS | |||||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | ||||
| @@ -17,16 +17,3 @@ include(cmake/graphengine.cmake) | |||||
| add_subdirectory(easy_graph) | add_subdirectory(easy_graph) | ||||
| add_subdirectory(ge_graph_dsl) | add_subdirectory(ge_graph_dsl) | ||||
| add_subdirectory(ge_running_env) | add_subdirectory(ge_running_env) | ||||
| file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS | |||||
| "utils/*.cc" | |||||
| ) | |||||
| add_library(framework STATIC ${UTILS_SRC}) | |||||
| target_include_directories(framework | |||||
| PUBLIC utils/ | |||||
| ) | |||||
| set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||||
| target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env) | |||||
| @@ -26,16 +26,32 @@ EG_NS_BEGIN | |||||
| //////////////////////////////////////////////////////////////// | //////////////////////////////////////////////////////////////// | ||||
| namespace detail { | namespace detail { | ||||
| template<typename GRAPH_BUILDER> | |||||
| template <typename GRAPH_BUILDER> | |||||
| Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | ||||
| GraphBuilder builder(name); | GraphBuilder builder(name); | ||||
| builderInDSL(builder); | builderInDSL(builder); | ||||
| return std::move(*builder); | return std::move(*builder); | ||||
| } | } | ||||
| struct GraphDefiner { | |||||
| GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) { | |||||
| name = specifiedName ? specifiedName : defaultName; | |||||
| } | |||||
| template <typename USER_BUILDER> | |||||
| auto operator|(USER_BUILDER &&userBuilder) { | |||||
| GraphBuilder graphBuilder{name}; | |||||
| std::forward<USER_BUILDER>(userBuilder)(graphBuilder); | |||||
| return *graphBuilder; | |||||
| } | |||||
| private: | |||||
| const char *name; | |||||
| }; | |||||
| } // namespace detail | } // namespace detail | ||||
| #define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__) | |||||
| #define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER) | |||||
| #define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER) | |||||
| #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | ||||
| #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | ||||
| #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | ||||
| @@ -16,10 +16,15 @@ | |||||
| #include "easy_graph/layout/graph_layout.h" | #include "easy_graph/layout/graph_layout.h" | ||||
| #include "easy_graph/layout/layout_executor.h" | #include "easy_graph/layout/layout_executor.h" | ||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
| #include "easy_graph/graph/graph.h" | #include "easy_graph/graph/graph.h" | ||||
| EG_NS_BEGIN | EG_NS_BEGIN | ||||
| namespace { | |||||
| GraphEasyExecutor default_executor; | |||||
| } | |||||
| void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | ||||
| this->executor_ = &executor; | this->executor_ = &executor; | ||||
| options_ = opts; | options_ = opts; | ||||
| @@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | |||||
| Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | ||||
| const LayoutOption *options = opts ? opts : this->options_; | const LayoutOption *options = opts ? opts : this->options_; | ||||
| if (!executor_) | |||||
| return EG_UNIMPLEMENTED; | |||||
| if (!executor_) return static_cast<LayoutExecutor &>(default_executor).Layout(graph, options); | |||||
| return executor_->Layout(graph, options); | return executor_->Layout(graph, options); | ||||
| } | } | ||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef D52AA06185E34BBFB714FFBCDAB0D53A | |||||
| #define D52AA06185E34BBFB714FFBCDAB0D53A | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include <exception> | |||||
| #include <string> | |||||
| GE_NS_BEGIN | |||||
| struct AssertError : std::exception { | |||||
| AssertError(const char *file, int line, const std::string &info); | |||||
| private: | |||||
| const char *what() const noexcept override; | |||||
| private: | |||||
| std::string info; | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_31309AA0A4E44C009C22AD9351BF3410 | |||||
| #define INC_31309AA0A4E44C009C22AD9351BF3410 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "graph/compute_graph.h" | |||||
| GE_NS_BEGIN | |||||
| using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||||
| struct CheckUtils { | |||||
| static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun); | |||||
| static void init(); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -1,17 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "tensor_builder_utils.h" | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef C8B32320BD4943D588594B82FFBF2685 | |||||
| #define C8B32320BD4943D588594B82FFBF2685 | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| GE_NS_BEGIN | |||||
| struct FilterScopeGuard { | |||||
| FilterScopeGuard(const std::vector<std::string> &); | |||||
| ~FilterScopeGuard(); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,59 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||||
| #define AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| #include "ge_graph_dsl/assert/assert_error.h" | |||||
| #include "ge_graph_dsl/assert/filter_scope_guard.h" | |||||
| GE_NS_BEGIN | |||||
| #ifdef GTEST_MESSAGE_AT_ | |||||
| #define GRAPH_CHECK_MESSAGE(file, line, message) \ | |||||
| GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure) | |||||
| #elif | |||||
| #define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message) | |||||
| #endif | |||||
| namespace detail { | |||||
| struct GraphAssert { | |||||
| GraphAssert(const char *file, unsigned int line, const std::string &phase_id) | |||||
| : file_(file), line_(line), phase_id_(phase_id) {} | |||||
| void operator|(const ::GE_NS::GraphCheckFun &check_fun) { | |||||
| bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun); | |||||
| if (!ret) { | |||||
| auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! "; | |||||
| GRAPH_CHECK_MESSAGE(file_, line_, message.c_str()); | |||||
| } | |||||
| } | |||||
| private: | |||||
| const char *file_; | |||||
| unsigned int line_; | |||||
| const std::string phase_id_; | |||||
| }; | |||||
| } // namespace detail | |||||
| #define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__}); | |||||
| #define CHECK_GRAPH(phase_id) \ | |||||
| ::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph) | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -33,14 +33,12 @@ struct OpDescCfg { | |||||
| std::vector<int64_t> shape_; | std::vector<int64_t> shape_; | ||||
| }; | }; | ||||
| OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW, | |||||
| OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW, | |||||
| DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | ||||
| : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | ||||
| protected: | protected: | ||||
| OpType GetType() const { | |||||
| return type_; | |||||
| } | |||||
| OpType GetType() const { return type_; } | |||||
| OpType type_; | OpType type_; | ||||
| int in_cnt_; | int in_cnt_; | ||||
| int out_cnt_; | int out_cnt_; | ||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_dsl/assert/assert_error.h" | |||||
| GE_NS_BEGIN | |||||
| AssertError::AssertError(const char *file, int line, const std::string &info) { | |||||
| this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info; | |||||
| } | |||||
| const char *AssertError::what() const noexcept { return info.c_str(); } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,34 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "ge_graph_default_checker.h" | |||||
| #include "ge_graph_check_dumper.h" | |||||
| GE_NS_BEGIN | |||||
| bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) { | |||||
| auto &dumper = dynamic_cast<GeGraphCheckDumper &>(GraphDumperRegistry::GetDumper()); | |||||
| return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun)); | |||||
| } | |||||
| void CheckUtils::init() { | |||||
| static GeGraphCheckDumper checkDumper; | |||||
| GraphDumperRegistry::Register(checkDumper); | |||||
| } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,31 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_dsl/assert/filter_scope_guard.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "ge_dump_filter.h" | |||||
| GE_NS_BEGIN | |||||
| namespace { | |||||
| GeDumpFilter &GetDumpFilter() { return dynamic_cast<GeDumpFilter &>(GraphDumperRegistry::GetDumper()); } | |||||
| } // namespace | |||||
| FilterScopeGuard::FilterScopeGuard(const std::vector<std::string> &filter) { GetDumpFilter().Update(filter); } | |||||
| FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,33 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||||
| #define INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||||
| #include <vector> | |||||
| #include <string> | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "easy_graph/infra/keywords.h" | |||||
| GE_NS_BEGIN | |||||
| INTERFACE(GeDumpFilter) { | |||||
| ABSTRACT(void Update(const std::vector<std::string> &)); | |||||
| ABSTRACT(void Reset()); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,79 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_check_dumper.h" | |||||
| #include "graph/model.h" | |||||
| #include "graph/buffer.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "ge_graph_default_checker.h" | |||||
| GE_NS_BEGIN | |||||
| GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); } | |||||
| bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const { | |||||
| auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix); | |||||
| return (iter != suffixes_.end()); | |||||
| } | |||||
| void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) { | |||||
| if (!IsNeedDump(suffix)) { | |||||
| return; | |||||
| } | |||||
| auto iter = buffers_.find(suffix); | |||||
| if (iter != buffers_.end()) { | |||||
| DumpGraph(graph, iter->second); | |||||
| } else { | |||||
| buffers_[suffix] = Buffer(); | |||||
| DumpGraph(graph, buffers_.at(suffix)); | |||||
| } | |||||
| } | |||||
| bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) { | |||||
| auto iter = buffers_.find(checker.PhaseId()); | |||||
| if (iter == buffers_.end()) { | |||||
| return false; | |||||
| } | |||||
| DoCheck(checker, iter->second); | |||||
| return true; | |||||
| } | |||||
| void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) { | |||||
| Model model("", ""); | |||||
| Model::Load(buffer.GetData(), buffer.GetSize(), model); | |||||
| auto load_graph = model.GetGraph(); | |||||
| checker.Check(GraphUtils::GetComputeGraph(load_graph)); | |||||
| } | |||||
| void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) { | |||||
| Model model("", ""); | |||||
| buffer.clear(); | |||||
| model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||||
| model.Save(buffer, true); | |||||
| } | |||||
| void GeGraphCheckDumper::Update(const std::vector<std::string> &new_suffixes_) { | |||||
| suffixes_ = new_suffixes_; | |||||
| buffers_.clear(); | |||||
| } | |||||
| void GeGraphCheckDumper::Reset() { | |||||
| static std::vector<std::string> default_suffixes_{"PreRunAfterBuild"}; | |||||
| suffixes_ = default_suffixes_; | |||||
| buffers_.clear(); | |||||
| } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,49 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_8EFED0015C27464897BF64531355C810 | |||||
| #define INC_8EFED0015C27464897BF64531355C810 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "ge_dump_filter.h" | |||||
| #include <string> | |||||
| GE_NS_BEGIN | |||||
| struct GeGraphChecker; | |||||
| struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter { | |||||
| GeGraphCheckDumper(); | |||||
| virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix); | |||||
| bool CheckFor(const GeGraphChecker &checker); | |||||
| private: | |||||
| void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer); | |||||
| void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer); | |||||
| private: | |||||
| void Update(const std::vector<std::string> &) override; | |||||
| void Reset() override; | |||||
| bool IsNeedDump(const std::string &suffix) const; | |||||
| private: | |||||
| std::map<std::string, ::GE_NS::Buffer> buffers_; | |||||
| std::vector<std::string> suffixes_; | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_5960A8F437324904BEE0690271258762 | |||||
| #define INC_5960A8F437324904BEE0690271258762 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "easy_graph/infra/keywords.h" | |||||
| #include "graph/compute_graph.h" | |||||
| GE_NS_BEGIN | |||||
| INTERFACE(GeGraphChecker) { | |||||
| ABSTRACT(const std::string &PhaseId() const); | |||||
| ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const); | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "ge_graph_default_checker.h" | |||||
| GE_NS_BEGIN | |||||
| GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun) | |||||
| : phase_id_(phase_id), check_fun_(check_fun) {} | |||||
| const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; } | |||||
| void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); } | |||||
| GE_NS_END | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef BCF4D96BE9FC48938DE7B7E93B551C54 | |||||
| #define BCF4D96BE9FC48938DE7B7E93B551C54 | |||||
| #include "ge_graph_dsl/ge.h" | |||||
| #include "ge_graph_checker.h" | |||||
| #include "graph/compute_graph.h" | |||||
| GE_NS_BEGIN | |||||
| using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||||
| struct GeGraphDefaultChecker : GeGraphChecker { | |||||
| GeGraphDefaultChecker(const std::string &, const GraphCheckFun &); | |||||
| private: | |||||
| const std::string &PhaseId() const override; | |||||
| void Check(const ge::ComputeGraphPtr &graph) const override; | |||||
| private: | |||||
| const std::string phase_id_; | |||||
| const GraphCheckFun check_fun_; | |||||
| }; | |||||
| GE_NS_END | |||||
| #endif | |||||
| @@ -23,15 +23,22 @@ GE_NS_BEGIN | |||||
| namespace { | namespace { | ||||
| #define OP_CFG(optype, ...) \ | |||||
| { \ | |||||
| optype, OpDescCfg { \ | |||||
| optype, __VA_ARGS__ \ | |||||
| } \ | |||||
| #define OP_CFG(optype, ...) \ | |||||
| { \ | |||||
| optype, OpDescCfg { optype, __VA_ARGS__ } \ | |||||
| } | } | ||||
| static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | ||||
| OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | ||||
| OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}), | |||||
| OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
| OP_CFG(VARIABLE, 1, 1)}; | OP_CFG(VARIABLE, 1, 1)}; | ||||
| } // namespace | } // namespace | ||||
| @@ -19,6 +19,4 @@ | |||||
| USING_GE_NS | USING_GE_NS | ||||
| OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { | |||||
| return op_; | |||||
| } | |||||
| OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; } | |||||
| @@ -36,17 +36,11 @@ GE_NS_BEGIN | |||||
| GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | ||||
| void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { | |||||
| build_graph_ = graph; | |||||
| } | |||||
| void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; } | |||||
| Graph GeGraphVisitor::BuildGeGraph() const { | |||||
| return GraphUtils::CreateGraphFromComputeGraph(build_graph_); | |||||
| } | |||||
| Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); } | |||||
| ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { | |||||
| return build_graph_; | |||||
| } | |||||
| ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; } | |||||
| Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | ||||
| build_graph_->SetName(graph.GetName()); | build_graph_->SetName(graph.GetName()); | ||||
| @@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE | |||||
| ) | ) | ||||
| set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | ||||
| target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl) | |||||
| target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl) | |||||
| include(CTest) | include(CTest) | ||||
| enable_testing() | enable_testing() | ||||
| @@ -0,0 +1,129 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "gtest/gtest.h" | |||||
| #include "easy_graph/layout/graph_layout.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
| #include "ge_graph_dsl/graph_dsl.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "ge_graph_dsl/assert/graph_assert.h" | |||||
| #include "graph/model.h" | |||||
| #include "graph/buffer.h" | |||||
| USING_GE_NS | |||||
| class CheckGraphTest : public testing::Test { | |||||
| private: | |||||
| EG_NS::GraphEasyExecutor executor; | |||||
| protected: | |||||
| void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||||
| void TearDown() {} | |||||
| }; | |||||
| TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DUMP_GRAPH_WHEN("after_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
| CHECK_GRAPH(after_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DEF_GRAPH(g2) { | |||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||||
| }; | |||||
| DUMP_GRAPH_WHEN("before_build", "after_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build"); | |||||
| CHECK_GRAPH(before_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
| }; | |||||
| CHECK_GRAPH(after_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g2"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DEF_GRAPH(g2) { | |||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||||
| }; | |||||
| DUMP_GRAPH_WHEN("before_build") | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build"); | |||||
| CHECK_GRAPH(before_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g2"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_check_phases_is_work) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DUMP_GRAPH_WHEN("before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
| auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {}); | |||||
| ASSERT_FALSE(ret); | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| DUMP_GRAPH_WHEN("before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
| CHECK_GRAPH(before_build) { | |||||
| ASSERT_EQ(graph->GetName(), "g1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
| }; | |||||
| } | |||||
| TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) { | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto ge_graph = ToGeGraph(g1); | |||||
| ge::Model model("", ""); | |||||
| model.SetGraph(ge_graph); | |||||
| Buffer buffer; | |||||
| model.Save(buffer, true); | |||||
| ge::Model loadModel("", ""); | |||||
| Model::Load(buffer.GetData(), buffer.GetSize(), loadModel); | |||||
| auto load_graph = loadModel.GetGraph(); | |||||
| ASSERT_EQ(load_graph.GetName(), "g1"); | |||||
| ASSERT_EQ(load_graph.GetAllNodes().size(), 2); | |||||
| } | |||||
| @@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test { | |||||
| EG_NS::GraphEasyExecutor executor; | EG_NS::GraphEasyExecutor executor; | ||||
| protected: | protected: | ||||
| void SetUp() { | |||||
| EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); | |||||
| } | |||||
| void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | ||||
| DEF_GRAPH(g1) { | |||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| auto computeGraph = ToComputeGraph(g1); | auto computeGraph = ToComputeGraph(g1); | ||||
| @@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_graph_with_name) { | TEST_F(GraphDslTest, test_build_graph_with_name) { | ||||
| DEF_GRAPH(g1, "sample_graph") { | |||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) { | |||||
| auto data = std::make_shared<OpDesc>("data1", DATA); | auto data = std::make_shared<OpDesc>("data1", DATA); | ||||
| auto add = std::make_shared<OpDesc>("Add", ADD); | auto add = std::make_shared<OpDesc>("Add", ADD); | ||||
| CHAIN(NODE(data)->NODE(add)); | CHAIN(NODE(data)->NODE(add)); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||||
| auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | ||||
| auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | ||||
| CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | ||||
| DEF_GRAPH(g1) { | |||||
| CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); | |||||
| }); | |||||
| DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_from_control_chain) { | TEST_F(GraphDslTest, test_build_from_control_chain) { | ||||
| DEF_GRAPH(g1) { | |||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) { | |||||
| } | } | ||||
| TEST_F(GraphDslTest, test_build_from_data_chain) { | TEST_F(GraphDslTest, test_build_from_data_chain) { | ||||
| DEF_GRAPH(g1) { | |||||
| DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
| }); | |||||
| DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) { | |||||
| ->NODE("Add4") | ->NODE("Add4") | ||||
| ->NODE("Add5") | ->NODE("Add5") | ||||
| ->NODE("net_output", NETOUTPUT)); | ->NODE("net_output", NETOUTPUT)); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) { | |||||
| CHAIN(NODE("add2")->NODE("net_output")); | CHAIN(NODE("add2")->NODE("net_output")); | ||||
| CHAIN(NODE("add3")->NODE("net_output")); | CHAIN(NODE("add3")->NODE("net_output")); | ||||
| CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | ||||
| }); | |||||
| }; | |||||
| auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
| @@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) { | |||||
| DEF_GRAPH(sub_1) { | DEF_GRAPH(sub_1) { | ||||
| CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | ||||
| CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | ||||
| }); | |||||
| }; | |||||
| DEF_GRAPH(sub_2) { | DEF_GRAPH(sub_2) { | ||||
| CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | ||||
| CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | ||||
| }); | |||||
| }; | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | ||||
| CHAIN(NODE("data_i", DATA)->NODE("while")); | CHAIN(NODE("data_i", DATA)->NODE("while")); | ||||
| }); | |||||
| }; | |||||
| sub_1.Layout(); | sub_1.Layout(); | ||||
| sub_2.Layout(); | sub_2.Layout(); | ||||
| @@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul"); | |||||
| REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | ||||
| REGISTER_OPTYPE_DEFINE(ADD, "Add"); | REGISTER_OPTYPE_DEFINE(ADD, "Add"); | ||||
| REGISTER_OPTYPE_DEFINE(WHILE, "While"); | REGISTER_OPTYPE_DEFINE(WHILE, "While"); | ||||
| REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); | |||||
| REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); | |||||
| REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); | |||||
| REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); | |||||
| REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); | |||||
| REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); | |||||
| GE_NS_END | GE_NS_END | ||||
| @@ -1,22 +1,25 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| #define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| class tensor_builder_utils {}; | |||||
| #endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| int main(int argc, char **argv) { | |||||
| ::GE_NS::CheckUtils::init(); | |||||
| testing::InitGoogleTest(&argc, argv); | |||||
| int ret = RUN_ALL_TESTS(); | |||||
| return ret; | |||||
| } | |||||
| @@ -1,48 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph_builder_utils.h" | |||||
| #include "inc/external/graph/operator.h" | |||||
| #include "inc/external/graph/operator_factory.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
| Format format, DataType data_type, std::vector<int64_t> shape) { | |||||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
| tensor_desc->SetShape(GeShape(std::move(shape))); | |||||
| tensor_desc->SetFormat(format); | |||||
| tensor_desc->SetDataType(data_type); | |||||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
| for (int i = 0; i < in_cnt; ++i) { | |||||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| for (int i = 0; i < out_cnt; ++i) { | |||||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
| } | |||||
| return graph_->AddNode(op_desc); | |||||
| } | |||||
| void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||||
| GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||||
| } | |||||
| void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||||
| GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||||
| } | |||||
| } // namespace st | |||||
| } // namespace ge | |||||
| @@ -1,55 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| #define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| namespace st { | |||||
| class ComputeGraphBuilder { | |||||
| public: | |||||
| explicit ComputeGraphBuilder(const std::string &name) { | |||||
| graph_ = std::make_shared<ComputeGraph>(name); | |||||
| } | |||||
| NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
| Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||||
| std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
| void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||||
| void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||||
| ComputeGraphPtr GetComputeGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return graph_; | |||||
| } | |||||
| Graph GetGraph() { | |||||
| graph_->TopologicalSorting(); | |||||
| return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||||
| } | |||||
| private: | |||||
| ComputeGraphPtr graph_; | |||||
| }; | |||||
| } // namespace st | |||||
| } // namespace ge | |||||
| #endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
| @@ -8,7 +8,7 @@ target_include_directories(graph_engine_test | |||||
| set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | ||||
| target_link_libraries(graph_engine_test PRIVATE gtest framework) | |||||
| target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env) | |||||
| include(CTest) | include(CTest) | ||||
| enable_testing() | enable_testing() | ||||
| @@ -15,23 +15,12 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include <map> | |||||
| #include "external/ge/ge_api.h" | #include "external/ge/ge_api.h" | ||||
| #include "ge_running_env/fake_engine.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "builder/graph_builder_utils.h" | |||||
| #include "ge_running_env/ge_running_env_faker.h" | #include "ge_running_env/ge_running_env_faker.h" | ||||
| #include "graph/operator_reg.h" | |||||
| #include "graph/operator.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "ge_graph_dsl/graph_dsl.h" | #include "ge_graph_dsl/graph_dsl.h" | ||||
| #undef protected | |||||
| #undef private | |||||
| #include "ge_graph_dsl/assert/graph_assert.h" | |||||
| using namespace std; | using namespace std; | ||||
| using namespace ge; | using namespace ge; | ||||
| @@ -57,76 +46,58 @@ namespace { | |||||
| * | * | ||||
| **/ | **/ | ||||
| Graph BuildV1ControlFlowGraph() { | Graph BuildV1ControlFlowGraph() { | ||||
| // build graph | |||||
| st::ComputeGraphBuilder graphBuilder("g1"); | |||||
| auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1); | |||||
| auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1); | |||||
| ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||||
| auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1); | |||||
| auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1); | |||||
| auto less = graphBuilder.AddNode("less", LESS, 2, 1); | |||||
| auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL); | |||||
| auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2); | |||||
| auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1); | |||||
| auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1); | |||||
| auto add = graphBuilder.AddNode("add", ADD, 2, 1); | |||||
| auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1); | |||||
| auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1); | |||||
| auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1); | |||||
| ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||||
| auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1); | |||||
| auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2); | |||||
| auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1); | |||||
| auto mul = graphBuilder.AddNode("mul", MUL, 2, 1); | |||||
| auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1); | |||||
| auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1); | |||||
| auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2); | |||||
| // i = i+1 | |||||
| graphBuilder.AddDataEdge(data_i, 0, enter_i, 0); | |||||
| graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0); | |||||
| graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1); | |||||
| graphBuilder.AddDataEdge(merge_i, 0, less, 0); | |||||
| graphBuilder.AddDataEdge(const_5, 0, less, 1); | |||||
| graphBuilder.AddDataEdge(less, 0, loopcond, 0); | |||||
| graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1); | |||||
| graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0); | |||||
| graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0); | |||||
| graphBuilder.AddDataEdge(switch_i, 1, add, 0); | |||||
| graphBuilder.AddDataEdge(const_1, 0, add, 1); | |||||
| graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0); | |||||
| graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1); | |||||
| // a=a*2 | |||||
| graphBuilder.AddDataEdge(data_a, 0, enter_a, 0); | |||||
| graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0); | |||||
| graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1); | |||||
| graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1); | |||||
| graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0); | |||||
| graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0); | |||||
| graphBuilder.AddDataEdge(switch_a, 1, mul, 0); | |||||
| graphBuilder.AddDataEdge(const_2, 0, mul, 1); | |||||
| graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0); | |||||
| graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0); | |||||
| // set const weight | |||||
| int64_t dims_size = 1; | int64_t dims_size = 1; | ||||
| vector<int64_t> data_vec = {5}; | vector<int64_t> data_vec = {5}; | ||||
| for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | ||||
| vector<int32_t> data_value_vec(dims_size, 1); | vector<int32_t> data_value_vec(dims_size, 1); | ||||
| GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | ||||
| GeTensorPtr data_tensor = | |||||
| make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); | |||||
| OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); | |||||
| OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); | |||||
| OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); | |||||
| GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), | |||||
| data_value_vec.size() * sizeof(int32_t)); | |||||
| return graphBuilder.GetGraph(); | |||||
| auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1"); | |||||
| auto const_op = OP_CFG(CONSTANT).Weight(data_tensor); | |||||
| DEF_GRAPH(g1) { | |||||
| CHAIN(NODE("data_i", DATA) | |||||
| ->NODE("enter_i", enter) | |||||
| ->EDGE(0, 0) | |||||
| ->NODE("merge_i", MERGE) | |||||
| ->NODE("less", LESS) | |||||
| ->NODE("loopcond", LOOPCOND)); | |||||
| CHAIN(NODE("const_1", const_op) | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("add", ADD) | |||||
| ->NODE("iteration_i", NEXTITERATION) | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("merge_i")); | |||||
| CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less")); | |||||
| CHAIN(NODE("loopcond") | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("switch_i", SWITCH) | |||||
| ->EDGE(0, 0) | |||||
| ->NODE("exit_i", EXIT) | |||||
| ->EDGE(0, 1) | |||||
| ->NODE("netoutput", NETOUTPUT)); | |||||
| CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add")); | |||||
| CHAIN(NODE("data_a", DATA) | |||||
| ->NODE("enter_a", enter) | |||||
| ->NODE("merge_a", MERGE) | |||||
| ->NODE("switch_a", SWITCH) | |||||
| ->NODE("exit_a", EXIT) | |||||
| ->EDGE(0, 0) | |||||
| ->NODE("netoutput")); | |||||
| CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a")); | |||||
| CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL)); | |||||
| CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a")); | |||||
| }; | |||||
| return ToGeGraph(g1); | |||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| class FrameworkTest : public testing::Test { | class FrameworkTest : public testing::Test { | ||||
| protected: | protected: | ||||
| GeRunningEnvFaker ge_env; | |||||
| void SetUp() { ge_env.InstallDefault(); } | void SetUp() { ge_env.InstallDefault(); } | ||||
| void TearDown() {} | void TearDown() {} | ||||
| GeRunningEnvFaker ge_env; | |||||
| }; | }; | ||||
| /// data data | /// data data | ||||
| @@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto graph = ToGeGraph(g1); | |||||
| // new session & add graph | |||||
| map<AscendString, AscendString> options; | map<AscendString, AscendString> options; | ||||
| Session session(options); | Session session(options); | ||||
| auto ret = session.AddGraph(1, graph, options); | |||||
| EXPECT_EQ(ret, SUCCESS); | |||||
| // build input tensor | |||||
| session.AddGraph(1, ToGeGraph(g1), options); | |||||
| std::vector<InputTensorInfo> inputs; | std::vector<InputTensorInfo> inputs; | ||||
| // build_graph through session | |||||
| ret = session.BuildGraph(1, inputs); | |||||
| auto ret = session.BuildGraph(1, inputs); | |||||
| EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
| CHECK_GRAPH(PreRunAfterBuild) { | |||||
| ASSERT_EQ(graph->GetName(), "g1_1"); | |||||
| ASSERT_EQ(graph->GetAllNodesSize(), 4); | |||||
| }; | |||||
| } | } | ||||
| /** data a = 2; | /** data a = 2; | ||||
| @@ -15,24 +15,12 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include "easy_graph/graph/box.h" | |||||
| #include "easy_graph/graph/node.h" | |||||
| #include "external/ge/ge_api.h" | |||||
| #include "easy_graph/builder/graph_dsl.h" | #include "easy_graph/builder/graph_dsl.h" | ||||
| #include "easy_graph/builder/box_builder.h" | |||||
| #include "easy_graph/layout/graph_layout.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "framework/common/types.h" | #include "framework/common/types.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/ge_local_context.h" | |||||
| #include "ge_graph_dsl/graph_dsl.h" | #include "ge_graph_dsl/graph_dsl.h" | ||||
| #include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | |||||
| #define protected public | |||||
| #define private public | |||||
| #include "ge_opt_info/ge_opt_info.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| namespace ge { | namespace ge { | ||||
| class STEST_opt_info : public testing::Test { | class STEST_opt_info : public testing::Test { | ||||
| @@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto graph = ToGeGraph(g1); | auto graph = ToGeGraph(g1); | ||||
| @@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) { | |||||
| DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
| CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
| }); | |||||
| }; | |||||
| auto graph = ToGeGraph(g1); | auto graph = ToGeGraph(g1); | ||||
| @@ -15,9 +15,8 @@ | |||||
| */ | */ | ||||
| #include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
| #include "common/debug/log.h" | |||||
| #include "external/ge/ge_api.h" | #include "external/ge/ge_api.h" | ||||
| #include "ge_graph_dsl/assert/check_utils.h" | |||||
| #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | ||||
| using namespace std; | using namespace std; | ||||
| @@ -31,6 +30,7 @@ int main(int argc, char **argv) { | |||||
| std::cout << "ge init failed , ret code:" << init_status << endl; | std::cout << "ge init failed , ret code:" << init_status << endl; | ||||
| } | } | ||||
| GeRunningEnvFaker::BackupEnv(); | GeRunningEnvFaker::BackupEnv(); | ||||
| CheckUtils::init(); | |||||
| testing::InitGoogleTest(&argc, argv); | testing::InitGoogleTest(&argc, argv); | ||||
| int ret = RUN_ALL_TESTS(); | int ret = RUN_ALL_TESTS(); | ||||
| return ret; | return ret; | ||||
| @@ -90,6 +90,7 @@ set(SRC_FILES | |||||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | ||||
| @@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES | |||||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||