|
|
|
@@ -15,7 +15,6 @@ |
|
|
|
"""Tensor data container.""" |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from mindinsight.datavisual.common.exceptions import DataTypeError |
|
|
|
from mindinsight.datavisual.common.log import logger |
|
|
|
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket |
|
|
|
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2 |
|
|
|
@@ -253,19 +252,14 @@ class TensorContainer: |
|
|
|
Get ndarray of tensor. |
|
|
|
|
|
|
|
Args: |
|
|
|
tensor (float16|float32|float64): tensor data. |
|
|
|
tensor (mindinsight_anf_ir.proto.DataType): tensor data. |
|
|
|
|
|
|
|
Returns: |
|
|
|
numpy.ndarray, ndarray of tensor. |
|
|
|
|
|
|
|
Raises: |
|
|
|
DataTypeError, If data type of tensor is not among float16 or float32 or float64. |
|
|
|
""" |
|
|
|
data_type_str = anf_ir_pb2.DataType.Name(self.data_type) |
|
|
|
if data_type_str == 'DT_FLOAT16': |
|
|
|
return np.array(tuple(tensor), dtype=np.float16).reshape(self.dims) |
|
|
|
if data_type_str == 'DT_FLOAT32': |
|
|
|
return np.array(tuple(tensor), dtype=np.float32).reshape(self.dims) |
|
|
|
if data_type_str == 'DT_FLOAT64': |
|
|
|
return np.array(tuple(tensor), dtype=np.float64).reshape(self.dims) |
|
|
|
raise DataTypeError("Data type: {}.".format(data_type_str)) |
|
|
|
return np.array(tuple(tensor)).reshape(self.dims) |