Browse Source

!1266 fix bug of dynamic shape load error

From: @wan_xuelei
Reviewed-by: @xchu42
Signed-off-by:
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
4fa93a81ea
3 changed files with 23 additions and 3 deletions
  1. +12
    -3
      ge/graph/load/model_manager/model_manager.cc
  2. +2
    -0
      ge/graph/load/model_manager/model_manager.h
  3. +9
    -0
      tests/ut/ge/graph/load/model_manager_unittest.cc

+ 12
- 3
ge/graph/load/model_manager/model_manager.cc View File

@@ -286,6 +286,17 @@ ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const string
return SUCCESS;
}

bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) {
auto root_graph = ge_root_model.GetRootGraph();
if (root_graph == nullptr) {
GELOGE(FAILED, "no model on root model");
return false;
}
bool is_shape_unknown = root_graph->GetGraphUnknownFlag();
bool is_dsp_partitioned_graph = false;
(void)AttrUtils::GetBool(root_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dsp_partitioned_graph);
return is_shape_unknown || is_dsp_partitioned_graph || GetContext().GetHostExecFlag();
}
///
/// @ingroup domi_ome
/// @brief load model online
@@ -299,9 +310,7 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge
}
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel();
string model_name = "";
bool is_shape_unknown = ge_root_model->GetRootGraph()->GetGraphUnknownFlag();
// if multi subgraph is known, do hybrid load process
if (is_shape_unknown || GetContext().GetHostExecFlag() || (name_to_model.size() > 1)) {
if (IsNeedHybridLoad(*ge_root_model)) {
return DoLoadHybridModelOnline(model_id, model_name, ge_root_model, listener);
}



+ 2
- 0
ge/graph/load/model_manager/model_manager.h View File

@@ -294,6 +294,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {
std::vector<InputOutputDims> &output_dims);

bool IsDynamicShape(uint32_t model_id);
bool IsNeedHybridLoad(ge::GeRootModel &ge_root_model);
ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info);

ge::Status EnableExceptionDump(const std::map<string, string> &options);
@@ -340,6 +341,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager {

void GenModelId(uint32_t *id);


std::map<uint32_t, std::shared_ptr<DavinciModel>> model_map_;
std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_;
std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_;


+ 9
- 0
tests/ut/ge/graph/load/model_manager_unittest.cc View File

@@ -151,6 +151,15 @@ class DModelListener : public ModelListener {
uint32_t OnComputeDone(uint32_t model_id, uint32_t data_index, uint32_t resultCode) { return 0; }
};

TEST_F(UtestModelManagerModelManager, case_is_need_hybrid_load) {
ModelManager mm;
uint32_t model_id = 0;
ComputeGraphPtr root_graph = std::make_shared<ComputeGraph>("graph");
ge::GeRootModel model;
EXPECT_EQ(mm.IsNeedHybridLoad(model), false);
model.SetRootGraph(root_graph);
EXPECT_EQ(mm.IsNeedHybridLoad(model), false);
}

TEST_F(UtestModelManagerModelManager, case_load_incorrect_param) {
ModelManager mm;


Loading…
Cancel
Save