| @@ -25,6 +25,7 @@ | |||||
| namespace mindspore { | namespace mindspore { | ||||
| class FuncGraph; | class FuncGraph; | ||||
| namespace inference { | namespace inference { | ||||
| using VectorForMSTensorPtr = std::vector<std::shared_ptr<inference::MSTensor>>; | |||||
| class MS_API MSSession { | class MS_API MSSession { | ||||
| public: | public: | ||||
| MSSession() = default; | MSSession() = default; | ||||
| @@ -33,7 +34,9 @@ class MS_API MSSession { | |||||
| virtual uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) = 0; | virtual uint32_t CompileGraph(std::shared_ptr<FuncGraph> funcGraphPtr) = 0; | ||||
| virtual MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) = 0; | |||||
| virtual MultiTensor RunGraph(uint32_t graph_id, const VectorForMSTensorPtr &inputs) = 0; | |||||
| virtual bool CheckModelInputs(uint32_t graph_id, const VectorForMSTensorPtr &inputs) const = 0; | |||||
| }; | }; | ||||
| std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device); | std::shared_ptr<FuncGraph> MS_API LoadModel(const char *model_buf, size_t size, const std::string &device); | ||||
| @@ -13,6 +13,8 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include <algorithm> | |||||
| #include "backend/session/ascend_inference_session.h" | #include "backend/session/ascend_inference_session.h" | ||||
| #include "frontend/operator/ops.h" | #include "frontend/operator/ops.h" | ||||
| #include "ir/tensor.h" | #include "ir/tensor.h" | ||||
| @@ -85,5 +87,80 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { | |||||
| } | } | ||||
| return graph_id; | return graph_id; | ||||
| } | } | ||||
| bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, | |||||
| const std::vector<std::shared_ptr<inference::MSTensor> > &inputs) { | |||||
| MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id; | |||||
| auto kernel_graph = GetGraph(graph_id); | |||||
| MS_EXCEPTION_IF_NULL(kernel_graph); | |||||
| auto kernel_graph_inputs = kernel_graph->inputs(); | |||||
| size_t no_weight_input = 0; | |||||
| for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { | |||||
| tensor::TensorPtr tensor = nullptr; | |||||
| if (!kernel_graph_inputs[i]->isa<Parameter>()) { | |||||
| MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; | |||||
| continue; | |||||
| } | |||||
| auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>(); | |||||
| if (!AnfAlgo::IsParameterWeight(parameter)) { | |||||
| // compare input number | |||||
| if (no_weight_input >= inputs.size()) { | |||||
| MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() | |||||
| << "] less than that of graph."; | |||||
| return false; | |||||
| } | |||||
| auto input = inputs[no_weight_input++]; | |||||
| if (!CompareInput(input, parameter)) { | |||||
| MS_LOG(ERROR) << "Please check the input information."; | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| bool AscendInferenceSession::CompareInput(const std::shared_ptr<inference::MSTensor> &input, | |||||
| const ParameterPtr ¶meter) { | |||||
| MS_EXCEPTION_IF_NULL(input); | |||||
| MS_EXCEPTION_IF_NULL(parameter); | |||||
| // compare dims | |||||
| auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); | |||||
| if (input->shape().size() != parameter_shape.size()) { | |||||
| MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size() | |||||
| << ", but the parameter dim is " << parameter_shape.size() | |||||
| << ". parameter : " << parameter->DebugString(); | |||||
| return false; | |||||
| } | |||||
| // compare shape | |||||
| auto input_shape = input->shape(); | |||||
| vector<size_t> trans_input; | |||||
| (void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), | |||||
| [](const int dim) { return static_cast<size_t>(dim); }); | |||||
| if (trans_input != parameter_shape) { | |||||
| MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) | |||||
| << ", but the parameter shape is " << PrintInputShape(parameter_shape) | |||||
| << ". parameter : " << parameter->DebugString(); | |||||
| return false; | |||||
| } | |||||
| // compare data type | |||||
| auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); | |||||
| if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) { | |||||
| MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type() | |||||
| << ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0) | |||||
| << ". parameter : " << parameter->DebugString(); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) { | |||||
| string res = "["; | |||||
| for (auto dim : shape) { | |||||
| res += " " + std::to_string(dim); | |||||
| } | |||||
| return res + " ]"; | |||||
| } | |||||
| } // namespace session | } // namespace session | ||||
| } // namespace mindspore | } // namespace mindspore | ||||
| @@ -39,6 +39,9 @@ class AscendInferenceSession : public AscendSession { | |||||
| void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph, | ||||
| const std::vector<tensor::TensorPtr> &inputs_const) const; | const std::vector<tensor::TensorPtr> &inputs_const) const; | ||||
| GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override; | ||||
| bool CheckModelInputs(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override; | |||||
| bool CompareInput(const std::shared_ptr<inference::MSTensor> &input, const ParameterPtr ¶meter); | |||||
| std::string PrintInputShape(std::vector<size_t> shape); | |||||
| }; | }; | ||||
| MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); | MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession); | ||||
| } // namespace session | } // namespace session | ||||
| @@ -204,5 +204,11 @@ int Session::Init(const std::string &device, uint32_t device_id) { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| bool Session::CheckModelInputs(uint32_t graph_id, | |||||
| const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) const { | |||||
| MS_ASSERT(session_impl_ != nullptr); | |||||
| return session_impl_->CheckModelInputs(graph_id, inputs); | |||||
| } | |||||
| Session::Session() = default; | Session::Session() = default; | ||||
| } // namespace mindspore::inference | } // namespace mindspore::inference | ||||
| @@ -37,6 +37,9 @@ class Session : public MSSession { | |||||
| MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override; | MultiTensor RunGraph(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) override; | ||||
| bool CheckModelInputs(uint32_t graph_id, | |||||
| const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) const override; | |||||
| int Init(const std::string &device, uint32_t device_id); | int Init(const std::string &device, uint32_t device_id); | ||||
| static void RegAllOp(); | static void RegAllOp(); | ||||
| @@ -106,6 +106,9 @@ class SessionBasic { | |||||
| virtual void GetSummaryNodes(KernelGraph *graph); | virtual void GetSummaryNodes(KernelGraph *graph); | ||||
| void AssignParamKey(const KernelGraphPtr &kernel_graph); | void AssignParamKey(const KernelGraphPtr &kernel_graph); | ||||
| void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const); | void InitPSParamAndOptim(const KernelGraphPtr &kernel_graph, const std::vector<tensor::TensorPtr> &inputs_const); | ||||
| virtual bool CheckModelInputs(uint32_t graph_id, const std::vector<std::shared_ptr<inference::MSTensor>> &inputs) { | |||||
| return true; | |||||
| } | |||||
| #ifdef ENABLE_DEBUGGER | #ifdef ENABLE_DEBUGGER | ||||
| // set debugger | // set debugger | ||||
| @@ -67,6 +67,11 @@ Status Session::Predict(const std::vector<MSTensorPtr> &inputs, inference::Multi | |||||
| std::lock_guard<std::mutex> lock(mutex_); | std::lock_guard<std::mutex> lock(mutex_); | ||||
| MS_LOG(INFO) << "run Predict"; | MS_LOG(INFO) << "run Predict"; | ||||
| if (!session_->CheckModelInputs(graph_id_, inputs)) { | |||||
| MS_LOG(ERROR) << "Input error."; | |||||
| return FAILED; | |||||
| } | |||||
| *outputs = session_->RunGraph(graph_id_, inputs); | *outputs = session_->RunGraph(graph_id_, inputs); | ||||
| MS_LOG(INFO) << "run Predict finished"; | MS_LOG(INFO) << "run Predict finished"; | ||||
| return SUCCESS; | return SUCCESS; | ||||