Browse Source

support catch IndexError in C++

tags/v0.5.0-beta
buxue 5 years ago
parent
commit
2c9ca199ad
2 changed files with 27 additions and 6 deletions
  1. +3
    -0
      mindspore/ccsrc/pipeline/pipeline.cc
  2. +24
    -6
      tests/ut/python/ops/test_tuple_slice.py

+ 3
- 0
mindspore/ccsrc/pipeline/pipeline.cc View File

@@ -471,6 +471,9 @@ bool ExecutorPy::Compile(const py::object &obj, const py::tuple &args, const py:
} catch (const py::value_error &ex) { } catch (const py::value_error &ex) {
ReleaseResource(phase); ReleaseResource(phase);
throw py::value_error(ex); throw py::value_error(ex);
} catch (const py::index_error &ex) {
ReleaseResource(phase);
throw py::index_error(ex);
} catch (const std::exception &ex) { } catch (const std::exception &ex) {
ReleaseResource(phase); ReleaseResource(phase);
// re-throw this exception to Python interpreter to handle it // re-throw this exception to Python interpreter to handle it


+ 24
- 6
tests/ut/python/ops/test_tuple_slice.py View File

@@ -80,6 +80,17 @@ class NetWork_3(Cell):
return res return res




class NetWorkOutOfBounds(Cell):
""" NetWork_3 definition """

def __init__(self):
super(NetWorkOutOfBounds, self).__init__()
self.addN = P.AddN()

def construct(self, tensor_tuple):
return tensor_tuple[100]


test_cases = [ test_cases = [
('SlicePositive', { ('SlicePositive', {
'block': NetWork_1(), 'block': NetWork_1(),
@@ -104,16 +115,23 @@ test_cases = [
test_cases_for_verify_exception = [ test_cases_for_verify_exception = [
('SliceStartCross', { ('SliceStartCross', {
'block': (NetWork_3(), {'exception': RuntimeError}), 'block': (NetWork_3(), {'exception': RuntimeError}),
'desc_inputs': [*(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))],
}), }),
('SliceStepZero', { ('SliceStepZero', {
'block': (NetWork_3(), {'exception': RuntimeError}), 'block': (NetWork_3(), {'exception': RuntimeError}),
'desc_inputs': [*(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
'desc_inputs': [Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32))],
}),
('SliceOutOfBounds', {
'block': (NetWorkOutOfBounds(), {'exception': IndexError}),
'desc_inputs': [(Tensor(np.ones([2, 3, 4], np.int32)),
Tensor(np.zeros([2, 3, 4], np.int32)),
Tensor(np.ones([2, 3, 4], np.int32)))],
}), }),

] ]






Loading…
Cancel
Save