@@ -52,7 +52,6 @@ ConstructorInitializerAllOnOneLineOrOnePerLine: true | |||
ConstructorInitializerIndentWidth: 4 | |||
ContinuationIndentWidth: 4 | |||
Cpp11BracedListStyle: true | |||
DerivePointerAlignment: true | |||
DisableFormat: false | |||
ExperimentalAutoDetectBinPacking: false | |||
FixNamespaceComments: true | |||
@@ -94,7 +93,7 @@ PenaltyBreakString: 1000 | |||
PenaltyBreakTemplateDeclaration: 10 | |||
PenaltyExcessCharacter: 1000000 | |||
PenaltyReturnTypeOnItsOwnLine: 200 | |||
PointerAlignment: Left | |||
PointerAlignment: Right | |||
RawStringFormats: | |||
- Language: Cpp | |||
Delimiters: | |||
@@ -1 +1 @@ | |||
Subproject commit f3f137de034885f0c7394d7f04b41b08d450d2d2 | |||
Subproject commit 9c9907b76a457f456072af96b8cbcfb7943beccc |
@@ -38,5 +38,20 @@ RUN wget https://github.com/ccup/lcov/archive/refs/tags/add_lcov.tar.gz -O add_l | |||
ENV PROJECT_HOME=/code/Turing/graphEngine | |||
RUN mkdir /var/run/sshd | |||
RUN echo "root:root" | chpasswd | |||
RUN sed -i 's/\#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config | |||
RUN sed 's@session\s*required\s*pam_loginuid.so@session optional pam_loginuid.so@g' -i /etc/pam.d/sshd | |||
ENV NOTVISIBLE "in users profile" | |||
RUN echo "export VISIBLE=now" >> /etc/profile | |||
EXPOSE 22 7777 | |||
RUN useradd -ms /bin/bash debugger | |||
RUN echo "debugger:ge123" | chpasswd | |||
CMD ["/usr/sbin/sshd" "-D" "&"] | |||
RUN echo "alias ge=/code/Turing/graphEngine/scripts/ge.sh">>~/.bashrc | |||
@@ -21,7 +21,7 @@ MOUNT_PROJECT_HOME=$(cd $PROJECT_HOME || return; pwd) | |||
DOCKER_BUILD_ENV_NAME=${MOUNT_PROJECT_HOME#*/} | |||
DOCKER_BUILD_ENV_NAME=${DOCKER_BUILD_ENV_NAME//\//\_} | |||
DOCKER_IMAGE_TAG=ge_build_env.1.0.6 | |||
DOCKER_IMAGE_TAG=ge_build_env.1.0.9 | |||
DOCKER_IAMGE_NAME=joycode2art/turing | |||
DOCKER_FULL_IMAGE_NAME=${DOCKER_IAMGE_NAME}:${DOCKER_IMAGE_TAG} | |||
@@ -61,7 +61,7 @@ function enter_docker_env(){ | |||
if test -z "$(docker images |grep ${DOCKER_IAMGE_NAME} | grep ${DOCKER_IMAGE_TAG})"; then | |||
echo "please run 'ge env --pull' to download images first!" | |||
elif test -z "$(docker ps -a |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
$docker_cmd run -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||
$docker_cmd run -p 7002:22 -p 7003:7777 --privileged=true -it -v ${MOUNT_PROJECT_HOME}:/code/Turing/graphEngine --workdir ${docker_work_dir} --name ${DOCKER_BUILD_ENV_NAME} ${DOCKER_FULL_IMAGE_NAME} ${docker_bash_dir} | |||
elif test -z "$(docker ps |grep ${DOCKER_BUILD_ENV_NAME})"; then | |||
$docker_cmd start ${DOCKER_BUILD_ENV_NAME} | |||
$docker_cmd exec -w ${docker_work_dir} -it ${DOCKER_BUILD_ENV_NAME} ${docker_bash_dir} | |||
@@ -60,6 +60,7 @@ set(SRCS | |||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||
@@ -17,16 +17,3 @@ include(cmake/graphengine.cmake) | |||
add_subdirectory(easy_graph) | |||
add_subdirectory(ge_graph_dsl) | |||
add_subdirectory(ge_running_env) | |||
file(GLOB_RECURSE UTILS_SRC CONFIGURE_DEPENDS | |||
"utils/*.cc" | |||
) | |||
add_library(framework STATIC ${UTILS_SRC}) | |||
target_include_directories(framework | |||
PUBLIC utils/ | |||
) | |||
set_target_properties(framework PROPERTIES CXX_STANDARD 11) | |||
target_link_libraries(framework PUBLIC ge_graph_dsl ge_with_env) |
@@ -26,16 +26,32 @@ EG_NS_BEGIN | |||
//////////////////////////////////////////////////////////////// | |||
namespace detail { | |||
template<typename GRAPH_BUILDER> | |||
template <typename GRAPH_BUILDER> | |||
Graph BuildGraph(const char *name, GRAPH_BUILDER builderInDSL) { | |||
GraphBuilder builder(name); | |||
builderInDSL(builder); | |||
return std::move(*builder); | |||
} | |||
struct GraphDefiner { | |||
GraphDefiner(const char *defaultName, const char *specifiedName = nullptr) { | |||
name = specifiedName ? specifiedName : defaultName; | |||
} | |||
template <typename USER_BUILDER> | |||
auto operator|(USER_BUILDER &&userBuilder) { | |||
GraphBuilder graphBuilder{name}; | |||
std::forward<USER_BUILDER>(userBuilder)(graphBuilder); | |||
return *graphBuilder; | |||
} | |||
private: | |||
const char *name; | |||
}; | |||
} // namespace detail | |||
#define HAS_NAME(...) NOT_EMPTY_SELECT(__VA_ARGS__) | |||
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::BuildGraph(HAS_NAME(__VA_ARGS__)(__VA_ARGS__, #G), [&](::EG_NS::GraphBuilder& BUILDER) | |||
#define DEF_GRAPH(G, ...) ::EG_NS::Graph G = ::EG_NS::detail::GraphDefiner(#G, ##__VA_ARGS__) | [&](auto &&BUILDER) | |||
#define DATA_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::DATA)->__VA_ARGS__ | |||
#define CTRL_CHAIN(...) ::EG_NS::ChainBuilder(BUILDER, ::EG_NS::EdgeType::CTRL)->__VA_ARGS__ | |||
#define CHAIN(...) DATA_CHAIN(__VA_ARGS__) | |||
@@ -16,10 +16,15 @@ | |||
#include "easy_graph/layout/graph_layout.h" | |||
#include "easy_graph/layout/layout_executor.h" | |||
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||
#include "easy_graph/graph/graph.h" | |||
EG_NS_BEGIN | |||
namespace { | |||
GraphEasyExecutor default_executor; | |||
} | |||
void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | |||
this->executor_ = &executor; | |||
options_ = opts; | |||
@@ -27,8 +32,7 @@ void GraphLayout::Config(LayoutExecutor &executor, const LayoutOption *opts) { | |||
Status GraphLayout::Layout(const Graph &graph, const LayoutOption *opts) { | |||
const LayoutOption *options = opts ? opts : this->options_; | |||
if (!executor_) | |||
return EG_UNIMPLEMENTED; | |||
if (!executor_) return static_cast<LayoutExecutor &>(default_executor).Layout(graph, options); | |||
return executor_->Layout(graph, options); | |||
} | |||
@@ -0,0 +1,37 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef D52AA06185E34BBFB714FFBCDAB0D53A | |||
#define D52AA06185E34BBFB714FFBCDAB0D53A | |||
#include "ge_graph_dsl/ge.h" | |||
#include <exception> | |||
#include <string> | |||
GE_NS_BEGIN | |||
struct AssertError : std::exception { | |||
AssertError(const char *file, int line, const std::string &info); | |||
private: | |||
const char *what() const noexcept override; | |||
private: | |||
std::string info; | |||
}; | |||
GE_NS_END | |||
#endif |
@@ -0,0 +1,32 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_31309AA0A4E44C009C22AD9351BF3410 | |||
#define INC_31309AA0A4E44C009C22AD9351BF3410 | |||
#include "ge_graph_dsl/ge.h" | |||
#include "graph/compute_graph.h" | |||
GE_NS_BEGIN | |||
using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||
struct CheckUtils { | |||
static bool CheckGraph(const std::string &phase_id, const GraphCheckFun &fun); | |||
static void init(); | |||
}; | |||
GE_NS_END | |||
#endif |
@@ -1,17 +1,32 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "tensor_builder_utils.h" | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef C8B32320BD4943D588594B82FFBF2685 | |||
#define C8B32320BD4943D588594B82FFBF2685 | |||
#include <vector> | |||
#include <string> | |||
#include "ge_graph_dsl/ge.h" | |||
GE_NS_BEGIN | |||
struct FilterScopeGuard { | |||
FilterScopeGuard(const std::vector<std::string> &); | |||
~FilterScopeGuard(); | |||
}; | |||
GE_NS_END | |||
#endif |
@@ -0,0 +1,59 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||
#define AD954C4ADF5B44F5B1CC8BCD72EE9ED6 | |||
#include "ge_graph_dsl/ge.h" | |||
#include "ge_graph_dsl/assert/check_utils.h" | |||
#include "ge_graph_dsl/assert/assert_error.h" | |||
#include "ge_graph_dsl/assert/filter_scope_guard.h" | |||
GE_NS_BEGIN | |||
#ifdef GTEST_MESSAGE_AT_ | |||
#define GRAPH_CHECK_MESSAGE(file, line, message) \ | |||
GTEST_MESSAGE_AT_(file, line, message, ::testing::TestPartResult::kFatalFailure) | |||
#elif | |||
#define GRAPH_CHECK_MESSAGE(file, line, message) throw AssertError(file, line, message) | |||
#endif | |||
namespace detail { | |||
struct GraphAssert { | |||
GraphAssert(const char *file, unsigned int line, const std::string &phase_id) | |||
: file_(file), line_(line), phase_id_(phase_id) {} | |||
void operator|(const ::GE_NS::GraphCheckFun &check_fun) { | |||
bool ret = ::GE_NS::CheckUtils::CheckGraph(phase_id_, check_fun); | |||
if (!ret) { | |||
auto message = "expect dump graph in phase: [" + phase_id_ + "], while not find the dump graph! "; | |||
GRAPH_CHECK_MESSAGE(file_, line_, message.c_str()); | |||
} | |||
} | |||
private: | |||
const char *file_; | |||
unsigned int line_; | |||
const std::string phase_id_; | |||
}; | |||
} // namespace detail | |||
#define DUMP_GRAPH_WHEN(...) ::GE_NS::FilterScopeGuard guard__COUNTER__({__VA_ARGS__}); | |||
#define CHECK_GRAPH(phase_id) \ | |||
::GE_NS::detail::GraphAssert(__FILE__, __LINE__, #phase_id) | [&](const ::GE_NS::ComputeGraphPtr &graph) | |||
GE_NS_END | |||
#endif |
@@ -33,14 +33,12 @@ struct OpDescCfg { | |||
std::vector<int64_t> shape_; | |||
}; | |||
OpDescCfg(const OpType &type, int in_cnt = 0, int out_cnt = 0, Format format = FORMAT_NCHW, | |||
OpDescCfg(const OpType &type, int in_cnt = 1, int out_cnt = 1, Format format = FORMAT_NCHW, | |||
DataType data_type = DT_FLOAT, std::vector<int64_t> shape = {1, 1, 224, 224}) | |||
: type_(type), in_cnt_(in_cnt), out_cnt_(out_cnt), default_tensor_(format, data_type, shape) {} | |||
protected: | |||
OpType GetType() const { | |||
return type_; | |||
} | |||
OpType GetType() const { return type_; } | |||
OpType type_; | |||
int in_cnt_; | |||
int out_cnt_; | |||
@@ -0,0 +1,26 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_graph_dsl/assert/assert_error.h" | |||
GE_NS_BEGIN | |||
AssertError::AssertError(const char *file, int line, const std::string &info) { | |||
this->info = std::string(file) + ":" + std::to_string(line) + "\n" + info; | |||
} | |||
const char *AssertError::what() const noexcept { return info.c_str(); } | |||
GE_NS_END |
@@ -0,0 +1,34 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_graph_dsl/assert/check_utils.h" | |||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||
#include "ge_graph_default_checker.h" | |||
#include "ge_graph_check_dumper.h" | |||
GE_NS_BEGIN | |||
bool CheckUtils::CheckGraph(const std::string &phase_id, const GraphCheckFun &fun) { | |||
auto &dumper = dynamic_cast<GeGraphCheckDumper &>(GraphDumperRegistry::GetDumper()); | |||
return dumper.CheckFor(GeGraphDefaultChecker(phase_id, fun)); | |||
} | |||
void CheckUtils::init() { | |||
static GeGraphCheckDumper checkDumper; | |||
GraphDumperRegistry::Register(checkDumper); | |||
} | |||
GE_NS_END |
@@ -0,0 +1,31 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_graph_dsl/assert/filter_scope_guard.h" | |||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||
#include "ge_dump_filter.h" | |||
GE_NS_BEGIN | |||
namespace { | |||
GeDumpFilter &GetDumpFilter() { return dynamic_cast<GeDumpFilter &>(GraphDumperRegistry::GetDumper()); } | |||
} // namespace | |||
FilterScopeGuard::FilterScopeGuard(const std::vector<std::string> &filter) { GetDumpFilter().Update(filter); } | |||
FilterScopeGuard::~FilterScopeGuard() { GetDumpFilter().Reset(); } | |||
GE_NS_END |
@@ -0,0 +1,33 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||
#define INC_4C6224E8F7474EF89B18CCB0E4B19FD6 | |||
#include <vector> | |||
#include <string> | |||
#include "ge_graph_dsl/ge.h" | |||
#include "easy_graph/infra/keywords.h" | |||
GE_NS_BEGIN | |||
INTERFACE(GeDumpFilter) { | |||
ABSTRACT(void Update(const std::vector<std::string> &)); | |||
ABSTRACT(void Reset()); | |||
}; | |||
GE_NS_END | |||
#endif |
@@ -0,0 +1,79 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_graph_check_dumper.h" | |||
#include "graph/model.h" | |||
#include "graph/buffer.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "ge_graph_default_checker.h" | |||
GE_NS_BEGIN | |||
GeGraphCheckDumper::GeGraphCheckDumper() { Reset(); } | |||
bool GeGraphCheckDumper::IsNeedDump(const std::string &suffix) const { | |||
auto iter = std::find(suffixes_.begin(), suffixes_.end(), suffix); | |||
return (iter != suffixes_.end()); | |||
} | |||
void GeGraphCheckDumper::Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix) { | |||
if (!IsNeedDump(suffix)) { | |||
return; | |||
} | |||
auto iter = buffers_.find(suffix); | |||
if (iter != buffers_.end()) { | |||
DumpGraph(graph, iter->second); | |||
} else { | |||
buffers_[suffix] = Buffer(); | |||
DumpGraph(graph, buffers_.at(suffix)); | |||
} | |||
} | |||
bool GeGraphCheckDumper::CheckFor(const GeGraphChecker &checker) { | |||
auto iter = buffers_.find(checker.PhaseId()); | |||
if (iter == buffers_.end()) { | |||
return false; | |||
} | |||
DoCheck(checker, iter->second); | |||
return true; | |||
} | |||
void GeGraphCheckDumper::DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer) { | |||
Model model("", ""); | |||
Model::Load(buffer.GetData(), buffer.GetSize(), model); | |||
auto load_graph = model.GetGraph(); | |||
checker.Check(GraphUtils::GetComputeGraph(load_graph)); | |||
} | |||
void GeGraphCheckDumper::DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer) { | |||
Model model("", ""); | |||
buffer.clear(); | |||
model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(graph)); | |||
model.Save(buffer, true); | |||
} | |||
void GeGraphCheckDumper::Update(const std::vector<std::string> &new_suffixes_) { | |||
suffixes_ = new_suffixes_; | |||
buffers_.clear(); | |||
} | |||
void GeGraphCheckDumper::Reset() { | |||
static std::vector<std::string> default_suffixes_{"PreRunAfterBuild"}; | |||
suffixes_ = default_suffixes_; | |||
buffers_.clear(); | |||
} | |||
GE_NS_END |
@@ -0,0 +1,49 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_8EFED0015C27464897BF64531355C810 | |||
#define INC_8EFED0015C27464897BF64531355C810 | |||
#include "ge_graph_dsl/ge.h" | |||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||
#include "ge_dump_filter.h" | |||
#include <string> | |||
GE_NS_BEGIN | |||
struct GeGraphChecker; | |||
struct GeGraphCheckDumper : GeGraphDumper, GeDumpFilter { | |||
GeGraphCheckDumper(); | |||
virtual void Dump(const ge::ComputeGraphPtr &graph, const std::string &suffix); | |||
bool CheckFor(const GeGraphChecker &checker); | |||
private: | |||
void DoCheck(const GeGraphChecker &checker, ::GE_NS::Buffer &buffer); | |||
void DumpGraph(const ge::ComputeGraphPtr &graph, ::GE_NS::Buffer &buffer); | |||
private: | |||
void Update(const std::vector<std::string> &) override; | |||
void Reset() override; | |||
bool IsNeedDump(const std::string &suffix) const; | |||
private: | |||
std::map<std::string, ::GE_NS::Buffer> buffers_; | |||
std::vector<std::string> suffixes_; | |||
}; | |||
GE_NS_END | |||
#endif |
@@ -0,0 +1,32 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef INC_5960A8F437324904BEE0690271258762 | |||
#define INC_5960A8F437324904BEE0690271258762 | |||
#include "ge_graph_dsl/ge.h" | |||
#include "easy_graph/infra/keywords.h" | |||
#include "graph/compute_graph.h" | |||
GE_NS_BEGIN | |||
INTERFACE(GeGraphChecker) { | |||
ABSTRACT(const std::string &PhaseId() const); | |||
ABSTRACT(void Check(const ge::ComputeGraphPtr &graph) const); | |||
}; | |||
GE_NS_END | |||
#endif |
@@ -0,0 +1,28 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "ge_graph_default_checker.h" | |||
GE_NS_BEGIN | |||
GeGraphDefaultChecker::GeGraphDefaultChecker(const std::string &phase_id, const GraphCheckFun &check_fun) | |||
: phase_id_(phase_id), check_fun_(check_fun) {} | |||
const std::string &GeGraphDefaultChecker::PhaseId() const { return phase_id_; } | |||
void GeGraphDefaultChecker::Check(const ge::ComputeGraphPtr &graph) const { return check_fun_(graph); } | |||
GE_NS_END |
@@ -0,0 +1,41 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef BCF4D96BE9FC48938DE7B7E93B551C54 | |||
#define BCF4D96BE9FC48938DE7B7E93B551C54 | |||
#include "ge_graph_dsl/ge.h" | |||
#include "ge_graph_checker.h" | |||
#include "graph/compute_graph.h" | |||
GE_NS_BEGIN | |||
using GraphCheckFun = std::function<void(const ::GE_NS::ComputeGraphPtr &)>; | |||
struct GeGraphDefaultChecker : GeGraphChecker { | |||
GeGraphDefaultChecker(const std::string &, const GraphCheckFun &); | |||
private: | |||
const std::string &PhaseId() const override; | |||
void Check(const ge::ComputeGraphPtr &graph) const override; | |||
private: | |||
const std::string phase_id_; | |||
const GraphCheckFun check_fun_; | |||
}; | |||
GE_NS_END | |||
#endif |
@@ -23,15 +23,22 @@ GE_NS_BEGIN | |||
namespace { | |||
#define OP_CFG(optype, ...) \ | |||
{ \ | |||
optype, OpDescCfg { \ | |||
optype, __VA_ARGS__ \ | |||
} \ | |||
#define OP_CFG(optype, ...) \ | |||
{ \ | |||
optype, OpDescCfg { optype, __VA_ARGS__ } \ | |||
} | |||
static std::map<OpType, OpDescCfg> cfg_repo{OP_CFG(DATA, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(ADD, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(ENTER, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(MERGE, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(CONSTANT, 0, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(LESS, 2, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL, {1, 1, 224, 224}), | |||
OP_CFG(SWITCH, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(EXIT, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(NEXTITERATION, 1, 1, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(NETOUTPUT, 2, 2, FORMAT_NCHW, DT_FLOAT, {1, 1, 224, 224}), | |||
OP_CFG(VARIABLE, 1, 1)}; | |||
} // namespace | |||
@@ -19,6 +19,4 @@ | |||
USING_GE_NS | |||
OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { | |||
return op_; | |||
} | |||
OpDescPtr OpDescPtrBox::Build(const ::EG_NS::NodeId &id) const { return op_; } |
@@ -36,17 +36,11 @@ GE_NS_BEGIN | |||
GeGraphVisitor::GeGraphVisitor() : build_graph_(std::make_shared<ComputeGraph>("")) {} | |||
void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { | |||
build_graph_ = graph; | |||
} | |||
void GeGraphVisitor::reset(const ComputeGraphPtr &graph) { build_graph_ = graph; } | |||
Graph GeGraphVisitor::BuildGeGraph() const { | |||
return GraphUtils::CreateGraphFromComputeGraph(build_graph_); | |||
} | |||
Graph GeGraphVisitor::BuildGeGraph() const { return GraphUtils::CreateGraphFromComputeGraph(build_graph_); } | |||
ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { | |||
return build_graph_; | |||
} | |||
ComputeGraphPtr GeGraphVisitor::BuildComputeGraph() const { return build_graph_; } | |||
Status GeGraphVisitor::Visit(const ::EG_NS::Graph &graph) { | |||
build_graph_->SetName(graph.GetName()); |
@@ -26,7 +26,7 @@ target_compile_options(ge_graph_dsl_test PRIVATE | |||
) | |||
set_target_properties(ge_graph_dsl_test PROPERTIES CXX_STANDARD 17) | |||
target_link_libraries(ge_graph_dsl_test PUBLIC gtest gtest_main ge_graph_dsl) | |||
target_link_libraries(ge_graph_dsl_test PUBLIC gtest ge_graph_dsl) | |||
include(CTest) | |||
enable_testing() |
@@ -0,0 +1,129 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "gtest/gtest.h" | |||
#include "easy_graph/layout/graph_layout.h" | |||
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||
#include "ge_graph_dsl/graph_dsl.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/utils/dumper/ge_graph_dumper.h" | |||
#include "framework/common/types.h" | |||
#include "ge_graph_dsl/assert/graph_assert.h" | |||
#include "graph/model.h" | |||
#include "graph/buffer.h" | |||
USING_GE_NS | |||
class CheckGraphTest : public testing::Test { | |||
private: | |||
EG_NS::GraphEasyExecutor executor; | |||
protected: | |||
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||
void TearDown() {} | |||
}; | |||
TEST_F(CheckGraphTest, test_ge_graph_dump_is_work) { | |||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
DUMP_GRAPH_WHEN("after_build"); | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||
CHECK_GRAPH(after_build) { | |||
ASSERT_EQ(graph->GetName(), "g1"); | |||
ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||
}; | |||
} | |||
TEST_F(CheckGraphTest, test_ge_graph_dump_two_phase) { | |||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
DEF_GRAPH(g2) { | |||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||
}; | |||
DUMP_GRAPH_WHEN("before_build", "after_build"); | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "after_build"); | |||
CHECK_GRAPH(before_build) { | |||
ASSERT_EQ(graph->GetName(), "g1"); | |||
ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||
}; | |||
CHECK_GRAPH(after_build) { | |||
ASSERT_EQ(graph->GetName(), "g2"); | |||
ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||
}; | |||
} | |||
TEST_F(CheckGraphTest, test_ge_graph_dump_one_phase_two_times) { | |||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
DEF_GRAPH(g2) { | |||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CTRL_CHAIN(NODE("data2", DATA)->NODE("add", ADD)); | |||
}; | |||
DUMP_GRAPH_WHEN("before_build") | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g2), "before_build"); | |||
CHECK_GRAPH(before_build) { | |||
ASSERT_EQ(graph->GetName(), "g2"); | |||
ASSERT_EQ(graph->GetAllNodesSize(), 3); | |||
}; | |||
} | |||
TEST_F(CheckGraphTest, test_check_phases_is_work) { | |||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
DUMP_GRAPH_WHEN("before_build"); | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||
auto ret = ::GE_NS::CheckUtils::CheckGraph("after_build", [&](const ::GE_NS::ComputeGraphPtr &graph) {}); | |||
ASSERT_FALSE(ret); | |||
} | |||
TEST_F(CheckGraphTest, test_check_one_phase_dump_another_not_dump) { | |||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
DUMP_GRAPH_WHEN("before_build"); | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "before_build"); | |||
GraphDumperRegistry::GetDumper().Dump(ToComputeGraph(g1), "after_build"); | |||
CHECK_GRAPH(before_build) { | |||
ASSERT_EQ(graph->GetName(), "g1"); | |||
ASSERT_EQ(graph->GetAllNodesSize(), 2); | |||
}; | |||
} | |||
TEST_F(CheckGraphTest, test_model_serialize_and_unserialize_success) { | |||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
auto ge_graph = ToGeGraph(g1); | |||
ge::Model model("", ""); | |||
model.SetGraph(ge_graph); | |||
Buffer buffer; | |||
model.Save(buffer, true); | |||
ge::Model loadModel("", ""); | |||
Model::Load(buffer.GetData(), buffer.GetSize(), loadModel); | |||
auto load_graph = loadModel.GetGraph(); | |||
ASSERT_EQ(load_graph.GetName(), "g1"); | |||
ASSERT_EQ(load_graph.GetAllNodes().size(), 2); | |||
} |
@@ -37,17 +37,13 @@ class GraphDslTest : public testing::Test { | |||
EG_NS::GraphEasyExecutor executor; | |||
protected: | |||
void SetUp() { | |||
EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); | |||
} | |||
void SetUp() { EG_NS::GraphLayout::GetInstance().Config(executor, nullptr); } | |||
void TearDown() {} | |||
}; | |||
TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
}); | |||
DEF_GRAPH(g1) { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
auto geGraph = ToGeGraph(g1); | |||
auto computeGraph = ToComputeGraph(g1); | |||
@@ -57,9 +53,7 @@ TEST_F(GraphDslTest, test_build_graph_from_optype_with_name) { | |||
} | |||
TEST_F(GraphDslTest, test_build_graph_with_name) { | |||
DEF_GRAPH(g1, "sample_graph") { | |||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
}); | |||
DEF_GRAPH(g1, "sample_graph") { CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -72,7 +66,7 @@ TEST_F(GraphDslTest, test_build_from_from_op_desc_ptr) { | |||
auto data = std::make_shared<OpDesc>("data1", DATA); | |||
auto add = std::make_shared<OpDesc>("Add", ADD); | |||
CHAIN(NODE(data)->NODE(add)); | |||
}); | |||
}; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -84,7 +78,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||
auto datCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | |||
auto addCfg = OP_CFG(DATA).InCnt(1).OutCnt(1); | |||
CHAIN(NODE("data1", datCfg)->NODE("add", addCfg)); | |||
}); | |||
}; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -92,9 +86,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg) { | |||
} | |||
TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); | |||
}); | |||
DEF_GRAPH(g1) { CHAIN(NODE("data1", OP_CFG(DATA).InCnt(1).OutCnt(1))->NODE("add", OP_CFG(ADD).InCnt(2).OutCnt(1))); }; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -102,9 +94,7 @@ TEST_F(GraphDslTest, test_build_from_op_desc_cfg_inline) { | |||
} | |||
TEST_F(GraphDslTest, test_build_from_control_chain) { | |||
DEF_GRAPH(g1) { | |||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
}); | |||
DEF_GRAPH(g1) { CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -112,9 +102,7 @@ TEST_F(GraphDslTest, test_build_from_control_chain) { | |||
} | |||
TEST_F(GraphDslTest, test_build_from_data_chain) { | |||
DEF_GRAPH(g1) { | |||
DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
}); | |||
DEF_GRAPH(g1) { DATA_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); }; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -125,7 +113,7 @@ TEST_F(GraphDslTest, test_build_from_data_chain_with_edge) { | |||
DEF_GRAPH(g1) { | |||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CHAIN(NODE("data1", DATA)->EDGE(2, 2)->NODE("add")); | |||
}); | |||
}; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -136,7 +124,7 @@ TEST_F(GraphDslTest, test_build_graph_reused_before_node) { | |||
DEF_GRAPH(g1) { | |||
CTRL_CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CHAIN(NODE("data1")->EDGE(2, 2)->NODE("add")); | |||
}); | |||
}; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -147,7 +135,7 @@ TEST_F(GraphDslTest, test_build_graph_with_constant_folding) { | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CHAIN(NODE("data2", DATA)->NODE("add")); | |||
}); | |||
}; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -168,7 +156,7 @@ TEST_F(GraphDslTest, test_build_complex_normal_graph_build_suggested) { | |||
->NODE("Add4") | |||
->NODE("Add5") | |||
->NODE("net_output", NETOUTPUT)); | |||
}); | |||
}; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -187,7 +175,7 @@ TEST_F(GraphDslTest, test_build_complex_mult_normal_graph_build) { | |||
CHAIN(NODE("add2")->NODE("net_output")); | |||
CHAIN(NODE("add3")->NODE("net_output")); | |||
CTRL_CHAIN(NODE("add1")->NODE("add2")->NODE("add3")); | |||
}); | |||
}; | |||
auto geGraph = ToGeGraph(g1); | |||
@@ -198,17 +186,17 @@ TEST_F(GraphDslTest, test_build_graph_with_sub_graph) { | |||
DEF_GRAPH(sub_1) { | |||
CHAIN(NODE("data_i", DATA)->NODE("less", LESS)->NODE("netoutput", NETOUTPUT)); | |||
CHAIN(NODE("const_5", CONSTANTOP)->NODE("less")); | |||
}); | |||
}; | |||
DEF_GRAPH(sub_2) { | |||
CHAIN(NODE("data_a", DATA)->NODE("mul", MUL)->NODE("netoutput", NETOUTPUT)); | |||
CHAIN(NODE("const_2", CONSTANTOP)->NODE("mul")); | |||
}); | |||
}; | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data_a", DATA)->NODE("while", WHILE, sub_1, sub_2)->NODE("netoutput", NETOUTPUT)); | |||
CHAIN(NODE("data_i", DATA)->NODE("while")); | |||
}); | |||
}; | |||
sub_1.Layout(); | |||
sub_2.Layout(); | |||
@@ -30,5 +30,11 @@ REGISTER_OPTYPE_DEFINE(MUL, "Mul"); | |||
REGISTER_OPTYPE_DEFINE(NETOUTPUT, "NetOutput"); | |||
REGISTER_OPTYPE_DEFINE(ADD, "Add"); | |||
REGISTER_OPTYPE_DEFINE(WHILE, "While"); | |||
REGISTER_OPTYPE_DEFINE(ENTER, "Enter"); | |||
REGISTER_OPTYPE_DEFINE(MERGE, "Merge"); | |||
REGISTER_OPTYPE_DEFINE(LOOPCOND, "Loopcond"); | |||
REGISTER_OPTYPE_DEFINE(SWITCH, "Switch"); | |||
REGISTER_OPTYPE_DEFINE(EXIT, "Exit"); | |||
REGISTER_OPTYPE_DEFINE(NEXTITERATION, "Nextiteration"); | |||
GE_NS_END |
@@ -1,22 +1,25 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
#define GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
class tensor_builder_utils {}; | |||
#endif // GRAPHENGINE_LLT_ST_TENSOR_BUILDER_UTILS_H | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include "ge_graph_dsl/assert/check_utils.h" | |||
int main(int argc, char **argv) { | |||
::GE_NS::CheckUtils::init(); | |||
testing::InitGoogleTest(&argc, argv); | |||
int ret = RUN_ALL_TESTS(); | |||
return ret; | |||
} |
@@ -1,48 +0,0 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include "graph_builder_utils.h" | |||
#include "inc/external/graph/operator.h" | |||
#include "inc/external/graph/operator_factory.h" | |||
#include "graph/utils/graph_utils.h" | |||
namespace ge { | |||
namespace st { | |||
NodePtr ComputeGraphBuilder::AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||
Format format, DataType data_type, std::vector<int64_t> shape) { | |||
auto tensor_desc = std::make_shared<GeTensorDesc>(); | |||
tensor_desc->SetShape(GeShape(std::move(shape))); | |||
tensor_desc->SetFormat(format); | |||
tensor_desc->SetDataType(data_type); | |||
auto op_desc = std::make_shared<OpDesc>(name, type); | |||
for (int i = 0; i < in_cnt; ++i) { | |||
op_desc->AddInputDesc(tensor_desc->Clone()); | |||
} | |||
for (int i = 0; i < out_cnt; ++i) { | |||
op_desc->AddOutputDesc(tensor_desc->Clone()); | |||
} | |||
return graph_->AddNode(op_desc); | |||
} | |||
void ComputeGraphBuilder::AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx) { | |||
GraphUtils::AddEdge(src_node->GetOutDataAnchor(src_idx), dst_node->GetInDataAnchor(dst_idx)); | |||
} | |||
void ComputeGraphBuilder::AddControlEdge(NodePtr &src_node, NodePtr &dst_node) { | |||
GraphUtils::AddEdge(src_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); | |||
} | |||
} // namespace st | |||
} // namespace ge |
@@ -1,55 +0,0 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#ifndef GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||
#define GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H | |||
#include <string> | |||
#include <vector> | |||
#include "graph/compute_graph.h" | |||
#include "graph/utils/graph_utils.h" | |||
#include "graph/graph.h" | |||
#include "graph/node.h" | |||
namespace ge { | |||
namespace st { | |||
class ComputeGraphBuilder { | |||
public: | |||
explicit ComputeGraphBuilder(const std::string &name) { | |||
graph_ = std::make_shared<ComputeGraph>(name); | |||
} | |||
NodePtr AddNode(const std::string &name, const std::string &type, int in_cnt, int out_cnt, | |||
Format format = FORMAT_NCHW, DataType data_type = DT_FLOAT, | |||
std::vector<int64_t> shape = {1, 1, 224, 224}); | |||
void AddDataEdge(NodePtr &src_node, int src_idx, NodePtr &dst_node, int dst_idx); | |||
void AddControlEdge(NodePtr &src_node, NodePtr &dst_node); | |||
ComputeGraphPtr GetComputeGraph() { | |||
graph_->TopologicalSorting(); | |||
return graph_; | |||
} | |||
Graph GetGraph() { | |||
graph_->TopologicalSorting(); | |||
return GraphUtils::CreateGraphFromComputeGraph(graph_); | |||
} | |||
private: | |||
ComputeGraphPtr graph_; | |||
}; | |||
} // namespace st | |||
} // namespace ge | |||
#endif // GRAPHENGINE_LLT_ST_GRAPH_BUILDER_H |
@@ -8,7 +8,7 @@ target_include_directories(graph_engine_test | |||
set_target_properties(graph_engine_test PROPERTIES CXX_STANDARD 17) | |||
target_link_libraries(graph_engine_test PRIVATE gtest framework) | |||
target_link_libraries(graph_engine_test PRIVATE gtest ge_graph_dsl ge_with_env) | |||
include(CTest) | |||
enable_testing() |
@@ -15,23 +15,12 @@ | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <map> | |||
#include "external/ge/ge_api.h" | |||
#include "ge_running_env/fake_engine.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "framework/common/types.h" | |||
#include "builder/graph_builder_utils.h" | |||
#include "ge_running_env/ge_running_env_faker.h" | |||
#include "graph/operator_reg.h" | |||
#include "graph/operator.h" | |||
#define protected public | |||
#define private public | |||
#include "graph/utils/op_desc_utils.h" | |||
#include "ge_graph_dsl/graph_dsl.h" | |||
#undef protected | |||
#undef private | |||
#include "ge_graph_dsl/assert/graph_assert.h" | |||
using namespace std; | |||
using namespace ge; | |||
@@ -57,76 +46,58 @@ namespace { | |||
* | |||
**/ | |||
Graph BuildV1ControlFlowGraph() { | |||
// build graph | |||
st::ComputeGraphBuilder graphBuilder("g1"); | |||
auto data_i = graphBuilder.AddNode("data_i", DATA, 1, 1); | |||
auto enter_i = graphBuilder.AddNode("enter_i", ENTER, 1, 1); | |||
ge::AttrUtils::SetStr(enter_i->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||
auto merge_i = graphBuilder.AddNode("merge_i", MERGE, 2, 1); | |||
auto const_5 = graphBuilder.AddNode("const_5", CONSTANT, 0, 1); | |||
auto less = graphBuilder.AddNode("less", LESS, 2, 1); | |||
auto loopcond = graphBuilder.AddNode("loopcond", LOOPCOND, 1, 1, FORMAT_NCHW, DT_BOOL); | |||
auto switch_i = graphBuilder.AddNode("switch_i", SWITCH, 2, 2); | |||
auto exit_i = graphBuilder.AddNode("switch_i", EXIT, 1, 1); | |||
auto const_1 = graphBuilder.AddNode("const_1", CONSTANT, 0, 1); | |||
auto add = graphBuilder.AddNode("add", ADD, 2, 1); | |||
auto next_iteration_i = graphBuilder.AddNode("next_iteration_i", NEXTITERATION, 1, 1); | |||
auto data_a = graphBuilder.AddNode("data_a", DATA, 1, 1); | |||
auto enter_a = graphBuilder.AddNode("enter_a", ENTER, 1, 1); | |||
ge::AttrUtils::SetStr(enter_a->GetOpDesc(), ENTER_ATTR_FRAME_NAME, "1"); | |||
auto merge_a = graphBuilder.AddNode("merge_a", MERGE, 2, 1); | |||
auto switch_a = graphBuilder.AddNode("switch_a", SWITCH, 2, 2); | |||
auto exit_a = graphBuilder.AddNode("exit_a", EXIT, 1, 1); | |||
auto mul = graphBuilder.AddNode("mul", MUL, 2, 1); | |||
auto const_2 = graphBuilder.AddNode("const_2", CONSTANT, 0, 1); | |||
auto next_iteration_a = graphBuilder.AddNode("next_iteration_a", NEXTITERATION, 1, 1); | |||
auto netoutput = graphBuilder.AddNode("netoutput", NETOUTPUT, 2, 2); | |||
// i = i+1 | |||
graphBuilder.AddDataEdge(data_i, 0, enter_i, 0); | |||
graphBuilder.AddDataEdge(enter_i, 0, merge_i, 0); | |||
graphBuilder.AddDataEdge(next_iteration_i, 0, merge_i, 1); | |||
graphBuilder.AddDataEdge(merge_i, 0, less, 0); | |||
graphBuilder.AddDataEdge(const_5, 0, less, 1); | |||
graphBuilder.AddDataEdge(less, 0, loopcond, 0); | |||
graphBuilder.AddDataEdge(loopcond, 0, switch_i, 1); | |||
graphBuilder.AddDataEdge(merge_i, 0, switch_i, 0); | |||
graphBuilder.AddDataEdge(switch_i, 0, exit_i, 0); | |||
graphBuilder.AddDataEdge(switch_i, 1, add, 0); | |||
graphBuilder.AddDataEdge(const_1, 0, add, 1); | |||
graphBuilder.AddDataEdge(add, 0, next_iteration_i, 0); | |||
graphBuilder.AddDataEdge(exit_i, 0, netoutput, 1); | |||
// a=a*2 | |||
graphBuilder.AddDataEdge(data_a, 0, enter_a, 0); | |||
graphBuilder.AddDataEdge(enter_a, 0, merge_a, 0); | |||
graphBuilder.AddDataEdge(next_iteration_a, 0, merge_a, 1); | |||
graphBuilder.AddDataEdge(loopcond, 0, switch_a, 1); | |||
graphBuilder.AddDataEdge(merge_a, 0, switch_a, 0); | |||
graphBuilder.AddDataEdge(switch_a, 0, exit_a, 0); | |||
graphBuilder.AddDataEdge(switch_a, 1, mul, 0); | |||
graphBuilder.AddDataEdge(const_2, 0, mul, 1); | |||
graphBuilder.AddDataEdge(mul, 0, next_iteration_a, 0); | |||
graphBuilder.AddDataEdge(exit_a, 0, netoutput, 0); | |||
// set const weight | |||
int64_t dims_size = 1; | |||
vector<int64_t> data_vec = {5}; | |||
for_each(data_vec.begin(), data_vec.end(), [&](int64_t &data) { dims_size *= data; }); | |||
vector<int32_t> data_value_vec(dims_size, 1); | |||
GeTensorDesc data_tensor_desc(GeShape(data_vec), FORMAT_NCHW, DT_INT32); | |||
GeTensorPtr data_tensor = | |||
make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), data_value_vec.size() * sizeof(int32_t)); | |||
OpDescUtils::SetWeights(const_5->GetOpDesc(), data_tensor); | |||
OpDescUtils::SetWeights(const_2->GetOpDesc(), data_tensor); | |||
OpDescUtils::SetWeights(const_1->GetOpDesc(), data_tensor); | |||
GeTensorPtr data_tensor = make_shared<GeTensor>(data_tensor_desc, (uint8_t *)data_value_vec.data(), | |||
data_value_vec.size() * sizeof(int32_t)); | |||
return graphBuilder.GetGraph(); | |||
auto enter = OP_CFG(ENTER).Attr(ENTER_ATTR_FRAME_NAME, "1"); | |||
auto const_op = OP_CFG(CONSTANT).Weight(data_tensor); | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data_i", DATA) | |||
->NODE("enter_i", enter) | |||
->EDGE(0, 0) | |||
->NODE("merge_i", MERGE) | |||
->NODE("less", LESS) | |||
->NODE("loopcond", LOOPCOND)); | |||
CHAIN(NODE("const_1", const_op) | |||
->EDGE(0, 1) | |||
->NODE("add", ADD) | |||
->NODE("iteration_i", NEXTITERATION) | |||
->EDGE(0, 1) | |||
->NODE("merge_i")); | |||
CHAIN(NODE("const_5", const_op)->EDGE(0, 1)->NODE("less")); | |||
CHAIN(NODE("loopcond") | |||
->EDGE(0, 1) | |||
->NODE("switch_i", SWITCH) | |||
->EDGE(0, 0) | |||
->NODE("exit_i", EXIT) | |||
->EDGE(0, 1) | |||
->NODE("netoutput", NETOUTPUT)); | |||
CHAIN(NODE("merge_i")->EDGE(0, 0)->NODE("switch_i")->EDGE(1, 0)->NODE("add")); | |||
CHAIN(NODE("data_a", DATA) | |||
->NODE("enter_a", enter) | |||
->NODE("merge_a", MERGE) | |||
->NODE("switch_a", SWITCH) | |||
->NODE("exit_a", EXIT) | |||
->EDGE(0, 0) | |||
->NODE("netoutput")); | |||
CHAIN(NODE("iteration_a", NEXTITERATION)->EDGE(0, 1)->NODE("merge_a")); | |||
CHAIN(NODE("loopcond")->EDGE(0, 1)->NODE("switch_a")->EDGE(1, 0)->NODE("mul", MUL)); | |||
CHAIN(NODE("const_2", const_op)->EDGE(0, 1)->NODE("mul")->EDGE(0, 0)->NODE("iteration_a")); | |||
}; | |||
return ToGeGraph(g1); | |||
} | |||
} // namespace | |||
class FrameworkTest : public testing::Test { | |||
protected: | |||
GeRunningEnvFaker ge_env; | |||
void SetUp() { ge_env.InstallDefault(); } | |||
void TearDown() {} | |||
GeRunningEnvFaker ge_env; | |||
}; | |||
/// data data | |||
@@ -136,19 +107,19 @@ TEST_F(FrameworkTest, test_framework_add) { | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CHAIN(NODE("data2", DATA)->NODE("add")); | |||
}); | |||
}; | |||
auto graph = ToGeGraph(g1); | |||
// new session & add graph | |||
map<AscendString, AscendString> options; | |||
Session session(options); | |||
auto ret = session.AddGraph(1, graph, options); | |||
EXPECT_EQ(ret, SUCCESS); | |||
// build input tensor | |||
session.AddGraph(1, ToGeGraph(g1), options); | |||
std::vector<InputTensorInfo> inputs; | |||
// build_graph through session | |||
ret = session.BuildGraph(1, inputs); | |||
auto ret = session.BuildGraph(1, inputs); | |||
EXPECT_EQ(ret, SUCCESS); | |||
CHECK_GRAPH(PreRunAfterBuild) { | |||
ASSERT_EQ(graph->GetName(), "g1_1"); | |||
ASSERT_EQ(graph->GetAllNodesSize(), 4); | |||
}; | |||
} | |||
/** data a = 2; | |||
@@ -15,24 +15,12 @@ | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include "easy_graph/graph/box.h" | |||
#include "easy_graph/graph/node.h" | |||
#include "external/ge/ge_api.h" | |||
#include "easy_graph/builder/graph_dsl.h" | |||
#include "easy_graph/builder/box_builder.h" | |||
#include "easy_graph/layout/graph_layout.h" | |||
#include "easy_graph/layout/engines/graph_easy/graph_easy_option.h" | |||
#include "easy_graph/layout/engines/graph_easy/graph_easy_executor.h" | |||
#include "graph/graph.h" | |||
#include "graph/compute_graph.h" | |||
#include "framework/common/types.h" | |||
#include "graph/debug/ge_attr_define.h" | |||
#include "graph/ge_local_context.h" | |||
#include "ge_graph_dsl/graph_dsl.h" | |||
#include "ge_graph_dsl/op_desc/op_desc_cfg_box.h" | |||
#define protected public | |||
#define private public | |||
#include "ge_opt_info/ge_opt_info.h" | |||
#undef private | |||
#undef protected | |||
namespace ge { | |||
class STEST_opt_info : public testing::Test { | |||
@@ -52,7 +40,7 @@ TEST_F(STEST_opt_info, get_opt_info_all) { | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CHAIN(NODE("data2", DATA)->NODE("add")); | |||
}); | |||
}; | |||
auto graph = ToGeGraph(g1); | |||
@@ -95,7 +83,7 @@ TEST_F(STEST_opt_info, get_opt_info_success) { | |||
DEF_GRAPH(g1) { | |||
CHAIN(NODE("data1", DATA)->NODE("add", ADD)); | |||
CHAIN(NODE("data2", DATA)->NODE("add")); | |||
}); | |||
}; | |||
auto graph = ToGeGraph(g1); | |||
@@ -15,9 +15,8 @@ | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include "common/debug/log.h" | |||
#include "external/ge/ge_api.h" | |||
#include "ge_graph_dsl/assert/check_utils.h" | |||
#include "ge_running_env/include/ge_running_env/ge_running_env_faker.h" | |||
using namespace std; | |||
@@ -31,6 +30,7 @@ int main(int argc, char **argv) { | |||
std::cout << "ge init failed , ret code:" << init_status << endl; | |||
} | |||
GeRunningEnvFaker::BackupEnv(); | |||
CheckUtils::init(); | |||
testing::InitGoogleTest(&argc, argv); | |||
int ret = RUN_ALL_TESTS(); | |||
return ret; | |||
@@ -90,6 +90,7 @@ set(SRC_FILES | |||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/type_utils.cc" | |||
@@ -102,6 +102,7 @@ set(GRAPH_SRC_FILES | |||
"${GE_CODE_DIR}/metadef/graph/detail/attributes_holder.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/anchor_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/graph_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/dumper/ge_graph_dumper.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/ge_ir_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/node_utils.cc" | |||
"${GE_CODE_DIR}/metadef/graph/utils/op_desc_utils.cc" | |||