Browse Source

support catch IndexError in C++

tags/v0.5.0-beta
buxue 5 years ago
parent
commit
2c9ca199ad
2 changed files with 27 additions and 6 deletions
  1. +3
    -0
      mindspore/ccsrc/pipeline/pipeline.cc
  2. +24
    -6
      tests/ut/python/ops/test_tuple_slice.py

+ 3
- 0
mindspore/ccsrc/pipeline/pipeline.cc View File

@@ -471,6 +471,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
} catch (const py::value_error &ex) {
ReleaseResource(phase);
throw py::value_error(ex);
} catch (const py::index_error &ex) {
ReleaseResource(phase);
throw py::index_error(ex);
} catch (const std::exception &ex) {
ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it


+ 24
- 6
tests/ut/python/ops/test_tuple_slice.py View File

@@ -80,6 +80,17 @@ class NetWork_3(Cell):
return res


class NetWorkOutOfBounds(Cell):
""" NetWork_3 definition """

def __init__(self):
super(NetWorkOutOfBounds, self).__init__()
self.addN = P.AddN()

def construct(self, tensor_tuple):
return tensor_tuple[100]


test_cases = [
('SlicePositive', {
'block': NetWork_1(),
@@ -104,16 +115,23 @@ test_cases = [
test_cases_for_verify_exception = [
('SliceStartCross', {
'block': (NetWork_3(), {'exception': RuntimeError}),
'desc_inputs': [*(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))],
}),
('SliceStepZero', {
'block': (NetWork_3(), {'exception': RuntimeError}),
'desc_inputs': [*(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))],
}),
('SliceOutOfBounds', {
'block': (NetWorkOutOfBounds(), {'exception': IndexError}),
'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
}),

]




Loading…
Cancel
Save