Browse Source

For dynamic shape complie.

pull/656/head
unknown 4 years ago
parent
commit
8e28cacc94
2 changed files with 62 additions and 10 deletions
  1. +60
    -8
      ge/graph/build/graph_builder.cc
  2. +2
    -2
      ge/graph/build/model_builder.h

+ 60
- 8
ge/graph/build/graph_builder.cc View File

@@ -14,6 +14,7 @@
* limitations under the License.
*/

#include "graph/build/memory/graph_mem_assigner.h"
#include "graph/build/graph_builder.h"
#include "common/ge/ge_util.h"
#include "common/helper/model_helper.h"
@@ -272,18 +273,65 @@ Status GraphBuilder::BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, std::v

Status GraphBuilder::BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr,
uint64_t session_id) {
ModelPtr model_ptr = MakeShared<ge::Model>();
if (model_ptr == nullptr) {
return MEMALLOC_FAILED;
}

Graph2SubGraphInfoList subgraph_map;
ge::ModelBuilder builder(session_id, com_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_);
GE_CHK_STATUS_RET(builder.RreBuildModel(), "Failed to pre build model.");
GELOGI("Begin to build unknown shape graph[%s].", comp_graph->GetName().c_str());
GE_TIMESTAMP_START(CalcOpParam);
GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Graph[%s] builder CalcOpParam() return fail.",
comp_graph->GetName().c_str());
GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam");
GE_DUMP(comp_graph, "AfterCalcOpParam");
Graph2SubGraphInfoList subgraph_map;
ge::ModelBuilder builder(session_id, comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_);
ModelPtr model_ptr = MakeShared<ge::Model>();
if (model_ptr == nullptr) {
return MEMALLOC_FAILED;


for (auto &node : comp_graph->GetsDirectNode()) {
GE_CHECK_NOTNULL(node);
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
auto num_inputs = op_desc->GetInputsSize();
std::vector<int64_t> input_offsets(num_inputs, 0);
int valid_input_index = -1;
for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
auto in_anchor = node->GetInDataAnchor(i);
auto peer_out_anchor = in_anchor->GetPeerOutAnchor();
if (peer_out_anchor == nullptr) {
continue;
}
++valid_input_index;
auto peer_node = peer_out_anchor->GetOwnerNode();
if (peer_node == nullptr) {
continue;
}
if (peer_node->GetType() != CONSTANT) {
continue;
}

std::vector<GeTensorPtr> weight = OpDescUtils::MutableWeights(peer_node);
if (weight.empty()) {
GELOGE(FAILED, "weights size of node %s is empty", node->GetName().c_str());
return FAILED;
}
GeTensorPtr weight = weights[0];
GE_CHECK_NOTNULL(weight);
int64_t input_offset = 0;
(void) TensorUtils::GetDataOffset(weight->MutableTensorDesc(), input_offset);
input_offsetp[valid_input_index] = input_offset;
GELOGD("[%s] input[%u] is const, offset = %ld", node->GetName().c_str(), valid_input_index, input_offset);
}

op_desc->SetInputOffset(input_offsets);
std::vector<int64_t> output_offsets(op_desc->GetOutputsSize(), 0);
op_desc->SetOutputOffset(output_offsets);
}
GE_CHK_STATUS_RET(builder.MergeWeights(), "Failed to merge weights.");

GE_TIMESTAMP_START(BuildModelForGetDynShapeTask);
GE_CHK_STATUS_RET(builder.BuildModelForGetDynShapeTask(*model_ptr),
"Graph[%s] builder BuildModelForGetDynShapeTask() return fail.", comp_graph->GetName().c_str());
@@ -375,10 +423,14 @@ Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph,
op_desc->GetName().c_str());
}
}
//
for (auto &sub_graph : comp_graph->GetAllSubgraphs()) {
auto all_graphs = comp_graph->GetAllSubgraphs();
if (all_graphs.empty()) {
all_graphs.push_back(comp_graph);
}
for (auto &sub_graph : all_graphs) {
// exclude functional subgraph in known subgraph
if (sub_graph->GetParentGraph() != comp_graph && !sub_graph->GetParentGraph()->GetGraphUnknownFlag()) {
if (sub_graph->GetParentGraph() != nullptr && sub_graph->GetParentGraph() != comp_graph &&
!sub_graph->GetParentGraph()->GetGraphUnknownFlag()) {
continue;
}



+ 2
- 2
ge/graph/build/model_builder.h View File

@@ -55,13 +55,13 @@ class ModelBuilder {

ge::Buffer GetWeightBuffer() const;

Status MergeWeights();

protected:
void AddNodeInputProperty();

void ClearOriginalFormat();

Status MergeWeights();

private:
bool SetInputConst(const OpDescPtr &op_desc, const NodePtr &src_node, size_t index, vector<bool> &is_input_const);



Loading…
Cancel
Save