|
|
|
@@ -13,6 +13,8 @@ |
|
|
|
* See the License for the specific language governing permissions and |
|
|
|
* limitations under the License. |
|
|
|
*/ |
|
|
|
|
|
|
|
#include <algorithm> |
|
|
|
#include "backend/session/ascend_inference_session.h" |
|
|
|
#include "frontend/operator/ops.h" |
|
|
|
#include "ir/tensor.h" |
|
|
|
@@ -85,5 +87,80 @@ GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) { |
|
|
|
} |
|
|
|
return graph_id; |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendInferenceSession::CheckModelInputs(uint32_t graph_id, |
|
|
|
const std::vector<std::shared_ptr<inference::MSTensor> > &inputs) { |
|
|
|
MS_LOG(INFO) << "Start check client inputs, graph id : " << graph_id; |
|
|
|
auto kernel_graph = GetGraph(graph_id); |
|
|
|
MS_EXCEPTION_IF_NULL(kernel_graph); |
|
|
|
auto kernel_graph_inputs = kernel_graph->inputs(); |
|
|
|
size_t no_weight_input = 0; |
|
|
|
for (size_t i = 0; i < kernel_graph_inputs.size(); ++i) { |
|
|
|
tensor::TensorPtr tensor = nullptr; |
|
|
|
if (!kernel_graph_inputs[i]->isa<Parameter>()) { |
|
|
|
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter."; |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto parameter = kernel_graph_inputs[i]->cast<ParameterPtr>(); |
|
|
|
if (!AnfAlgo::IsParameterWeight(parameter)) { |
|
|
|
// compare input number |
|
|
|
if (no_weight_input >= inputs.size()) { |
|
|
|
MS_LOG(ERROR) << "Input number is inconsistent. The actual input number [" << inputs.size() |
|
|
|
<< "] less than that of graph."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto input = inputs[no_weight_input++]; |
|
|
|
if (!CompareInput(input, parameter)) { |
|
|
|
MS_LOG(ERROR) << "Please check the input information."; |
|
|
|
return false; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
bool AscendInferenceSession::CompareInput(const std::shared_ptr<inference::MSTensor> &input, |
|
|
|
const ParameterPtr ¶meter) { |
|
|
|
MS_EXCEPTION_IF_NULL(input); |
|
|
|
MS_EXCEPTION_IF_NULL(parameter); |
|
|
|
// compare dims |
|
|
|
auto parameter_shape = AnfAlgo::GetOutputDeviceShape(parameter, 0); |
|
|
|
if (input->shape().size() != parameter_shape.size()) { |
|
|
|
MS_LOG(ERROR) << "Input dim is inconsistent. The actual dim is " << input->shape().size() |
|
|
|
<< ", but the parameter dim is " << parameter_shape.size() |
|
|
|
<< ". parameter : " << parameter->DebugString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// compare shape |
|
|
|
auto input_shape = input->shape(); |
|
|
|
vector<size_t> trans_input; |
|
|
|
(void)std::transform(input_shape.begin(), input_shape.end(), std::back_inserter(trans_input), |
|
|
|
[](const int dim) { return static_cast<size_t>(dim); }); |
|
|
|
if (trans_input != parameter_shape) { |
|
|
|
MS_LOG(ERROR) << "Input shape is inconsistent. The actual shape is " << PrintInputShape(trans_input) |
|
|
|
<< ", but the parameter shape is " << PrintInputShape(parameter_shape) |
|
|
|
<< ". parameter : " << parameter->DebugString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// compare data type |
|
|
|
auto kernel_build_info = AnfAlgo::GetSelectKernelBuildInfo(parameter); |
|
|
|
if (input->data_type() != kernel_build_info->GetOutputDeviceType(0)) { |
|
|
|
MS_LOG(ERROR) << "Input data type is inconsistent. The actual data type is " << input->data_type() |
|
|
|
<< ", but the parameter data type is " << kernel_build_info->GetOutputDeviceType(0) |
|
|
|
<< ". parameter : " << parameter->DebugString(); |
|
|
|
return false; |
|
|
|
} |
|
|
|
return true; |
|
|
|
} |
|
|
|
|
|
|
|
std::string AscendInferenceSession::PrintInputShape(std::vector<size_t> shape) { |
|
|
|
string res = "["; |
|
|
|
for (auto dim : shape) { |
|
|
|
res += " " + std::to_string(dim); |
|
|
|
} |
|
|
|
return res + " ]"; |
|
|
|
} |
|
|
|
} // namespace session |
|
|
|
} // namespace mindspore |