|
|
@@ -160,7 +160,7 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, |
|
|
const std::vector<tensor::TensorPtr> &input_tensors, |
|
|
const std::vector<tensor::TensorPtr> &input_tensors, |
|
|
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) { |
|
|
std::map<tensor::TensorPtr, session::KernelWithIndex> *tensor_to_node) { |
|
|
auto &node = node_output_pair.first; |
|
|
auto &node = node_output_pair.first; |
|
|
auto &output_index = node_output_pair.second; |
|
|
|
|
|
|
|
|
int output_index = SizeToInt(node_output_pair.second); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(node); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; |
|
|
MS_LOG(INFO) << "Create tensor for output[" << node->DebugString() << "] index[" << output_index << "]"; |
|
|
@@ -172,25 +172,24 @@ BaseRef CreateNodeOutputTensor(const session::KernelWithIndex &node_output_pair, |
|
|
if (type_id == kTypeUnknown) { |
|
|
if (type_id == kTypeUnknown) { |
|
|
type_id = AnfAlgo::GetOutputInferDataType(node, output_index); |
|
|
type_id = AnfAlgo::GetOutputInferDataType(node, output_index); |
|
|
} |
|
|
} |
|
|
tensor::TensorPtr tensor = nullptr; |
|
|
|
|
|
std::vector<int64_t> temp_shape; |
|
|
std::vector<int64_t> temp_shape; |
|
|
if (graph->IsUniqueTargetInternalOutput(node, output_index)) { |
|
|
|
|
|
temp_shape.emplace_back(1); |
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); |
|
|
|
|
|
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); |
|
|
|
|
|
tensor->set_sync_status(kNoNeedSync); |
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
|
auto shape = AnfAlgo::GetOutputInferShape(node, output_index); |
|
|
|
|
|
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); |
|
|
|
|
|
tensor::TensorPtr tensor; |
|
|
|
|
|
bool is_internal_output = graph->IsInternalOutput(node, output_index); |
|
|
|
|
|
if (is_internal_output) { |
|
|
tensor = graph->GetInternalOutputTensor(node, output_index); |
|
|
tensor = graph->GetInternalOutputTensor(node, output_index); |
|
|
if (tensor == nullptr) { |
|
|
if (tensor == nullptr) { |
|
|
auto shape = AnfAlgo::GetOutputInferShape(node, output_index); |
|
|
|
|
|
(void)std::copy(shape.begin(), shape.end(), std::back_inserter(temp_shape)); |
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); |
|
|
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); |
|
|
bool is_internal_output = graph->IsInternalOutput(node, output_index); |
|
|
|
|
|
if (is_internal_output) { |
|
|
|
|
|
graph->AddInternalOutputTensor(node, output_index, tensor); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
graph->AddInternalOutputTensor(node, output_index, tensor); |
|
|
} |
|
|
} |
|
|
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); |
|
|
|
|
|
|
|
|
} else { |
|
|
|
|
|
tensor = std::make_shared<tensor::Tensor>(type_id, temp_shape); |
|
|
|
|
|
} |
|
|
|
|
|
tensor->set_padding_type(AnfAlgo::GetOutputReshapeType(node, output_index)); |
|
|
|
|
|
if (is_internal_output) { |
|
|
|
|
|
tensor->set_sync_status(kNoNeedSync); |
|
|
|
|
|
} else { |
|
|
// if in pynative mode,data only copied to host when user want to print data |
|
|
// if in pynative mode,data only copied to host when user want to print data |
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
auto ms_context = MsContext::GetInstance(); |
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
MS_EXCEPTION_IF_NULL(ms_context); |
|
|
@@ -682,16 +681,20 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf |
|
|
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter |
|
|
// if parameter's python parameter has been exist a backend parameter, reuse the exist parameter |
|
|
if (param_value != nullptr) { |
|
|
if (param_value != nullptr) { |
|
|
new_parameter = param_value->parameter(); |
|
|
new_parameter = param_value->parameter(); |
|
|
if (new_parameter == nullptr) { |
|
|
|
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
|
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
|
|
|
param_value->set_parameter(new_parameter); |
|
|
|
|
|
} |
|
|
|
|
|
} else { |
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
if (new_parameter == nullptr) { |
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
TraceGuard trace_guard(std::make_shared<TraceCopy>(anf->debug_info())); |
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
new_parameter = graph->NewParameter(anf->cast<ParameterPtr>()); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto input_node_iter = partial_parameters_map_.find(anf); |
|
|
|
|
|
if (input_node_iter != partial_parameters_map_.end()) { |
|
|
|
|
|
InitInternalOutputParameter(input_node_iter->second, new_parameter); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (param_value != nullptr) { |
|
|
|
|
|
param_value->set_parameter(new_parameter); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
new_parameter->IncreaseUsedGraphCount(); |
|
|
new_parameter->IncreaseUsedGraphCount(); |
|
|
graph_inputs->push_back(new_parameter); |
|
|
graph_inputs->push_back(new_parameter); |
|
|
valid_inputs->push_back(true); |
|
|
valid_inputs->push_back(true); |
|
|
@@ -1771,10 +1774,11 @@ bool CNodeFirstInputIsPrimitive(const AnfNodePtr &node) { |
|
|
|
|
|
|
|
|
std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager, |
|
|
std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_graph_manager, |
|
|
const AnfNodePtr &front_node) { |
|
|
const AnfNodePtr &front_node) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(front_func_graph_manager); |
|
|
auto &users = front_func_graph_manager->node_users()[front_node]; |
|
|
auto &users = front_func_graph_manager->node_users()[front_node]; |
|
|
std::vector<AnfNodePtr> result; |
|
|
std::vector<AnfNodePtr> result; |
|
|
for (auto &user : users) { |
|
|
for (auto &user : users) { |
|
|
if (IsPrimitiveCNode(user.first, prim::kPrimDepend)) { |
|
|
|
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimDepend)) { |
|
|
auto depend_cnode = user.first->cast<CNodePtr>(); |
|
|
auto depend_cnode = user.first->cast<CNodePtr>(); |
|
|
if (depend_cnode == nullptr) { |
|
|
if (depend_cnode == nullptr) { |
|
|
continue; |
|
|
continue; |
|
|
@@ -1784,9 +1788,12 @@ std::vector<AnfNodePtr> ExtendNodeUsers(const FuncGraphManagerPtr &front_func_gr |
|
|
} |
|
|
} |
|
|
auto res = ExtendNodeUsers(front_func_graph_manager, user.first); |
|
|
auto res = ExtendNodeUsers(front_func_graph_manager, user.first); |
|
|
result.insert(result.end(), res.begin(), res.end()); |
|
|
result.insert(result.end(), res.begin(), res.end()); |
|
|
continue; |
|
|
|
|
|
|
|
|
} else if (AnfAlgo::CheckPrimitiveType(user.first, prim::kPrimMakeTuple)) { |
|
|
|
|
|
auto res = ExtendNodeUsers(front_func_graph_manager, user.first); |
|
|
|
|
|
(void)result.insert(result.end(), res.begin(), res.end()); |
|
|
|
|
|
} else { |
|
|
|
|
|
(void)result.emplace_back(user.first); |
|
|
} |
|
|
} |
|
|
(void)result.emplace_back(user.first); |
|
|
|
|
|
} |
|
|
} |
|
|
return result; |
|
|
return result; |
|
|
} |
|
|
} |
|
|
@@ -1812,10 +1819,54 @@ AnfNodePtr GetSupportedInternalNode(const AnfNodePtr &front_node) { |
|
|
} |
|
|
} |
|
|
return nullptr; |
|
|
return nullptr; |
|
|
} |
|
|
} |
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
|
void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, |
|
|
|
|
|
const FuncGraphManagerPtr &front_func_graph_manager, |
|
|
|
|
|
const std::shared_ptr<KernelGraph> &backend_graph) { |
|
|
|
|
|
|
|
|
constexpr auto kMixTarget = "MixTarget"; |
|
|
|
|
|
constexpr auto kNoTarget = "NoTarget"; |
|
|
|
|
|
std::string SessionBasic::AddPartialParametersMap(const FuncGraphManagerPtr &front_func_graph_manager, |
|
|
|
|
|
const AnfNodePtr &partial_node) { |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_node); |
|
|
|
|
|
auto iter = partial_target_map_.find(partial_node); |
|
|
|
|
|
if (iter != partial_target_map_.end()) { |
|
|
|
|
|
return iter->second; |
|
|
|
|
|
} |
|
|
|
|
|
auto partial_cnode = partial_node->cast<CNodePtr>(); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_cnode); |
|
|
|
|
|
auto partial_graph = GetValueNode<FuncGraphPtr>(partial_cnode->input(kFirstDataInputIndex)); |
|
|
|
|
|
MS_EXCEPTION_IF_NULL(partial_graph); |
|
|
|
|
|
auto parameters = partial_graph->parameters(); |
|
|
|
|
|
auto partial_inputs = partial_cnode->inputs(); |
|
|
|
|
|
if (parameters.size() + 2 != partial_inputs.size()) { |
|
|
|
|
|
return kMixTarget; |
|
|
|
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < parameters.size(); ++i) { |
|
|
|
|
|
partial_parameters_map_[parameters[i]] = partial_inputs[2 + i]; |
|
|
|
|
|
} |
|
|
|
|
|
auto graph_nodes = TopoSort(partial_graph->get_return()); |
|
|
|
|
|
std::string graph_target = kNoTarget; |
|
|
|
|
|
for (auto &node : graph_nodes) { |
|
|
|
|
|
if (!node->isa<CNode>()) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
if (!AnfAlgo::IsRealKernel(node)) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
std::string cur_target = GetCNodeTarget(node); |
|
|
|
|
|
if (graph_target == kNoTarget) { |
|
|
|
|
|
graph_target = cur_target; |
|
|
|
|
|
} |
|
|
|
|
|
if (graph_target != cur_target) { |
|
|
|
|
|
graph_target = kMixTarget; |
|
|
|
|
|
break; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
(void)partial_target_map_.insert({partial_node, graph_target}); |
|
|
|
|
|
return graph_target; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void SessionBasic::HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr &backend_node, |
|
|
|
|
|
const FuncGraphManagerPtr &front_func_graph_manager, |
|
|
|
|
|
const std::shared_ptr<KernelGraph> &backend_graph) { |
|
|
auto front_node = GetSupportedInternalNode(input_front_node); |
|
|
auto front_node = GetSupportedInternalNode(input_front_node); |
|
|
if (front_node == nullptr) { |
|
|
if (front_node == nullptr) { |
|
|
return; |
|
|
return; |
|
|
@@ -1839,7 +1890,14 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr & |
|
|
} |
|
|
} |
|
|
if (internal_output) { |
|
|
if (internal_output) { |
|
|
auto users = ExtendNodeUsers(front_func_graph_manager, front_node); |
|
|
auto users = ExtendNodeUsers(front_func_graph_manager, front_node); |
|
|
for (auto user : users) { |
|
|
|
|
|
|
|
|
for (auto &user : users) { |
|
|
|
|
|
if (AnfAlgo::CheckPrimitiveType(user, prim::kPrimPartial)) { |
|
|
|
|
|
auto partial_target = AddPartialParametersMap(front_func_graph_manager, user); |
|
|
|
|
|
if (partial_target != kNoTarget && partial_target != kernel_target) { |
|
|
|
|
|
unique_target = false; |
|
|
|
|
|
} |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
if (!CNodeFirstInputIsPrimitive(user)) { |
|
|
if (!CNodeFirstInputIsPrimitive(user)) { |
|
|
internal_output = false; |
|
|
internal_output = false; |
|
|
break; |
|
|
break; |
|
|
@@ -1859,7 +1917,6 @@ void HandleInternalOutput(const AnfNodePtr &input_front_node, const AnfNodePtr & |
|
|
backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target); |
|
|
backend_graph->AddInternalOutput(front_node, backend_real_kernel, backend_real_kernel_pair.second, unique_target); |
|
|
} |
|
|
} |
|
|
} |
|
|
} |
|
|
} // namespace |
|
|
|
|
|
|
|
|
|
|
|
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { |
|
|
CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std::shared_ptr<KernelGraph> &graph) { |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
MS_EXCEPTION_IF_NULL(graph); |
|
|
@@ -1868,7 +1925,7 @@ CNodePtr SessionBasic::ConstructOutput(const AnfNodePtrList &outputs, const std: |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
MS_EXCEPTION_IF_NULL(output); |
|
|
MS_LOG(INFO) << "Output:" << output->DebugString(); |
|
|
MS_LOG(INFO) << "Output:" << output->DebugString(); |
|
|
} |
|
|
} |
|
|
auto FindEqu = [graph, outputs](const AnfNodePtr &out) -> AnfNodePtr { |
|
|
|
|
|
|
|
|
auto FindEqu = [graph, outputs, this](const AnfNodePtr &out) -> AnfNodePtr { |
|
|
auto backend_anf = graph->GetBackendAnfByFrontAnf(out); |
|
|
auto backend_anf = graph->GetBackendAnfByFrontAnf(out); |
|
|
if (backend_anf != nullptr) { |
|
|
if (backend_anf != nullptr) { |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
auto context_ptr = MsContext::GetInstance(); |
|
|
|