| @@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true | |||
| ConstructorInitializerIndentWidth: 4 | |||
| ContinuationIndentWidth: 4 | |||
| Cpp11BracedListStyle: true | |||
| DerivePointerAlignment: true | |||
| DisableFormat: false | |||
| ExperimentalAutoDetectBinPacking: false | |||
| FixNamespaceComments: true | |||
| @@ -94,7 +93,7 @@ PenaltyBreakString: 1000 | |||
| PenaltyBreakTemplateDeclaration: 10 | |||
| PenaltyExcessCharacter: 1000000 | |||
| PenaltyReturnTypeOnItsOwnLine: 200 | |||
| PointerAlignment: Left | |||
| PointerAlignment: Right | |||
| RawStringFormats: | |||
| - Language: Cpp | |||
| Delimiters: | |||
| @@ -1 +1 @@ | |||
| Subproject commit f3f137de034885f0c7394d7f04b41b08d450d2d2 | |||
| Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc | |||
| @@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l | |||
| ENV PROJECT_HOME=/code/Turing/graphEngine | |||
| RUN mkdir /var/run/sshd | |||
| RUN echo "root:root" | chpasswd | |||
| RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config | |||
| RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd | |||
| ENV NOTVISIBLE "in users profile" | |||
| RUN echo "export VISIBLE=now" >> /etc/profile | |||
| EXPOSE 22 7777 | |||
| RUN useradd -ms /bin/bash debugger | |||
| RUN echo "debugger:ge123" | chpasswd | |||
| CMD ["/usr/sbin/sshd" "-D" "&"] | |||
| RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | |||
| @@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
| DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | |||
| DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | |||
| DOCKER_IMAGE_TAG=ge_build_env.1.0.6 | |||
| DOCKER_IMAGE_TAG=ge_build_env.1.0.9 | |||
| DOCKER_IAMGE_NAME=joycode2art/turing | |||
| DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | |||
| @@ -61,7 +61,7 @@ function enter_docker_env(){ | |||
| if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | |||
| echo "please run 'ge env --pull' to download images first!" | |||
| elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
| $docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||
| $docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||
| elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
| $docker_cmd start ${DOCKER_BUILD_ENV_NAME} | |||
| $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | |||
| @@ -60,6 +60,7 @@ set(SRCS | |||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||
| @@ -17,16 +17,3 @@ include(cmake/graphengine.cmake) | |||
| add_subdirectory(easy_graph) | |||
| add_subdirectory(ge_graph_dsl) | |||
| add_subdirectory(ge_running_env) | |||
| file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS | |||
| "utils/*.cc" | |||
| ) | |||
| add_library(framework STATIC ${UTILS_SRC}) | |||
| target_include_directories(framework | |||
| PUBLIC utils/ | |||
| ) | |||
| set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||
| target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env) | |||
| @@ -26,16 +26,32 @@ EG_NS_BEGIN | |||
| //////////////////////////////////////////////////////////////// | |||
| namespace detail { | |||
| template<typename GRAPH_BUILDER> | |||
| template <typename GRAPH_BUILDER> | |||
| Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | |||
| GraphBuilder builder(name); | |||
| builderInDSL(builder); | |||
| return std::move(*builder); | |||
| } | |||
| struct GraphDefiner { | |||
| GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) { | |||
| name = specifiedName ? specifiedName : defaultName; | |||
| } | |||
| template <typename USER_BUILDER> | |||
| auto operator|(USER_BUILDER &&userBuilder) { | |||
| GraphBuilder graphBuilder{name}; | |||
| std::forward<USER_BUILDER>(userBuilder)(graphBuilder); | |||
| return *graphBuilder; | |||
| } | |||
| private: | |||
| const char *name; | |||
| }; | |||
| } // namespace detail | |||
| #define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__) | |||
| #define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER) | |||
| #define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER) | |||
| #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | |||
| #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | |||
| #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | |||
| @@ -16,10 +16,15 @@ | |||
| #include "easy_graph/layout/graph_layout.h" | |||
| #include "easy_graph/layout/layout_executor.h" | |||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||
| #include "easy_graph/graph/graph.h" | |||
| EG_NS_BEGIN | |||
| namespace { | |||
| GraphEasyExecutor default_executor; | |||
| } | |||
| void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | |||
| this->executor_ = &executor; | |||
| options_ = opts; | |||
| @@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | |||
| Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | |||
| const LayoutOption *options = opts ? opts : this->options_; | |||
| if (!executor_) | |||
| return EG_UNIMPLEMENTED; | |||
| if (!executor_) return static_cast<LayoutExecutor &>(default_executor).Layout(graph, options); | |||
| return executor_->Layout(graph, options); | |||
| } | |||
| @@ -0,0 +1,37 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef D52AA06185E34BBFB714FFBCDAB0D53A | |||
| #define D52AA06185E34BBFB714FFBCDAB0D53A | |||
| #include "ge_graph_dsl/ge.h" | |||
| #include <exception> | |||
| #include <string> | |||
| GE_NS_BEGIN | |||
| struct AssertError : std::exception { | |||
| AssertError(const char *file, int line, const std::string &info); | |||
| private: | |||
| const char *what() const noexcept override; | |||
| private: | |||
| std::string info; | |||
| }; | |||
| GE_NS_END | |||
| #endif | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_31309AA0A4E44C009C22AD9351BF3410 | |||
| #define INC_31309AA0A4E44C009C22AD9351BF3410 | |||
| #include "ge_graph_dsl/ge.h" | |||
| #include "graph/compute_graph.h" | |||
| GE_NS_BEGIN | |||
| using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||
| struct CheckUtils { | |||
| static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun); | |||
| static void init(); | |||
| }; | |||
| GE_NS_END | |||
| #endif | |||
| @@ -1,17 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "tensor_builder_utils.h" | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef C8B32320BD4943D588594B82FFBF2685 | |||
| #define C8B32320BD4943D588594B82FFBF2685 | |||
| #include <vector> | |||
| #include <string> | |||
| #include "ge_graph_dsl/ge.h" | |||
| GE_NS_BEGIN | |||
| struct FilterScopeGuard { | |||
| FilterScopeGuard(const std::vector<std::string> &); | |||
| ~FilterScopeGuard(); | |||
| }; | |||
| GE_NS_END | |||
| #endif | |||
| @@ -0,0 +1,59 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||
| #define AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||
| #include "ge_graph_dsl/ge.h" | |||
| #include "ge_graph_dsl/assert/check_utils.h" | |||
| #include "ge_graph_dsl/assert/assert_error.h" | |||
| #include "ge_graph_dsl/assert/filter_scope_guard.h" | |||
| GE_NS_BEGIN | |||
| #ifdef GTEST_MESSAGE_AT_ | |||
| #define GRAPH_CHECK_MESSAGE(file, line, message) \ | |||
| GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure) | |||
| #elif | |||
| #define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message) | |||
| #endif | |||
| namespace detail { | |||
| struct GraphAssert { | |||
| GraphAssert(const char *file, unsigned int line, const std::string &phase_id) | |||
| : file_(file), line_(line), phase_id_(phase_id) {} | |||
| void operator|(const ::GE_NS::GraphCheckFun &check_fun) { | |||
| bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun); | |||
| if (!ret) { | |||
| auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! "; | |||
| GRAPH_CHECK_MESSAGE(file_, line_, message.c_str()); | |||
| } | |||
| } | |||
| private: | |||
| const char *file_; | |||
| unsigned int line_; | |||
| const std::string phase_id_; | |||
| }; | |||
| } // namespace detail | |||
| #define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__}); | |||
| #define CHECK_GRAPH(phase_id) \ | |||
| ::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph) | |||
| GE_NS_END | |||
| #endif | |||
| @@ -33,14 +33,12 @@ struct OpDescCfg { | |||
| std::vector<int64_t> shape_; | |||
| }; | |||
| OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW, | |||
| OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW, | |||
| DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | |||
| : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | |||
| protected: | |||
| OpType GetType() const { | |||
| return type_; | |||
| } | |||
| OpType GetType() const { return type_; } | |||
| OpType type_; | |||
| int in_cnt_; | |||
| int out_cnt_; | |||
| @@ -0,0 +1,26 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_graph_dsl/assert/assert_error.h" | |||
| GE_NS_BEGIN | |||
| AssertError::AssertError(const char *file, int line, const std::string &info) { | |||
| this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info; | |||
| } | |||
| const char *AssertError::what() const noexcept { return info.c_str(); } | |||
| GE_NS_END | |||
| @@ -0,0 +1,34 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_graph_dsl/assert/check_utils.h" | |||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||
| #include "ge_graph_default_checker.h" | |||
| #include "ge_graph_check_dumper.h" | |||
| GE_NS_BEGIN | |||
| bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) { | |||
| auto &dumper = dynamic_cast<GeGraphCheckDumper &>(GraphDumperRegistry::GetDumper()); | |||
| return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun)); | |||
| } | |||
| void CheckUtils::init() { | |||
| static GeGraphCheckDumper checkDumper; | |||
| GraphDumperRegistry::Register(checkDumper); | |||
| } | |||
| GE_NS_END | |||
| @@ -0,0 +1,31 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_graph_dsl/assert/filter_scope_guard.h" | |||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||
| #include "ge_dump_filter.h" | |||
| GE_NS_BEGIN | |||
| namespace { | |||
| GeDumpFilter &GetDumpFilter() { return dynamic_cast<GeDumpFilter &>(GraphDumperRegistry::GetDumper()); } | |||
| } // namespace | |||
| FilterScopeGuard::FilterScopeGuard(const std::vector<std::string> &filter) { GetDumpFilter().Update(filter); } | |||
| FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); } | |||
| GE_NS_END | |||
| @@ -0,0 +1,33 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||
| #define INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||
| #include <vector> | |||
| #include <string> | |||
| #include "ge_graph_dsl/ge.h" | |||
| #include "easy_graph/infra/keywords.h" | |||
| GE_NS_BEGIN | |||
| INTERFACE(GeDumpFilter) { | |||
| ABSTRACT(void Update(const std::vector<std::string> &)); | |||
| ABSTRACT(void Reset()); | |||
| }; | |||
| GE_NS_END | |||
| #endif | |||
| @@ -0,0 +1,79 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_graph_check_dumper.h" | |||
| #include "graph/model.h" | |||
| #include "graph/buffer.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "ge_graph_default_checker.h" | |||
| GE_NS_BEGIN | |||
| GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); } | |||
| bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const { | |||
| auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix); | |||
| return (iter != suffixes_.end()); | |||
| } | |||
| void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) { | |||
| if (!IsNeedDump(suffix)) { | |||
| return; | |||
| } | |||
| auto iter = buffers_.find(suffix); | |||
| if (iter != buffers_.end()) { | |||
| DumpGraph(graph, iter->second); | |||
| } else { | |||
| buffers_[suffix] = Buffer(); | |||
| DumpGraph(graph, buffers_.at(suffix)); | |||
| } | |||
| } | |||
| bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) { | |||
| auto iter = buffers_.find(checker.PhaseId()); | |||
| if (iter == buffers_.end()) { | |||
| return false; | |||
| } | |||
| DoCheck(checker, iter->second); | |||
| return true; | |||
| } | |||
| void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) { | |||
| Model model("", ""); | |||
| Model::Load(buffer.GetData(), buffer.GetSize(), model); | |||
| auto load_graph = model.GetGraph(); | |||
| checker.Check(GraphUtils::GetComputeGraph(load_graph)); | |||
| } | |||
| void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) { | |||
| Model model("", ""); | |||
| buffer.clear(); | |||
| model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||
| model.Save(buffer, true); | |||
| } | |||
| void GeGraphCheckDumper::Update(const std::vector<std::string> &new_suffixes_) { | |||
| suffixes_ = new_suffixes_; | |||
| buffers_.clear(); | |||
| } | |||
| void GeGraphCheckDumper::Reset() { | |||
| static std::vector<std::string> default_suffixes_{"PreRunAfterBuild"}; | |||
| suffixes_ = default_suffixes_; | |||
| buffers_.clear(); | |||
| } | |||
| GE_NS_END | |||
| @@ -0,0 +1,49 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_8EFED0015C27464897BF64531355C810 | |||
| #define INC_8EFED0015C27464897BF64531355C810 | |||
| #include "ge_graph_dsl/ge.h" | |||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||
| #include "ge_dump_filter.h" | |||
| #include <string> | |||
| GE_NS_BEGIN | |||
| struct GeGraphChecker; | |||
| struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter { | |||
| GeGraphCheckDumper(); | |||
| virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix); | |||
| bool CheckFor(const GeGraphChecker &checker); | |||
| private: | |||
| void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer); | |||
| void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer); | |||
| private: | |||
| void Update(const std::vector<std::string> &) override; | |||
| void Reset() override; | |||
| bool IsNeedDump(const std::string &suffix) const; | |||
| private: | |||
| std::map<std::string, ::GE_NS::Buffer> buffers_; | |||
| std::vector<std::string> suffixes_; | |||
| }; | |||
| GE_NS_END | |||
| #endif | |||
| @@ -0,0 +1,32 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef INC_5960A8F437324904BEE0690271258762 | |||
| #define INC_5960A8F437324904BEE0690271258762 | |||
| #include "ge_graph_dsl/ge.h" | |||
| #include "easy_graph/infra/keywords.h" | |||
| #include "graph/compute_graph.h" | |||
| GE_NS_BEGIN | |||
| INTERFACE(GeGraphChecker) { | |||
| ABSTRACT(const std::string &PhaseId() const); | |||
| ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const); | |||
| }; | |||
| GE_NS_END | |||
| #endif | |||
| @@ -0,0 +1,28 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "ge_graph_default_checker.h" | |||
| GE_NS_BEGIN | |||
| GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun) | |||
| : phase_id_(phase_id), check_fun_(check_fun) {} | |||
| const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; } | |||
| void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); } | |||
| GE_NS_END | |||
| @@ -0,0 +1,41 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef BCF4D96BE9FC48938DE7B7E93B551C54 | |||
| #define BCF4D96BE9FC48938DE7B7E93B551C54 | |||
| #include "ge_graph_dsl/ge.h" | |||
| #include "ge_graph_checker.h" | |||
| #include "graph/compute_graph.h" | |||
| GE_NS_BEGIN | |||
| using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||
| struct GeGraphDefaultChecker : GeGraphChecker { | |||
| GeGraphDefaultChecker(const std::string &, const GraphCheckFun &); | |||
| private: | |||
| const std::string &PhaseId() const override; | |||
| void Check(const ge::ComputeGraphPtr &graph) const override; | |||
| private: | |||
| const std::string phase_id_; | |||
| const GraphCheckFun check_fun_; | |||
| }; | |||
| GE_NS_END | |||
| #endif | |||
| @@ -23,15 +23,22 @@ GE_NS_BEGIN | |||
| namespace { | |||
| #define OP_CFG(optype, ...) \ | |||
| { \ | |||
| optype, OpDescCfg { \ | |||
| optype, __VA_ARGS__ \ | |||
| } \ | |||
| #define OP_CFG(optype, ...) \ | |||
| { \ | |||
| optype, OpDescCfg { optype, __VA_ARGS__ } \ | |||
| } | |||
| static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}), | |||
| OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
| OP_CFG(VARIABLE, 1, 1)}; | |||
| } // namespace | |||
| @@ -19,6 +19,4 @@ | |||
| USING_GE_NS | |||
| OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { | |||
| return op_; | |||
| } | |||
| OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; } | |||
| @@ -36,17 +36,11 @@ GE_NS_BEGIN | |||
| GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | |||
| void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { | |||
| build_graph_ = graph; | |||
| } | |||
| void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; } | |||
| Graph GeGraphVisitor::BuildGeGraph() const { | |||
| return GraphUtils::CreateGraphFromComputeGraph(build_graph_); | |||
| } | |||
| Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); } | |||
| ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { | |||
| return build_graph_; | |||
| } | |||
| ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; } | |||
| Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | |||
| build_graph_->SetName(graph.GetName()); | |||
| @@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE | |||
| ) | |||
| set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | |||
| target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl) | |||
| target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl) | |||
| include(CTest) | |||
| enable_testing() | |||
| @@ -0,0 +1,129 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "gtest/gtest.h" | |||
| #include "easy_graph/layout/graph_layout.h" | |||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||
| #include "ge_graph_dsl/graph_dsl.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/utils/dumper/ge_graph_dumper.h" | |||
| #include "framework/common/types.h" | |||
| #include "ge_graph_dsl/assert/graph_assert.h" | |||
| #include "graph/model.h" | |||
| #include "graph/buffer.h" | |||
| USING_GE_NS | |||
| class CheckGraphTest : public testing::Test { | |||
| private: | |||
| EG_NS::GraphEasyExecutor executor; | |||
| protected: | |||
| void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||
| void TearDown() {} | |||
| }; | |||
| TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) { | |||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| DUMP_GRAPH_WHEN("after_build"); | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||
| CHECK_GRAPH(after_build) { | |||
| ASSERT_EQ(graph->GetName(), "g1"); | |||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||
| }; | |||
| } | |||
| TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) { | |||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| DEF_GRAPH(g2) { | |||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||
| }; | |||
| DUMP_GRAPH_WHEN("before_build", "after_build"); | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build"); | |||
| CHECK_GRAPH(before_build) { | |||
| ASSERT_EQ(graph->GetName(), "g1"); | |||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||
| }; | |||
| CHECK_GRAPH(after_build) { | |||
| ASSERT_EQ(graph->GetName(), "g2"); | |||
| ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||
| }; | |||
| } | |||
| TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) { | |||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| DEF_GRAPH(g2) { | |||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||
| }; | |||
| DUMP_GRAPH_WHEN("before_build") | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build"); | |||
| CHECK_GRAPH(before_build) { | |||
| ASSERT_EQ(graph->GetName(), "g2"); | |||
| ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||
| }; | |||
| } | |||
| TEST_F(CheckGraphTest, test_check_phases_is_work) { | |||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| DUMP_GRAPH_WHEN("before_build"); | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||
| auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {}); | |||
| ASSERT_FALSE(ret); | |||
| } | |||
| TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) { | |||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| DUMP_GRAPH_WHEN("before_build"); | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||
| GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||
| CHECK_GRAPH(before_build) { | |||
| ASSERT_EQ(graph->GetName(), "g1"); | |||
| ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||
| }; | |||
| } | |||
| TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) { | |||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| auto ge_graph = ToGeGraph(g1); | |||
| ge::Model model("", ""); | |||
| model.SetGraph(ge_graph); | |||
| Buffer buffer; | |||
| model.Save(buffer, true); | |||
| ge::Model loadModel("", ""); | |||
| Model::Load(buffer.GetData(), buffer.GetSize(), loadModel); | |||
| auto load_graph = loadModel.GetGraph(); | |||
| ASSERT_EQ(load_graph.GetName(), "g1"); | |||
| ASSERT_EQ(load_graph.GetAllNodes().size(), 2); | |||
| } | |||
| @@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test { | |||
| EG_NS::GraphEasyExecutor executor; | |||
| protected: | |||
| void SetUp() { | |||
| EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); | |||
| } | |||
| void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||
| void TearDown() {} | |||
| }; | |||
| TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| }); | |||
| DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| auto computeGraph = ToComputeGraph(g1); | |||
| @@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | |||
| } | |||
| TEST_F(GraphDslTest, test_build_graph_with_name) { | |||
| DEF_GRAPH(g1, "sample_graph") { | |||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| }); | |||
| DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) { | |||
| auto data = std::make_shared<OpDesc>("data1", DATA); | |||
| auto add = std::make_shared<OpDesc>("Add", ADD); | |||
| CHAIN(NODE(data)->NODE(add)); | |||
| }); | |||
| }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||
| auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | |||
| auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | |||
| CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | |||
| }); | |||
| }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||
| } | |||
| TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); | |||
| }); | |||
| DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | |||
| } | |||
| TEST_F(GraphDslTest, test_build_from_control_chain) { | |||
| DEF_GRAPH(g1) { | |||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| }); | |||
| DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) { | |||
| } | |||
| TEST_F(GraphDslTest, test_build_from_data_chain) { | |||
| DEF_GRAPH(g1) { | |||
| DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| }); | |||
| DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) { | |||
| DEF_GRAPH(g1) { | |||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | |||
| }); | |||
| }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) { | |||
| DEF_GRAPH(g1) { | |||
| CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | |||
| }); | |||
| }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) { | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CHAIN(NODE("data2", DATA)->NODE("add")); | |||
| }); | |||
| }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) { | |||
| ->NODE("Add4") | |||
| ->NODE("Add5") | |||
| ->NODE("net_output", NETOUTPUT)); | |||
| }); | |||
| }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) { | |||
| CHAIN(NODE("add2")->NODE("net_output")); | |||
| CHAIN(NODE("add3")->NODE("net_output")); | |||
| CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | |||
| }); | |||
| }; | |||
| auto geGraph = ToGeGraph(g1); | |||
| @@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) { | |||
| DEF_GRAPH(sub_1) { | |||
| CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | |||
| CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | |||
| }); | |||
| }; | |||
| DEF_GRAPH(sub_2) { | |||
| CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | |||
| CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | |||
| }); | |||
| }; | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | |||
| CHAIN(NODE("data_i", DATA)->NODE("while")); | |||
| }); | |||
| }; | |||
| sub_1.Layout(); | |||
| sub_2.Layout(); | |||
| @@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul"); | |||
| REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | |||
| REGISTER_OPTYPE_DEFINE(ADD, "Add"); | |||
| REGISTER_OPTYPE_DEFINE(WHILE, "While"); | |||
| REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); | |||
| REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); | |||
| REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); | |||
| REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); | |||
| REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); | |||
| REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); | |||
| GE_NS_END | |||
| @@ -1,22 +1,25 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
| #define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
| class tensor_builder_utils {}; | |||
| #endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include "ge_graph_dsl/assert/check_utils.h" | |||
| int main(int argc, char **argv) { | |||
| ::GE_NS::CheckUtils::init(); | |||
| testing::InitGoogleTest(&argc, argv); | |||
| int ret = RUN_ALL_TESTS(); | |||
| return ret; | |||
| } | |||
| @@ -1,48 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #include "graph_builder_utils.h" | |||
| #include "inc/external/graph/operator.h" | |||
| #include "inc/external/graph/operator_factory.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| namespace ge { | |||
| namespace st { | |||
| NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||
| Format format, DataType data_type, std::vector<int64_t> shape) { | |||
| auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||
| tensor_desc->SetShape(GeShape(std::move(shape))); | |||
| tensor_desc->SetFormat(format); | |||
| tensor_desc->SetDataType(data_type); | |||
| auto op_desc = std::make_shared<OpDesc>(name, type); | |||
| for (int i = 0; i < in_cnt; ++i) { | |||
| op_desc->AddInputDesc(tensor_desc->Clone()); | |||
| } | |||
| for (int i = 0; i < out_cnt; ++i) { | |||
| op_desc->AddOutputDesc(tensor_desc->Clone()); | |||
| } | |||
| return graph_->AddNode(op_desc); | |||
| } | |||
| void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||
| GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||
| } | |||
| void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||
| GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||
| } | |||
| } // namespace st | |||
| } // namespace ge | |||
| @@ -1,55 +0,0 @@ | |||
| /** | |||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||
| * | |||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||
| * you may not use this file except in compliance with the License. | |||
| * You may obtain a copy of the License at | |||
| * | |||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||
| * | |||
| * Unless required by applicable law or agreed to in writing, software | |||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| * See the License for the specific language governing permissions and | |||
| * limitations under the License. | |||
| */ | |||
| #ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||
| #define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||
| #include <string> | |||
| #include <vector> | |||
| #include "graph/compute_graph.h" | |||
| #include "graph/utils/graph_utils.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/node.h" | |||
| namespace ge { | |||
| namespace st { | |||
| class ComputeGraphBuilder { | |||
| public: | |||
| explicit ComputeGraphBuilder(const std::string &name) { | |||
| graph_ = std::make_shared<ComputeGraph>(name); | |||
| } | |||
| NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||
| Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||
| std::vector<int64_t> shape = {1, 1, 224, 224}); | |||
| void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||
| void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||
| ComputeGraphPtr GetComputeGraph() { | |||
| graph_->TopologicalSorting(); | |||
| return graph_; | |||
| } | |||
| Graph GetGraph() { | |||
| graph_->TopologicalSorting(); | |||
| return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||
| } | |||
| private: | |||
| ComputeGraphPtr graph_; | |||
| }; | |||
| } // namespace st | |||
| } // namespace ge | |||
| #endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||
| @@ -8,7 +8,7 @@ target_include_directories(graph_engine_test | |||
| set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | |||
| target_link_libraries(graph_engine_test PRIVATE gtest framework) | |||
| target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env) | |||
| include(CTest) | |||
| enable_testing() | |||
| @@ -15,23 +15,12 @@ | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include <map> | |||
| #include "external/ge/ge_api.h" | |||
| #include "ge_running_env/fake_engine.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "framework/common/types.h" | |||
| #include "builder/graph_builder_utils.h" | |||
| #include "ge_running_env/ge_running_env_faker.h" | |||
| #include "graph/operator_reg.h" | |||
| #include "graph/operator.h" | |||
| #define protected public | |||
| #define private public | |||
| #include "graph/utils/op_desc_utils.h" | |||
| #include "ge_graph_dsl/graph_dsl.h" | |||
| #undef protected | |||
| #undef private | |||
| #include "ge_graph_dsl/assert/graph_assert.h" | |||
| using namespace std; | |||
| using namespace ge; | |||
| @@ -57,76 +46,58 @@ namespace { | |||
| * | |||
| **/ | |||
| Graph BuildV1ControlFlowGraph() { | |||
| // build graph | |||
| st::ComputeGraphBuilder graphBuilder("g1"); | |||
| auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1); | |||
| auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1); | |||
| ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||
| auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1); | |||
| auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1); | |||
| auto less = graphBuilder.AddNode("less", LESS, 2, 1); | |||
| auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL); | |||
| auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2); | |||
| auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1); | |||
| auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1); | |||
| auto add = graphBuilder.AddNode("add", ADD, 2, 1); | |||
| auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1); | |||
| auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1); | |||
| auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1); | |||
| ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||
| auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1); | |||
| auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2); | |||
| auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1); | |||
| auto mul = graphBuilder.AddNode("mul", MUL, 2, 1); | |||
| auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1); | |||
| auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1); | |||
| auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2); | |||
| // i = i+1 | |||
| graphBuilder.AddDataEdge(data_i, 0, enter_i, 0); | |||
| graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0); | |||
| graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1); | |||
| graphBuilder.AddDataEdge(merge_i, 0, less, 0); | |||
| graphBuilder.AddDataEdge(const_5, 0, less, 1); | |||
| graphBuilder.AddDataEdge(less, 0, loopcond, 0); | |||
| graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1); | |||
| graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0); | |||
| graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0); | |||
| graphBuilder.AddDataEdge(switch_i, 1, add, 0); | |||
| graphBuilder.AddDataEdge(const_1, 0, add, 1); | |||
| graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0); | |||
| graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1); | |||
| // a=a*2 | |||
| graphBuilder.AddDataEdge(data_a, 0, enter_a, 0); | |||
| graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0); | |||
| graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1); | |||
| graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1); | |||
| graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0); | |||
| graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0); | |||
| graphBuilder.AddDataEdge(switch_a, 1, mul, 0); | |||
| graphBuilder.AddDataEdge(const_2, 0, mul, 1); | |||
| graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0); | |||
| graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0); | |||
| // set const weight | |||
| int64_t dims_size = 1; | |||
| vector<int64_t> data_vec = {5}; | |||
| for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | |||
| vector<int32_t> data_value_vec(dims_size, 1); | |||
| GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | |||
| GeTensorPtr data_tensor = | |||
| make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); | |||
| OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); | |||
| OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); | |||
| OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); | |||
| GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), | |||
| data_value_vec.size() * sizeof(int32_t)); | |||
| return graphBuilder.GetGraph(); | |||
| auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1"); | |||
| auto const_op = OP_CFG(CONSTANT).Weight(data_tensor); | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data_i", DATA) | |||
| ->NODE("enter_i", enter) | |||
| ->EDGE(0, 0) | |||
| ->NODE("merge_i", MERGE) | |||
| ->NODE("less", LESS) | |||
| ->NODE("loopcond", LOOPCOND)); | |||
| CHAIN(NODE("const_1", const_op) | |||
| ->EDGE(0, 1) | |||
| ->NODE("add", ADD) | |||
| ->NODE("iteration_i", NEXTITERATION) | |||
| ->EDGE(0, 1) | |||
| ->NODE("merge_i")); | |||
| CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less")); | |||
| CHAIN(NODE("loopcond") | |||
| ->EDGE(0, 1) | |||
| ->NODE("switch_i", SWITCH) | |||
| ->EDGE(0, 0) | |||
| ->NODE("exit_i", EXIT) | |||
| ->EDGE(0, 1) | |||
| ->NODE("netoutput", NETOUTPUT)); | |||
| CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add")); | |||
| CHAIN(NODE("data_a", DATA) | |||
| ->NODE("enter_a", enter) | |||
| ->NODE("merge_a", MERGE) | |||
| ->NODE("switch_a", SWITCH) | |||
| ->NODE("exit_a", EXIT) | |||
| ->EDGE(0, 0) | |||
| ->NODE("netoutput")); | |||
| CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a")); | |||
| CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL)); | |||
| CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a")); | |||
| }; | |||
| return ToGeGraph(g1); | |||
| } | |||
| } // namespace | |||
| class FrameworkTest : public testing::Test { | |||
| protected: | |||
| GeRunningEnvFaker ge_env; | |||
| void SetUp() { ge_env.InstallDefault(); } | |||
| void TearDown() {} | |||
| GeRunningEnvFaker ge_env; | |||
| }; | |||
| /// data data | |||
| @@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) { | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CHAIN(NODE("data2", DATA)->NODE("add")); | |||
| }); | |||
| }; | |||
| auto graph = ToGeGraph(g1); | |||
| // new session & add graph | |||
| map<AscendString, AscendString> options; | |||
| Session session(options); | |||
| auto ret = session.AddGraph(1, graph, options); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| // build input tensor | |||
| session.AddGraph(1, ToGeGraph(g1), options); | |||
| std::vector<InputTensorInfo> inputs; | |||
| // build_graph through session | |||
| ret = session.BuildGraph(1, inputs); | |||
| auto ret = session.BuildGraph(1, inputs); | |||
| EXPECT_EQ(ret, SUCCESS); | |||
| CHECK_GRAPH(PreRunAfterBuild) { | |||
| ASSERT_EQ(graph->GetName(), "g1_1"); | |||
| ASSERT_EQ(graph->GetAllNodesSize(), 4); | |||
| }; | |||
| } | |||
| /** data a = 2; | |||
| @@ -15,24 +15,12 @@ | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include "easy_graph/graph/box.h" | |||
| #include "easy_graph/graph/node.h" | |||
| #include "external/ge/ge_api.h" | |||
| #include "easy_graph/builder/graph_dsl.h" | |||
| #include "easy_graph/builder/box_builder.h" | |||
| #include "easy_graph/layout/graph_layout.h" | |||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||
| #include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||
| #include "graph/graph.h" | |||
| #include "graph/compute_graph.h" | |||
| #include "framework/common/types.h" | |||
| #include "graph/debug/ge_attr_define.h" | |||
| #include "graph/ge_local_context.h" | |||
| #include "ge_graph_dsl/graph_dsl.h" | |||
| #include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | |||
| #define protected public | |||
| #define private public | |||
| #include "ge_opt_info/ge_opt_info.h" | |||
| #undef private | |||
| #undef protected | |||
| namespace ge { | |||
| class STEST_opt_info : public testing::Test { | |||
| @@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) { | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CHAIN(NODE("data2", DATA)->NODE("add")); | |||
| }); | |||
| }; | |||
| auto graph = ToGeGraph(g1); | |||
| @@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) { | |||
| DEF_GRAPH(g1) { | |||
| CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
| CHAIN(NODE("data2", DATA)->NODE("add")); | |||
| }); | |||
| }; | |||
| auto graph = ToGeGraph(g1); | |||
| @@ -15,9 +15,8 @@ | |||
| */ | |||
| #include <gtest/gtest.h> | |||
| #include "common/debug/log.h" | |||
| #include "external/ge/ge_api.h" | |||
| #include "ge_graph_dsl/assert/check_utils.h" | |||
| #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | |||
| using namespace std; | |||
| @@ -31,6 +30,7 @@ int main(int argc, char **argv) { | |||
| std::cout << "ge init failed , ret code:" << init_status << endl; | |||
| } | |||
| GeRunningEnvFaker::BackupEnv(); | |||
| CheckUtils::init(); | |||
| testing::InitGoogleTest(&argc, argv); | |||
| int ret = RUN_ALL_TESTS(); | |||
| return ret; | |||
| @@ -90,6 +90,7 @@ set(SRC_FILES | |||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||
| @@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES | |||
| "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
| "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||