|
|
|
@@ -17,10 +17,12 @@ import numpy as np |
|
|
|
|
|
|
|
from mindinsight.datavisual.data_transform.histogram import Histogram, Bucket |
|
|
|
from mindinsight.datavisual.utils.utils import calc_histogram_bins |
|
|
|
from mindinsight.datavisual.common.exceptions import TensorTooLargeError |
|
|
|
from mindinsight.utils.exceptions import ParamValueError |
|
|
|
from mindinsight.utils.tensor import TensorUtils |
|
|
|
|
|
|
|
MAX_TENSOR_COUNT = 10000000 |
|
|
|
TENSOR_TOO_LARGE_ERROR = TensorTooLargeError("").error_code |
|
|
|
|
|
|
|
|
|
|
|
def calc_original_buckets(np_value, stats): |
|
|
|
@@ -74,6 +76,10 @@ class TensorContainer: |
|
|
|
self._dims = tuple(tensor_message.dims) |
|
|
|
self._data_type = tensor_message.data_type |
|
|
|
self._np_array = self.get_ndarray(tensor_message.float_data) |
|
|
|
self._error_code = None |
|
|
|
if self._np_array.size > MAX_TENSOR_COUNT: |
|
|
|
self._error_code = TENSOR_TOO_LARGE_ERROR |
|
|
|
self._np_array = np.array([]) |
|
|
|
self._stats = TensorUtils.get_statistics_from_tensor(self._np_array) |
|
|
|
original_buckets = calc_original_buckets(self._np_array, self._stats) |
|
|
|
self._count = sum(bucket.count for bucket in original_buckets) |
|
|
|
@@ -81,11 +87,17 @@ class TensorContainer: |
|
|
|
self._min = self._stats.min |
|
|
|
self._histogram = Histogram(tuple(original_buckets), self._max, self._min, self._count) |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
def size(self): |
|
|
|
"""Get size of tensor.""" |
|
|
|
return self._np_array.size |
|
|
|
|
|
|
|
@property |
|
|
|
def error_code(self): |
|
|
|
"""Get size of tensor.""" |
|
|
|
return self._error_code |
|
|
|
|
|
|
|
@property |
|
|
|
def dims(self): |
|
|
|
"""Get dims of tensor.""" |
|
|
|
@@ -128,6 +140,8 @@ class TensorContainer: |
|
|
|
|
|
|
|
def buckets(self): |
|
|
|
"""Get histogram buckets.""" |
|
|
|
if self._histogram is None: |
|
|
|
return None |
|
|
|
return self._histogram.buckets() |
|
|
|
|
|
|
|
def get_ndarray(self, tensor): |
|
|
|
|