|
|
|
@@ -22,9 +22,11 @@ |
|
|
|
#include "graph/utils/node_utils.h" |
|
|
|
#include "graph/utils/tensor_utils.h" |
|
|
|
#include "graph/utils/type_utils.h" |
|
|
|
#include "graph/compute_graph.h" |
|
|
|
#include "ge_local_engine/ops_kernel_store/op/op_factory.h" |
|
|
|
#include "ge_local_engine/common/constant/constant.h" |
|
|
|
#include "register/ops_kernel_builder_registry.h" |
|
|
|
#include "framework/common/debug/log.h" |
|
|
|
|
|
|
|
namespace ge { |
|
|
|
namespace ge_local { |
|
|
|
@@ -34,6 +36,7 @@ namespace { |
|
|
|
const char *const kConstantOpType = "Constant"; |
|
|
|
const char *const kConstantOpAttrName = "value"; |
|
|
|
const char *const kDataOpType = "Data"; |
|
|
|
const char *const ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape"; |
|
|
|
} // namespace |
|
|
|
|
|
|
|
GeLocalOpsKernelBuilder::~GeLocalOpsKernelBuilder() { |
|
|
|
@@ -161,13 +164,24 @@ Status GeLocalOpsKernelBuilder::CalcConstantStrMemSize(const OpDescPtr &op_desc, |
|
|
|
} |
|
|
|
|
|
|
|
Status GeLocalOpsKernelBuilder::GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) { |
|
|
|
bool is_shape_unknown = false; |
|
|
|
if (NodeUtils::GetNodeUnknownShapeStatus(node, is_shape_unknown) == GRAPH_SUCCESS) { |
|
|
|
if (is_shape_unknown) { |
|
|
|
GELOGI("op:%s is unknown shape, does not need to generate task", |
|
|
|
node.GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
|
bool is_in_unknown_subgraph = false; |
|
|
|
bool forced_unknown = false; |
|
|
|
for (const auto &node : node.GetOwnerComputeGraph()->GetDirectNode()) { |
|
|
|
GE_CHK_GRAPH_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_in_unknown_subgraph), |
|
|
|
"[Get][ShapeStatus] of node[%s] failed!", node->GetName().c_str()); |
|
|
|
if (is_in_unknown_subgraph) { |
|
|
|
break; |
|
|
|
} |
|
|
|
if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, forced_unknown) && forced_unknown) { |
|
|
|
GELOGD("node %s was marked as unknown shape.", node->GetName().c_str()); |
|
|
|
is_in_unknown_subgraph = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
if (is_in_unknown_subgraph) { |
|
|
|
GELOGI("op:%s is in unknown shape subgraph, does not need to generate task", |
|
|
|
node.GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
string name = node.GetName(); |
|
|
|
string type = node.GetType(); |
|
|
|
|