@@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true | |||||
ConstructorInitializerIndentWidth: 4 | ConstructorInitializerIndentWidth: 4 | ||||
ContinuationIndentWidth: 4 | ContinuationIndentWidth: 4 | ||||
Cpp11BracedListStyle: true | Cpp11BracedListStyle: true | ||||
DerivePointerAlignment: true | |||||
DisableFormat: false | DisableFormat: false | ||||
ExperimentalAutoDetectBinPacking: false | ExperimentalAutoDetectBinPacking: false | ||||
FixNamespaceComments: true | FixNamespaceComments: true | ||||
@@ -94,7 +93,7 @@ PenaltyBreakString: 1000 | |||||
PenaltyBreakTemplateDeclaration: 10 | PenaltyBreakTemplateDeclaration: 10 | ||||
PenaltyExcessCharacter: 1000000 | PenaltyExcessCharacter: 1000000 | ||||
PenaltyReturnTypeOnItsOwnLine: 200 | PenaltyReturnTypeOnItsOwnLine: 200 | ||||
PointerAlignment: Left | |||||
PointerAlignment: Right | |||||
RawStringFormats: | RawStringFormats: | ||||
- Language: Cpp | - Language: Cpp | ||||
Delimiters: | Delimiters: | ||||
@@ -1 +1 @@ | |||||
Subproject commit f3f137de034885f0c7394d7f04b41b08d450d2d2 | |||||
Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc |
@@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l | |||||
ENV PROJECT_HOME=/code/Turing/graphEngine | ENV PROJECT_HOME=/code/Turing/graphEngine | ||||
RUN mkdir /var/run/sshd | |||||
RUN echo "root:root" | chpasswd | |||||
RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config | |||||
RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd | |||||
ENV NOTVISIBLE "in users profile" | |||||
RUN echo "export VISIBLE=now" >> /etc/profile | |||||
EXPOSE 22 7777 | |||||
RUN useradd -ms /bin/bash debugger | |||||
RUN echo "debugger:ge123" | chpasswd | |||||
CMD ["/usr/sbin/sshd" "-D" "&"] | |||||
RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | ||||
@@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||||
DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | ||||
DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | ||||
DOCKER_IMAGE_TAG=ge_build_env.1.0.6 | |||||
DOCKER_IMAGE_TAG=ge_build_env.1.0.9 | |||||
DOCKER_IAMGE_NAME=joycode2art/turing | DOCKER_IAMGE_NAME=joycode2art/turing | ||||
DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | ||||
@@ -61,7 +61,7 @@ function enter_docker_env(){ | |||||
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | ||||
echo "please run 'ge env --pull' to download images first!" | echo "please run 'ge env --pull' to download images first!" | ||||
elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | ||||
$docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||||
$docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||||
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | ||||
$docker_cmd start ${DOCKER_BUILD_ENV_NAME} | $docker_cmd start ${DOCKER_BUILD_ENV_NAME} | ||||
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | $docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | ||||
@@ -60,6 +60,7 @@ set(SRCS | |||||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | ||||
@@ -17,16 +17,3 @@ include(cmake/graphengine.cmake) | |||||
add_subdirectory(easy_graph) | add_subdirectory(easy_graph) | ||||
add_subdirectory(ge_graph_dsl) | add_subdirectory(ge_graph_dsl) | ||||
add_subdirectory(ge_running_env) | add_subdirectory(ge_running_env) | ||||
file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS | |||||
"utils/*.cc" | |||||
) | |||||
add_library(framework STATIC ${UTILS_SRC}) | |||||
target_include_directories(framework | |||||
PUBLIC utils/ | |||||
) | |||||
set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||||
target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env) |
@@ -26,16 +26,32 @@ EG_NS_BEGIN | |||||
//////////////////////////////////////////////////////////////// | //////////////////////////////////////////////////////////////// | ||||
namespace detail { | namespace detail { | ||||
template<typename GRAPH_BUILDER> | |||||
template <typename GRAPH_BUILDER> | |||||
Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | ||||
GraphBuilder builder(name); | GraphBuilder builder(name); | ||||
builderInDSL(builder); | builderInDSL(builder); | ||||
return std::move(*builder); | return std::move(*builder); | ||||
} | } | ||||
struct GraphDefiner { | |||||
GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) { | |||||
name = specifiedName ? specifiedName : defaultName; | |||||
} | |||||
template <typename USER_BUILDER> | |||||
auto operator|(USER_BUILDER &&userBuilder) { | |||||
GraphBuilder graphBuilder{name}; | |||||
std::forward<USER_BUILDER>(userBuilder)(graphBuilder); | |||||
return *graphBuilder; | |||||
} | |||||
private: | |||||
const char *name; | |||||
}; | |||||
} // namespace detail | } // namespace detail | ||||
#define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__) | |||||
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER) | |||||
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER) | |||||
#define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | #define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | ||||
#define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | #define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | ||||
#define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | #define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | ||||
@@ -16,10 +16,15 @@ | |||||
#include "easy_graph/layout/graph_layout.h" | #include "easy_graph/layout/graph_layout.h" | ||||
#include "easy_graph/layout/layout_executor.h" | #include "easy_graph/layout/layout_executor.h" | ||||
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
#include "easy_graph/graph/graph.h" | #include "easy_graph/graph/graph.h" | ||||
EG_NS_BEGIN | EG_NS_BEGIN | ||||
namespace { | |||||
GraphEasyExecutor default_executor; | |||||
} | |||||
void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | ||||
this->executor_ = &executor; | this->executor_ = &executor; | ||||
options_ = opts; | options_ = opts; | ||||
@@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | |||||
Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | ||||
const LayoutOption *options = opts ? opts : this->options_; | const LayoutOption *options = opts ? opts : this->options_; | ||||
if (!executor_) | |||||
return EG_UNIMPLEMENTED; | |||||
if (!executor_) return static_cast<LayoutExecutor &>(default_executor).Layout(graph, options); | |||||
return executor_->Layout(graph, options); | return executor_->Layout(graph, options); | ||||
} | } | ||||
@@ -0,0 +1,37 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef D52AA06185E34BBFB714FFBCDAB0D53A | |||||
#define D52AA06185E34BBFB714FFBCDAB0D53A | |||||
#include "ge_graph_dsl/ge.h" | |||||
#include <exception> | |||||
#include <string> | |||||
GE_NS_BEGIN | |||||
struct AssertError : std::exception { | |||||
AssertError(const char *file, int line, const std::string &info); | |||||
private: | |||||
const char *what() const noexcept override; | |||||
private: | |||||
std::string info; | |||||
}; | |||||
GE_NS_END | |||||
#endif |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_31309AA0A4E44C009C22AD9351BF3410 | |||||
#define INC_31309AA0A4E44C009C22AD9351BF3410 | |||||
#include "ge_graph_dsl/ge.h" | |||||
#include "graph/compute_graph.h" | |||||
GE_NS_BEGIN | |||||
using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||||
struct CheckUtils { | |||||
static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun); | |||||
static void init(); | |||||
}; | |||||
GE_NS_END | |||||
#endif |
@@ -1,17 +1,32 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "tensor_builder_utils.h" | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef C8B32320BD4943D588594B82FFBF2685 | |||||
#define C8B32320BD4943D588594B82FFBF2685 | |||||
#include <vector> | |||||
#include <string> | |||||
#include "ge_graph_dsl/ge.h" | |||||
GE_NS_BEGIN | |||||
struct FilterScopeGuard { | |||||
FilterScopeGuard(const std::vector<std::string> &); | |||||
~FilterScopeGuard(); | |||||
}; | |||||
GE_NS_END | |||||
#endif |
@@ -0,0 +1,59 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||||
#define AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||||
#include "ge_graph_dsl/ge.h" | |||||
#include "ge_graph_dsl/assert/check_utils.h" | |||||
#include "ge_graph_dsl/assert/assert_error.h" | |||||
#include "ge_graph_dsl/assert/filter_scope_guard.h" | |||||
GE_NS_BEGIN | |||||
#ifdef GTEST_MESSAGE_AT_ | |||||
#define GRAPH_CHECK_MESSAGE(file, line, message) \ | |||||
GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure) | |||||
#elif | |||||
#define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message) | |||||
#endif | |||||
namespace detail { | |||||
struct GraphAssert { | |||||
GraphAssert(const char *file, unsigned int line, const std::string &phase_id) | |||||
: file_(file), line_(line), phase_id_(phase_id) {} | |||||
void operator|(const ::GE_NS::GraphCheckFun &check_fun) { | |||||
bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun); | |||||
if (!ret) { | |||||
auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! "; | |||||
GRAPH_CHECK_MESSAGE(file_, line_, message.c_str()); | |||||
} | |||||
} | |||||
private: | |||||
const char *file_; | |||||
unsigned int line_; | |||||
const std::string phase_id_; | |||||
}; | |||||
} // namespace detail | |||||
#define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__}); | |||||
#define CHECK_GRAPH(phase_id) \ | |||||
::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph) | |||||
GE_NS_END | |||||
#endif |
@@ -33,14 +33,12 @@ struct OpDescCfg { | |||||
std::vector<int64_t> shape_; | std::vector<int64_t> shape_; | ||||
}; | }; | ||||
OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW, | |||||
OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW, | |||||
DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | ||||
: type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | : type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | ||||
protected: | protected: | ||||
OpType GetType() const { | |||||
return type_; | |||||
} | |||||
OpType GetType() const { return type_; } | |||||
OpType type_; | OpType type_; | ||||
int in_cnt_; | int in_cnt_; | ||||
int out_cnt_; | int out_cnt_; | ||||
@@ -0,0 +1,26 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_graph_dsl/assert/assert_error.h" | |||||
GE_NS_BEGIN | |||||
AssertError::AssertError(const char *file, int line, const std::string &info) { | |||||
this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info; | |||||
} | |||||
const char *AssertError::what() const noexcept { return info.c_str(); } | |||||
GE_NS_END |
@@ -0,0 +1,34 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_graph_dsl/assert/check_utils.h" | |||||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||||
#include "ge_graph_default_checker.h" | |||||
#include "ge_graph_check_dumper.h" | |||||
GE_NS_BEGIN | |||||
bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) { | |||||
auto &dumper = dynamic_cast<GeGraphCheckDumper &>(GraphDumperRegistry::GetDumper()); | |||||
return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun)); | |||||
} | |||||
void CheckUtils::init() { | |||||
static GeGraphCheckDumper checkDumper; | |||||
GraphDumperRegistry::Register(checkDumper); | |||||
} | |||||
GE_NS_END |
@@ -0,0 +1,31 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_graph_dsl/assert/filter_scope_guard.h" | |||||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||||
#include "ge_dump_filter.h" | |||||
GE_NS_BEGIN | |||||
namespace { | |||||
GeDumpFilter &GetDumpFilter() { return dynamic_cast<GeDumpFilter &>(GraphDumperRegistry::GetDumper()); } | |||||
} // namespace | |||||
FilterScopeGuard::FilterScopeGuard(const std::vector<std::string> &filter) { GetDumpFilter().Update(filter); } | |||||
FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); } | |||||
GE_NS_END |
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||||
#define INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||||
#include <vector> | |||||
#include <string> | |||||
#include "ge_graph_dsl/ge.h" | |||||
#include "easy_graph/infra/keywords.h" | |||||
GE_NS_BEGIN | |||||
INTERFACE(GeDumpFilter) { | |||||
ABSTRACT(void Update(const std::vector<std::string> &)); | |||||
ABSTRACT(void Reset()); | |||||
}; | |||||
GE_NS_END | |||||
#endif |
@@ -0,0 +1,79 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_graph_check_dumper.h" | |||||
#include "graph/model.h" | |||||
#include "graph/buffer.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "ge_graph_default_checker.h" | |||||
GE_NS_BEGIN | |||||
GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); } | |||||
bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const { | |||||
auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix); | |||||
return (iter != suffixes_.end()); | |||||
} | |||||
void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) { | |||||
if (!IsNeedDump(suffix)) { | |||||
return; | |||||
} | |||||
auto iter = buffers_.find(suffix); | |||||
if (iter != buffers_.end()) { | |||||
DumpGraph(graph, iter->second); | |||||
} else { | |||||
buffers_[suffix] = Buffer(); | |||||
DumpGraph(graph, buffers_.at(suffix)); | |||||
} | |||||
} | |||||
bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) { | |||||
auto iter = buffers_.find(checker.PhaseId()); | |||||
if (iter == buffers_.end()) { | |||||
return false; | |||||
} | |||||
DoCheck(checker, iter->second); | |||||
return true; | |||||
} | |||||
void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) { | |||||
Model model("", ""); | |||||
Model::Load(buffer.GetData(), buffer.GetSize(), model); | |||||
auto load_graph = model.GetGraph(); | |||||
checker.Check(GraphUtils::GetComputeGraph(load_graph)); | |||||
} | |||||
void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) { | |||||
Model model("", ""); | |||||
buffer.clear(); | |||||
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||||
model.Save(buffer, true); | |||||
} | |||||
void GeGraphCheckDumper::Update(const std::vector<std::string> &new_suffixes_) { | |||||
suffixes_ = new_suffixes_; | |||||
buffers_.clear(); | |||||
} | |||||
void GeGraphCheckDumper::Reset() { | |||||
static std::vector<std::string> default_suffixes_{"PreRunAfterBuild"}; | |||||
suffixes_ = default_suffixes_; | |||||
buffers_.clear(); | |||||
} | |||||
GE_NS_END |
@@ -0,0 +1,49 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_8EFED0015C27464897BF64531355C810 | |||||
#define INC_8EFED0015C27464897BF64531355C810 | |||||
#include "ge_graph_dsl/ge.h" | |||||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||||
#include "ge_dump_filter.h" | |||||
#include <string> | |||||
GE_NS_BEGIN | |||||
struct GeGraphChecker; | |||||
struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter { | |||||
GeGraphCheckDumper(); | |||||
virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix); | |||||
bool CheckFor(const GeGraphChecker &checker); | |||||
private: | |||||
void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer); | |||||
void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer); | |||||
private: | |||||
void Update(const std::vector<std::string> &) override; | |||||
void Reset() override; | |||||
bool IsNeedDump(const std::string &suffix) const; | |||||
private: | |||||
std::map<std::string, ::GE_NS::Buffer> buffers_; | |||||
std::vector<std::string> suffixes_; | |||||
}; | |||||
GE_NS_END | |||||
#endif |
@@ -0,0 +1,32 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_5960A8F437324904BEE0690271258762 | |||||
#define INC_5960A8F437324904BEE0690271258762 | |||||
#include "ge_graph_dsl/ge.h" | |||||
#include "easy_graph/infra/keywords.h" | |||||
#include "graph/compute_graph.h" | |||||
GE_NS_BEGIN | |||||
INTERFACE(GeGraphChecker) { | |||||
ABSTRACT(const std::string &PhaseId() const); | |||||
ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const); | |||||
}; | |||||
GE_NS_END | |||||
#endif |
@@ -0,0 +1,28 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "ge_graph_default_checker.h" | |||||
GE_NS_BEGIN | |||||
GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun) | |||||
: phase_id_(phase_id), check_fun_(check_fun) {} | |||||
const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; } | |||||
void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); } | |||||
GE_NS_END |
@@ -0,0 +1,41 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef BCF4D96BE9FC48938DE7B7E93B551C54 | |||||
#define BCF4D96BE9FC48938DE7B7E93B551C54 | |||||
#include "ge_graph_dsl/ge.h" | |||||
#include "ge_graph_checker.h" | |||||
#include "graph/compute_graph.h" | |||||
GE_NS_BEGIN | |||||
using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||||
struct GeGraphDefaultChecker : GeGraphChecker { | |||||
GeGraphDefaultChecker(const std::string &, const GraphCheckFun &); | |||||
private: | |||||
const std::string &PhaseId() const override; | |||||
void Check(const ge::ComputeGraphPtr &graph) const override; | |||||
private: | |||||
const std::string phase_id_; | |||||
const GraphCheckFun check_fun_; | |||||
}; | |||||
GE_NS_END | |||||
#endif |
@@ -23,15 +23,22 @@ GE_NS_BEGIN | |||||
namespace { | namespace { | ||||
#define OP_CFG(optype, ...) \ | |||||
{ \ | |||||
optype, OpDescCfg { \ | |||||
optype, __VA_ARGS__ \ | |||||
} \ | |||||
#define OP_CFG(optype, ...) \ | |||||
{ \ | |||||
optype, OpDescCfg { optype, __VA_ARGS__ } \ | |||||
} | } | ||||
static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | ||||
OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | ||||
OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}), | |||||
OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||||
OP_CFG(VARIABLE, 1, 1)}; | OP_CFG(VARIABLE, 1, 1)}; | ||||
} // namespace | } // namespace | ||||
@@ -19,6 +19,4 @@ | |||||
USING_GE_NS | USING_GE_NS | ||||
OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { | |||||
return op_; | |||||
} | |||||
OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; } |
@@ -36,17 +36,11 @@ GE_NS_BEGIN | |||||
GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | ||||
void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { | |||||
build_graph_ = graph; | |||||
} | |||||
void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; } | |||||
Graph GeGraphVisitor::BuildGeGraph() const { | |||||
return GraphUtils::CreateGraphFromComputeGraph(build_graph_); | |||||
} | |||||
Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); } | |||||
ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { | |||||
return build_graph_; | |||||
} | |||||
ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; } | |||||
Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | ||||
build_graph_->SetName(graph.GetName()); | build_graph_->SetName(graph.GetName()); |
@@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE | |||||
) | ) | ||||
set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | ||||
target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl) | |||||
target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl) | |||||
include(CTest) | include(CTest) | ||||
enable_testing() | enable_testing() |
@@ -0,0 +1,129 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "gtest/gtest.h" | |||||
#include "easy_graph/layout/graph_layout.h" | |||||
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||||
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
#include "ge_graph_dsl/graph_dsl.h" | |||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||||
#include "framework/common/types.h" | |||||
#include "ge_graph_dsl/assert/graph_assert.h" | |||||
#include "graph/model.h" | |||||
#include "graph/buffer.h" | |||||
USING_GE_NS | |||||
class CheckGraphTest : public testing::Test { | |||||
private: | |||||
EG_NS::GraphEasyExecutor executor; | |||||
protected: | |||||
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||||
void TearDown() {} | |||||
}; | |||||
TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) { | |||||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
DUMP_GRAPH_WHEN("after_build"); | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
CHECK_GRAPH(after_build) { | |||||
ASSERT_EQ(graph->GetName(), "g1"); | |||||
ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
}; | |||||
} | |||||
TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) { | |||||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
DEF_GRAPH(g2) { | |||||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||||
}; | |||||
DUMP_GRAPH_WHEN("before_build", "after_build"); | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build"); | |||||
CHECK_GRAPH(before_build) { | |||||
ASSERT_EQ(graph->GetName(), "g1"); | |||||
ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
}; | |||||
CHECK_GRAPH(after_build) { | |||||
ASSERT_EQ(graph->GetName(), "g2"); | |||||
ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||||
}; | |||||
} | |||||
TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) { | |||||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
DEF_GRAPH(g2) { | |||||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||||
}; | |||||
DUMP_GRAPH_WHEN("before_build") | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build"); | |||||
CHECK_GRAPH(before_build) { | |||||
ASSERT_EQ(graph->GetName(), "g2"); | |||||
ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||||
}; | |||||
} | |||||
TEST_F(CheckGraphTest, test_check_phases_is_work) { | |||||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
DUMP_GRAPH_WHEN("before_build"); | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {}); | |||||
ASSERT_FALSE(ret); | |||||
} | |||||
TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) { | |||||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
DUMP_GRAPH_WHEN("before_build"); | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||||
CHECK_GRAPH(before_build) { | |||||
ASSERT_EQ(graph->GetName(), "g1"); | |||||
ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||||
}; | |||||
} | |||||
TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) { | |||||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
auto ge_graph = ToGeGraph(g1); | |||||
ge::Model model("", ""); | |||||
model.SetGraph(ge_graph); | |||||
Buffer buffer; | |||||
model.Save(buffer, true); | |||||
ge::Model loadModel("", ""); | |||||
Model::Load(buffer.GetData(), buffer.GetSize(), loadModel); | |||||
auto load_graph = loadModel.GetGraph(); | |||||
ASSERT_EQ(load_graph.GetName(), "g1"); | |||||
ASSERT_EQ(load_graph.GetAllNodes().size(), 2); | |||||
} |
@@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test { | |||||
EG_NS::GraphEasyExecutor executor; | EG_NS::GraphEasyExecutor executor; | ||||
protected: | protected: | ||||
void SetUp() { | |||||
EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); | |||||
} | |||||
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||||
void TearDown() {} | void TearDown() {} | ||||
}; | }; | ||||
TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | ||||
DEF_GRAPH(g1) { | |||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
}); | |||||
DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
auto computeGraph = ToComputeGraph(g1); | auto computeGraph = ToComputeGraph(g1); | ||||
@@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | |||||
} | } | ||||
TEST_F(GraphDslTest, test_build_graph_with_name) { | TEST_F(GraphDslTest, test_build_graph_with_name) { | ||||
DEF_GRAPH(g1, "sample_graph") { | |||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
}); | |||||
DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) { | |||||
auto data = std::make_shared<OpDesc>("data1", DATA); | auto data = std::make_shared<OpDesc>("data1", DATA); | ||||
auto add = std::make_shared<OpDesc>("Add", ADD); | auto add = std::make_shared<OpDesc>("Add", ADD); | ||||
CHAIN(NODE(data)->NODE(add)); | CHAIN(NODE(data)->NODE(add)); | ||||
}); | |||||
}; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||||
auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | ||||
auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | ||||
CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | ||||
}); | |||||
}; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||||
} | } | ||||
TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | ||||
DEF_GRAPH(g1) { | |||||
CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); | |||||
}); | |||||
DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); }; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | |||||
} | } | ||||
TEST_F(GraphDslTest, test_build_from_control_chain) { | TEST_F(GraphDslTest, test_build_from_control_chain) { | ||||
DEF_GRAPH(g1) { | |||||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
}); | |||||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) { | |||||
} | } | ||||
TEST_F(GraphDslTest, test_build_from_data_chain) { | TEST_F(GraphDslTest, test_build_from_data_chain) { | ||||
DEF_GRAPH(g1) { | |||||
DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||||
}); | |||||
DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) { | |||||
DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | ||||
}); | |||||
}; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) { | |||||
DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | ||||
}); | |||||
}; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) { | |||||
DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
}); | |||||
}; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) { | |||||
->NODE("Add4") | ->NODE("Add4") | ||||
->NODE("Add5") | ->NODE("Add5") | ||||
->NODE("net_output", NETOUTPUT)); | ->NODE("net_output", NETOUTPUT)); | ||||
}); | |||||
}; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) { | |||||
CHAIN(NODE("add2")->NODE("net_output")); | CHAIN(NODE("add2")->NODE("net_output")); | ||||
CHAIN(NODE("add3")->NODE("net_output")); | CHAIN(NODE("add3")->NODE("net_output")); | ||||
CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | ||||
}); | |||||
}; | |||||
auto geGraph = ToGeGraph(g1); | auto geGraph = ToGeGraph(g1); | ||||
@@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) { | |||||
DEF_GRAPH(sub_1) { | DEF_GRAPH(sub_1) { | ||||
CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | ||||
CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | ||||
}); | |||||
}; | |||||
DEF_GRAPH(sub_2) { | DEF_GRAPH(sub_2) { | ||||
CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | ||||
CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | ||||
}); | |||||
}; | |||||
DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | ||||
CHAIN(NODE("data_i", DATA)->NODE("while")); | CHAIN(NODE("data_i", DATA)->NODE("while")); | ||||
}); | |||||
}; | |||||
sub_1.Layout(); | sub_1.Layout(); | ||||
sub_2.Layout(); | sub_2.Layout(); | ||||
@@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul"); | |||||
REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | ||||
REGISTER_OPTYPE_DEFINE(ADD, "Add"); | REGISTER_OPTYPE_DEFINE(ADD, "Add"); | ||||
REGISTER_OPTYPE_DEFINE(WHILE, "While"); | REGISTER_OPTYPE_DEFINE(WHILE, "While"); | ||||
REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); | |||||
REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); | |||||
REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); | |||||
REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); | |||||
REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); | |||||
REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); | |||||
GE_NS_END | GE_NS_END |
@@ -1,22 +1,25 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
class tensor_builder_utils {}; | |||||
#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include <gtest/gtest.h> | |||||
#include "ge_graph_dsl/assert/check_utils.h" | |||||
int main(int argc, char **argv) { | |||||
::GE_NS::CheckUtils::init(); | |||||
testing::InitGoogleTest(&argc, argv); | |||||
int ret = RUN_ALL_TESTS(); | |||||
return ret; | |||||
} |
@@ -1,48 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph_builder_utils.h" | |||||
#include "inc/external/graph/operator.h" | |||||
#include "inc/external/graph/operator_factory.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
namespace ge { | |||||
namespace st { | |||||
NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
Format format, DataType data_type, std::vector<int64_t> shape) { | |||||
auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
tensor_desc->SetShape(GeShape(std::move(shape))); | |||||
tensor_desc->SetFormat(format); | |||||
tensor_desc->SetDataType(data_type); | |||||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||||
for (int i = 0; i < in_cnt; ++i) { | |||||
op_desc->AddInputDesc(tensor_desc->Clone()); | |||||
} | |||||
for (int i = 0; i < out_cnt; ++i) { | |||||
op_desc->AddOutputDesc(tensor_desc->Clone()); | |||||
} | |||||
return graph_->AddNode(op_desc); | |||||
} | |||||
void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||||
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||||
} | |||||
void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||||
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||||
} | |||||
} // namespace st | |||||
} // namespace ge |
@@ -1,55 +0,0 @@ | |||||
/** | |||||
* Copyright 2021 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||||
#include <string> | |||||
#include <vector> | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/utils/graph_utils.h" | |||||
#include "graph/graph.h" | |||||
#include "graph/node.h" | |||||
namespace ge { | |||||
namespace st { | |||||
class ComputeGraphBuilder { | |||||
public: | |||||
explicit ComputeGraphBuilder(const std::string &name) { | |||||
graph_ = std::make_shared<ComputeGraph>(name); | |||||
} | |||||
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||||
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||||
std::vector<int64_t> shape = {1, 1, 224, 224}); | |||||
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||||
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||||
ComputeGraphPtr GetComputeGraph() { | |||||
graph_->TopologicalSorting(); | |||||
return graph_; | |||||
} | |||||
Graph GetGraph() { | |||||
graph_->TopologicalSorting(); | |||||
return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||||
} | |||||
private: | |||||
ComputeGraphPtr graph_; | |||||
}; | |||||
} // namespace st | |||||
} // namespace ge | |||||
#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H |
@@ -8,7 +8,7 @@ target_include_directories(graph_engine_test | |||||
set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | ||||
target_link_libraries(graph_engine_test PRIVATE gtest framework) | |||||
target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env) | |||||
include(CTest) | include(CTest) | ||||
enable_testing() | enable_testing() |
@@ -15,23 +15,12 @@ | |||||
*/ | */ | ||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include <map> | |||||
#include "external/ge/ge_api.h" | #include "external/ge/ge_api.h" | ||||
#include "ge_running_env/fake_engine.h" | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "builder/graph_builder_utils.h" | |||||
#include "ge_running_env/ge_running_env_faker.h" | #include "ge_running_env/ge_running_env_faker.h" | ||||
#include "graph/operator_reg.h" | |||||
#include "graph/operator.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "graph/utils/op_desc_utils.h" | |||||
#include "ge_graph_dsl/graph_dsl.h" | #include "ge_graph_dsl/graph_dsl.h" | ||||
#undef protected | |||||
#undef private | |||||
#include "ge_graph_dsl/assert/graph_assert.h" | |||||
using namespace std; | using namespace std; | ||||
using namespace ge; | using namespace ge; | ||||
@@ -57,76 +46,58 @@ namespace { | |||||
* | * | ||||
**/ | **/ | ||||
Graph BuildV1ControlFlowGraph() { | Graph BuildV1ControlFlowGraph() { | ||||
// build graph | |||||
st::ComputeGraphBuilder graphBuilder("g1"); | |||||
auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1); | |||||
auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1); | |||||
ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||||
auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1); | |||||
auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1); | |||||
auto less = graphBuilder.AddNode("less", LESS, 2, 1); | |||||
auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL); | |||||
auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2); | |||||
auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1); | |||||
auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1); | |||||
auto add = graphBuilder.AddNode("add", ADD, 2, 1); | |||||
auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1); | |||||
auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1); | |||||
auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1); | |||||
ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||||
auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1); | |||||
auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2); | |||||
auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1); | |||||
auto mul = graphBuilder.AddNode("mul", MUL, 2, 1); | |||||
auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1); | |||||
auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1); | |||||
auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2); | |||||
// i = i+1 | |||||
graphBuilder.AddDataEdge(data_i, 0, enter_i, 0); | |||||
graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0); | |||||
graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1); | |||||
graphBuilder.AddDataEdge(merge_i, 0, less, 0); | |||||
graphBuilder.AddDataEdge(const_5, 0, less, 1); | |||||
graphBuilder.AddDataEdge(less, 0, loopcond, 0); | |||||
graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1); | |||||
graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0); | |||||
graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0); | |||||
graphBuilder.AddDataEdge(switch_i, 1, add, 0); | |||||
graphBuilder.AddDataEdge(const_1, 0, add, 1); | |||||
graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0); | |||||
graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1); | |||||
// a=a*2 | |||||
graphBuilder.AddDataEdge(data_a, 0, enter_a, 0); | |||||
graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0); | |||||
graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1); | |||||
graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1); | |||||
graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0); | |||||
graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0); | |||||
graphBuilder.AddDataEdge(switch_a, 1, mul, 0); | |||||
graphBuilder.AddDataEdge(const_2, 0, mul, 1); | |||||
graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0); | |||||
graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0); | |||||
// set const weight | |||||
int64_t dims_size = 1; | int64_t dims_size = 1; | ||||
vector<int64_t> data_vec = {5}; | vector<int64_t> data_vec = {5}; | ||||
for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | ||||
vector<int32_t> data_value_vec(dims_size, 1); | vector<int32_t> data_value_vec(dims_size, 1); | ||||
GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | ||||
GeTensorPtr data_tensor = | |||||
make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); | |||||
OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); | |||||
OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); | |||||
OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); | |||||
GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), | |||||
data_value_vec.size() * sizeof(int32_t)); | |||||
return graphBuilder.GetGraph(); | |||||
auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1"); | |||||
auto const_op = OP_CFG(CONSTANT).Weight(data_tensor); | |||||
DEF_GRAPH(g1) { | |||||
CHAIN(NODE("data_i", DATA) | |||||
->NODE("enter_i", enter) | |||||
->EDGE(0, 0) | |||||
->NODE("merge_i", MERGE) | |||||
->NODE("less", LESS) | |||||
->NODE("loopcond", LOOPCOND)); | |||||
CHAIN(NODE("const_1", const_op) | |||||
->EDGE(0, 1) | |||||
->NODE("add", ADD) | |||||
->NODE("iteration_i", NEXTITERATION) | |||||
->EDGE(0, 1) | |||||
->NODE("merge_i")); | |||||
CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less")); | |||||
CHAIN(NODE("loopcond") | |||||
->EDGE(0, 1) | |||||
->NODE("switch_i", SWITCH) | |||||
->EDGE(0, 0) | |||||
->NODE("exit_i", EXIT) | |||||
->EDGE(0, 1) | |||||
->NODE("netoutput", NETOUTPUT)); | |||||
CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add")); | |||||
CHAIN(NODE("data_a", DATA) | |||||
->NODE("enter_a", enter) | |||||
->NODE("merge_a", MERGE) | |||||
->NODE("switch_a", SWITCH) | |||||
->NODE("exit_a", EXIT) | |||||
->EDGE(0, 0) | |||||
->NODE("netoutput")); | |||||
CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a")); | |||||
CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL)); | |||||
CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a")); | |||||
}; | |||||
return ToGeGraph(g1); | |||||
} | } | ||||
} // namespace | } // namespace | ||||
class FrameworkTest : public testing::Test { | class FrameworkTest : public testing::Test { | ||||
protected: | protected: | ||||
GeRunningEnvFaker ge_env; | |||||
void SetUp() { ge_env.InstallDefault(); } | void SetUp() { ge_env.InstallDefault(); } | ||||
void TearDown() {} | void TearDown() {} | ||||
GeRunningEnvFaker ge_env; | |||||
}; | }; | ||||
/// data data | /// data data | ||||
@@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) { | |||||
DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
}); | |||||
}; | |||||
auto graph = ToGeGraph(g1); | |||||
// new session & add graph | |||||
map<AscendString, AscendString> options; | map<AscendString, AscendString> options; | ||||
Session session(options); | Session session(options); | ||||
auto ret = session.AddGraph(1, graph, options); | |||||
EXPECT_EQ(ret, SUCCESS); | |||||
// build input tensor | |||||
session.AddGraph(1, ToGeGraph(g1), options); | |||||
std::vector<InputTensorInfo> inputs; | std::vector<InputTensorInfo> inputs; | ||||
// build_graph through session | |||||
ret = session.BuildGraph(1, inputs); | |||||
auto ret = session.BuildGraph(1, inputs); | |||||
EXPECT_EQ(ret, SUCCESS); | EXPECT_EQ(ret, SUCCESS); | ||||
CHECK_GRAPH(PreRunAfterBuild) { | |||||
ASSERT_EQ(graph->GetName(), "g1_1"); | |||||
ASSERT_EQ(graph->GetAllNodesSize(), 4); | |||||
}; | |||||
} | } | ||||
/** data a = 2; | /** data a = 2; | ||||
@@ -15,24 +15,12 @@ | |||||
*/ | */ | ||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include "easy_graph/graph/box.h" | |||||
#include "easy_graph/graph/node.h" | |||||
#include "external/ge/ge_api.h" | |||||
#include "easy_graph/builder/graph_dsl.h" | #include "easy_graph/builder/graph_dsl.h" | ||||
#include "easy_graph/builder/box_builder.h" | |||||
#include "easy_graph/layout/graph_layout.h" | |||||
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||||
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||||
#include "graph/graph.h" | |||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "framework/common/types.h" | #include "framework/common/types.h" | ||||
#include "graph/debug/ge_attr_define.h" | |||||
#include "graph/ge_local_context.h" | |||||
#include "ge_graph_dsl/graph_dsl.h" | #include "ge_graph_dsl/graph_dsl.h" | ||||
#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | |||||
#define protected public | |||||
#define private public | |||||
#include "ge_opt_info/ge_opt_info.h" | |||||
#undef private | |||||
#undef protected | |||||
namespace ge { | namespace ge { | ||||
class STEST_opt_info : public testing::Test { | class STEST_opt_info : public testing::Test { | ||||
@@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) { | |||||
DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
}); | |||||
}; | |||||
auto graph = ToGeGraph(g1); | auto graph = ToGeGraph(g1); | ||||
@@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) { | |||||
DEF_GRAPH(g1) { | DEF_GRAPH(g1) { | ||||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | ||||
CHAIN(NODE("data2", DATA)->NODE("add")); | CHAIN(NODE("data2", DATA)->NODE("add")); | ||||
}); | |||||
}; | |||||
auto graph = ToGeGraph(g1); | auto graph = ToGeGraph(g1); | ||||
@@ -15,9 +15,8 @@ | |||||
*/ | */ | ||||
#include <gtest/gtest.h> | #include <gtest/gtest.h> | ||||
#include "common/debug/log.h" | |||||
#include "external/ge/ge_api.h" | #include "external/ge/ge_api.h" | ||||
#include "ge_graph_dsl/assert/check_utils.h" | |||||
#include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | #include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | ||||
using namespace std; | using namespace std; | ||||
@@ -31,6 +30,7 @@ int main(int argc, char **argv) { | |||||
std::cout << "ge init failed , ret code:" << init_status << endl; | std::cout << "ge init failed , ret code:" << init_status << endl; | ||||
} | } | ||||
GeRunningEnvFaker::BackupEnv(); | GeRunningEnvFaker::BackupEnv(); | ||||
CheckUtils::init(); | |||||
testing::InitGoogleTest(&argc, argv); | testing::InitGoogleTest(&argc, argv); | ||||
int ret = RUN_ALL_TESTS(); | int ret = RUN_ALL_TESTS(); | ||||
return ret; | return ret; | ||||
@@ -90,6 +90,7 @@ set(SRC_FILES | |||||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | ||||
@@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES | |||||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | "${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||||
"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | ||||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | "${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | ||||