Browse Source

graph id

Signed-off-by: zhupuxu <zhupuxu@huawei.com>
tags/v1.5.1
zhupuxu 3 years ago
parent
commit
69a8d570d0
13 changed files with 370 additions and 14 deletions
  1. +32
    -0
      ge/common/profiling/ge_profiling.cc
  2. +47
    -4
      ge/common/profiling/profiling_manager.cc
  3. +18
    -1
      ge/common/profiling/profiling_manager.h
  4. +43
    -1
      ge/graph/execute/graph_execute.cc
  5. +5
    -0
      ge/graph/execute/graph_execute.h
  6. +11
    -2
      ge/graph/load/model_manager/davinci_model.cc
  7. +17
    -5
      ge/graph/load/model_manager/model_manager.cc
  8. +5
    -0
      ge/graph/manager/graph_manager.cc
  9. +4
    -0
      ge/session/inner_session.cc
  10. +2
    -0
      inc/framework/common/profiling/ge_profiling.h
  11. +44
    -0
      tests/ut/ge/graph/execute/graph_execute_unittest.cc
  12. +29
    -1
      tests/ut/ge/graph/load/model_manager_unittest.cc
  13. +113
    -0
      tests/ut/ge/profiling/ge_profiling_manager_unittest.cc

+ 32
- 0
ge/common/profiling/ge_profiling.cc View File

@@ -20,9 +20,11 @@
#include "framework/common/debug/ge_log.h"
#include "framework/common/debug/log.h"
#include "graph/load/graph_loader.h"
#include "graph/ge_context.h"
#include "init/gelib.h"
#include "framework/common/ge_inner_error_codes.h"
#include "model/ge_model.h"
#include "framework/omg/omg_inner_types.h"

namespace {
const uint32_t kDeviceListIndex = 3;
@@ -35,6 +37,7 @@ const std::string kProfilingStop = "prof_stop";
const std::string kProfModelSubscribe = "prof_model_subscribe";
const std::string kProfModelUnsubscribe = "prof_model_cancel_subscribe";
const std::string kRtSetDeviceRegName = "profiling";
const std::string kPofilingModelId = "modelId";

const std::map<ProfCommandHandleType, std::string> kProfCommandTypeMap = {
{kProfCommandhandleInit, kProfilingInit},
@@ -195,6 +198,31 @@ ge::Status ProfCommandHandle(ProfCommandHandleType type, void *data, uint32_t le
return ge::PARAM_INVALID;
}
}
auto &profiling_manager = ge::ProfilingManager::Instance();
auto is_train = domi::GetContext().train_flag;
if (type == kProfCommandhandleModelSubscribe && is_train) {
profiling_manager.SetSubscribeInfo(prof_config_param->profSwitch, prof_config_param->modelId, true);
return ge::SUCCESS;
}
auto is_subscribe = profiling_manager.GetSubscribeInfo().is_subscribe;
if (type == kProfCommandhandleModelUnsubscribe && is_subscribe) {
prof_params.clear();
prof_params.emplace_back(kPofilingModelId);
uint32_t model_id = 0;
// GraphId is actually stored in prof_config_param
uint32_t graph_id = prof_config_param->modelId;
auto ret = profiling_manager.GetModelIdFromGraph(graph_id, model_id);
if (ret != ge::SUCCESS) {
GELOGE(ret, "graph_id:%u not not found", graph_id);
REPORT_INPUT_ERROR("E10001", std::vector<std::string>({"value", "parameter", "reason"}),
std::vector<std::string>({std::to_string(graph_id),
"GraphToModelMap",
"graph_id does not exist!"}));
return ge::FAILED;
}

prof_params.emplace_back(std::to_string(model_id));
}
ge::GraphLoader graph_loader;
ge::Command command;
command.cmd_params.clear();
@@ -248,3 +276,7 @@ ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream
"tag id must be 0 when first run, must be 1 when second run"}));
return ge::FAILED;
}

ge::Status ProfGetDeviceFormGraphId(uint32_t graph_id, uint32_t &device_id) {
return ge::ProfilingManager::Instance().GetDeviceIdFromGraph(graph_id, device_id);
}

+ 47
- 4
ge/common/profiling/profiling_manager.cc View File

@@ -66,10 +66,13 @@ const std::string kIdx = "idx";

namespace ge {
ProfilingManager::ProfilingManager()
: is_load_profiling_(false), is_execute_profiling_(false), is_training_trace_(false), subscribe_count_(0) {
prof_cb_.msprofCtrlCallback = nullptr;
prof_cb_.msprofReporterCallback = nullptr;
index_id_ = UINT64_MAX;
: is_load_profiling_(false),
is_execute_profiling_(false),
is_training_trace_(false),
subscribe_count_(0),
prof_cb_({nullptr, nullptr}),
index_id_(UINT64_MAX),
subscribe_info_({false, 0, 0}) {
}

ProfilingManager::~ProfilingManager() {}
@@ -610,6 +613,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFi
// profiling plugin uninit
PluginUnInit();

CleanSubscribeInfo();

int32_t dev_num = -1;
rtError_t rt_ret = rtProfilerStop(PROF_MODEL_LOAD_MASK, dev_num, nullptr);
if (rt_ret != RT_ERROR_NONE) {
@@ -632,6 +637,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::ProfFi
}
device_id_module_map_.clear();
device_id_.clear();
device_id_map_.clear();
model_id_map_.clear();
GELOGI("Prof finalize success.");
#endif
return SUCCESS;
@@ -1057,4 +1064,40 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP
return;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::GetDeviceIdFromGraph(
uint32_t graph_id, uint32_t &device_id) {
auto iter = device_id_map_.find(graph_id);
if (iter != device_id_map_.end()) {
device_id = iter->second;
return SUCCESS;
}
REPORT_CALL_ERROR("E19999", "graph_id:%u does not exist!", graph_id);
GELOGE(PARAM_INVALID, "[Check][GraphId]graph_id:%u does not exist!", graph_id);
return FAILED;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::SetSubscribeInfo(
uint64_t prof_switch, uint32_t model_id, bool is_subscribe) {
subscribe_info_.is_subscribe = is_subscribe;
subscribe_info_.prof_switch = prof_switch;
subscribe_info_.graph_id = model_id;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::CleanSubscribeInfo() {
subscribe_info_.is_subscribe = false;
subscribe_info_.prof_switch = 0;
subscribe_info_.graph_id = 0;
}

FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ProfilingManager::GetModelIdFromGraph(
uint32_t graph_id, uint32_t &model_id) {
auto iter = model_id_map_.find(graph_id);
if (iter != model_id_map_.end()) {
model_id = iter->second;
return SUCCESS;
}
REPORT_CALL_ERROR("E19999", "graph_id:%u does not exist!", graph_id);
GELOGE(PARAM_INVALID, "[Check][GraphId]graph_id:%u does not exist!", graph_id);
return FAILED;
}
} // namespace ge

+ 18
- 1
ge/common/profiling/profiling_manager.h View File

@@ -62,6 +62,12 @@ struct DeviceSubsInfo {
uint32_t subscribe_count;
};

struct ProfSubscribeInfo {
bool is_subscribe;
uint64_t prof_switch;
uint32_t graph_id;
};

struct MsprofCallback {
MsprofCtrlCallback msprofCtrlCallback;
MsprofReporterCallback msprofReporterCallback;
@@ -102,7 +108,15 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager {
void ReportData(const int32_t &device_id, const std::string &data, const std::string &tag_name);
Status ProfileStepInfo(uint64_t index_id, uint64_t model_id, uint16_t tag_id, rtStream_t stream, int32_t device_id);
void SetStepInfoIndex(uint64_t index_id) { index_id_ = index_id; }
uint64_t GetStepInfoIndex() { return index_id_; }
uint64_t GetStepInfoIndex() const { return index_id_; }
void SetGraphIdToDeviceMap(uint32_t graph_id, uint32_t device_id) { device_id_map_[graph_id] = device_id; }
Status GetDeviceIdFromGraph(uint32_t graph_id, uint32_t &device_id);
void SetSubscribeInfo(uint64_t prof_switch, uint32_t model_id, bool is_subscribe);
const ProfSubscribeInfo &GetSubscribeInfo() const { return subscribe_info_; }
void CleanSubscribeInfo();
void SetGraphIdToModelMap(uint32_t graph_id, uint32_t model_id) { model_id_map_[graph_id] = model_id; }
Status GetModelIdFromGraph(uint32_t graph_id, uint32_t &model_id);

private:
Status InitFromOptions(const Options &options, MsprofGeOptions &prof_conf);
Status ParseOptions(const std::string &options);
@@ -130,6 +144,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager {
std::string bp_point_;
uint32_t reporter_max_len_ = 0;
uint64_t index_id_;
std::map<uint32_t, uint32_t> device_id_map_; // key: graph_id, value: device_id
std::map<uint32_t, uint32_t> model_id_map_; // key: graph_id, value: model_id
ProfSubscribeInfo subscribe_info_;
};
} // namespace ge
#endif // GE_COMMON_PROFILING_PROFILING_MANAGER_H_

+ 43
- 1
ge/graph/execute/graph_execute.cc View File

@@ -21,6 +21,7 @@

#include "graph/load/model_manager/model_manager.h"
#include "graph/load/model_manager/davinci_model.h"
#include "common/profiling/profiling_manager.h"

namespace ge {
using Uint32Pair = pair<uint32_t, uint32_t>;
@@ -365,7 +366,11 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_ro
GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[SyncExecute][Model] Error! graph id:%u", graph_id);
return GE_GRAPH_SYNC_MODEL_FAILED;
}

ret = ModelSubscribe(graph_id);
if (ret != SUCCESS) {
GELOGE(ret, "[Call][ModelSubscribe] failed, graph_id:%u", graph_id);
return ret;
}
return SUCCESS;
}

@@ -776,4 +781,41 @@ Status GraphExecutor::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint
}
return SUCCESS;
}

Status GraphExecutor::GetModelByID(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model) {
auto model_manager = ge::ModelManager::GetInstance();
GE_CHECK_NOTNULL(model_manager);
davinci_model = model_manager->GetModel(static_cast<uint32_t>(model_id));
if (davinci_model == nullptr) {
REPORT_INNER_ERROR("E19999", "GetModel from model_manager fail, model_id:%u", model_id);
GELOGE(ge::FAILED, "[Get][Model] failed, Model id:%d is invaild or model is not loaded.", model_id);
return ge::FAILED;
}
return ge::SUCCESS;
}

Status GraphExecutor::ModelSubscribe(uint32_t graph_id) {
auto &profiling_manager = ProfilingManager::Instance();
const auto &subcribe_info = profiling_manager.GetSubscribeInfo();
if (subcribe_info.is_subscribe) {
std::shared_ptr<DavinciModel> davinci_model = nullptr;
uint32_t model_id = 0;
Status ret = profiling_manager.GetModelIdFromGraph(graph_id, model_id);
if (ret != SUCCESS) {
GELOGE(ret, "[Call][GetModelIdFromGraph] failed, graph_id:%u", graph_id);
return ret;
}
ret = GetModelByID(model_id, davinci_model);
if (ret != SUCCESS) {
GELOGE(ret, "[Call][GetModelByID] failed, model_id:%u", model_id);
return ret;
}
ret = profiling_manager.ProfModelSubscribe(subcribe_info.prof_switch, davinci_model.get());
if (ret != SUCCESS) {
GELOGE(ret, "[Call][ProfModelSubscribe] failed");
return ret;
}
}
return SUCCESS;
}
} // namespace ge

+ 5
- 0
ge/graph/execute/graph_execute.h View File

@@ -38,6 +38,7 @@
#include "graph/model.h"
#include "graph/utils/graph_utils.h"
#include "graph/utils/tensor_utils.h"
#include "graph/load/model_manager/davinci_model.h"

namespace ge {
class GraphExecutor {
@@ -148,6 +149,10 @@ class GraphExecutor {
static Status SetCallback(uint32_t model_id, const GeRootModelPtr &ge_root_model,
const RunAsyncCallback &callback);

Status ModelSubscribe(uint32_t graph_id);

Status GetModelByID(uint32_t model_id, std::shared_ptr<DavinciModel> &davinci_model);

bool init_flag_;

bool train_graph_flag_;


+ 11
- 2
ge/graph/load/model_manager/davinci_model.cc View File

@@ -62,6 +62,7 @@
#include "graph/common/omg_util.h"
#include "graph/build/memory/block_mem_assigner.h"
#include "graph/manager/session_scope_mem_allocator.h"
#include "framework/omg/omg_inner_types.h"

// create std::thread, catch exceptions using try/catch
#define CREATE_STD_THREAD(thread_id, func, args) \
@@ -763,8 +764,16 @@ void DavinciModel::SaveSpecifyAttrValues(const OpDescPtr &op_desc) {
}

Status DavinciModel::ReportProfilingData() {
ProfilingManager::Instance().ReportProfilingData(model_id_, GetTaskDescInfo());
GE_CHK_STATUS(SinkModelProfile(), "[Sink][ModelProfile] failed, model_id:%u.", model_id_);
bool is_train = domi::GetContext().train_flag;
auto model_id = model_id_;
auto &profiling_manager = ProfilingManager::Instance();
auto graph_id = runtime_param_.graph_id;
if (is_train) {
GELOGD("Replace model_id:%u with graph_id:%u, when training.", model_id, graph_id);
model_id = graph_id;
}
profiling_manager.ReportProfilingData(model_id, GetTaskDescInfo());
GE_CHK_STATUS(SinkModelProfile(), "[Sink][ModelProfile] failed, model_id:%u.", model_id);

return SUCCESS;
}


+ 17
- 5
ge/graph/load/model_manager/model_manager.cc View File

@@ -368,7 +368,17 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge

GELOGI("Parse model %u success.", model_id);
} while (0);

auto &profiling_manager = ProfilingManager::Instance();
const auto &subcribe_info = profiling_manager.GetSubscribeInfo();
if (subcribe_info.is_subscribe) {
auto graph_id = davinci_model->GetRuntimeParam().graph_id;
if(subcribe_info.graph_id == graph_id) {
profiling_manager.SetGraphIdToModelMap(graph_id, model_id);
}
else {
GELOGW("graph_id:%u is not in subcribe info.", graph_id);
}
}
return ret;
}

@@ -758,12 +768,15 @@ Status ModelManager::HandleProfModelUnsubscribeCommand(const Command &command) {
if (ret != SUCCESS) {
return ret;
}
if (ProfilingManager::Instance().ProfModelUnsubscribe(static_cast<void *>(davinci_model.get())) != SUCCESS) {
auto &profiling_manager = ProfilingManager::Instance();
if (profiling_manager.ProfModelUnsubscribe(static_cast<void *>(davinci_model.get())) != SUCCESS) {
GELOGE(FAILED, "[Handle][ProfModelUnsubscribe] failed.");
return FAILED;
}

auto is_subscribe = profiling_manager.GetSubscribeInfo().is_subscribe;
if (is_subscribe) {
profiling_manager.CleanSubscribeInfo();
}
return SUCCESS;
}

@@ -1826,5 +1839,4 @@ Status ModelManager::CheckAicpuOpList(GeModelPtr ge_model) {
"[Call][LaunchKernelCheckAicpuOp] failed.");
return SUCCESS;
}

} // namespace ge

+ 5
- 0
ge/graph/manager/graph_manager.cc View File

@@ -109,6 +109,7 @@
#include "register/custom_pass_helper.h"
#include "external/graph/types.h"
#include "common/util/error_manager/error_manager.h"
#include "common/profiling/profiling_manager.h"

namespace {
const char *const kSummary = "Summary";
@@ -462,6 +463,9 @@ Status GraphManager::AddGraph(const GraphId &graph_id, const Graph &graph,
const std::map<std::string, std::string> &options,
const OmgContext &omg_context) {
IncreaseGraphCount(graph_id);
auto device_id = GetContext().DeviceId();
GELOGD("Device id is %u", device_id);
ProfilingManager::Instance().SetGraphIdToDeviceMap(graph_id, device_id);
// validation for adding graphs of same graph_id in multi-thread secenario
// 1.previous thread owns same graph_id has finished the AddGraph procession
if (GetAddGraphCondition(graph_id) == kDoneAdded) {
@@ -1715,6 +1719,7 @@ Status GraphManager::ParseTrainGraphFlag(bool &train_flag) {
train_flag = true;
}
}
domi::GetContext().train_flag = train_flag;
GELOGI("Is train flag: %d.", train_flag);
return SUCCESS;
}


+ 4
- 0
ge/session/inner_session.cc View File

@@ -35,6 +35,7 @@
#include "graph/utils/tensor_adapter.h"
#include "runtime/mem.h"
#include "ir_build/option_utils.h"
#include "common/profiling/profiling_manager.h"

namespace ge {
namespace {
@@ -231,6 +232,9 @@ Status InnerSession::GetVariable(const std::string &name, Tensor &val) {

Status InnerSession::AddGraph(uint32_t graph_id, const Graph &graph) {
std::map<std::string, std::string> options;
auto device_id = GetContext().DeviceId();
GELOGD("Device id is %u", device_id);
ProfilingManager::Instance().SetGraphIdToDeviceMap(graph_id, device_id);
return AddGraph(graph_id, graph, options);
}



+ 2
- 0
inc/framework/common/profiling/ge_profiling.h View File

@@ -50,4 +50,6 @@ GE_FUNC_VISIBILITY ge::Status ProfCommandHandle(ProfCommandHandleType type, void
///
GE_FUNC_VISIBILITY ge::Status ProfSetStepInfo(uint64_t index_id, uint16_t tag_id, rtStream_t stream);

GE_FUNC_VISIBILITY ge::Status ProfGetDeviceFormGraphId(uint32_t graph_id, uint32_t &device_id);

#endif // INC_FRAMEWORK_COMMON_GE_PROFILING_H_

+ 44
- 0
tests/ut/ge/graph/execute/graph_execute_unittest.cc View File

@@ -17,6 +17,8 @@
#include <gtest/gtest.h>
#include <memory>

#include "common/profiling/profiling_manager.h"

#define protected public
#define private public
#include "graph/execute/graph_execute.h"
@@ -125,4 +127,46 @@ TEST_F(UtestGraphExecuteTest, test_set_callback) {
auto status = executor.SetCallback(1, ge_root_model, callback);
EXPECT_EQ(status, SUCCESS);
}

TEST_F(UtestGraphExecuteTest, test_without_subscribe) {
GraphExecutor executor;
auto ret = executor.ModelSubscribe(1);
EXPECT_EQ(ret, SUCCESS);
}

TEST_F(UtestGraphExecuteTest, test_with_subscribe_failed1) {
GraphExecutor executor;
uint32_t graph_id = 1;
auto &profiling_manager = ProfilingManager::Instance();
profiling_manager.SetSubscribeInfo(0, 1, true);
auto ret = executor.ModelSubscribe(graph_id);
profiling_manager.CleanSubscribeInfo();
EXPECT_NE(ret, SUCCESS);
}

TEST_F(UtestGraphExecuteTest, test_with_subscribe_failed2) {
GraphExecutor executor;
uint32_t graph_id = 1;
uint32_t model_id = 1;
auto &profiling_manager = ProfilingManager::Instance();
profiling_manager.SetSubscribeInfo(0, 1, true);
profiling_manager.SetGraphIdToModelMap(2, model_id);
auto ret = executor.ModelSubscribe(graph_id);
profiling_manager.CleanSubscribeInfo();
EXPECT_NE(ret, SUCCESS);
}

TEST_F(UtestGraphExecuteTest, test_with_subscribe_success) {
GraphExecutor executor;
uint32_t graph_id = 1;
uint32_t model_id = 1;
GraphNodePtr graph_node = std::make_shared<GraphNode>(graph_id);
DavinciModel model(model_id, nullptr);
auto &profiling_manager = ProfilingManager::Instance();
profiling_manager.SetSubscribeInfo(0, 1, true);
profiling_manager.SetGraphIdToModelMap(graph_id, model_id);
auto ret = executor.ModelSubscribe(graph_id);
profiling_manager.CleanSubscribeInfo();
EXPECT_EQ(ret, SUCCESS);
}
} // namespace ge

+ 29
- 1
tests/ut/ge/graph/load/model_manager_unittest.cc View File

@@ -26,6 +26,7 @@
#include "graph/load/graph_loader.h"
#include "graph/load/model_manager/davinci_model.h"
#include "graph/ops_stub.h"
#include "common/profiling/profiling_manager.h"

using namespace std;
using namespace testing;
@@ -135,7 +136,8 @@ class UtestModelManagerModelManager : public testing::Test {
class DModelListener : public ModelListener {
public:
DModelListener(){};
uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; }
uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index,
uint32_t resultCode, std::vector<ge::Tensor> &outputs) { return 0; }
};

TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) {
@@ -426,4 +428,30 @@ TEST_F(UtestModelManagerModelManager, test_launch_kernel_cust_aicpu) {
EXPECT_EQ(mm.LaunchKernelCustAicpuSo("deleteCustOp"), SUCCESS);
EXPECT_TRUE(mm.cust_aicpu_so_.empty());
}

shared_ptr<ModelListener> listerner(new DModelListener());
TEST_F(UtestModelManagerModelManager, test_load_model_online) {
ModelManager mm;
uint32_t model_id = 1;
ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetSubscribeInfo(0, model_id, true);
Status ret = mm.LoadModelOnline(model_id, ge_root_model, listerner);
profiling_manager.CleanSubscribeInfo();
}

TEST_F(UtestModelManagerModelManager, command_profiling) {
ModelManager manager;
uint32_t model_id = 1;
Command cmd;
auto model = std::make_shared<DavinciModel>(1, listerner);
model->SetId(model_id);
cmd.cmd_params.push_back("modelId");
cmd.cmd_params.push_back(to_string(model_id));
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetSubscribeInfo(0, model_id, true);
Status ret = manager.HandleProfModelUnsubscribeCommand(cmd);
profiling_manager.CleanSubscribeInfo();
}
} // namespace ge

+ 113
- 0
tests/ut/ge/profiling/ge_profiling_manager_unittest.cc View File

@@ -21,11 +21,16 @@
#include <map>
#include <string>

#include "graph/load/model_manager/davinci_model.h"

#define protected public
#define private public
#include "common/profiling/profiling_manager.h"
#include "graph/ge_local_context.h"
#include "inc/framework/common/profiling/ge_profiling.h"
#include "graph/manager/graph_manager.h"
#include "graph/ops_stub.h"
#include "inc/framework/omg/omg_inner_types.h"
#undef protected
#undef private

@@ -43,6 +48,23 @@ int32_t ReporterCallback(uint32_t moduleId, uint32_t type, void *data, uint32_t
return -1;
}

void CreateGraph(Graph &graph) {
TensorDesc desc(ge::Shape({1, 3, 224, 224}));
uint32_t size = desc.GetShape().GetShapeSize();
desc.SetSize(size);
auto data = op::Data("Data").set_attr_index(0);
data.update_input_desc_data(desc);
data.update_output_desc_out(desc);

auto flatten = op::Flatten("Flatten").set_input_x(data, data.name_out_out());

std::vector<Operator> inputs{data};
std::vector<Operator> outputs{flatten};
std::vector<Operator> targets{flatten};
// Graph graph("test_graph");
graph.SetInputs(inputs).SetOutputs(outputs).SetTargets(targets);
}

TEST_F(UtestGeProfilinganager, init_success) {
setenv("PROFILING_MODE", "true", true);
Options options;
@@ -133,3 +155,94 @@ TEST_F(UtestGeProfilinganager, set_step_info_failed) {
Status ret = ProfSetStepInfo(index_id, 1, stream);
EXPECT_EQ(ret, ge::FAILED);
}

TEST_F(UtestGeProfilinganager, get_device_from_graph) {
GraphId graph_id = 1;
uint32_t device_id = 0;
GraphManager graph_manager;
GraphNodePtr graph_node = MakeShared<ge::GraphNode>(graph_id);
graph_manager.AddGraphNode(graph_id, graph_node);
graph_manager.SetAddGraphCondition(graph_id, 2);
Graph graph("test_graph");
CreateGraph(graph);
std::map<std::string, std::string> options;
OmgContext context;
Status ret = graph_manager.AddGraph(graph_id, graph, options, context);
EXPECT_EQ(ret, ge::SUCCESS);
ret = ProfGetDeviceFormGraphId(graph_id, device_id);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST_F(UtestGeProfilinganager, handle_subscribe_info) {
ProfCommandHandleType prof_type = kProfCommandhandleModelSubscribe;
ProfCommandHandleData prof_data;
prof_data.profSwitch = 0;
prof_data.modelId = 1;
domi::GetContext().train_flag = true;
auto prof_ptr = std::make_shared<ProfCommandHandleData>(prof_data);
Status ret = ProfCommandHandle(prof_type, static_cast<void *>(prof_ptr.get()), sizeof(prof_data));
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST_F(UtestGeProfilinganager, handle_unsubscribe_info) {
ProfCommandHandleType prof_type = kProfCommandhandleModelUnsubscribe;
ProfCommandHandleData prof_data;
prof_data.profSwitch = 0;
prof_data.modelId = 1;
domi::GetContext().train_flag = true;
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetSubscribeInfo(0, 1, true);
auto prof_ptr = std::make_shared<ProfCommandHandleData>(prof_data);
Status ret = ProfCommandHandle(prof_type, static_cast<void *>(prof_ptr.get()), sizeof(prof_data));
profiling_manager.CleanSubscribeInfo();
}

TEST_F(UtestGeProfilinganager, set_subscribe_info) {
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetSubscribeInfo(0, 1, true);
const auto &subInfo = profiling_manager.GetSubscribeInfo();
EXPECT_EQ(subInfo.prof_switch, 0);
EXPECT_EQ(subInfo.graph_id, 1);
EXPECT_EQ(subInfo.is_subscribe, true);
}

TEST_F(UtestGeProfilinganager, clean_subscribe_info) {
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.CleanSubscribeInfo();
const auto &subInfo = profiling_manager.GetSubscribeInfo();
EXPECT_EQ(subInfo.prof_switch, 0);
EXPECT_EQ(subInfo.graph_id, 0);
EXPECT_EQ(subInfo.is_subscribe, false);
}

TEST_F(UtestGeProfilinganager, get_model_id_success) {
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetGraphIdToModelMap(0, 1);
uint32_t model_id = 0;
Status ret = profiling_manager.GetModelIdFromGraph(0, model_id);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST_F(UtestGeProfilinganager, get_model_id_failed) {
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetGraphIdToModelMap(0, 1);
uint32_t model_id = 0;
Status ret = profiling_manager.GetModelIdFromGraph(10, model_id);
EXPECT_EQ(ret, ge::FAILED);
}

TEST_F(UtestGeProfilinganager, get_device_id_success) {
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetGraphIdToDeviceMap(0, 1);
uint32_t device_id = 0;
Status ret = profiling_manager.GetDeviceIdFromGraph(0, device_id);
EXPECT_EQ(ret, ge::SUCCESS);
}

TEST_F(UtestGeProfilinganager, get_device_id_failed) {
auto &profiling_manager = ge::ProfilingManager::Instance();
profiling_manager.SetGraphIdToDeviceMap(0, 1);
uint32_t device_id = 0;
Status ret = profiling_manager.GetDeviceIdFromGraph(10, device_id);
EXPECT_EQ(ret, ge::FAILED);
}

Loading…
Cancel
Save