You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

common.h 4.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef ST_RESNET50_GE_COMMON_H_
  17. #define ST_RESNET50_GE_COMMON_H_
  18. #include "common/ge_inner_error_codes.h"
  19. #include "utils/tensor_utils.h"
  20. #define MY_USER_GE_LOGI(...) GE_LOG_INFO(1, __VA_ARGS__)
  21. #define MY_USER_GE_LOGW(...) GE_LOG_WARN(1, __VA_ARGS__)
  22. #define MY_USER_GE_LOGE(...) GE_LOG_ERROR(1, 3, __VA_ARGS__)
  23. #ifndef USER_GE_LOGI
  24. #define USER_GE_LOGI MY_USER_GE_LOGI
  25. #endif // USER_GE_LOGI
  26. #ifndef USER_GE_LOGW
  27. #define USER_GE_LOGW MY_USER_GE_LOGW
  28. #endif // USER_GE_LOGW
  29. #ifndef USER_GE_LOGE
  30. #define USER_GE_LOGE MY_USER_GE_LOGE
  31. #endif // USER_GE_LOGE
  32. /// train_flag is 0 when infer, train_flag is 1 when train.this param is set for RunGranph_readData() and
  33. /// RunGraph_initData()
  34. #define TRAIN_FLAG_INFER "infer"
  35. #define TRAIN_FLAG_TRAIN "train"
  36. #include <string.h>
  37. #include <unistd.h>
  38. #include <algorithm>
  39. #include <chrono>
  40. #include <iostream>
  41. #include <thread>
  42. #include <vector>
  43. #include "ge_api.h"
  44. #include "graph.h"
  45. #include "ptest.h"
  46. #include "ops/all_ops.h"
  47. using namespace std;
  48. using namespace ge;
  49. // read bin file and compile result
  50. void update_op_format(Operator ops, Format format = ge::FORMAT_NCHW);
  51. void getDimInfo(FILE *fp, std::vector<uint64_t> &dim_info);
  52. void *readTestDataFile(std::string infile, std::vector<uint64_t> &dim_info);
  53. void *readUint8TestDataFile(std::string infile, int size);
  54. bool allclose(float *a, float *b, uint64_t count, float rtol, float atol);
  55. bool compFp32WithTData(float *actual_output_data, std::string expected_data_file, float rtol, float atol);
  56. Tensor load_variable_input_data(string input_path, std::vector<int64_t> shapes, Format ft = ge::FORMAT_NCHW,
  57. DataType dt = ge::DT_FLOAT);
  58. // constructor Tensor
  59. int GetDatTypeSize(DataType dt);
  60. ge::Tensor genTensor(std::vector<int64_t> tensor_shape, Format format = ge::FORMAT_NCHW, DataType dt = ge::DT_FLOAT);
  61. ge::Tensor genTensor_withVaule(std::vector<int64_t> tensor_shape, float value = 1);
  62. Tensor genTesnor_Shape_as_data(std::vector<int64_t> tensor_shape);
  63. // Init GE
  64. ge::Status GEInitialize_api(string train_flag = "0", string run_mode_path = "0");
  65. ge::Status GEInitialize_api_new(string train_flag = "infer", string run_mode = "fe");
  66. ge::Status GEFinalize_api();
  67. // constructor session and build graph
  68. ge::Session *create_aipp_session();
  69. ge::Session *create_session();
  70. ge::Status session_add_and_run_graph(ge::Session *session, uint32_t graphId, Graph &graph, std::vector<Tensor> inputs,
  71. std::vector<Tensor> &outputs);
  72. // common interface for infer
  73. int RunGraph_initData(Graph &graph, string op_name, map<string, std::vector<int64_t>> attr_test,
  74. string train_flag = "infer", string run_mode_path = "fe");
  75. void Inputs_load_Data(string op_name, std::vector<Tensor> &input, map<string, std::vector<int64_t>> attr_test,
  76. Format format = ge::FORMAT_NCHW, DataType dt = ge::DT_FLOAT);
  77. bool comparaData(std::vector<Tensor> &output, string op_name, map<string, std::vector<int64_t>> attr_test);
  78. int RunGraph_readData(Graph &graph, string op_name, map<string, std::vector<int64_t>> attr_test,
  79. string train_flag = "infer", string run_mode_path = "fe", Format format = ge::FORMAT_NCHW,
  80. DataType dt = ge::DT_FLOAT);
  81. // common interface for train
  82. int buildCheckPointGraph(Graph &graph, map<string, TensorDesc> variables);
  83. int buildInitGraph(Graph &graph, std::vector<TensorDesc> desc_var, std::vector<std::string> name_var,
  84. std::vector<float> values_var);
  85. int buildInitGraph_other_dataType(Graph &graph, std::vector<TensorDesc> desc_var, std::vector<std::string> name_var);
  86. bool build_multi_input_multi_output_graph(Graph &graph);
  87. void build_big_graph(Graph &graph, map<string, std::vector<int64_t>> attr);
  88. int buildConvGraph_new(Graph &graph, std::vector<TensorDesc> desc_var, std::vector<std::string> name_var, int flag = 2);
  89. #endif // ST_RESNET50_GE_COMMON_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示