|
@@ -609,7 +609,7 @@ Status ModifyDataNetOutputFormatAndShape(OpDescPtr &op_desc, uint32_t index, For |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &switchn_node) { |
|
|
|
|
|
|
|
|
Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, NodePtr &mbatch_node, int32_t &index) { |
|
|
is_dynamic_batch = false; |
|
|
is_dynamic_batch = false; |
|
|
std::string related_node_name; |
|
|
std::string related_node_name; |
|
|
if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { |
|
|
if (AttrUtils::GetStr(data_node->GetOpDesc(), kMbatchSwitchnName, related_node_name)) { |
|
@@ -620,13 +620,17 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node |
|
|
data_node->GetName().c_str()); |
|
|
data_node->GetName().c_str()); |
|
|
return INTERNAL_ERROR; |
|
|
return INTERNAL_ERROR; |
|
|
} |
|
|
} |
|
|
for (const NodePtr &next_node : data_node->GetOutNodes()) { |
|
|
|
|
|
if (next_node->GetName() == related_node_name) { |
|
|
|
|
|
switchn_node = next_node; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto out_data_nodes_anchors = data_node->GetOutDataNodesAndAnchors(); |
|
|
|
|
|
for (const auto &out_data_node_anchor : out_data_nodes_anchors) { |
|
|
|
|
|
if (out_data_node_anchor.first->GetName() == related_node_name) { |
|
|
|
|
|
mbatch_node = out_data_node_anchor.first; |
|
|
|
|
|
index = out_data_node_anchor.second->GetIdx(); |
|
|
break; |
|
|
break; |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
if (switchn_node == nullptr) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (mbatch_node == nullptr) { |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
ErrorManager::GetInstance().ATCReportErrMessage( |
|
|
"E15002", {"opname", "value", "reason"}, |
|
|
"E15002", {"opname", "value", "reason"}, |
|
|
{data_node->GetName(), related_node_name, "but can not find it on the graph"}); |
|
|
{data_node->GetName(), related_node_name, "but can not find it on the graph"}); |
|
@@ -679,7 +683,7 @@ Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { |
|
|
// In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole |
|
|
// In the dynamic shape process, transnode insertion by FE is advanced to the stage of whole |
|
|
// graph optimization, GE only sets the final data_type/format/shape information for variable, |
|
|
// graph optimization, GE only sets the final data_type/format/shape information for variable, |
|
|
// data and netoutput, and no longer inserts the transnode. |
|
|
// data and netoutput, and no longer inserts the transnode. |
|
|
Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node, DataType &dt_set) { |
|
|
|
|
|
|
|
|
Status ProcessInputDtDynShape(NodePtr &node_ptr, NodePtr &switchn_node, DataType &dt_set) { |
|
|
GE_CHECK_NOTNULL(node_ptr); |
|
|
GE_CHECK_NOTNULL(node_ptr); |
|
|
auto op_desc = node_ptr->GetOpDesc(); |
|
|
auto op_desc = node_ptr->GetOpDesc(); |
|
|
GE_CHECK_NOTNULL(op_desc); |
|
|
GE_CHECK_NOTNULL(op_desc); |
|
@@ -712,19 +716,84 @@ Status ProcessInputDtDynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr |
|
|
GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str()); |
|
|
GELOGI("[Process][InputDynShape] Set input and output size of node [%s] success.", node_ptr->GetName().c_str()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
if (is_dynamic_batch) { |
|
|
|
|
|
GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); |
|
|
|
|
|
auto switchn_op_desc = switchn_node->GetOpDesc(); |
|
|
|
|
|
GE_CHECK_NOTNULL(switchn_op_desc); |
|
|
|
|
|
auto switchn_input = switchn_op_desc->MutableInputDesc(0); |
|
|
|
|
|
GE_CHECK_NOTNULL(switchn_input); |
|
|
|
|
|
switchn_input->SetDataType(dt_set); |
|
|
|
|
|
for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { |
|
|
|
|
|
const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); |
|
|
|
|
|
GE_CHECK_NOTNULL(switchn_output); |
|
|
|
|
|
switchn_output->SetDataType(dt_set); |
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status UpdateInputOutputDataType(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { |
|
|
|
|
|
auto mbatch_desc = mbatch_node->GetOpDesc(); |
|
|
|
|
|
GE_CHECK_NOTNULL(mbatch_desc); |
|
|
|
|
|
auto mbatch_input = mbatch_desc->MutableInputDesc(index); |
|
|
|
|
|
GE_CHECK_NOTNULL(mbatch_input); |
|
|
|
|
|
mbatch_input->SetDataType(dt_set); |
|
|
|
|
|
|
|
|
|
|
|
if (mbatch_node->GetType() == SWITCHN) { |
|
|
|
|
|
for (uint32_t i = 0; i < mbatch_node->GetAllOutDataAnchorsSize(); ++i) { |
|
|
|
|
|
const GeTensorDescPtr &mbatch_output = mbatch_desc->MutableOutputDesc(i); |
|
|
|
|
|
GE_CHECK_NOTNULL(mbatch_output); |
|
|
|
|
|
mbatch_output->SetDataType(dt_set); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GELOGD("Update input and output data type of node[name: %s, type: %s, input index: %d] to %s.", |
|
|
|
|
|
mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(), index, |
|
|
|
|
|
TypeUtils::DataTypeToSerialString(dt_set).c_str()); |
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status UpdateSubgraphDataOfCase(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { |
|
|
|
|
|
if (mbatch_node->GetType() != CASE) { |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto subgraphs = NodeUtils::GetAllSubgraphs(*mbatch_node); |
|
|
|
|
|
for (const auto &subgraph : subgraphs) { |
|
|
|
|
|
GE_CHECK_NOTNULL(subgraph); |
|
|
|
|
|
for (auto &sub_node : subgraph->GetDirectNode()) { |
|
|
|
|
|
GE_CHECK_NOTNULL(sub_node); |
|
|
|
|
|
if (sub_node->GetType() != DATA) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto data_desc = sub_node->GetOpDesc(); |
|
|
|
|
|
GE_CHECK_NOTNULL(data_desc); |
|
|
|
|
|
int32_t parent_node_index = 0; |
|
|
|
|
|
if (!AttrUtils::GetInt(data_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index) || |
|
|
|
|
|
(parent_node_index != index)) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto data_input = data_desc->MutableInputDesc(0); |
|
|
|
|
|
GE_CHECK_NOTNULL(data_input); |
|
|
|
|
|
data_input->SetDataType(dt_set); |
|
|
|
|
|
auto data_output = data_desc->MutableOutputDesc(0); |
|
|
|
|
|
GE_CHECK_NOTNULL(data_output); |
|
|
|
|
|
data_output->SetDataType(dt_set); |
|
|
|
|
|
GELOGD("Update input and output data type of node[name: %s, type: %s, parent_node_index: %d] in subgraph %s " |
|
|
|
|
|
"to %s.", data_desc->GetName().c_str(), data_desc->GetType().c_str(), parent_node_index, |
|
|
|
|
|
subgraph->GetName().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str()); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status ProcessMbatchScene(NodePtr &mbatch_node, DataType &dt_set, int32_t index) { |
|
|
|
|
|
GELOGI("The node [%s] dtype set fp16.", mbatch_node->GetName().c_str()); |
|
|
|
|
|
if (UpdateInputOutputDataType(mbatch_node, dt_set, index) != SUCCESS) { |
|
|
|
|
|
GELOGE(FAILED, "Update input and output data type of node[name: %s, type: %s] to %s failed.", |
|
|
|
|
|
mbatch_node->GetName().c_str(), mbatch_node->GetType().c_str(), |
|
|
|
|
|
TypeUtils::DataTypeToSerialString(dt_set).c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (UpdateSubgraphDataOfCase(mbatch_node, dt_set, index) != SUCCESS) { |
|
|
|
|
|
GELOGE(FAILED, "Update input and output data type of Data node[parent_node_index: %d] in subgraphs of " |
|
|
|
|
|
"node[name: %s, type: %s] to %s failed.", index, mbatch_node->GetName().c_str(), |
|
|
|
|
|
mbatch_node->GetType().c_str(), TypeUtils::DataTypeToSerialString(dt_set).c_str()); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
@@ -785,21 +854,27 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { |
|
|
DataType dt_set = TypeUtils::SerialStringToDataType(set_dt_str); |
|
|
DataType dt_set = TypeUtils::SerialStringToDataType(set_dt_str); |
|
|
GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); |
|
|
GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); |
|
|
bool is_dynamic_batch = false; |
|
|
bool is_dynamic_batch = false; |
|
|
NodePtr switchn_node = nullptr; |
|
|
|
|
|
if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, switchn_node)) { |
|
|
|
|
|
|
|
|
NodePtr mbatch_node = nullptr; |
|
|
|
|
|
int32_t index = 0; |
|
|
|
|
|
if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, mbatch_node, index)) { |
|
|
GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); |
|
|
GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
if (ProcessInputDtDynShape(node_ptr, is_dynamic_batch, switchn_node, dt_set) != SUCCESS) { |
|
|
|
|
|
|
|
|
if (ProcessInputDtDynShape(node_ptr, mbatch_node, dt_set) != SUCCESS) { |
|
|
GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); |
|
|
GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|
|
|
|
if (is_dynamic_batch && ProcessMbatchScene(mbatch_node, dt_set, index) != SUCCESS) { |
|
|
|
|
|
GELOGE(INTERNAL_ERROR, "ProcessMbatchScene failed"); |
|
|
|
|
|
return FAILED; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
// check if need to set format |
|
|
// check if need to set format |
|
|
string set_format; |
|
|
string set_format; |
|
|
bool ret = ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), ATTR_ATC_USER_DEFINE_FORMAT, set_format); |
|
|
bool ret = ge::AttrUtils::GetStr(node_ptr->GetOpDesc(), ATTR_ATC_USER_DEFINE_FORMAT, set_format); |
|
|
if (ret && (!set_format.empty()) && TypeUtils::SerialStringToFormat(set_format) == FORMAT_NC1HWC0) { |
|
|
if (ret && (!set_format.empty()) && TypeUtils::SerialStringToFormat(set_format) == FORMAT_NC1HWC0) { |
|
|
GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); |
|
|
GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); |
|
|
if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { |
|
|
|
|
|
|
|
|
if (ProcessInputNC1HWC0DynShape(node_ptr, is_dynamic_batch, mbatch_node) != SUCCESS) { |
|
|
GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); |
|
|
GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); |
|
|
return FAILED; |
|
|
return FAILED; |
|
|
} |
|
|
} |
|
|