|
|
|
@@ -22,6 +22,10 @@ from . import dtype as mstype |
|
|
|
from ._register_for_tensor import tensor_operator_registry |
|
|
|
|
|
|
|
__all__ = ['Tensor', 'MetaTensor'] |
|
|
|
np_types = (np.int8, np.int16, np.int32, np.int64, |
|
|
|
np.uint8, np.uint16, np.uint32, np.uint64, np.float16, |
|
|
|
np.float32, np.float64, np.bool_) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Tensor(Tensor_): |
|
|
|
@@ -54,6 +58,10 @@ class Tensor(Tensor_): |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, input_data, dtype=None): |
|
|
|
# If input data is numpy number, convert it to np array |
|
|
|
if isinstance(input_data, np_types): |
|
|
|
input_data = np.array(input_data) |
|
|
|
|
|
|
|
# If input_data is tuple/list/numpy.ndarray, it's support in check_type method. |
|
|
|
check_type('tensor input_data', input_data, (Tensor_, float, int)) |
|
|
|
if dtype is not None: |
|
|
|
|