You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.cc 6.0 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/model.h"
  17. #include <fcntl.h>
  18. #include <google/protobuf/io/coded_stream.h>
  19. #include <google/protobuf/io/zero_copy_stream.h>
  20. #include <google/protobuf/io/zero_copy_stream_impl.h>
  21. #include <google/protobuf/text_format.h>
  22. #include <sys/stat.h>
  23. #include <sys/types.h>
  24. #include <unistd.h>
  25. #include <algorithm>
  26. #include <cstring>
  27. #include <fstream>
  28. #include <iomanip>
  29. #include "debug/ge_attr_define.h"
  30. #include "debug/ge_util.h"
  31. #include "framework/common/debug/ge_log.h"
  32. #include "graph/model_serialize.h"
  33. #include "proto/ge_ir.pb.h"
  34. #include "utils/attr_utils.h"
  35. #include "utils/ge_ir_utils.h"
  36. using google::protobuf::io::FileInputStream;
  37. using google::protobuf::io::FileOutputStream;
  38. using google::protobuf::io::ZeroCopyInputStream;
  39. namespace {
  40. const int DEFAULT_VERSION = 1;
  41. const int ACCESS_PERMISSION_BITS = 0400;
  42. } // namespace
  43. namespace ge {
  44. void Model::Init() {
  45. (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
  46. (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
  47. (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
  48. (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0);
  49. (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0);
  50. (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI);
  51. version_ = 0;
  52. }
  53. Model::Model() {
  54. attrs_.InitDefault();
  55. Init();
  56. }
  57. Model::Model(const string &name, const string &custom_version)
  58. : name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) {
  59. attrs_.InitDefault();
  60. Init();
  61. }
  62. string Model::GetName() const { return name_; }
  63. void Model::SetName(const string &name) { name_ = name; }
  64. uint32_t Model::GetVersion() const { return version_; }
  65. string Model::GetPlatformVersion() const { return platform_version_; }
  66. void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; }
  67. Graph Model::GetGraph() const { return graph_; }
  68. graphStatus Model::Save(Buffer &buffer, bool is_dump) const {
  69. ModelSerialize serialize;
  70. buffer = serialize.SerializeModel(*this, is_dump);
  71. return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED;
  72. }
  73. void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; }
  74. graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) {
  75. ModelSerialize serialize;
  76. model = serialize.UnserializeModel(data, len);
  77. return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED;
  78. }
  79. graphStatus Model::SaveToFile(const string &file_name) const {
  80. Buffer buffer;
  81. if ((*this).Save(buffer) != GRAPH_SUCCESS) {
  82. GE_LOGE("save to file fail.");
  83. return GRAPH_FAILED;
  84. }
  85. // Write file
  86. ge::proto::ModelDef ge_proto;
  87. if (buffer.GetData() != nullptr) {
  88. std::string str((const char *)buffer.GetData(), buffer.GetSize());
  89. if (!ge_proto.ParseFromString(str)) {
  90. return GRAPH_FAILED;
  91. }
  92. char real_path[PATH_MAX] = {0x00};
  93. if (strlen(file_name.c_str()) >= PATH_MAX) {
  94. return GRAPH_FAILED;
  95. }
  96. if (realpath(file_name.c_str(), real_path) == nullptr) {
  97. GELOGI("file %s does not exit, it will be created.", file_name.c_str());
  98. }
  99. int fd = open(real_path, O_WRONLY | O_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS);
  100. if (fd < 0) {
  101. GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno));
  102. return GRAPH_FAILED;
  103. }
  104. bool ret = ge_proto.SerializeToFileDescriptor(fd);
  105. if (!ret) {
  106. GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed");
  107. if (close(fd) != 0) {
  108. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  109. return GRAPH_FAILED;
  110. }
  111. return GRAPH_FAILED;
  112. }
  113. if (close(fd) != 0) {
  114. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  115. return GRAPH_FAILED;
  116. }
  117. if (!ret) {
  118. GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed");
  119. return GRAPH_FAILED;
  120. }
  121. }
  122. return GRAPH_SUCCESS;
  123. }
  124. graphStatus Model::Load(ge::proto::ModelDef &model_def) {
  125. ModelSerialize serialize;
  126. *this = serialize.UnserializeModel(model_def);
  127. return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED;
  128. }
  129. bool Model::IsValid() const { return graph_.IsValid(); }
  130. graphStatus Model::LoadFromFile(const string &file_name) {
  131. char real_path[PATH_MAX] = {0x00};
  132. if (strlen(file_name.c_str()) >= PATH_MAX) {
  133. return GRAPH_FAILED;
  134. }
  135. if (realpath(file_name.c_str(), real_path) == nullptr) {
  136. GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str());
  137. return GRAPH_FAILED;
  138. }
  139. int fd = open(real_path, O_RDONLY);
  140. if (fd < 0) {
  141. GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno));
  142. return GRAPH_FAILED;
  143. }
  144. ge::proto::ModelDef model_def;
  145. bool ret = model_def.ParseFromFileDescriptor(fd);
  146. if (!ret) {
  147. GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed");
  148. if (close(fd) != 0) {
  149. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  150. return GRAPH_FAILED;
  151. }
  152. return GRAPH_FAILED;
  153. }
  154. if (close(fd) != 0) {
  155. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  156. return GRAPH_FAILED;
  157. }
  158. if (!ret) {
  159. GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed");
  160. return GRAPH_FAILED;
  161. }
  162. return Load(model_def);
  163. }
  164. ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; }
  165. ConstProtoAttrMapHelper Model::GetAttrMap() const {
  166. return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
  167. }
  168. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示