|
- # Copyright 2019 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """A reservoir sampling on the values."""
-
- import random
- import threading
-
- from mindinsight.datavisual.common.log import logger
- from mindinsight.datavisual.common.enums import PluginNameEnum
- from mindinsight.utils.exceptions import ParamValueError
- from mindinsight.datavisual.utils.utils import calc_histogram_bins
-
-
- def binary_search(samples, target):
- """Binary search target in samples."""
- left = 0
- right = len(samples) - 1
- while left <= right:
- mid = (left + right) // 2
- if target < samples[mid].step:
- right = mid - 1
- elif target > samples[mid].step:
- left = mid + 1
- else:
- return mid
-
- # if right is -1, it is less than the first one.
- # if list is [1, 2, 4], target is 3, right will be 1, so wo will insert by 2.
- return right + 1
-
-
- class Reservoir:
- """
- A container based on Reservoir Sampling algorithm.
-
- The newly added sample will be preserved. If the container is full, an old
- sample will be replaced randomly. The probability of each sample being
- replaced is the same.
- """
-
- def __init__(self, size):
- """
- A Container constructor which create a new Reservoir.
-
- Args:
- size (int): Container Size. If the size is 0, the container is not limited.
-
- Raises:
- ValueError: If size is negative integer.
- """
- if not isinstance(size, (int,)) or size < 0:
- raise ParamValueError('size must be nonnegative integer, was %s' % size)
-
- self._samples_max_size = size
- self._samples = []
- self._sample_counter = 0
- self._sample_selector = random.Random(0)
- self._mutex = threading.Lock()
-
- def samples(self):
- """Return all stored samples."""
- with self._mutex:
- return list(self._samples)
-
- def add_sample(self, sample):
- """
- Add a sample to Reservoir.
-
- Replace the old sample when the capacity is full.
- New added samples are guaranteed to be added to the reservoir.
-
- Args:
- sample (Any): The sample to add to the Reservoir.
- """
- with self._mutex:
- if len(self._samples) < self._samples_max_size or self._samples_max_size == 0:
- self._add_sample(sample)
- else:
- # Use the Reservoir Sampling algorithm to replace the old sample.
- rand_int = self._sample_selector.randint(0, self._sample_counter)
- if rand_int < self._samples_max_size:
- self._samples.pop(rand_int)
- else:
- self._samples = self._samples[:-1]
- self._add_sample(sample)
- self._sample_counter += 1
-
- def _add_sample(self, sample):
- """Search the index and add sample."""
- if not self._samples or sample.step > self._samples[-1].step:
- self._samples.append(sample)
- return
- index = binary_search(self._samples, sample.step)
- if index == len(self._samples):
- self._samples.append(sample)
- else:
- self._samples.insert(index, sample)
-
- def remove_sample(self, filter_fun):
- """
- Remove the samples from Reservoir that do not meet the filter criteria.
-
- Args:
- filter_fun (Callable[..., Any]): Determines whether a sample meets
- the deletion condition.
-
- Returns:
- int, the number of samples removed.
- """
- remove_size = 0
-
- with self._mutex:
- before_remove_size = len(self._samples)
- if before_remove_size > 0:
- # remove samples that meet the filter criteria.
- self._samples = list(filter(filter_fun, self._samples))
- after_remove_size = len(self._samples)
- remove_size = before_remove_size - after_remove_size
-
- if remove_size > 0:
- # update _sample_counter when samples has been removed.
- sample_remaining_rate = float(
- after_remove_size) / before_remove_size
- self._sample_counter = int(
- round(self._sample_counter * sample_remaining_rate))
-
- return remove_size
-
-
- class _VisualRange:
- """Simple helper class to merge visual ranges."""
- def __init__(self):
- self._max = 0.0
- self._min = 0.0
- self._updated = False
-
- def update(self, max_val: float, min_val: float) -> None:
- """
- Merge visual range with given range.
-
- Args:
- max_val (float): Max value of given range.
- min_val (float): Min value of given range.
-
- """
- if not self._updated:
- self._max = max_val
- self._min = min_val
- self._updated = True
- return
-
- if max_val > self._max:
- self._max = max_val
-
- if min_val < self._min:
- self._min = min_val
-
- @property
- def max(self):
- """Gets max value of current range."""
- return self._max
-
- @property
- def min(self):
- """Gets min value of current range."""
- return self._min
-
-
- class HistogramReservoir(Reservoir):
- """
- Reservoir for histogram, which needs updating range over all steps.
-
- Args:
- size (int): Container Size. If the size is 0, the container is not limited.
- """
- def __init__(self, size):
- super().__init__(size)
- # Marker to avoid redundant calc for unchanged histograms.
- self._visual_range_up_to_date = False
-
- def add_sample(self, sample):
- """Adds sample, see parent class for details."""
- super().add_sample(sample)
- self._visual_range_up_to_date = False
-
- def samples(self):
- """Return all stored samples."""
- with self._mutex:
- if self._visual_range_up_to_date:
- return list(self._samples)
-
- # calc visual range
- visual_range = _VisualRange()
- max_count = 0
- for sample in self._samples:
- histogram_container = sample.value
- if histogram_container.count == 0:
- # ignore empty tensor
- continue
- max_count = max(histogram_container.count, max_count)
- visual_range.update(histogram_container.max, histogram_container.min)
-
- if visual_range.max == visual_range.min and not max_count:
- logger.info("Max equals to min. Count is zero.")
-
- bins = calc_histogram_bins(max_count)
-
- # update visual range
- logger.debug(
- "Visual histogram: min %s, max %s, bins %s, max_count %s.",
- visual_range.min,
- visual_range.max,
- bins,
- max_count)
- for sample in self._samples:
- histogram = sample.value.histogram
- histogram.set_visual_range(visual_range.max, visual_range.min, bins)
-
- self._visual_range_up_to_date = True
- return list(self._samples)
-
-
- class ReservoirFactory:
- """Factory class to get reservoir instances."""
- def create_reservoir(self, plugin_name: str, size: int) -> Reservoir:
- """
- Creates reservoir for given plugin name.
-
- Args:
- plugin_name (str): Plugin name
- size (int): Container Size. If the size is 0, the container is not limited.
-
- Returns:
- Reservoir, reservoir instance for given plugin name.
- """
- if plugin_name in (PluginNameEnum.HISTOGRAM.value, PluginNameEnum.TENSOR.value):
- return HistogramReservoir(size)
- return Reservoir(size)
|