Browse Source

support throw attribute error from c++

tags/v0.6.0-beta
Wei Luning 5 years ago
parent
commit
f362503d57
7 changed files with 61 additions and 4 deletions
  1. +4
    -0
      mindspore/ccsrc/pipeline/jit/pipeline.cc
  2. +2
    -2
      mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc
  3. +24
    -0
      mindspore/ccsrc/pybind_api/pybind_patch.h
  4. +3
    -2
      mindspore/ccsrc/utils/log_adapter.cc
  5. +1
    -0
      mindspore/ccsrc/utils/log_adapter.h
  6. +4
    -0
      mindspore/ccsrc/utils/log_adapter_py.cc
  7. +23
    -0
      tests/ut/python/pynative_mode/test_parse_method.py

+ 4
- 0
mindspore/ccsrc/pipeline/jit/pipeline.cc View File

@@ -40,6 +40,7 @@
#include "debug/trace.h"
#include "pipeline/pynative/pynative_execute.h"
#include "frontend/optimizer/py_pass_manager.h"
#include "pybind_api/pybind_patch.h"

#if (!_WIN32 && !ENABLE_GE && !ENABLE_TESTCASES)
#include "frontend/parallel/ps/common.h"
@@ -536,6 +537,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
} catch (const py::index_error &ex) {
ReleaseResource(phase);
throw py::index_error(ex);
} catch (const py::attribute_error &ex) {
ReleaseResource(phase);
throw py::attribute_error(ex);
} catch (const std::exception &ex) {
ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it


+ 2
- 2
mindspore/ccsrc/pipeline/jit/static_analysis/prim.cc View File

@@ -761,8 +761,8 @@ EvalResultPtr GetEvaluatedValueForClassAttrOrMethod(const AnalysisEnginePtr &eng

ValuePtr method = cls->GetMethod(item_name);
if (method->isa<AnyValue>()) {
MS_LOG(EXCEPTION) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
<< ", item value: " << item_v->ToString();
MS_EXCEPTION(AttributeError) << "Unknown field, data type: " << args_spec_list[0]->BuildType()->ToString()
<< ", item value: " << item_v->ToString();
}

// Infer class method


+ 24
- 0
mindspore/ccsrc/pybind_api/pybind_patch.h View File

@@ -0,0 +1,24 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef PYBIND_API_PYBIND_PATCH_H_
#define PYBIND_API_PYBIND_PATCH_H_

namespace pybind11 {
PYBIND11_RUNTIME_EXCEPTION(attribute_error, PyExc_AttributeError)
}

#endif // PYBIND_API_PYBIND_PATCH_H_

+ 3
- 2
mindspore/ccsrc/utils/log_adapter.cc View File

@@ -145,10 +145,11 @@ static std::string ExceptionTypeToString(ExceptionType type) {
_TO_STRING(IndexError),
_TO_STRING(ValueError),
_TO_STRING(TypeError),
_TO_STRING(AttributeError),
};
// clang-format on
#undef _TO_STRING
if (type < UnknownError || type > TypeError) {
if (type < UnknownError || type > AttributeError) {
type = UnknownError;
}
return std::string(type_names[type]);
@@ -212,7 +213,7 @@ void LogWriter::operator^(const LogStream &stream) const {
std::ostringstream oss;
oss << location_.file_ << ":" << location_.line_ << " " << location_.func_ << "] ";
if (exception_type_ != NoExceptionType && exception_type_ != IndexError && exception_type_ != TypeError &&
exception_type_ != ValueError) {
exception_type_ != ValueError && exception_type_ != AttributeError) {
oss << ExceptionTypeToString(exception_type_) << " ";
}
oss << msg.str();


+ 1
- 0
mindspore/ccsrc/utils/log_adapter.h View File

@@ -58,6 +58,7 @@ enum ExceptionType {
IndexError,
ValueError,
TypeError,
AttributeError,
};

struct LocationInfo {


+ 4
- 0
mindspore/ccsrc/utils/log_adapter_py.cc View File

@@ -18,6 +18,7 @@

#include <string>
#include "pybind11/pybind11.h"
#include "pybind_api/pybind_patch.h"

namespace py = pybind11;
namespace mindspore {
@@ -38,6 +39,9 @@ class PyExceptionInitializer {
if (exception_type == TypeError) {
throw py::type_error(str);
}
if (exception_type == AttributeError) {
throw py::attribute_error(str);
}
py::pybind11_fail(str);
}
};


+ 23
- 0
tests/ut/python/pynative_mode/test_parse_method.py View File

@@ -304,6 +304,29 @@ def test_access():
""" test_access """
invoke_dataclass(1, 2)

@dataclass
class Access2:
a: int
b: int

def max(self):
if self.a > self.b:
return self.c
return self.b


@ms_function
def invoke_dataclass2(x, y):
""" invoke_dataclass """
acs = Access2(x, y)
return acs.max()


def test_access_attr_error():
""" test_access """
with pytest.raises(AttributeError):
invoke_dataclass2(1, 2)


def myfunc(x):
""" myfunc """


Loading…
Cancel
Save