You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

model.cc 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/model.h"
  17. #include <google/protobuf/io/coded_stream.h>
  18. #include <google/protobuf/io/zero_copy_stream.h>
  19. #include <google/protobuf/io/zero_copy_stream_impl.h>
  20. #include <google/protobuf/text_format.h>
  21. #include <sys/stat.h>
  22. #include <sys/types.h>
  23. #include <algorithm>
  24. #include <cstring>
  25. #include <fstream>
  26. #include <iomanip>
  27. #include "debug/ge_attr_define.h"
  28. #include "debug/ge_util.h"
  29. #include "framework/common/debug/ge_log.h"
  30. #include "graph/model_serialize.h"
  31. #include "mmpa/mmpa_api.h"
  32. #include "utils/attr_utils.h"
  33. #include "utils/ge_ir_utils.h"
  34. #include "proto/ge_ir.pb.h"
  35. using google::protobuf::io::FileInputStream;
  36. using google::protobuf::io::FileOutputStream;
  37. using google::protobuf::io::ZeroCopyInputStream;
  38. namespace {
  39. const int DEFAULT_VERSION = 1;
  40. const int ACCESS_PERMISSION_BITS = 0400;
  41. } // namespace
  42. namespace ge {
  43. void Model::Init() {
  44. (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0);
  45. (void)AttrUtils::SetInt(this, ATTR_MODEL_P2P_MEMORY_SIZE, 0);
  46. (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0);
  47. (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0);
  48. (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0);
  49. (void)AttrUtils::SetInt(this, ATTR_MODEL_WEIGHT_SIZE, 0);
  50. (void)AttrUtils::SetStr(this, ATTR_MODEL_TARGET_TYPE, TARGET_TYPE_MINI);
  51. version_ = 0;
  52. }
  53. Model::Model() {
  54. attrs_.InitDefault();
  55. Init();
  56. }
  57. Model::Model(const string &name, const string &custom_version)
  58. : name_(name), version_(DEFAULT_VERSION), platform_version_(custom_version) {
  59. attrs_.InitDefault();
  60. Init();
  61. }
  62. string Model::GetName() const { return name_; }
  63. void Model::SetName(const string &name) { name_ = name; }
  64. uint32_t Model::GetVersion() const { return version_; }
  65. string Model::GetPlatformVersion() const { return platform_version_; }
  66. void Model::SetGraph(const ge::Graph &graph) { graph_ = graph; }
  67. Graph Model::GetGraph() const { return graph_; }
  68. graphStatus Model::Save(Buffer &buffer, bool is_dump) const {
  69. ModelSerialize serialize;
  70. buffer = serialize.SerializeModel(*this, is_dump);
  71. return buffer.GetSize() > 0 ? GRAPH_SUCCESS : GRAPH_FAILED;
  72. }
  73. void Model::SetAttr(const ProtoAttrMapHelper &attrs) { attrs_ = attrs; }
  74. graphStatus Model::Load(const uint8_t *data, size_t len, Model &model) {
  75. ModelSerialize serialize;
  76. model = serialize.UnserializeModel(data, len);
  77. return model.IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED;
  78. }
  79. graphStatus Model::SaveToFile(const string &file_name) const {
  80. Buffer buffer;
  81. if ((*this).Save(buffer) != GRAPH_SUCCESS) {
  82. GE_LOGE("save to file fail.");
  83. return GRAPH_FAILED;
  84. }
  85. // Write file
  86. ge::proto::ModelDef ge_proto;
  87. if (buffer.GetData() != nullptr) {
  88. std::string str((const char *)buffer.GetData(), buffer.GetSize());
  89. if (!ge_proto.ParseFromString(str)) {
  90. return GRAPH_FAILED;
  91. }
  92. char real_path[MMPA_MAX_PATH] = {0x00};
  93. if (strlen(file_name.c_str()) >= MMPA_MAX_PATH) {
  94. return GRAPH_FAILED;
  95. }
  96. INT32 result = mmRealPath(file_name.c_str(), real_path, MMPA_MAX_PATH);
  97. if (result != EN_OK) {
  98. GELOGI("file %s does not exit, it will be created.", file_name.c_str());
  99. }
  100. int fd = mmOpen2(real_path, M_WRONLY | M_CREAT | O_TRUNC, ACCESS_PERMISSION_BITS);
  101. if (fd < 0) {
  102. GELOGE(GRAPH_FAILED, "open file failed, file path [%s], %s ", real_path, strerror(errno));
  103. return GRAPH_FAILED;
  104. }
  105. bool ret = ge_proto.SerializeToFileDescriptor(fd);
  106. if (!ret) {
  107. GELOGE(GRAPH_FAILED, "SerializeToFileDescriptor failed");
  108. if (close(fd) != 0) {
  109. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  110. return GRAPH_FAILED;
  111. }
  112. return GRAPH_FAILED;
  113. }
  114. if (close(fd) != 0) {
  115. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  116. return GRAPH_FAILED;
  117. }
  118. if (!ret) {
  119. GELOGE(GRAPH_FAILED, "function [SerializeToFileDescriptor] failed");
  120. return GRAPH_FAILED;
  121. }
  122. }
  123. return GRAPH_SUCCESS;
  124. }
  125. graphStatus Model::Load(ge::proto::ModelDef &model_def) {
  126. ModelSerialize serialize;
  127. *this = serialize.UnserializeModel(model_def);
  128. return this->IsValid() ? GRAPH_SUCCESS : GRAPH_FAILED;
  129. }
  130. bool Model::IsValid() const { return graph_.IsValid(); }
  131. graphStatus Model::LoadFromFile(const string &file_name) {
  132. char real_path[MMPA_MAX_PATH] = {0x00};
  133. if (strlen(file_name.c_str()) >= MMPA_MAX_PATH) {
  134. return GRAPH_FAILED;
  135. }
  136. INT32 result = mmRealPath(file_name.c_str(), real_path, MMPA_MAX_PATH);
  137. if (result != EN_OK) {
  138. GELOGE(GRAPH_FAILED, "file %s does not exit, can not load.", file_name.c_str());
  139. return GRAPH_FAILED;
  140. }
  141. int fd = mmOpen(real_path, M_RDONLY);
  142. if (fd < 0) {
  143. GELOGE(GRAPH_FAILED, "open file failed, %s", strerror(errno));
  144. return GRAPH_FAILED;
  145. }
  146. ge::proto::ModelDef model_def;
  147. bool ret = model_def.ParseFromFileDescriptor(fd);
  148. if (!ret) {
  149. GELOGE(GRAPH_FAILED, "ParseFromFileDescriptor failed");
  150. if (mmClose(fd) != 0) {
  151. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  152. return GRAPH_FAILED;
  153. }
  154. return GRAPH_FAILED;
  155. }
  156. if (mmClose(fd) != 0) {
  157. GELOGE(GRAPH_FAILED, "close file descriptor fail.");
  158. return GRAPH_FAILED;
  159. }
  160. if (!ret) {
  161. GELOGE(GRAPH_FAILED, "function [ParseFromFileDescriptor] failed");
  162. return GRAPH_FAILED;
  163. }
  164. return Load(model_def);
  165. }
  166. ProtoAttrMapHelper Model::MutableAttrMap() { return attrs_; }
  167. ConstProtoAttrMapHelper Model::GetAttrMap() const {
  168. return ConstProtoAttrMapHelper(attrs_.GetProtoOwner(), attrs_.GetProtoMsg());
  169. }
  170. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示