| @@ -36,8 +36,12 @@ REGISTER_PYBIND_DEFINE( | |||||
| (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); | (void)m_sub.def("str_to_type", &StringToType, "string to typeptr"); | ||||
| (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | (void)py::class_<Type, std::shared_ptr<Type>>(m_sub, "Type") | ||||
| .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | .def_readonly(PYTHON_DTYPE_FLAG, &mindspore::Type::parse_info_) | ||||
| .def("__eq__", | |||||
| [](const TypePtr &t1, const TypePtr &t2) { | |||||
| .def("__eq__", | |||||
| [](const TypePtr &t1, const py::object &other) { | |||||
| if (!py::isinstance<Type>(other)) { | |||||
| return false; | |||||
| } | |||||
| auto t2 = py::cast<TypePtr>(other); | |||||
| if (t1 != nullptr && t2 != nullptr) { | if (t1 != nullptr && t2 != nullptr) { | ||||
| return *t1 == *t2; | return *t1 == *t2; | ||||
| } | } | ||||
| @@ -134,3 +134,11 @@ def test_dtype(): | |||||
| with pytest.raises(NotImplementedError): | with pytest.raises(NotImplementedError): | ||||
| x = 1.5 | x = 1.5 | ||||
| dtype.get_py_obj_dtype(type(type(x))) | dtype.get_py_obj_dtype(type(type(x))) | ||||
| def test_type_equal(): | |||||
| t1 = (dtype.int32, dtype.int32) | |||||
| valid_types = [dtype.float16, dtype.float32] | |||||
| assert t1 not in valid_types | |||||
| assert dtype.int32 not in valid_types | |||||
| assert dtype.float32 in valid_types | |||||
| @@ -971,7 +971,7 @@ raise_error_set = [ | |||||
| Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)], | Tensor(np.random.randint(7, size=(3, 4, 5)), mstype.int32)], | ||||
| }), | }), | ||||
| ('TensorGetItemByMixedTensorsTypeError', { | ('TensorGetItemByMixedTensorsTypeError', { | ||||
| 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': TypeError}), | |||||
| 'block': (TensorGetItemByMixedTensorsTypeError(), {'exception': IndexError}), | |||||
| 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | 'desc_inputs': [Tensor(np.arange(3 * 4 * 5 * 6 * 7 * 8 * 9).reshape((3, 4, 5, 6, 7, 8, 9)), mstype.int32), | ||||
| Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | Tensor(np.random.randint(3, size=(3, 4, 5)), mstype.int32), | ||||
| Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)], | Tensor(np.random.randint(4, size=(3, 4, 5)), mstype.int32)], | ||||