From 81059dd0193a9052d6261d6e39f22f387034225d Mon Sep 17 00:00:00 2001 From: LianLiguang Date: Sat, 12 Dec 2020 18:52:13 +0800 Subject: [PATCH] change cast format and it's connected type --- .../ascend/ascend_backend_optimization.cc | 1 + .../ascend/format_type/convert_cast_format.cc | 90 +++++++++++++++---- .../ascend/format_type/convert_cast_format.h | 7 +- 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index baf08f7073..84f5076237 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -219,6 +219,7 @@ void AscendDataLayout(const std::shared_ptr &kernel_graph) } else { data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); + data_layout_pm->AddPass(std::make_shared()); data_layout_pm->AddPass(std::make_shared()); } data_layout_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc index 763f0da805..552efb5e48 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.cc @@ -16,8 +16,11 @@ #include "backend/optimizer/ascend/format_type/convert_cast_format.h" #include +#include +#include #include "backend/session/anf_runtime_algorithm.h" +#include "backend/optimizer/common/helper.h" namespace mindspore { namespace opt { const BaseRef ConvertCastFormat::DefinePattern() const { @@ -26,8 +29,8 @@ const BaseRef ConvertCastFormat::DefinePattern() const { return VectorRef({X, Xs}); } -const AnfNodePtr ConvertCastFormat::Process(const mindspore::FuncGraphPtr &, const mindspore::AnfNodePtr &node, - const mindspore::EquivPtr &) const { +const AnfNodePtr ConvertCastFormat::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { if (node == nullptr || !node->isa() || !AnfAlgo::IsRealCNodeKernel(node)) { return nullptr; } @@ -44,26 +47,77 @@ const AnfNodePtr ConvertCastFormat::Process(const mindspore::FuncGraphPtr &, con continue; } auto cast_node = input_node->cast(); - MS_EXCEPTION_IF_NULL(cast_node); - auto input_node_name = AnfAlgo::GetCNodeName(cast_node); - if (input_node_name != prim::kPrimCast->name()) { - continue; + ChangeCastFormat(cast_node, func_graph); + } + return nullptr; +} + +void ConvertCastFormat::SetCastFormat(const CNodePtr &cast_node, const string &format) const { + auto info_builder = + std::make_shared(AnfAlgo::GetSelectKernelBuildInfo(cast_node)); + info_builder->SetInputsFormat({format}); + info_builder->SetOutputsFormat({format}); + AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get()); +} + +void ConvertCastFormat::ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const { + MS_EXCEPTION_IF_NULL(cast_node); + auto input_node_name = AnfAlgo::GetCNodeName(cast_node); + if (input_node_name != prim::kPrimCast->name()) { + return; + } + if (AnfAlgo::HasNodeAttr(kAttrVisited, cast_node) && AnfAlgo::GetNodeAttr(cast_node, kAttrVisited)) { + return; + } + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), cast_node); + auto used_cast_node_list = GetRealNodeUsedList(func_graph, cast_node); + MS_EXCEPTION_IF_NULL(used_cast_node_list); + std::unordered_map format_counter; + for (const auto &node_info : *used_cast_node_list) { + MS_EXCEPTION_IF_NULL(node_info.first); + auto cast_out_node = node_info.first->cast(); + MS_EXCEPTION_IF_NULL(cast_out_node); + for (size_t index = 0; index < AnfAlgo::GetInputTensorNum(cast_out_node); ++index) { + if (AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_out_node->cast(), index), 0).first != + cast_node) { + continue; + } + auto format = AnfAlgo::GetInputFormat(cast_out_node, index); + auto it = format_counter.find(format); + if (it == format_counter.end()) { + format_counter[format] = 1; + } else { + it->second++; + } } - auto format = AnfAlgo::GetInputFormat(node, input_index); - auto cast_input_node = AnfAlgo::VisitKernelWithReturnType(AnfAlgo::GetInputNode(cast_node, 0), 0).first; - auto cast_input_format = AnfAlgo::GetOutputFormat(cast_input_node, 0); - // change cast to default that can be more faster when it cast other hw format - if (cast_input_format != format) { - if (cast_input_format == kOpFormat_DEFAULT || format == kOpFormat_DEFAULT) { - auto info_builder = std::make_shared( - AnfAlgo::GetSelectKernelBuildInfo(cast_node)); - info_builder->SetInputsFormat({kOpFormat_DEFAULT}); - info_builder->SetOutputsFormat({kOpFormat_DEFAULT}); - AnfAlgo::SetSelectKernelBuildInfo(info_builder->Build(), cast_node.get()); + } + auto cast_input_format = AnfAlgo::GetPrevNodeOutputFormat(cast_node, 0); + string convert_format = kOpFormat_DEFAULT; + if (cast_input_format == kOpFormat_DEFAULT) { + SetCastFormat(cast_node, convert_format); + return; + } + if (format_counter.size() == 1 && format_counter.begin()->first == kOpFormat_DEFAULT) { + SetCastFormat(cast_node, convert_format); + return; + } + auto it = format_counter.find(cast_input_format); + if (it == format_counter.end()) { + format_counter[cast_input_format] = 1; + } else { + it->second++; + } + if (format_counter.size() < 2) { + size_t max_counter = 0; + for (const auto &iter : format_counter) { + if (iter.second > max_counter) { + max_counter = iter.second; + convert_format = iter.first; } } + // change cast to default that can be more faster when it cast other hw format + SetCastFormat(cast_node, convert_format); } - return nullptr; } } // namespace opt } // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h index 5c7f5a2a0d..5c9d2f22ce 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/convert_cast_format.h @@ -16,6 +16,8 @@ #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_ #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_CONVERT_CAST_FORMAT_H_ +#include + #include "backend/optimizer/common/optimizer.h" namespace mindspore { @@ -25,10 +27,11 @@ class ConvertCastFormat : public PatternProcessPass { explicit ConvertCastFormat(bool multigraph = true) : PatternProcessPass("convert_cast_format", multigraph) {} ~ConvertCastFormat() override = default; const BaseRef DefinePattern() const override; - const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + const AnfNodePtr Process(const FuncGraphPtr &func_graph, const AnfNodePtr &, const EquivPtr &) const override; private: - bool NeedChangeCastFormat(); + void ChangeCastFormat(const CNodePtr &cast_node, const FuncGraphPtr &func_graph) const; + void SetCastFormat(const CNodePtr &cast_node, const string &format) const; }; } // namespace opt } // namespace mindspore