You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infer_value_range_pass.cc 13 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/infer_value_range_pass.h"
  17. #include "common/util/error_manager/error_manager.h"
  18. #include "framework/common/debug/ge_log.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. #include "graph/operator_factory_impl.h"
  21. #include "graph/passes/folding_pass.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "common/ge/ge_util.h"
  24. #include "init/gelib.h"
  25. using std::unique_ptr;
  26. namespace ge {
  27. namespace {
  28. Status RunCpuKernelForValueRange(NodePtr &node, const vector<ConstGeTensorPtr> &inputs,
  29. std::vector<GeTensorPtr> &outputs) {
  30. // should use RunOpKernelWithCheck, RunOpKernel for ut test
  31. auto ret = FoldingPass::RunOpKernel(node, inputs, outputs);
  32. if (ret != SUCCESS) {
  33. auto op_kernel = folding_pass::GetKernelByType(node);
  34. if (op_kernel == nullptr) {
  35. GELOGE(PARAM_INVALID, "Calculate value range failed, no op kernel for node %s type %s", node->GetName().c_str(),
  36. node->GetType().c_str());
  37. return PARAM_INVALID;
  38. }
  39. ret = op_kernel->Compute(node->GetOpDesc(), inputs, outputs);
  40. if (ret != SUCCESS) {
  41. REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(),
  42. node->GetType().c_str());
  43. GELOGE(INTERNAL_ERROR, "Calculate for node %s failed in constant folding", node->GetName().c_str());
  44. return ret;
  45. }
  46. }
  47. GELOGI("Node %s type %s, run cpu kernel success.", node->GetName().c_str(), node->GetType().c_str());
  48. return SUCCESS;
  49. }
  50. } // namespace
  51. graphStatus InferValueRangePass::Infer(NodePtr &node) {
  52. PrintInOutTensorShape(node, "before_infer_value_range");
  53. auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());
  54. // Use registered func to calculate value range
  55. if (!infer_value_range_param.use_cpu_kernel) {
  56. if (infer_value_range_param.infer_value_func == nullptr) {
  57. GELOGE(GRAPH_PARAM_INVALID, "The registered func to infer value range is nullptr.");
  58. return GRAPH_PARAM_INVALID;
  59. }
  60. Operator op = OpDescUtils::CreateOperatorFromNode(node);
  61. auto ret = node->GetOpDesc()->CallInferValueRangeFunc(op);
  62. if (ret != GRAPH_SUCCESS) {
  63. REPORT_CALL_ERROR("E19999", "Node %s call infer value range function failed.", node->GetName().c_str());
  64. GELOGE(GRAPH_FAILED, "[Call][InferFunction] failed, node: %s.", node->GetName().c_str());
  65. return GRAPH_FAILED;
  66. }
  67. return GRAPH_SUCCESS;
  68. }
  69. // Use CPU kernel func to calculate value range
  70. return ConstructInputAndInferValueRange(node);
  71. }
  72. bool InferValueRangePass::NeedInfer(const NodePtr &node) {
  73. auto infer_value_range_param = OperatorFactoryImpl::GetInferValueRangePara(node->GetType());
  74. if (!infer_value_range_param.is_initialized) {
  75. GELOGD("Node %s does not register func to infer value range, skip infer_value_range_pass.",
  76. node->GetName().c_str());
  77. return false;
  78. }
  79. if (infer_value_range_param.when_call == INPUT_IS_DYNAMIC) {
  80. // Only do infer for node that all inputs are dynamic, such as shape
  81. if (InputIsDynamic(node)) {
  82. return true;
  83. }
  84. } else if (infer_value_range_param.when_call == INPUT_HAS_VALUE_RANGE) {
  85. // Only do infer for node that all inputs have value_range or node type of inputs is constant/const
  86. if (InputIsConstOrHasValueRange(node)) {
  87. return true;
  88. }
  89. }
  90. GELOGD("Node %s does not need to infer value range, skip infer_value_range_pass.", node->GetName().c_str());
  91. return false;
  92. }
  93. graphStatus InferValueRangePass::UpdateTensorDesc(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
  94. changed = false;
  95. std::vector<std::pair<int64_t, int64_t>> src_value_range;
  96. std::vector<std::pair<int64_t, int64_t>> dst_value_range;
  97. (void)src->GetValueRange(src_value_range);
  98. (void)dst->GetValueRange(dst_value_range);
  99. if (src_value_range != dst_value_range) {
  100. changed = true;
  101. }
  102. dst = src;
  103. return GRAPH_SUCCESS;
  104. }
  105. graphStatus InferValueRangePass::UpdateInputDescAttr(const GeTensorDescPtr &src, GeTensorDescPtr &dst, bool &changed) {
  106. changed = false;
  107. std::vector<std::pair<int64_t, int64_t>> src_value_range;
  108. std::vector<std::pair<int64_t, int64_t>> dst_value_range;
  109. (void)src->GetValueRange(src_value_range);
  110. (void)dst->GetValueRange(dst_value_range);
  111. if (src_value_range != dst_value_range) {
  112. changed = true;
  113. }
  114. dst->SetValueRange(src_value_range);
  115. return GRAPH_SUCCESS;
  116. }
  117. void InferValueRangePass::AnalyzeFailedInfo(const NodePtr &node) {
  118. REPORT_CALL_ERROR("E19999", "Infer value range for node:%s(%s) failed.", node->GetName().c_str(),
  119. node->GetType().c_str());
  120. GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infer value range failed. node: %s", node->GetName().c_str());
  121. }
  122. bool InferValueRangePass::InputIsDynamic(const NodePtr &node) {
  123. bool input_is_dynamic = false;
  124. auto cur_op_desc = node->GetOpDesc();
  125. for (const auto &input_desc : cur_op_desc->GetAllInputsDescPtr()) {
  126. auto dims = input_desc->GetShape().GetDims();
  127. for (auto dim : dims) {
  128. if (dim == UNKNOWN_DIM || dim == UNKNOWN_DIM_NUM) {
  129. input_is_dynamic = true;
  130. break;
  131. }
  132. }
  133. }
  134. return input_is_dynamic;
  135. }
  136. bool InferValueRangePass::InputIsConstOrHasValueRange(const NodePtr &node) {
  137. bool input_is_const_or_has_value_range = true;
  138. auto cur_op_desc = node->GetOpDesc();
  139. auto in_data_anchors = node->GetAllInDataAnchors();
  140. for (auto i = 0; i < in_data_anchors.size(); ++i) {
  141. auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
  142. if (peer_out_anchor == nullptr) {
  143. continue;
  144. }
  145. auto peer_node = peer_out_anchor->GetOwnerNode();
  146. if (peer_node == nullptr) {
  147. continue;
  148. }
  149. if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
  150. continue;
  151. }
  152. const auto &input_desc = cur_op_desc->GetInputDesc(i);
  153. std::vector<std::pair<int64_t, int64_t>> value_range;
  154. (void)input_desc.GetValueRange(value_range);
  155. if (value_range.empty()) {
  156. input_is_const_or_has_value_range = false;
  157. break;
  158. }
  159. }
  160. return input_is_const_or_has_value_range;
  161. }
  162. vector<ConstGeTensorPtr> InferValueRangePass::ConstructInputTensors(const NodePtr &node, bool use_floor_value) {
  163. vector<ConstGeTensorPtr> input_tensors;
  164. auto cur_op_desc = node->GetOpDesc();
  165. auto in_data_anchors = node->GetAllInDataAnchors();
  166. for (auto i = 0; i < in_data_anchors.size(); ++i) {
  167. auto peer_out_anchor = in_data_anchors.at(i)->GetPeerOutAnchor();
  168. if (peer_out_anchor == nullptr) {
  169. continue;
  170. }
  171. auto peer_node = peer_out_anchor->GetOwnerNode();
  172. if (peer_node == nullptr) {
  173. continue;
  174. }
  175. // construct input tensor by constant node
  176. if ((peer_node->GetType() == CONSTANT) || (peer_node->GetType() == CONSTANTOP)) {
  177. vector<GeTensorPtr> const_weight = OpDescUtils::MutableWeights(peer_node);
  178. if (const_weight.empty()) {
  179. REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight is empty, node: %s(%s)",
  180. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  181. GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight is empty, node: %s(%s)", peer_node->GetName().c_str(),
  182. peer_node->GetType().c_str());
  183. return vector<ConstGeTensorPtr>();
  184. }
  185. // const/constant op has only one weight
  186. if (const_weight.at(0) == nullptr) {
  187. REPORT_INNER_ERROR("E19999", "MutableWeights failed, weight of constant is null, node: %s(%s)",
  188. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  189. GELOGE(INTERNAL_ERROR, "MutableWeights failed, weight of constant is null, node name: %s(%s)",
  190. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  191. return vector<ConstGeTensorPtr>();
  192. }
  193. input_tensors.push_back(const_weight.at(0));
  194. continue;
  195. }
  196. // construct input tensor by boundary of value range
  197. const auto &input_tensor_desc = cur_op_desc->GetInputDesc(i);
  198. std::vector<std::pair<int64_t, int64_t>> value_range;
  199. (void)input_tensor_desc.GetValueRange(value_range);
  200. if (value_range.size() != input_tensor_desc.GetShape().GetShapeSize()) {
  201. REPORT_INNER_ERROR("E19999", "Value range of input %s is invalid.", input_tensor_desc.GetName().c_str());
  202. GELOGE(GRAPH_PARAM_INVALID, "Value range of input %s is invalid.", input_tensor_desc.GetName().c_str());
  203. return vector<ConstGeTensorPtr>();
  204. }
  205. auto value_range_data_num = value_range.size();
  206. unique_ptr<int64_t[]> buf(new (std::nothrow) int64_t[value_range_data_num]());
  207. if (buf == nullptr) {
  208. REPORT_INNER_ERROR("E19999", "New buf failed");
  209. GELOGE(MEMALLOC_FAILED, "new buf failed");
  210. return vector<ConstGeTensorPtr>();
  211. }
  212. for (auto j = 0; j < value_range_data_num; ++j) {
  213. buf[j] = use_floor_value ? value_range[j].first : value_range[j].second;
  214. }
  215. GeTensorPtr tmp_tensor_ptr = MakeShared<GeTensor>(input_tensor_desc, reinterpret_cast<uint8_t *>(buf.get()),
  216. sizeof(int64_t) * value_range_data_num);
  217. if (tmp_tensor_ptr == nullptr) {
  218. REPORT_INNER_ERROR("E19999", "Make shared failed");
  219. GELOGE(MEMALLOC_FAILED, "Make shared failed");
  220. return vector<ConstGeTensorPtr>();
  221. }
  222. tmp_tensor_ptr->MutableTensorDesc().SetDataType(ge::DT_INT64);
  223. input_tensors.push_back(tmp_tensor_ptr);
  224. }
  225. return input_tensors;
  226. }
  227. graphStatus InferValueRangePass::ConstructInputAndInferValueRange(NodePtr &node) {
  228. auto inputs = ConstructInputTensors(node, true);
  229. if (inputs.empty()) {
  230. return GRAPH_PARAM_INVALID;
  231. }
  232. vector<GeTensorPtr> outputs_lower;
  233. auto ret = RunCpuKernelForValueRange(node, inputs, outputs_lower);
  234. if (ret != SUCCESS) {
  235. REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
  236. GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str());
  237. return GRAPH_FAILED;
  238. }
  239. inputs = ConstructInputTensors(node, false);
  240. if (inputs.empty()) {
  241. return GRAPH_PARAM_INVALID;
  242. }
  243. vector<GeTensorPtr> outputs_higher;
  244. ret = RunCpuKernelForValueRange(node, inputs, outputs_higher);
  245. if (ret != SUCCESS) {
  246. REPORT_INNER_ERROR("E19999", "Calculate for node %s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
  247. GELOGE(GRAPH_FAILED, "Calculate for node %s failed in constant folding", node->GetName().c_str());
  248. return GRAPH_FAILED;
  249. }
  250. // construct value range from output tensor
  251. OpDescPtr node_desc = node->GetOpDesc();
  252. std::vector<std::pair<int64_t, int64_t>> output_tensor_value_range;
  253. size_t node_output_desc_size = node_desc->GetOutputsSize();
  254. for (size_t i = 0; i < node_output_desc_size; ++i) {
  255. output_tensor_value_range.clear();
  256. auto lower_tensor = outputs_lower[i];
  257. auto lower_tensor_shape_size = lower_tensor->GetTensorDesc().GetShape().GetShapeSize();
  258. auto higher_tensor = outputs_higher[i];
  259. auto higher_tensor_shape_size = higher_tensor->GetTensorDesc().GetShape().GetShapeSize();
  260. auto output_tensor_desc = node_desc->MutableOutputDesc(i);
  261. auto output_tensor_shape_size = output_tensor_desc->GetShape().GetShapeSize();
  262. if (output_tensor_shape_size != lower_tensor_shape_size || output_tensor_shape_size != higher_tensor_shape_size) {
  263. GELOGE(GRAPH_PARAM_INVALID, "Value range of output %s is invalid.", output_tensor_desc->GetName().c_str());
  264. }
  265. for (auto j = 0; j < output_tensor_shape_size; ++j) {
  266. int64_t *x = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(lower_tensor->GetData().GetData()));
  267. int64_t *y = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(higher_tensor->GetData().GetData()));
  268. std::pair<int64_t, int64_t> pair({x[j], y[j]});
  269. output_tensor_value_range.emplace_back(pair);
  270. }
  271. output_tensor_desc->SetValueRange(output_tensor_value_range);
  272. }
  273. return GRAPH_SUCCESS;
  274. }
  275. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示