You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

graph_util.cc 4.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. /**
  2. * Copyright 2019 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/graph_util.h"
  17. #include <fstream>
  18. #include <sstream>
  19. #include "common/mslog.h"
  20. #include "include/errorcode.h"
  21. namespace mindspore {
  22. namespace predict {
  23. OpGraph *OpGraph::Build(const SubGraphDef &subGraphDef) {
  24. auto graph = std::unique_ptr<OpGraph>(new OpGraph());
  25. if (graph == nullptr) {
  26. MS_LOGE("malloc opgraph failed");
  27. return nullptr;
  28. }
  29. auto nodeDefs = subGraphDef.nodes();
  30. if (nodeDefs == nullptr) {
  31. MS_LOGE("nodeDefs from subGraphDef is nullptr");
  32. return nullptr;
  33. }
  34. uint32_t opCount = nodeDefs->size();
  35. for (uint32_t i = 0; i < opCount; i++) {
  36. auto nodeDef = nodeDefs->GetAs<NodeDef>(i);
  37. MS_ASSERT(nodeDef != nullptr);
  38. auto ret = graph->AddEdge(*nodeDef, *nodeDefs);
  39. if (ret != RET_OK) {
  40. MS_LOGE("%s add edge failed. ret:%d", nodeDef->opDef()->name()->c_str(), ret);
  41. return nullptr;
  42. }
  43. }
  44. return graph.release();
  45. }
  46. int OpGraph::AddEdge(const NodeDef &srcNodeDef, const flatbuffers::Vector<flatbuffers::Offset<NodeDef>> &nodeDefs) {
  47. MS_ASSERT(srcNodeDef.opDef() != nullptr);
  48. MS_ASSERT(srcNodeDef.opDef()->name() != nullptr);
  49. NODE_ID srcId = std::string(srcNodeDef.opDef()->name()->c_str());
  50. uint32_t opCount = nodeDefs.size();
  51. MS_ASSERT(srcNodeDef.opDef()->outputIndex() != nullptr);
  52. for (auto index : *(srcNodeDef.opDef()->outputIndex())) {
  53. for (uint32_t i = 0; i < opCount; i++) {
  54. auto dstNodeDef = nodeDefs.GetAs<NodeDef>(i);
  55. bool find = false;
  56. MS_ASSERT(dstNodeDef != nullptr);
  57. MS_ASSERT(dstNodeDef->opDef() != nullptr);
  58. auto inputIndex = dstNodeDef->opDef()->inputIndex();
  59. MS_ASSERT(inputIndex != nullptr);
  60. if (std::any_of(inputIndex->begin(), inputIndex->end(), [&index](int i) { return i == index; })) {
  61. find = true;
  62. }
  63. if (!find) {
  64. continue;
  65. }
  66. MS_ASSERT(dstNodeDef->opDef()->name() != nullptr);
  67. NODE_ID dstId = std::string(dstNodeDef->opDef()->name()->c_str());
  68. auto ret = AddEdge(srcId, dstId);
  69. if (ret != RET_OK) {
  70. return ret;
  71. }
  72. }
  73. }
  74. return RET_OK;
  75. }
  76. int OpGraph::AddEdge(const NODE_ID &srcId, const NODE_ID &dstId) {
  77. auto srcNode = AddNode(srcId);
  78. if (srcNode == nullptr) {
  79. MS_LOGE("add srcNode failed");
  80. return RET_ERROR;
  81. }
  82. srcNode->AddOutEdge(dstId);
  83. auto dstNode = AddNode(dstId);
  84. if (dstNode == nullptr) {
  85. MS_LOGE("add dstNode failed");
  86. return RET_ERROR;
  87. }
  88. dstNode->AddInEdge(srcId);
  89. return RET_OK;
  90. }
  91. OpNode *OpGraph::GetNode(const NODE_ID &nodeId) {
  92. auto node = nodes.find(nodeId);
  93. if (node == nodes.end()) {
  94. return nullptr;
  95. }
  96. return node->second;
  97. }
  98. OpNode *OpGraph::AddNode(const NODE_ID &nodeId) {
  99. auto node = GetNode(nodeId);
  100. if (node != nullptr) {
  101. return node;
  102. }
  103. node = new (std::nothrow) OpNode(nodeId);
  104. if (node == nullptr) {
  105. MS_LOGE("new node failed");
  106. return nullptr;
  107. }
  108. nodes[nodeId] = node;
  109. return node;
  110. }
  111. std::unordered_set<NODE_ID> OpGraph::GetInputNode() {
  112. std::unordered_set<NODE_ID> inputNodes;
  113. for (const auto &iter : nodes) {
  114. auto node = iter.second;
  115. MS_ASSERT(node != nullptr);
  116. if (node->GetAllInEdge().empty()) {
  117. inputNodes.insert(node->ID());
  118. }
  119. }
  120. return inputNodes;
  121. }
  122. std::unordered_set<NODE_ID> OpGraph::GetOutputNode() {
  123. std::unordered_set<NODE_ID> outputNodes;
  124. for (const auto &iter : nodes) {
  125. auto node = iter.second;
  126. MS_ASSERT(node != nullptr);
  127. if (node->GetAllOutEdge().empty()) {
  128. outputNodes.insert(node->ID());
  129. }
  130. }
  131. return outputNodes;
  132. }
  133. OpGraph::~OpGraph() {
  134. for (auto iter : nodes) {
  135. if (iter.second != nullptr) {
  136. delete iter.second;
  137. }
  138. }
  139. nodes.clear();
  140. }
  141. NODE_ID OpNode::ID() { return id; }
  142. void OpNode::AddInEdge(const NODE_ID &nodeId) { inEdges.insert(nodeId); }
  143. void OpNode::AddOutEdge(const NODE_ID &nodeId) { outEdges.insert(nodeId); }
  144. std::unordered_set<NODE_ID> OpNode::GetAllInEdge() { return inEdges; }
  145. std::unordered_set<NODE_ID> OpNode::GetAllOutEdge() { return outEdges; }
  146. } // namespace predict
  147. } // namespace mindspore