| @@ -327,6 +327,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr<session::Kerne | |||||
| auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | auto ir_fusion_pm = std::make_shared<PassManager>("ir_fusion_pm"); | ||||
| ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ir_fusion_pm->AddPass(std::make_shared<SplitFission>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | ir_fusion_pm->AddPass(std::make_shared<BnSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<BnGradSplit>()); | |||||
| ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ir_fusion_pm->AddPass(std::make_shared<LayerNormGradSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | ir_fusion_pm->AddPass(std::make_shared<TopKSplit>()); | ||||
| ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>()); | ir_fusion_pm->AddPass(std::make_shared<LinSpaceFission>()); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <list> | #include <list> | ||||
| #include "base/core_ops.h" | #include "base/core_ops.h" | ||||
| #include "base/base_ref_utils.h" | |||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "common/trans.h" | #include "common/trans.h" | ||||
| @@ -123,6 +124,284 @@ void InsertMakeTupleForOutput(NotNull<KernelGraphPtr> root_graph) { | |||||
| {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()}); | {NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())), root_graph->output()}); | ||||
| root_graph->set_output(make_tuple); | root_graph->set_output(make_tuple); | ||||
| } | } | ||||
| BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<size_t> &indexes, | |||||
| std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) { | |||||
| auto &node = node_output_pair.first; | |||||
| MS_EXCEPTION_IF_NULL(node); | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_EXCEPTION_IF_NULL(output_indexes); | |||||
| MS_LOG(INFO) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second | |||||
| << "]"; | |||||
| // if node is a value node, no need sync addr from device to host | |||||
| if (node->isa<ValueNode>()) { | |||||
| auto value_node = node->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| return value_node->value(); | |||||
| } | |||||
| if (node->isa<Parameter>()) { | |||||
| for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { | |||||
| if (input_idx >= input_tensors.size()) { | |||||
| MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); | |||||
| } | |||||
| if (graph->inputs()[input_idx] == node) { | |||||
| return input_tensors[input_idx]; | |||||
| } | |||||
| } | |||||
| MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr"; | |||||
| } | |||||
| (*output_indexes)[node_output_pair] = indexes; | |||||
| BaseRef output_placeholder = std::make_shared<BaseRef>(); | |||||
| return output_placeholder; | |||||
| } | |||||
| BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| const std::vector<size_t> &indexes, | |||||
| std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) { | |||||
| MS_EXCEPTION_IF_NULL(anf); | |||||
| MS_EXCEPTION_IF_NULL(output_indexes); | |||||
| MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]"; | |||||
| auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); | |||||
| MS_EXCEPTION_IF_NULL(item_with_index.first); | |||||
| MS_LOG(INFO) << "Create placeholder for output after visit:" << item_with_index.first->DebugString(); | |||||
| // special handle for maketuple | |||||
| if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { | |||||
| auto cnode = item_with_index.first->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| VectorRef ret; | |||||
| for (size_t i = 1; i < cnode->inputs().size(); ++i) { | |||||
| std::vector<size_t> cur_index = indexes; | |||||
| cur_index.emplace_back(i - 1); | |||||
| auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes); | |||||
| ret.push_back(out); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| // if is graph return nothing ,the function should return a null anylist | |||||
| size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first); | |||||
| if (size == 0) { | |||||
| return VectorRef(); | |||||
| } | |||||
| return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes); | |||||
| } | |||||
| void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| VectorRef *outputs, std::map<KernelWithIndex, std::vector<size_t>> *output_indexes) { | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| MS_EXCEPTION_IF_NULL(outputs); | |||||
| MS_EXCEPTION_IF_NULL(output_indexes); | |||||
| auto anf_outputs = kernel_graph->outputs(); | |||||
| size_t index = 0; | |||||
| for (auto &item : anf_outputs) { | |||||
| MS_EXCEPTION_IF_NULL(item); | |||||
| MS_LOG(INFO) << "Create node output placeholder[" << item->DebugString() << "]"; | |||||
| std::vector<size_t> indexes{index++}; | |||||
| outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes)); | |||||
| } | |||||
| } | |||||
| void GetRefCount(KernelGraph *graph, std::map<KernelWithIndex, size_t> *ref_count) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| for (const auto &kernel : graph->execution_order()) { | |||||
| for (size_t i = 1; i < kernel->inputs().size(); i += 1) { | |||||
| const auto &input = kernel->input(i); | |||||
| auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); | |||||
| const auto &node = kernel_with_index.first; | |||||
| if (node->isa<CNode>()) { | |||||
| (*ref_count)[kernel_with_index] += 1; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void GetParameterIndex(KernelGraph *graph, const std::vector<tensor::TensorPtr> &inputs, | |||||
| std::map<AnfNodePtr, size_t> *parameter_index) { | |||||
| size_t index = 0; | |||||
| for (const auto &input_node : graph->inputs()) { | |||||
| auto params = AnfAlgo::GetAllOutput(input_node); | |||||
| for (const auto ¶m : params) { | |||||
| if (index >= inputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index | |||||
| << ", input size: " << inputs.size(); | |||||
| } | |||||
| const auto &input = inputs[index]; | |||||
| // Check shape of input and parameter | |||||
| const auto &input_shape = input->shape(); | |||||
| const auto ¶m_shape = AnfAlgo::GetOutputInferShape(param, 0); | |||||
| if (input_shape.size() != param_shape.size()) { | |||||
| MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index | |||||
| << ", parameter: " << param->fullname_with_scope(); | |||||
| } | |||||
| for (size_t i = 0; i < input_shape.size(); i += 1) { | |||||
| if (input_shape[i] < 0 || static_cast<size_t>(input_shape[i]) != param_shape[i]) { | |||||
| MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index | |||||
| << ", parameter: " << param->fullname_with_scope(); | |||||
| } | |||||
| } | |||||
| parameter_index->emplace(param, index++); | |||||
| } | |||||
| } | |||||
| } | |||||
| void GetOpInputTensors(const CNodePtr &cnode, const std::map<KernelWithIndex, tensor::TensorPtr> &op_output, | |||||
| const std::map<AnfNodePtr, size_t> ¶meter_index, | |||||
| const std::vector<tensor::TensorPtr> &graph_inputs, InputTensorInfo *input_tensor_info) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| for (size_t i = 1; i < cnode->inputs().size(); i += 1) { | |||||
| const auto &input = cnode->input(i); | |||||
| auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); | |||||
| auto real_input = kernel_with_index.first; | |||||
| MS_EXCEPTION_IF_NULL(real_input); | |||||
| tensor::TensorPtr tensor = nullptr; | |||||
| if (real_input->isa<ValueNode>()) { | |||||
| auto value_node = input->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto value = GetValueNode(value_node); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| if (value->isa<ValueTuple>()) { | |||||
| auto value_tuple = value->cast<ValueTuplePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_tuple); | |||||
| if (kernel_with_index.second >= value_tuple->size()) { | |||||
| MS_LOG(EXCEPTION) << "Index " << kernel_with_index.second << "is out of value tuple range"; | |||||
| } | |||||
| auto tensor_value = value_tuple->value()[kernel_with_index.second]; | |||||
| if (tensor_value->isa<tensor::Tensor>()) { | |||||
| tensor = tensor_value->cast<tensor::TensorPtr>(); | |||||
| } | |||||
| } else if (value->isa<tensor::Tensor>()) { | |||||
| if (kernel_with_index.second != 0) { | |||||
| MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << kernel_with_index.second; | |||||
| } | |||||
| tensor = GetValueNode<TensorPtr>(value_node); | |||||
| } | |||||
| } else if (real_input->isa<Parameter>()) { | |||||
| const auto &iter = parameter_index.find(real_input); | |||||
| if (iter == parameter_index.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, node = " << cnode->DebugString(); | |||||
| } | |||||
| const size_t index = iter->second; | |||||
| if (index >= graph_inputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " | |||||
| << cnode->DebugString() << "input tensor size = " << graph_inputs.size(); | |||||
| } | |||||
| tensor = graph_inputs[index]; | |||||
| } else if (real_input->isa<CNode>()) { | |||||
| const auto &iter = op_output.find(kernel_with_index); | |||||
| if (iter == op_output.end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << real_input->DebugString(); | |||||
| } | |||||
| tensor = iter->second; | |||||
| input_tensor_info->input_kernel.insert(kernel_with_index); | |||||
| } else { | |||||
| MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString(); | |||||
| } | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from " | |||||
| << real_input->fullname_with_scope() << "-" << kernel_with_index.second; | |||||
| input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask | |||||
| : kParameterDataTensorMask); | |||||
| input_tensor_info->input_tensors.emplace_back(tensor); | |||||
| } | |||||
| } | |||||
| void HandleOpInputs(const std::set<KernelWithIndex> &input_kernel, std::map<KernelWithIndex, size_t> *ref_count, | |||||
| std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map) { | |||||
| for (auto &kernel_with_index : input_kernel) { | |||||
| if (!kernel_with_index.first->isa<CNode>()) { | |||||
| continue; | |||||
| } | |||||
| auto ref_iter = ref_count->find(kernel_with_index); | |||||
| if (ref_iter == ref_count->end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = " | |||||
| << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second; | |||||
| } | |||||
| ref_iter->second -= 1; | |||||
| if (ref_iter->second != 0) { | |||||
| continue; | |||||
| } | |||||
| ref_count->erase(ref_iter); | |||||
| auto output_iter = op_output_map->find(kernel_with_index); | |||||
| if (output_iter == op_output_map->end()) { | |||||
| MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = " | |||||
| << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second; | |||||
| } | |||||
| op_output_map->erase(output_iter); | |||||
| } | |||||
| } | |||||
| void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs, | |||||
| const std::map<KernelWithIndex, std::vector<size_t>> &output_indexes, | |||||
| const std::map<KernelWithIndex, size_t> &ref_count, | |||||
| std::map<KernelWithIndex, tensor::TensorPtr> *op_output_map, VectorRef *outputs) { | |||||
| auto output_tensors = TransformVectorRefToMultiTensor(op_outputs); | |||||
| if (output_tensors.size() != op_outputs.size()) { | |||||
| MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString(); | |||||
| } | |||||
| size_t out_index = 0; | |||||
| for (const auto &output_tensor : output_tensors) { | |||||
| auto kernel_with_index = make_pair(kernel, out_index++); | |||||
| if (ref_count.find(kernel_with_index) != ref_count.end()) { | |||||
| (*op_output_map)[kernel_with_index] = output_tensor; | |||||
| } | |||||
| const auto &iter = output_indexes.find(kernel_with_index); | |||||
| if (iter == output_indexes.end()) { | |||||
| continue; | |||||
| } | |||||
| const std::vector<size_t> &ref_indexes = iter->second; | |||||
| size_t n = 0; | |||||
| const VectorRef *cur_vector_ref = outputs; | |||||
| while (n != ref_indexes.size() - 1) { | |||||
| size_t index = ref_indexes.at(n++); | |||||
| const BaseRef &base_ref = (*cur_vector_ref)[index]; | |||||
| if (!utils::isa<VectorRef>(base_ref)) { | |||||
| MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, indexes: " << ref_indexes << "cur n: " << n - 1; | |||||
| } | |||||
| cur_vector_ref = &utils::cast<VectorRef>(base_ref); | |||||
| } | |||||
| BaseRef &tensor_ref = (*const_cast<VectorRef *>(cur_vector_ref))[ref_indexes.at(n)]; | |||||
| tensor_ref = output_tensor; | |||||
| } | |||||
| } | |||||
| void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) { | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| MS_EXCEPTION_IF_NULL(run_info); | |||||
| auto primitive = AnfAlgo::GetCNodePrimitive(cnode); | |||||
| run_info->primitive = primitive; | |||||
| run_info->op_name = primitive->name(); | |||||
| if (cnode->abstract() == nullptr) { | |||||
| MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString(); | |||||
| } | |||||
| run_info->abstract = cnode->abstract(); | |||||
| } | |||||
| GraphInfo GetSingleOpGraphInfo(const PrimitivePtr &prim, const std::vector<tensor::TensorPtr> &input_tensors) { | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| GraphInfo graph_info; | |||||
| // get input tensor info | |||||
| for (const auto &tensor : input_tensors) { | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto tensor_shape = tensor->shape(); | |||||
| (void)std::for_each(tensor_shape.begin(), tensor_shape.end(), | |||||
| [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); }); | |||||
| (void)graph_info.append(std::to_string(tensor->data_type()) + "_"); | |||||
| if (tensor->device_address() != nullptr) { | |||||
| const auto type_id = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->type_id(); | |||||
| (void)graph_info.append(std::to_string(type_id) + "_"); | |||||
| const auto format = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address())->format(); | |||||
| (void)graph_info.append(format + "_"); | |||||
| } | |||||
| } | |||||
| // get attr info | |||||
| const auto &attr_map = prim->evaluate_added_attrs(); | |||||
| (void)std::for_each(attr_map.begin(), attr_map.end(), | |||||
| [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); | |||||
| graph_info.append(prim->id()); | |||||
| return graph_info; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| void AscendSession::Init(uint32_t device_id) { | void AscendSession::Init(uint32_t device_id) { | ||||
| @@ -417,7 +696,7 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra | |||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; | MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; | ||||
| // malloc mem | // malloc mem | ||||
| RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get()); | |||||
| RunOpMemoryAlloc(input_tensors, graph.get()); | |||||
| // Build dynamic kernel | // Build dynamic kernel | ||||
| if (op_run_info.is_dynamic_shape) { | if (op_run_info.is_dynamic_shape) { | ||||
| BuildDynamicKernel(graph); | BuildDynamicKernel(graph); | ||||
| @@ -432,6 +711,39 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra | |||||
| MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; | MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; | ||||
| } | } | ||||
| void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||||
| VectorRef *outputs) { | |||||
| MS_LOG(INFO) << "Start"; | |||||
| auto kernel_graph = GetGraph(graph_id); | |||||
| std::map<AnfNodePtr, size_t> parameter_index; | |||||
| GetParameterIndex(kernel_graph.get(), inputs, ¶meter_index); | |||||
| std::map<KernelWithIndex, std::vector<size_t>> output_indexes; | |||||
| CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes); | |||||
| std::map<KernelWithIndex, size_t> cnode_ref; | |||||
| GetRefCount(kernel_graph.get(), &cnode_ref); | |||||
| std::map<KernelWithIndex, tensor::TensorPtr> op_output_map; | |||||
| for (const auto &kernel : kernel_graph->execution_order()) { | |||||
| // Generate input tensors, tensor masks and input kernel with index | |||||
| InputTensorInfo input_tensor_info; | |||||
| GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info); | |||||
| // Get OpRunInfo and GraphInfo | |||||
| OpRunInfo run_info; | |||||
| GetSingleOpRunInfo(kernel, &run_info); | |||||
| GraphInfo graph_info = GetSingleOpGraphInfo(run_info.primitive, input_tensor_info.input_tensors); | |||||
| // Build and run current single op | |||||
| BuildOpImpl(run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask); | |||||
| VectorRef op_outputs; | |||||
| RunOpImpl(run_info, graph_info, input_tensor_info.input_tensors, &op_outputs); | |||||
| // Handle inputs and outputs of current op | |||||
| HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map); | |||||
| HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs); | |||||
| } | |||||
| } | |||||
| // compile graph steps | // compile graph steps | ||||
| void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { | void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start!"; | MS_LOG(INFO) << "Start!"; | ||||
| @@ -591,15 +903,14 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| void AscendSession::RunOpMemoryAlloc(const ValuePtr &pre_output_value, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| void AscendSession::RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| KernelGraph *kernel_graph) const { | KernelGraph *kernel_graph) const { | ||||
| MS_LOG(INFO) << "Start memory alloc!"; | MS_LOG(INFO) << "Start memory alloc!"; | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| opt::RemoveNopNode(kernel_graph); | opt::RemoveNopNode(kernel_graph); | ||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph); | |||||
| runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); | |||||
| MS_LOG(INFO) << "Finish!"; | MS_LOG(INFO) << "Finish!"; | ||||
| } | } | ||||
| @@ -35,6 +35,11 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; | enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; | ||||
| struct InputTensorInfo { | |||||
| std::vector<tensor::TensorPtr> input_tensors; | |||||
| std::vector<int64_t> input_tensors_mask; | |||||
| std::set<KernelWithIndex> input_kernel; | |||||
| }; | |||||
| class AscendSession : public SessionBasic { | class AscendSession : public SessionBasic { | ||||
| public: | public: | ||||
| @@ -56,6 +61,8 @@ class AscendSession : public SessionBasic { | |||||
| const std::vector<int64_t> &tensors_mask) override; | const std::vector<int64_t> &tensors_mask) override; | ||||
| void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override; | const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) override; | ||||
| void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||||
| VectorRef *outputs) override; | |||||
| private: | private: | ||||
| // compile child graph when session have multiple child graphs | // compile child graph when session have multiple child graphs | ||||
| @@ -72,8 +79,7 @@ class AscendSession : public SessionBasic { | |||||
| void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void BuildDynamicKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void MemoryAlloc(KernelGraph *kernel_graph) const; | void MemoryAlloc(KernelGraph *kernel_graph) const; | ||||
| void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| KernelGraph *kernel_graph) const; | |||||
| void RunOpMemoryAlloc(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | |||||
| void RunOpMemoryClear(const KernelGraph *kernel_graph) const; | void RunOpMemoryClear(const KernelGraph *kernel_graph) const; | ||||
| void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const; | void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const; | ||||
| @@ -16,8 +16,9 @@ | |||||
| #include "backend/session/executor.h" | #include "backend/session/executor.h" | ||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <exception> | #include <exception> | ||||
| #include "runtime/device/kernel_runtime_manager.h" | |||||
| #include "backend/session/executor_manager.h" | #include "backend/session/executor_manager.h" | ||||
| #include "runtime/device/kernel_runtime_manager.h" | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/scoped_long_running.h" | #include "utils/scoped_long_running.h" | ||||
| @@ -134,6 +135,11 @@ void RunOpTask::Run() { | |||||
| session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_); | session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_); | ||||
| } | } | ||||
| void RunOpsInGraphTask::Run() { | |||||
| MS_EXCEPTION_IF_NULL(session_); | |||||
| session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_); | |||||
| } | |||||
| void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } | void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } | ||||
| void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } | void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } | ||||
| @@ -361,6 +367,18 @@ void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const Gr | |||||
| *outputs = task->outputs_; | *outputs = task->outputs_; | ||||
| } | } | ||||
| void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, | |||||
| const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | |||||
| MS_EXCEPTION_IF_NULL(session); | |||||
| MS_EXCEPTION_IF_NULL(outputs); | |||||
| auto task = std::make_shared<RunOpsInGraphTask>(); | |||||
| task->session_ = session; | |||||
| task->graph_id_ = graph_id; | |||||
| task->input_tensors_ = inputs; | |||||
| SyncRunTask(task); | |||||
| *outputs = task->outputs_; | |||||
| } | |||||
| bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) { | bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) { | ||||
| auto task = std::make_shared<CreateCommGroupTask>(); | auto task = std::make_shared<CreateCommGroupTask>(); | ||||
| task->group_name_ = group_name; | task->group_name_ = group_name; | ||||
| @@ -16,22 +16,23 @@ | |||||
| #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H | #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H | ||||
| #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H | #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H | ||||
| #include <vector> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include <memory> | |||||
| #include <condition_variable> | |||||
| #include <list> | #include <list> | ||||
| #include <queue> | |||||
| #include <map> | #include <map> | ||||
| #include <thread> | |||||
| #include <memory> | |||||
| #include <mutex> | #include <mutex> | ||||
| #include <condition_variable> | |||||
| #include <queue> | |||||
| #include <string> | |||||
| #include <thread> | |||||
| #include <utility> | |||||
| #include <vector> | |||||
| #include "backend/session/session_basic.h" | #include "backend/session/session_basic.h" | ||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| #include "utils/any.h" | #include "utils/any.h" | ||||
| #include "utils/contract.h" | |||||
| #include "utils/comm_manager.h" | #include "utils/comm_manager.h" | ||||
| #include "utils/contract.h" | |||||
| namespace mindspore { | namespace mindspore { | ||||
| namespace session { | namespace session { | ||||
| @@ -45,7 +46,8 @@ enum TaskType { | |||||
| kRunGraph, | kRunGraph, | ||||
| kRunOp, | kRunOp, | ||||
| kCreateCommGroup, | kCreateCommGroup, | ||||
| kDestroyCommGroup | |||||
| kDestroyCommGroup, | |||||
| kRunOpsInGraph | |||||
| }; | }; | ||||
| class Task { | class Task { | ||||
| @@ -98,6 +100,16 @@ class RunGraphTask : public Task { | |||||
| std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_; | std::map<tensor::TensorPtr, session::KernelWithIndex> tensor_to_node_; | ||||
| }; | }; | ||||
| class RunOpsInGraphTask : public Task { | |||||
| public: | |||||
| RunOpsInGraphTask() { type_ = kRunOpsInGraph; } | |||||
| ~RunOpsInGraphTask() override = default; | |||||
| void Run() override; | |||||
| std::vector<tensor::TensorPtr> input_tensors_; | |||||
| VectorRef outputs_; | |||||
| GraphId graph_id_{0}; | |||||
| }; | |||||
| class BuildOpTask : public Task { | class BuildOpTask : public Task { | ||||
| public: | public: | ||||
| BuildOpTask() { type_ = kBuildOp; } | BuildOpTask() { type_ = kBuildOp; } | ||||
| @@ -162,6 +174,8 @@ class Executor { | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask); | const std::vector<tensor::TensorPtr> &input_tensors, const std::vector<int64_t> &tensors_mask); | ||||
| void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | ||||
| void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||||
| VectorRef *outputs); | |||||
| void OnRunGraphFinished(); | void OnRunGraphFinished(); | ||||
| bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | bool CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks); | ||||
| bool DestroyCommGroup(const std::string &group_name); | bool DestroyCommGroup(const std::string &group_name); | ||||
| @@ -198,13 +198,12 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { | |||||
| runtime_instance->AssignMemory(kernel_graph); | runtime_instance->AssignMemory(kernel_graph); | ||||
| } | } | ||||
| void GPUSession::RunOpAllocateMemory(const ValuePtr &pre_output_value, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| void GPUSession::RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| KernelGraph *kernel_graph) const { | KernelGraph *kernel_graph) const { | ||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | ||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | MS_EXCEPTION_IF_NULL(runtime_instance); | ||||
| runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph); | |||||
| runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); | |||||
| } | } | ||||
| void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { | void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { | ||||
| @@ -351,6 +350,8 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||||
| VectorRef *outputs) { | VectorRef *outputs) { | ||||
| auto &kernel_graph = graphs_[graph_id]; | auto &kernel_graph = graphs_[graph_id]; | ||||
| MS_LOG(INFO) << "RunGraph graph_id: " << graph_id; | MS_LOG(INFO) << "RunGraph graph_id: " << graph_id; | ||||
| // In pynative mode, device addresses of tensors in value nodes change. | |||||
| SyncValueNodeDeviceAddr(kernel_graph); | |||||
| // Load input data from user input | // Load input data from user input | ||||
| LoadInputData(kernel_graph, inputs); | LoadInputData(kernel_graph, inputs); | ||||
| PreIterationDbg(kernel_graph); | PreIterationDbg(kernel_graph); | ||||
| @@ -366,6 +367,8 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector<tensor: | |||||
| Execute(kernel_graph); | Execute(kernel_graph); | ||||
| } | } | ||||
| PostLoadTensor(kernel_graph); | PostLoadTensor(kernel_graph); | ||||
| // In pynative mode, device addresses of tensors in value nodes need be clean. | |||||
| CleanValueNodeDeviceAddr(kernel_graph); | |||||
| // Summary | // Summary | ||||
| auto context_ptr = MsContext::GetInstance(); | auto context_ptr = MsContext::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(context_ptr); | MS_EXCEPTION_IF_NULL(context_ptr); | ||||
| @@ -400,7 +403,7 @@ void GPUSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_ | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | MS_EXCEPTION_IF_NULL(kernel_graph); | ||||
| // Remove NopOp from execution graph | // Remove NopOp from execution graph | ||||
| opt::RemoveNopNode(kernel_graph.get()); | opt::RemoveNopNode(kernel_graph.get()); | ||||
| RunOpAllocateMemory(op_run_info.value, input_tensors, kernel_graph.get()); | |||||
| RunOpAllocateMemory(input_tensors, kernel_graph.get()); | |||||
| // Execute the computation | // Execute the computation | ||||
| LoadInputData(kernel_graph, input_tensors); | LoadInputData(kernel_graph, input_tensors); | ||||
| Execute(kernel_graph); | Execute(kernel_graph); | ||||
| @@ -471,6 +474,28 @@ void GPUSession::PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph | |||||
| TensorLoader *tensor_loader = debug_services->tensor_loader(); | TensorLoader *tensor_loader = debug_services->tensor_loader(); | ||||
| tensor_loader->EmptyPrevTensor(); | tensor_loader->EmptyPrevTensor(); | ||||
| } | } | ||||
| void GPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||||
| return; | |||||
| } | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||||
| runtime_instance->SyncValueNodeDeviceAddr(kernel_graph.get()); | |||||
| } | |||||
| void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const { | |||||
| auto context_ptr = MsContext::GetInstance(); | |||||
| MS_EXCEPTION_IF_NULL(context_ptr); | |||||
| if (context_ptr->get_param<int>(MS_CTX_EXECUTION_MODE) != kPynativeMode) { | |||||
| return; | |||||
| } | |||||
| auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); | |||||
| MS_EXCEPTION_IF_NULL(runtime_instance); | |||||
| runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get()); | |||||
| } | |||||
| } // namespace gpu | } // namespace gpu | ||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -61,8 +61,7 @@ class GPUSession : public SessionBasic { | |||||
| void AllocateMemory(KernelGraph *kernel_graph) const; | void AllocateMemory(KernelGraph *kernel_graph) const; | ||||
| void RunOpAllocateMemory(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| KernelGraph *kernel_graph) const; | |||||
| void RunOpAllocateMemory(const std::vector<tensor::TensorPtr> &input_tensors, KernelGraph *kernel_graph) const; | |||||
| void RunOpClearMemory(KernelGraph *kernel_graph) const; | void RunOpClearMemory(KernelGraph *kernel_graph) const; | ||||
| @@ -82,6 +81,10 @@ class GPUSession : public SessionBasic { | |||||
| void PreLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void PreLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | void PostLoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const; | ||||
| void SyncValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| void CleanValueNodeDeviceAddr(const std::shared_ptr<KernelGraph> &kernel_graph) const; | |||||
| }; | }; | ||||
| using GPUSessionPtr = std::shared_ptr<GPUSession>; | using GPUSessionPtr = std::shared_ptr<GPUSession>; | ||||
| MS_REG_SESSION(kGPUDevice, GPUSession); | MS_REG_SESSION(kGPUDevice, GPUSession); | ||||
| @@ -14,9 +14,11 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "backend/session/session_basic.h" | #include "backend/session/session_basic.h" | ||||
| #include <utility> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <set> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <utility> | |||||
| #include "c_ops/primitive_c.h" | #include "c_ops/primitive_c.h" | ||||
| #include "ir/manager.h" | #include "ir/manager.h" | ||||
| @@ -1606,6 +1608,12 @@ void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info, | |||||
| executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); | executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); | ||||
| } | } | ||||
| void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||||
| VectorRef *outputs) { | |||||
| MS_EXCEPTION_IF_NULL(executor_); | |||||
| executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs); | |||||
| } | |||||
| void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) { | ||||
| MS_EXCEPTION_IF_NULL(executor_); | MS_EXCEPTION_IF_NULL(executor_); | ||||
| executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); | executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); | ||||
| @@ -22,6 +22,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <memory> | #include <memory> | ||||
| #include <map> | #include <map> | ||||
| #include <set> | |||||
| #include "backend/session/session_context.h" | #include "backend/session/session_context.h" | ||||
| #include "backend/session/kernel_graph.h" | #include "backend/session/kernel_graph.h" | ||||
| #include "backend/session/anf_runtime_algorithm.h" | #include "backend/session/anf_runtime_algorithm.h" | ||||
| @@ -49,7 +50,6 @@ struct OpRunInfo { | |||||
| std::string op_name; | std::string op_name; | ||||
| PrimitivePtr primitive; | PrimitivePtr primitive; | ||||
| AbstractBasePtr abstract; | AbstractBasePtr abstract; | ||||
| ValuePtr value = nullptr; | |||||
| bool is_dynamic_shape = false; | bool is_dynamic_shape = false; | ||||
| bool is_auto_mixed_precision = false; | bool is_auto_mixed_precision = false; | ||||
| std::string next_op_name = ""; | std::string next_op_name = ""; | ||||
| @@ -79,6 +79,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| const std::vector<int64_t> &tensors_mask); | const std::vector<int64_t> &tensors_mask); | ||||
| void RunOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | void RunOp(OpRunInfo *, const GraphInfo &, const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs); | ||||
| void RunOpsInGraph(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs); | |||||
| virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); | ||||
| @@ -138,6 +139,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| friend class RunGraphTask; | friend class RunGraphTask; | ||||
| friend class BuildOpTask; | friend class BuildOpTask; | ||||
| friend class RunOpTask; | friend class RunOpTask; | ||||
| friend class RunOpsInGraphTask; | |||||
| virtual bool IsSupportSummary() { return true; } | virtual bool IsSupportSummary() { return true; } | ||||
| virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &input_tensors, | ||||
| VectorRef *outputs, | VectorRef *outputs, | ||||
| @@ -155,6 +157,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> { | |||||
| const std::vector<int64_t> &tensors_mask) {} | const std::vector<int64_t> &tensors_mask) {} | ||||
| virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {} | const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {} | ||||
| virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector<tensor::TensorPtr> &inputs, | |||||
| VectorRef *outputs) {} | |||||
| void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | void RunInfer(NotNull<FuncGraphPtr> func_graph, const std::vector<tensor::TensorPtr> &inputs); | ||||
| virtual void SetSummaryNodes(KernelGraph *graph); | virtual void SetSummaryNodes(KernelGraph *graph); | ||||
| @@ -281,24 +281,6 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { | |||||
| return node_adjoint; | return node_adjoint; | ||||
| } | } | ||||
| void TensorSetAddress(const ValuePtr &value, std::map<std::string, tensor::TensorPtr> *tuple_tensors) { | |||||
| MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa<tensor::Tensor>(); | |||||
| if (value->isa<tensor::Tensor>()) { | |||||
| auto tnode = value->cast<tensor::TensorPtr>(); | |||||
| if (tuple_tensors->find(tnode->id()) != tuple_tensors->end()) { | |||||
| MS_LOG(DEBUG) << "Set tensor" << tnode->device_address(); | |||||
| (*tuple_tensors)[tnode->id()]->set_device_address(tnode->device_address()); | |||||
| } | |||||
| } | |||||
| if (value->isa<ValueTuple>()) { | |||||
| auto tuple = value->cast<ValueTuplePtr>(); | |||||
| for (size_t i = 0; i < tuple->size(); i++) { | |||||
| MS_LOG(DEBUG) << "Set tuple tensor" << (*tuple)[i]->ToString(); | |||||
| TensorSetAddress((*tuple)[i], tuple_tensors); | |||||
| } | |||||
| } | |||||
| } | |||||
| ValuePtr GenNewTensorInner(const ValuePtr &value) { | ValuePtr GenNewTensorInner(const ValuePtr &value) { | ||||
| std::vector<ValuePtr> value_list; | std::vector<ValuePtr> value_list; | ||||
| if (value->isa<tensor::Tensor>()) { | if (value->isa<tensor::Tensor>()) { | ||||
| @@ -328,7 +310,6 @@ ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, co | |||||
| void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) { | void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) { | ||||
| auto forward = cnode_morph->forward().first; | auto forward = cnode_morph->forward().first; | ||||
| auto forward_id = cnode_morph->forward().second; | |||||
| if (forward == nullptr) { | if (forward == nullptr) { | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -337,6 +318,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||||
| return; | return; | ||||
| } | } | ||||
| auto fg = GetValueNode<FuncGraphPtr>(input); | auto fg = GetValueNode<FuncGraphPtr>(input); | ||||
| // {prim::maketuple, forward_output, bprop_graph} | |||||
| auto output = fg->output(); | auto output = fg->output(); | ||||
| if (!output->isa<CNode>()) { | if (!output->isa<CNode>()) { | ||||
| return; | return; | ||||
| @@ -350,25 +332,22 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||||
| if (!IsValueNode<FuncGraph>(input_fg)) { | if (!IsValueNode<FuncGraph>(input_fg)) { | ||||
| return; | return; | ||||
| } | } | ||||
| std::map<std::string, tensor::TensorPtr> tuple_tensors; | |||||
| // replace forward output with value node | |||||
| auto equivdout = cnode_input->cast<CNodePtr>(); | auto equivdout = cnode_input->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(equivdout); | |||||
| auto func_graph = GetValueNode<FuncGraphPtr>(input_fg); | auto func_graph = GetValueNode<FuncGraphPtr>(input_fg); | ||||
| MS_EXCEPTION_IF_NULL(func_graph); | |||||
| auto manager = Manage({fg, func_graph}, false); | auto manager = Manage({fg, func_graph}, false); | ||||
| auto ref_size = manager->node_users()[equivdout].size(); | |||||
| auto forward_value = forward; | |||||
| if (!forward_id.empty() && ref_size > 1) { | |||||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||||
| inst->SaveOpForwardValue(forward_id, forward_value, &tuple_tensors); | |||||
| } | |||||
| forward_value = GenNewTensor(manager, equivdout, forward); | |||||
| auto forward_value = GenNewTensor(manager, equivdout, forward); | |||||
| MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; | MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; | ||||
| auto value_node = NewValueNode(forward_value); | auto value_node = NewValueNode(forward_value); | ||||
| value_node->set_has_new_value(true); | value_node->set_has_new_value(true); | ||||
| manager->Replace(equivdout, value_node); | manager->Replace(equivdout, value_node); | ||||
| // replace input object with value node | |||||
| auto paras = fg->parameters(); | auto paras = fg->parameters(); | ||||
| auto inputs_value = cnode_morph->inputs_value(); | auto inputs_value = cnode_morph->inputs_value(); | ||||
| if (inputs_value.size() == 0) { | |||||
| if (inputs_value.empty()) { | |||||
| return; | return; | ||||
| } | } | ||||
| if (inputs_value.size() != paras.size()) { | if (inputs_value.size() != paras.size()) { | ||||
| @@ -379,10 +358,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||||
| auto input_value = inputs_value[i]; | auto input_value = inputs_value[i]; | ||||
| if (para_ref_size > 0 && input_value.first != nullptr) { | if (para_ref_size > 0 && input_value.first != nullptr) { | ||||
| MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; | MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; | ||||
| auto inst = pynative::PynativeExecutor::GetInstance(); | |||||
| if (!input_value.second.empty()) { | |||||
| inst->SaveOpForwardValue(input_value.second, input_value.first, &tuple_tensors); | |||||
| } | |||||
| auto input_value_node = NewValueNode(input_value.first); | auto input_value_node = NewValueNode(input_value.first); | ||||
| input_value_node->set_has_new_value(true); | input_value_node->set_has_new_value(true); | ||||
| manager->Replace(paras[i], input_value_node); | manager->Replace(paras[i], input_value_node); | ||||
| @@ -394,30 +369,19 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor | |||||
| res->set_func_graph(fg); | res->set_func_graph(fg); | ||||
| PynativeElimOpt(res); | PynativeElimOpt(res); | ||||
| auto out = fg->output()->cast<CNodePtr>(); | auto out = fg->output()->cast<CNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(out); | |||||
| auto c_input = out->input(1); | auto c_input = out->input(1); | ||||
| MS_EXCEPTION_IF_NULL(c_input); | |||||
| if (!c_input->isa<ValueNode>()) { | if (!c_input->isa<ValueNode>()) { | ||||
| return; | return; | ||||
| } | } | ||||
| auto out_node = c_input->cast<ValueNodePtr>(); | auto out_node = c_input->cast<ValueNodePtr>(); | ||||
| MS_EXCEPTION_IF_NULL(out_node); | |||||
| out_node->set_value(GenNewTensor(manager, out_node, out_node->value())); | out_node->set_value(GenNewTensor(manager, out_node, out_node->value())); | ||||
| // clear resource | |||||
| cnode_morph->clear_inputs_value(); | cnode_morph->clear_inputs_value(); | ||||
| if (tuple_tensors.size() != 0) { | |||||
| MS_LOG(DEBUG) << "Start tuple out" << fg->output()->DebugString(4); | |||||
| for (auto &g : manager->func_graphs()) { | |||||
| for (auto &node : g->value_nodes()) { | |||||
| MS_LOG(DEBUG) << "Set Tensor addr" << node.first->ToString(); | |||||
| auto vnode = node.first->cast<ValueNodePtr>()->value(); | |||||
| TensorSetAddress(vnode, &tuple_tensors); | |||||
| } | |||||
| } | |||||
| } | |||||
| fg->ClearAllManagerInfo(); | fg->ClearAllManagerInfo(); | ||||
| func_graph->ClearAllManagerInfo(); | func_graph->ClearAllManagerInfo(); | ||||
| return; | |||||
| } | } | ||||
| bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { | bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { | ||||
| @@ -298,14 +298,29 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| return out; | return out; | ||||
| } | } | ||||
| void OnlySaveAbstractInfo(const ValueNodePtr &value_node) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto &value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(value); | |||||
| if (value->isa<tensor::Tensor>()) { | |||||
| auto tensor = value->cast<tensor::TensorPtr>(); | |||||
| MS_EXCEPTION_IF_NULL(tensor); | |||||
| auto new_tensor = std::make_shared<tensor::Tensor>(tensor->Dtype()->type_id(), tensor->shape()); | |||||
| value_node->set_value(MakeValue(new_tensor)); | |||||
| } | |||||
| } | |||||
| public: | public: | ||||
| AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { | ||||
| MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); | ||||
| PatternNode<AnfNodePtr> symbol_str_vnode, c_vnode, zeros_like_vnode, getitem_vnode, arg, arg1; | |||||
| PatternNode<AnfNodePtr> symbol_str_vnode; | |||||
| PatternNode<AnfNodePtr> c_vnode; | |||||
| PatternNode<AnfNodePtr> zeros_like_vnode; | |||||
| PatternNode<AnfNodePtr> arg; | |||||
| auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode); | auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode); | ||||
| auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode); | auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode); | ||||
| auto pattern = PCNode(getattr, arg); | auto pattern = PCNode(getattr, arg); | ||||
| // {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy)) | |||||
| if ((pattern).TryCapture(node) && | if ((pattern).TryCapture(node) && | ||||
| (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && | ||||
| CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { | ||||
| @@ -320,8 +335,8 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); | MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); | ||||
| // {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy)) | |||||
| auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode); | auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode); | ||||
| auto pattern1 = PCNode(resolve1, arg); | auto pattern1 = PCNode(resolve1, arg); | ||||
| @@ -338,7 +353,13 @@ class PynativeEliminater : public OptimizerCaller { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} | |||||
| PatternNode<AnfNodePtr> binop_grad_common; | |||||
| PatternNode<AnfNodePtr> getitem_vnode; | |||||
| PatternNode<AnfNodePtr> arg1; | |||||
| PatternNode<AnfNodePtr> arg2; | |||||
| PatternNode<AnfNodePtr> arg3; | |||||
| PatternNode<AnfNodePtr> arg4; | |||||
| // resolve(CommonOPS, getitem)((tensors), 3) | // resolve(CommonOPS, getitem)((tensors), 3) | ||||
| auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); | auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); | ||||
| auto pattern2 = PCNode(resolve2, arg, arg1); | auto pattern2 = PCNode(resolve2, arg, arg1); | ||||
| @@ -51,21 +51,19 @@ enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; | |||||
| struct OpExecInfo { | struct OpExecInfo { | ||||
| std::string op_name; | std::string op_name; | ||||
| std::string op_index; | |||||
| std::string prim_id; | std::string prim_id; | ||||
| PrimitivePyPtr py_primitive; | PrimitivePyPtr py_primitive; | ||||
| AbstractBasePtr abstract; | AbstractBasePtr abstract; | ||||
| bool is_dynamic_shape = false; | |||||
| ValuePtr value = nullptr; | |||||
| py::list op_inputs; | py::list op_inputs; | ||||
| py::dict op_attrs; | |||||
| std::vector<bool> inputs_mask; | std::vector<bool> inputs_mask; | ||||
| bool is_dynamic_shape = false; | |||||
| std::string next_op_name = ""; | std::string next_op_name = ""; | ||||
| bool is_mixed_precision_cast = false; | bool is_mixed_precision_cast = false; | ||||
| size_t next_input_index = 0; | size_t next_input_index = 0; | ||||
| }; | }; | ||||
| using OpExecInfoPtr = std::shared_ptr<OpExecInfo>; | using OpExecInfoPtr = std::shared_ptr<OpExecInfo>; | ||||
| OpExecInfoPtr GenerateOpExecInfo(const py::args &args); | |||||
| const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; | const std::set<std::string> ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; | ||||
| const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"}; | const std::set<std::string> force_infer_prim = {"TopK", "DropoutGenMask"}; | ||||
| @@ -149,12 +149,6 @@ static std::string GetId(const py::object &obj) { | |||||
| return py::cast<std::string>(ret); | return py::cast<std::string>(ret); | ||||
| } | } | ||||
| static std::string GetOpId(const OpExecInfoPtr &op_exec_info) { | |||||
| auto id = GetId(op_exec_info->py_primitive->GetPyObj()); | |||||
| op_exec_info->prim_id = id; | |||||
| return id; | |||||
| } | |||||
| std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) { | std::map<SignatureEnumDType, std::vector<size_t>> GetTypeIndex(const std::vector<SignatureEnumDType> &dtypes) { | ||||
| std::map<SignatureEnumDType, std::vector<size_t>> type_indexes; | std::map<SignatureEnumDType, std::vector<size_t>> type_indexes; | ||||
| for (size_t i = 0; i < dtypes.size(); ++i) { | for (size_t i = 0; i < dtypes.size(); ++i) { | ||||
| @@ -260,24 +254,6 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn | |||||
| MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString(); | MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString(); | ||||
| } | } | ||||
| OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { | |||||
| if (args.size() != PY_ARGS_NUM) { | |||||
| MS_LOG(ERROR) << "Three args are needed by RunOp"; | |||||
| return nullptr; | |||||
| } | |||||
| auto op_exec_info = std::make_shared<OpExecInfo>(); | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||||
| op_exec_info->op_name = py::cast<std::string>(args[PY_NAME]); | |||||
| auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | |||||
| if (!prim->HasPyObj()) { | |||||
| MS_LOG(EXCEPTION) << "Pyobj is empty"; | |||||
| } | |||||
| op_exec_info->py_primitive = prim; | |||||
| op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); | |||||
| op_exec_info->op_inputs = args[PY_INPUTS]; | |||||
| return op_exec_info; | |||||
| } | |||||
| std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, | std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, | ||||
| const std::vector<tensor::TensorPtr> &input_tensors) { | const std::vector<tensor::TensorPtr> &input_tensors) { | ||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | MS_EXCEPTION_IF_NULL(op_exec_info); | ||||
| @@ -580,7 +556,7 @@ py::tuple RunOp(const py::args &args) { | |||||
| auto executor = PynativeExecutor::GetInstance(); | auto executor = PynativeExecutor::GetInstance(); | ||||
| MS_EXCEPTION_IF_NULL(executor); | MS_EXCEPTION_IF_NULL(executor); | ||||
| MS_LOG(DEBUG) << "RunOp start " << args.size(); | MS_LOG(DEBUG) << "RunOp start " << args.size(); | ||||
| OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); | |||||
| OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args); | |||||
| try { | try { | ||||
| return executor->RunOpInner(op_exec_info); | return executor->RunOpInner(op_exec_info); | ||||
| } catch (const py::error_already_set &ex) { | } catch (const py::error_already_set &ex) { | ||||
| @@ -608,16 +584,17 @@ py::tuple RunOp(const py::args &args) { | |||||
| } | } | ||||
| py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | ||||
| auto prim = op_exec_info->py_primitive; | |||||
| auto name = op_exec_info->op_name; | |||||
| if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { | if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { | ||||
| return RunOpWithInitBackendPolicy(op_exec_info); | return RunOpWithInitBackendPolicy(op_exec_info); | ||||
| } | } | ||||
| // make cnode for building grad graph if grad flag is set. | |||||
| abstract::AbstractBasePtrList args_spec_list; | abstract::AbstractBasePtrList args_spec_list; | ||||
| std::vector<bool> op_masks; | std::vector<bool> op_masks; | ||||
| auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); | |||||
| auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list); | |||||
| op_exec_info->inputs_mask = op_masks; | |||||
| // get output abstract info | |||||
| bool is_find = false; | bool is_find = false; | ||||
| auto prim = op_exec_info->py_primitive; | |||||
| if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { | if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { | ||||
| auto abs_list = prim_abs_list_[prim->id()]; | auto abs_list = prim_abs_list_[prim->id()]; | ||||
| MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | ||||
| @@ -629,7 +606,6 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| is_find = true; | is_find = true; | ||||
| } | } | ||||
| } | } | ||||
| if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) { | if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) { | ||||
| // use python infer method | // use python infer method | ||||
| if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { | ||||
| @@ -648,11 +624,10 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| if (cnode != nullptr) { | if (cnode != nullptr) { | ||||
| cnode->set_abstract(op_exec_info->abstract); | cnode->set_abstract(op_exec_info->abstract); | ||||
| } | } | ||||
| op_exec_info->inputs_mask = op_masks; | |||||
| // infer output value for const prim | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | MS_EXCEPTION_IF_NULL(op_exec_info); | ||||
| if (op_exec_info->abstract != nullptr) { | if (op_exec_info->abstract != nullptr) { | ||||
| MS_LOG(DEBUG) << "Run op infer " << name << " " << op_exec_info->abstract->ToString(); | |||||
| MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); | |||||
| py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); | py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); | ||||
| if (!output["value"].is_none()) { | if (!output["value"].is_none()) { | ||||
| py::tuple value_ret(1); | py::tuple value_ret(1); | ||||
| @@ -665,7 +640,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| return value_ret; | return value_ret; | ||||
| } | } | ||||
| } | } | ||||
| // add output abstract info into cache | |||||
| if (!is_find) { | if (!is_find) { | ||||
| // const_value need infer every step | // const_value need infer every step | ||||
| auto &out = prim_abs_list_[prim->id()]; | auto &out = prim_abs_list_[prim->id()]; | ||||
| @@ -674,13 +649,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| out[args_spec_list].attrs = prim->evaluate_added_attrs(); | out[args_spec_list].attrs = prim->evaluate_added_attrs(); | ||||
| MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); | ||||
| } | } | ||||
| if (PynativeExecutor::GetInstance()->grad_flag()) { | |||||
| op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info); | |||||
| } else { | |||||
| (void)GetOpId(op_exec_info); | |||||
| } | |||||
| // run op with selected backend | |||||
| auto result = RunOpWithInitBackendPolicy(op_exec_info); | auto result = RunOpWithInitBackendPolicy(op_exec_info); | ||||
| py::object out_real = result; | py::object out_real = result; | ||||
| if (result.size() == 1) { | if (result.size() == 1) { | ||||
| @@ -689,13 +658,38 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { | |||||
| } | } | ||||
| std::string obj_id = GetId(out_real); | std::string obj_id = GetId(out_real); | ||||
| node_abs_map_[obj_id] = op_exec_info->abstract; | node_abs_map_[obj_id] = op_exec_info->abstract; | ||||
| PynativeExecutor::GetInstance()->SaveOutputNodeMap(obj_id, out_real, cnode); | |||||
| if (cnode != nullptr) { | |||||
| PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast<CNodePtr>(), result); | |||||
| } | |||||
| SaveOutputNodeMap(obj_id, out_real, cnode); | |||||
| SaveAllResult(op_exec_info, cnode, out_real); | |||||
| // Update the abstract and device address of value node with tensor in grad graph | |||||
| UpdateAbstractAndDeviceAddress(op_exec_info, out_real); | |||||
| return result; | return result; | ||||
| } | } | ||||
| OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { | |||||
| if (args.size() != PY_ARGS_NUM) { | |||||
| MS_LOG(ERROR) << "Three args are needed by RunOp"; | |||||
| return nullptr; | |||||
| } | |||||
| auto op_exec_info = std::make_shared<OpExecInfo>(); | |||||
| auto op_name = py::cast<std::string>(args[PY_NAME]); | |||||
| op_exec_info->op_name = op_name; | |||||
| if (grad_flag_) { | |||||
| MS_EXCEPTION_IF_NULL(resource_); | |||||
| int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int64_t>(); | |||||
| op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]); | |||||
| op_index_map_[op_name]++; | |||||
| } | |||||
| auto prim = py::cast<PrimitivePyPtr>(args[PY_PRIM]); | |||||
| MS_EXCEPTION_IF_NULL(prim); | |||||
| if (!prim->HasPyObj()) { | |||||
| MS_LOG(EXCEPTION) << "Pyobj is empty"; | |||||
| } | |||||
| op_exec_info->prim_id = GetId(prim->GetPyObj()); | |||||
| op_exec_info->py_primitive = prim; | |||||
| op_exec_info->op_inputs = args[PY_INPUTS]; | |||||
| return op_exec_info; | |||||
| } | |||||
| AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector<bool> *op_masks, | ||||
| abstract::AbstractBasePtrList *args_spec_list) { | abstract::AbstractBasePtrList *args_spec_list) { | ||||
| MS_EXCEPTION_IF_NULL(op_masks); | MS_EXCEPTION_IF_NULL(op_masks); | ||||
| @@ -997,6 +991,56 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { | |||||
| return node; | return node; | ||||
| } | } | ||||
| void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) { | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||||
| if (!grad_flag_) { | |||||
| return; | |||||
| } | |||||
| auto op_index = op_exec_info->op_index; | |||||
| auto output_value = PyAttrValue(out_real); | |||||
| MS_EXCEPTION_IF_NULL(output_value); | |||||
| std::vector<tensor::TensorPtr> output_tensors; | |||||
| TensorValueToTensor(output_value, &output_tensors); | |||||
| if (op_index_with_tensor_id_.find(op_index) == op_index_with_tensor_id_.end()) { | |||||
| // first step | |||||
| std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) { | |||||
| op_index_with_tensor_id_[op_index].emplace_back(tensor->id()); | |||||
| }); | |||||
| return; | |||||
| } | |||||
| const auto &tensor_id_list = op_index_with_tensor_id_[op_index]; | |||||
| for (size_t i = 0; i < tensor_id_list.size(); ++i) { | |||||
| auto tensor_id = tensor_id_list[i]; | |||||
| if (tensor_id_with_tensor_.find(tensor_id) != tensor_id_with_tensor_.end()) { | |||||
| auto &new_tensor = output_tensors[i]; | |||||
| auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id]; | |||||
| std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { | |||||
| tensor->set_shape(new_tensor->shape()); | |||||
| tensor->set_data_type(new_tensor->data_type()); | |||||
| tensor->set_device_address(new_tensor->device_address()); | |||||
| }); | |||||
| } | |||||
| } | |||||
| } | |||||
| void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { | |||||
| MS_EXCEPTION_IF_NULL(resource); | |||||
| tensor_id_with_tensor_.clear(); | |||||
| const auto &func_graph = resource->func_graph(); | |||||
| const auto &value_node_list = func_graph->value_nodes(); | |||||
| for (const auto &elem : value_node_list) { | |||||
| auto value_node = elem.first->cast<ValueNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| std::vector<tensor::TensorPtr> tensors; | |||||
| TensorValueToTensor(value_node->value(), &tensors); | |||||
| for (const auto &tensor : tensors) { | |||||
| if (tensor->device_address() != nullptr) { | |||||
| tensor_id_with_tensor_[tensor->id()].emplace_back(tensor); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { | AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { | ||||
| auto &out = graph_info_map_[curr_g_].node_map[obj_id]; | auto &out = graph_info_map_[curr_g_].node_map[obj_id]; | ||||
| if (out.second.size() == 1 && out.second[0] == -1) { | if (out.second.size() == 1 && out.second[0] == -1) { | ||||
| @@ -1054,23 +1098,6 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str | |||||
| return node; | return node; | ||||
| } | } | ||||
| ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { | |||||
| auto id = GetOpId(op_exec_info); | |||||
| int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int64_t>(); | |||||
| auto op = std::to_string(graph_id) + id; | |||||
| op.append(std::to_string(op_id_map_[id])); | |||||
| auto iter = op_forward_map_.find(op); | |||||
| if (iter != op_forward_map_.end()) { | |||||
| ++op_id_map_[id]; | |||||
| MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second; | |||||
| return iter->second; | |||||
| } | |||||
| if (!first_grad_step_) { | |||||
| ++op_id_map_[id]; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, | void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, | ||||
| const AnfNodePtr &cnode) { | const AnfNodePtr &cnode) { | ||||
| if (!grad_flag_ || graph_info_map_.empty()) { | if (!grad_flag_ || graph_info_map_.empty()) { | ||||
| @@ -1093,16 +1120,16 @@ void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::ob | |||||
| SetPyObjInGraphInfoMap(curr_g_, obj_id); | SetPyObjInGraphInfoMap(curr_g_, obj_id); | ||||
| } | } | ||||
| void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { | |||||
| if (!grad_flag_ || op_exec_info->value != nullptr || cnode == nullptr) { | |||||
| void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, | |||||
| const py::object &out_real) { | |||||
| if (!grad_flag_ || node == nullptr) { | |||||
| return; | return; | ||||
| } | } | ||||
| py::object out_real = out; | |||||
| if (out.size() == 1) { | |||||
| out_real = out[0]; | |||||
| } | |||||
| auto value = PyAttrValue(out_real); | |||||
| MS_EXCEPTION_IF_NULL(op_exec_info); | |||||
| auto cnode = node->cast<CNodePtr>(); | |||||
| MS_EXCEPTION_IF_NULL(cnode); | |||||
| // save input object | |||||
| size_t size = op_exec_info->op_inputs.size(); | size_t size = op_exec_info->op_inputs.size(); | ||||
| for (size_t i = 0; i < size; i++) { | for (size_t i = 0; i < size; i++) { | ||||
| auto obj = op_exec_info->op_inputs[i]; | auto obj = op_exec_info->op_inputs[i]; | ||||
| @@ -1113,59 +1140,19 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN | |||||
| cnode->add_input_value(nullptr, ""); | cnode->add_input_value(nullptr, ""); | ||||
| } | } | ||||
| } | } | ||||
| std::string id = GetOpId(op_exec_info); | |||||
| int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast<int64_t>(); | |||||
| auto op_id = std::to_string(graph_id) + id; | |||||
| op_id.append(std::to_string(op_id_map_[id])); | |||||
| cnode->set_forward(value, op_id); | |||||
| ++op_id_map_[id]; | |||||
| // save output object | |||||
| auto output_value = PyAttrValue(out_real); | |||||
| MS_EXCEPTION_IF_NULL(output_value); | |||||
| cnode->set_forward(output_value, op_exec_info->op_index); | |||||
| auto out_id = GetId(out_real); | auto out_id = GetId(out_real); | ||||
| if (py::isinstance<py::tuple>(out_real)) { | if (py::isinstance<py::tuple>(out_real)) { | ||||
| auto tuple_item = py::cast<py::tuple>(out_real); | auto tuple_item = py::cast<py::tuple>(out_real); | ||||
| for (size_t i = 0; i < tuple_item.size(); i++) { | for (size_t i = 0; i < tuple_item.size(); i++) { | ||||
| auto tuple_item_id = GetId(tuple_item[i]); | auto tuple_item_id = GetId(tuple_item[i]); | ||||
| obj_to_forward_id_[tuple_item_id] = op_id; | |||||
| } | |||||
| SaveOpForwardValue(op_id, value, nullptr); | |||||
| } | |||||
| obj_to_forward_id_[out_id] = op_id; | |||||
| } | |||||
| void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value, | |||||
| std::map<std::string, tensor::TensorPtr> *t_map) { | |||||
| if (op_forward_map_.find(id) != op_forward_map_.end()) { | |||||
| // for one op have multi outputs but save only one tensor | |||||
| if (op_forward_map_[id]->isa<ValueTuple>() && value->isa<tensor::Tensor>()) { | |||||
| auto tuple = op_forward_map_[id]->cast<ValueTuplePtr>(); | |||||
| auto value_t = value->cast<tensor::TensorPtr>(); | |||||
| for (size_t i = 0; i < tuple->size(); i++) { | |||||
| if ((*tuple)[i]->isa<tensor::Tensor>()) { | |||||
| auto tuple_t = (*tuple)[i]->cast<tensor::TensorPtr>(); | |||||
| if (value_t->id() == tuple_t->id()) { | |||||
| tuple_t->set_device_address(value_t->device_address()); | |||||
| MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString(); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| if (value->isa<ValueTuple>() && t_map != nullptr) { | |||||
| GenTupleMap(op_forward_map_[id]->cast<ValueTuplePtr>(), t_map); | |||||
| obj_to_forward_id_[tuple_item_id] = op_exec_info->op_index; | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Save op forward value: " | |||||
| << "(" << id << "), " << op_forward_map_[id]->ToString(); | |||||
| return; | |||||
| } | |||||
| if (value->isa<ValueTuple>() && t_map == nullptr) { | |||||
| // make cnode gen all tuple node and set device_address be null | |||||
| op_forward_map_[id] = CleanTupleAddr(value->cast<ValueTuplePtr>()); | |||||
| } else { | |||||
| op_forward_map_[id] = value; | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Save op forward value: " | |||||
| << "(" << id << "), " << value->ToString(); | |||||
| obj_to_forward_id_[out_id] = op_exec_info->op_index; | |||||
| } | } | ||||
| void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) { | void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map) { | ||||
| @@ -1307,10 +1294,13 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati | |||||
| ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); | ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); | ||||
| // get graph info for checking it whether existing in the cache | // get graph info for checking it whether existing in the cache | ||||
| std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); | std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); | ||||
| session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, | |||||
| op_exec_info->abstract, op_exec_info->value, | |||||
| op_exec_info->is_dynamic_shape, op_exec_info->is_mixed_precision_cast, | |||||
| op_exec_info->next_op_name, op_exec_info->next_input_index}; | |||||
| session::OpRunInfo op_run_info = {op_exec_info->op_name, | |||||
| op_exec_info->py_primitive, | |||||
| op_exec_info->abstract, | |||||
| op_exec_info->is_dynamic_shape, | |||||
| op_exec_info->is_mixed_precision_cast, | |||||
| op_exec_info->next_op_name, | |||||
| op_exec_info->next_input_index}; | |||||
| session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); | session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); | ||||
| EraseValueNodeTensor(tensors_mask, &input_tensors); | EraseValueNodeTensor(tensors_mask, &input_tensors); | ||||
| VectorRef outputs; | VectorRef outputs; | ||||
| @@ -1524,6 +1514,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg | |||||
| if (it != cell_resource_map_.end()) { | if (it != cell_resource_map_.end()) { | ||||
| resource_ = it->second; | resource_ = it->second; | ||||
| MS_EXCEPTION_IF_NULL(resource_); | MS_EXCEPTION_IF_NULL(resource_); | ||||
| op_index_map_.clear(); | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "Graph already compiled"; | MS_LOG(DEBUG) << "Graph already compiled"; | ||||
| return; | return; | ||||
| @@ -1571,7 +1562,8 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar | |||||
| resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; | resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; | ||||
| cell_resource_map_[cell_id] = resource_; | cell_resource_map_[cell_id] = resource_; | ||||
| MS_LOG(DEBUG) << "New top graph for " << cell_id; | MS_LOG(DEBUG) << "New top graph for " << cell_id; | ||||
| first_grad_step_ = true; | |||||
| op_index_map_.clear(); | |||||
| op_index_with_tensor_id_.clear(); | |||||
| top_graph_cells_.emplace(cell_id); | top_graph_cells_.emplace(cell_id); | ||||
| } | } | ||||
| @@ -1770,6 +1762,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje | |||||
| MS_LOG(DEBUG) << "Start opt"; | MS_LOG(DEBUG) << "Start opt"; | ||||
| PynativeOptimizeAction(resource_); | PynativeOptimizeAction(resource_); | ||||
| SaveTensorsInValueNode(resource_); | |||||
| TaskEmitAction(resource_); | TaskEmitAction(resource_); | ||||
| ExecuteAction(resource_); | ExecuteAction(resource_); | ||||
| cell_graph_map_[cell_id].second = true; | cell_graph_map_[cell_id].second = true; | ||||
| @@ -2021,7 +2014,6 @@ void PynativeExecutor::Clear(const std::string &flag) { | |||||
| } | } | ||||
| ConfigManager::GetInstance().ResetIterNum(); | ConfigManager::GetInstance().ResetIterNum(); | ||||
| if (top_graph_cells_.find(flag) != top_graph_cells_.end()) { | if (top_graph_cells_.find(flag) != top_graph_cells_.end()) { | ||||
| op_forward_map_.clear(); | |||||
| Clean(); | Clean(); | ||||
| } | } | ||||
| node_abs_map_.clear(); | node_abs_map_.clear(); | ||||
| @@ -2033,9 +2025,7 @@ void PynativeExecutor::Clear(const std::string &flag) { | |||||
| top_g_ = nullptr; | top_g_ = nullptr; | ||||
| df_builder_ = nullptr; | df_builder_ = nullptr; | ||||
| curr_g_ = nullptr; | curr_g_ = nullptr; | ||||
| first_grad_step_ = false; | |||||
| graph_info_map_.clear(); | graph_info_map_.clear(); | ||||
| op_id_map_.clear(); | |||||
| obj_to_forward_id_.clear(); | obj_to_forward_id_.clear(); | ||||
| node_abs_map_.clear(); | node_abs_map_.clear(); | ||||
| std::stack<FuncGraphPtr>().swap(graph_stack_); | std::stack<FuncGraphPtr>().swap(graph_stack_); | ||||
| @@ -83,13 +83,12 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| void set_grad_flag(bool flag) { grad_flag_ = flag; } | void set_grad_flag(bool flag) { grad_flag_ = flag; } | ||||
| py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); | py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); | ||||
| OpExecInfoPtr GenerateOpExecInfo(const py::args &args); | |||||
| void NewGraph(const py::object &cell, const py::args &args); | void NewGraph(const py::object &cell, const py::args &args); | ||||
| py::object Run(const py::tuple &args, const py::object &phase); | py::object Run(const py::tuple &args, const py::object &phase); | ||||
| py::object CheckGraph(const py::object &cell, const py::args &args); | py::object CheckGraph(const py::object &cell, const py::args &args); | ||||
| void EndGraph(const py::object &cell, const py::object &out, const py::args &args); | void EndGraph(const py::object &cell, const py::object &out, const py::args &args); | ||||
| void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); | void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); | ||||
| void SaveOpForwardValue(const std::string &id, const ValuePtr &value, | |||||
| std::map<std::string, tensor::TensorPtr> *t_map); | |||||
| // Call by python | // Call by python | ||||
| void Clear(const std::string &flag = ""); | void Clear(const std::string &flag = ""); | ||||
| @@ -134,9 +133,11 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| // replace for grad graph | // replace for grad graph | ||||
| ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple); | ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple); | ||||
| ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); | |||||
| void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map); | void GenTupleMap(const ValueTuplePtr &tuple, std::map<std::string, tensor::TensorPtr> *t_map); | ||||
| void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); | |||||
| void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real); | |||||
| // Update the abstract and device address info of value node and tensors in bprop graph | |||||
| void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); | |||||
| void SaveTensorsInValueNode(const ResourcePtr &resource); | |||||
| // construct grad graph | // construct grad graph | ||||
| void PushCurrentGraphToStack(); | void PushCurrentGraphToStack(); | ||||
| @@ -175,7 +176,6 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| static int64_t graph_id_; | static int64_t graph_id_; | ||||
| bool grad_flag_{false}; | bool grad_flag_{false}; | ||||
| bool dynamic_cell_{false}; | bool dynamic_cell_{false}; | ||||
| bool first_grad_step_{false}; | |||||
| bool grad_is_running{false}; | bool grad_is_running{false}; | ||||
| // Used for construct grad graph | // Used for construct grad graph | ||||
| @@ -199,9 +199,10 @@ class PynativeExecutor : public std::enable_shared_from_this<PynativeExecutor> { | |||||
| std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>> df_builder_map_; | std::unordered_map<std::string, std::pair<FuncGraphPtr, FuncGraphPtr>> df_builder_map_; | ||||
| // used for runop and replace forward result of grad graph | // used for runop and replace forward result of grad graph | ||||
| std::unordered_map<std::string, ValuePtr> op_forward_map_; | |||||
| std::unordered_map<std::string, size_t> op_id_map_; | |||||
| std::unordered_map<std::string, size_t> op_index_map_; | |||||
| std::unordered_map<std::string, std::string> obj_to_forward_id_; | std::unordered_map<std::string, std::string> obj_to_forward_id_; | ||||
| std::unordered_map<std::string, std::vector<std::string>> op_index_with_tensor_id_; | |||||
| std::unordered_map<std::string, std::vector<tensor::TensorPtr>> tensor_id_with_tensor_; | |||||
| std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; | std::unordered_map<std::string, abstract::AbstractBasePtr> node_abs_map_; | ||||
| std::unordered_map<std::string, AbstractListMap> prim_abs_list_; | std::unordered_map<std::string, AbstractListMap> prim_abs_list_; | ||||
| const inline static std::string kOpsFunctionModelName = "mindspore.ops.functional"; | const inline static std::string kOpsFunctionModelName = "mindspore.ops.functional"; | ||||
| @@ -81,15 +81,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { | |||||
| UpdateRefNodeOutputMem(graph); | UpdateRefNodeOutputMem(graph); | ||||
| } | } | ||||
| void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value, | |||||
| const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| session::KernelGraph *graph) { | session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| mem_manager_->ResetDynamicMemory(); | mem_manager_->ResetDynamicMemory(); | ||||
| RunOpAssignInputMemory(input_tensors, graph); | RunOpAssignInputMemory(input_tensors, graph); | ||||
| AssignStaticMemoryValueNode(graph); | AssignStaticMemoryValueNode(graph); | ||||
| RunOpAssignOutputNodeMemory(pre_output_value, graph); | |||||
| for (const auto &cnode : graph->execution_order()) { | for (const auto &cnode : graph->execution_order()) { | ||||
| RunOpAssignOutputMemory(cnode); | RunOpAssignOutputMemory(cnode); | ||||
| RunOpAssignWorkSpaceMemory(cnode); | RunOpAssignWorkSpaceMemory(cnode); | ||||
| @@ -680,6 +678,52 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { | |||||
| MS_LOG(INFO) << "AssignStaticMemoryValueNode end"; | MS_LOG(INFO) << "AssignStaticMemoryValueNode end"; | ||||
| } | } | ||||
| void KernelRuntime::SyncValueNodeDeviceAddr(session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_LOG(INFO) << "SyncValueNodeDeviceAddr start"; | |||||
| for (auto &value_node : graph->graph_value_nodes()) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto &node_value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(node_value); | |||||
| if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) { | |||||
| continue; | |||||
| } | |||||
| std::vector<tensor::TensorPtr> tensors; | |||||
| TensorValueToTensor(node_value, &tensors); | |||||
| for (size_t index = 0; index < tensors.size(); index += 1) { | |||||
| const auto &tensor = tensors[index]; | |||||
| if (tensor->device_address() != nullptr) { | |||||
| AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), index, | |||||
| value_node.get()); | |||||
| } else { | |||||
| MS_LOG(INFO) << "Tensor of ValueNode[" << value_node->fullname_with_scope() << "]'s device address is nullptr."; | |||||
| } | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "SyncValueNodeDeviceAddr end"; | |||||
| } | |||||
| void KernelRuntime::CleanValueNodeDeviceAddr(session::KernelGraph *graph) { | |||||
| MS_EXCEPTION_IF_NULL(graph); | |||||
| MS_LOG(INFO) << "CleanValueNodeDeviceAddr start"; | |||||
| for (auto &value_node : graph->graph_value_nodes()) { | |||||
| MS_EXCEPTION_IF_NULL(value_node); | |||||
| auto &node_value = value_node->value(); | |||||
| MS_EXCEPTION_IF_NULL(node_value); | |||||
| if (!node_value->isa<Tensor>() && !node_value->isa<ValueTuple>()) { | |||||
| continue; | |||||
| } | |||||
| std::vector<tensor::TensorPtr> tensors; | |||||
| TensorValueToTensor(node_value, &tensors); | |||||
| for (size_t index = 0; index < tensors.size(); index += 1) { | |||||
| if (tensors[index]->device_address() != nullptr) { | |||||
| AnfAlgo::SetOutputAddr(nullptr, index, value_node.get()); | |||||
| } | |||||
| } | |||||
| } | |||||
| MS_LOG(INFO) << "CleanValueNodeDeviceAddr end"; | |||||
| } | |||||
| void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { | ||||
| MS_EXCEPTION_IF_NULL(graph); | MS_EXCEPTION_IF_NULL(graph); | ||||
| MS_EXCEPTION_IF_NULL(mem_manager_); | MS_EXCEPTION_IF_NULL(mem_manager_); | ||||
| @@ -51,8 +51,7 @@ class KernelRuntime { | |||||
| virtual ~KernelRuntime(); | virtual ~KernelRuntime(); | ||||
| virtual bool Init() = 0; | virtual bool Init() = 0; | ||||
| virtual void AssignMemory(session::KernelGraph *graph); | virtual void AssignMemory(session::KernelGraph *graph); | ||||
| void RunOpAssignMemory(const ValuePtr &pre_output_value, const std::vector<tensor::TensorPtr> &input_tensors, | |||||
| session::KernelGraph *graph); | |||||
| void RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors, session::KernelGraph *graph); | |||||
| void RunOpClearMemory(const session::KernelGraph *graph); | void RunOpClearMemory(const session::KernelGraph *graph); | ||||
| static bool DumpDataEnabled(); | static bool DumpDataEnabled(); | ||||
| static bool DumpDataEnabledIteration(); | static bool DumpDataEnabledIteration(); | ||||
| @@ -67,6 +66,8 @@ class KernelRuntime { | |||||
| const AddressPtrList &kernel_workspaces) const; | const AddressPtrList &kernel_workspaces) const; | ||||
| virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); | virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); | ||||
| virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); | virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); | ||||
| virtual void SyncValueNodeDeviceAddr(session::KernelGraph *graph); | |||||
| virtual void CleanValueNodeDeviceAddr(session::KernelGraph *graph); | |||||
| virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, | virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs, | ||||
| const std::unordered_set<ValueNodePtr> &value_nodes, | const std::unordered_set<ValueNodePtr> &value_nodes, | ||||
| const std::vector<CNodePtr> &execution_order); | const std::vector<CNodePtr> &execution_order); | ||||
| @@ -18,13 +18,13 @@ | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <vector> | #include <vector> | ||||
| #include "utils/log_adapter.h" | |||||
| #include "backend/session/session_factory.h" | |||||
| #include "ir/anf.h" | #include "ir/anf.h" | ||||
| #include "pybind_api/ir/base_ref_py.h" | |||||
| #include "utils/callbacks.h" | #include "utils/callbacks.h" | ||||
| #include "utils/convert_utils.h" | #include "utils/convert_utils.h" | ||||
| #include "backend/session/session_factory.h" | |||||
| #include "utils/log_adapter.h" | |||||
| #include "utils/ms_utils.h" | #include "utils/ms_utils.h" | ||||
| #include "pybind_api/ir/base_ref_py.h" | |||||
| #ifdef ENABLE_GE | #ifdef ENABLE_GE | ||||
| #include "utils/callbacks_ge.h" | #include "utils/callbacks_ge.h" | ||||
| #endif | #endif | ||||
| @@ -83,10 +83,14 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std: | |||||
| MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | MS_LOG(INFO) << "PrecompileOnly, stop run graph"; | ||||
| return result; | return result; | ||||
| } | } | ||||
| if (target != target_device_ && !target.empty()) { | |||||
| other_sess_->BuildGraph(graph_id); | |||||
| } else if (!is_multi_graph_sink_) { | |||||
| target_sess_->BuildGraph(graph_id); | |||||
| auto ms_context = MsContext::GetInstance(); | |||||
| const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode); | |||||
| if (!pynative_mode || target != "Ascend") { | |||||
| if (target != target_device_ && !target.empty()) { | |||||
| other_sess_->BuildGraph(graph_id); | |||||
| } else if (!is_multi_graph_sink_) { | |||||
| target_sess_->BuildGraph(graph_id); | |||||
| } | |||||
| } | } | ||||
| result.run = std::make_shared<RunFunc>( | result.run = std::make_shared<RunFunc>( | ||||
| [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); | [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); | ||||
| @@ -154,12 +158,19 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s | |||||
| PushInputTensor(arg, &inputs); | PushInputTensor(arg, &inputs); | ||||
| } | } | ||||
| auto ms_context = MsContext::GetInstance(); | |||||
| const bool pynative_mode = (ms_context->get_param<int>(MS_CTX_EXECUTION_MODE) == kPynativeMode); | |||||
| VectorRef outputs; | VectorRef outputs; | ||||
| // call ms rungraph (graphId, input ,output) | // call ms rungraph (graphId, input ,output) | ||||
| if (target != target_device_ && !target.empty()) { | if (target != target_device_ && !target.empty()) { | ||||
| other_sess_->RunGraphAsync(g, inputs, &outputs); | other_sess_->RunGraphAsync(g, inputs, &outputs); | ||||
| } else { | } else { | ||||
| target_sess_->RunGraphAsync(g, inputs, &outputs); | |||||
| if (pynative_mode && target == "Ascend") { | |||||
| target_sess_->RunOpsInGraph(g, inputs, &outputs); | |||||
| } else { | |||||
| target_sess_->RunGraphAsync(g, inputs, &outputs); | |||||
| } | |||||
| } | } | ||||
| MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); | MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); | ||||
| @@ -134,7 +134,6 @@ class MulAdd(nn.Cell): | |||||
| assert dout.asnumpy() == 1.0 | assert dout.asnumpy() == 1.0 | ||||
| return dout, y | return dout, y | ||||
| class Ms_Cell(nn.Cell): | class Ms_Cell(nn.Cell): | ||||
| def __init__(self): | def __init__(self): | ||||
| super(Ms_Cell, self).__init__() | super(Ms_Cell, self).__init__() | ||||
| @@ -143,6 +142,19 @@ class Ms_Cell(nn.Cell): | |||||
| def construct(self, x): | def construct(self, x): | ||||
| return self.relu(x) | return self.relu(x) | ||||
| def bprop(self, x, out, dout): | |||||
| dout = Tensor(np.float32(0.0)) | |||||
| assert dout.shape == () | |||||
| return dout | |||||
| class Ms_Cell_Change_Shape(nn.Cell): | |||||
| def __init__(self): | |||||
| super(Ms_Cell_Change_Shape, self).__init__() | |||||
| self.relu = P.ReLU() | |||||
| def construct(self, x): | |||||
| return self.relu(x) | |||||
| def bprop(self, x, out, dout): | def bprop(self, x, out, dout): | ||||
| dout = Tensor(np.ones([5, 5]).astype(np.float32)) | dout = Tensor(np.ones([5, 5]).astype(np.float32)) | ||||
| assert dout.shape == (5, 5) | assert dout.shape == (5, 5) | ||||
| @@ -186,6 +198,19 @@ def test_pynative_custom_bprop_and_Cell_MulAdd(): | |||||
| (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)) | (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)) | ||||
| @pytest.mark.level0 | |||||
| @pytest.mark.platform_arm_ascend_training | |||||
| @pytest.mark.platform_x86_ascend_training | |||||
| @pytest.mark.env_onecard | |||||
| def test_pynative_custom_bprop_and_Cell_Ms_Cell_Change_Shape(): | |||||
| custom_cell = test_custom_cell_base() | |||||
| ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell_Change_Shape()) | |||||
| ms_Cell.bprop_debug = True | |||||
| with pytest.raises(RuntimeError) as ex: | |||||
| grad_all(ms_Cell)(Tensor(1, mstype.float32)) | |||||
| assert "Shapes of input and parameter are different, input index" in str(ex.value) | |||||
| @pytest.mark.level0 | @pytest.mark.level0 | ||||
| @pytest.mark.platform_arm_ascend_training | @pytest.mark.platform_arm_ascend_training | ||||
| @pytest.mark.platform_x86_ascend_training | @pytest.mark.platform_x86_ascend_training | ||||
| @@ -194,5 +219,5 @@ def test_pynative_custom_bprop_and_Cell_Ms_Cell(): | |||||
| custom_cell = test_custom_cell_base() | custom_cell = test_custom_cell_base() | ||||
| ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell()) | ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell()) | ||||
| ms_Cell.bprop_debug = True | ms_Cell.bprop_debug = True | ||||
| assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(1.0, mstype.float32),) | |||||
| assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(0.0, mstype.float32),) | |||||
| @@ -65,7 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() { | |||||
| py::none py_none; | py::none py_none; | ||||
| py::args args = py::make_tuple(conv_obj, op_name, op_inputs); | py::args args = py::make_tuple(conv_obj, op_name, op_inputs); | ||||
| py::list args_input = args[PY_INPUTS]; | py::list args_input = args[PY_INPUTS]; | ||||
| return GenerateOpExecInfo(args); | |||||
| return PynativeExecutor::GetInstance()->GenerateOpExecInfo(args); | |||||
| } | } | ||||
| TEST_F(TestPynativeExecute, TestCreateContext) { | TEST_F(TestPynativeExecute, TestCreateContext) { | ||||