|
- # -*- coding: UTF-8 -*-
- # MIT License
- #
- # Copyright (c) 2019 Vadim Velicodnii
- #
- # Permission is hereby granted, free of charge, to any person obtaining
- # a copy of this software and associated documentation files
- # (the "Software"), to deal in the Software without restriction,
- # including without limitation the rights to use, copy, modify, merge,
- # publish, distribute, sublicense, and/or sell copies of the Software,
- # and to permit persons to whom the Software is furnished to do so,
- # subject to the following conditions:
- #
- # The above copyright notice and this permission notice shall be included
- # in all copies or substantial portions of the Software.
- #
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
- # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
- # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
- # IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
- # CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
- # TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
- # SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
- from collections import namedtuple
- from collections.abc import Iterable
- from typing import Union, Optional
- import numpy as np
- import struct
- from io import BytesIO
- # Compatible tensorboard calculation graph
- from tensorboard.compat.proto.graph_pb2 import GraphDef
- from oneflow.customized.utils import HParamsPluginData
- # from tensorboard.plugins.hparams.plugin_data_pb2 import HParamsPluginData
- from tbparser.events_reader import EventReadingError, EventsFileReader
-
- SummaryItem = namedtuple(
- 'SummaryItem', ['tag', 'step', 'wall_time', 'value', 'type']
- )
- GraphItem = namedtuple(
- 'GraphItem', ['wall_time', 'value', 'type']
- )
-
- # tensor data type
- _data_type = {1: 'float',
- 2: 'double',
- 3: 'int32',
- # DT_UINT8 = 4;
- # DT_INT16 = 5;
- # DT_INT8 = 6;
- # DT_STRING = 7;
- # DT_COMPLEX64 = 8; // Single-precision complex
- 9: 'int64',
- 10: 'bool',
- # DT_QINT8 = 11; // Quantized int8
- # DT_QUINT8 = 12; // Quantized uint8
- # DT_QINT32 = 13; // Quantized int32
- # DT_BFLOAT16 = 14; // Float32 truncated to 16 bits. Only for cast ops.
- # DT_QINT16 = 15; // Quantized int16
- # DT_QUINT16 = 16; // Quantized uint16
- 17: 'uint16',
- # DT_COMPLEX128 = 18; // Double-precision complex
- # DT_HALF = 19;
- # DT_RESOURCE = 20;
- # DT_VARIANT = 21; // Arbitrary C++ data types
- 22: 'uint32',
- 23: 'uint64'}
-
-
- def _decode_byte(tensor):
- # 若tensor是float类型
- if tensor.dtype == 1:
- return struct.unpack('f', tensor.tensor_content)[0]
-
-
- class SummaryReader(Iterable):
- """
- Iterates over events in all the files in the current logdir.
- """
-
- def _get_scalar(self, value):
- """
- Decode an scalar event
- :param value: A value field of an event
- :return: Decoded scalar
- """
- if value.HasField('simple_value'):
- return value.simple_value
- elif value.HasField('metadata'):
- if value.metadata.plugin_data.plugin_name == 'scalars':
- tensor = value.tensor
- return _decode_byte(tensor)
- return None
-
- def _get_image(self, value) -> Optional[dict]:
- """
- Decode an image event
- :param value: A value field of an event
- :return: Decoded image
- """
- if value.HasField('image'):
- dic = {
- 'width': value.image.width,
- 'height': value.image.height,
- 'colorspace': value.image.colorspace,
- 'encoded_image_string': value.image.encoded_image_string
- }
- return dic
- return None
-
- def _get_text(self, value) -> Optional[np.ndarray]:
- """
- Return text data
- :param value: A value field of an event
- :return: text data
- TODO: Tensorflow API
- """
- if value.HasField('metadata'):
- if value.metadata.plugin_data.plugin_name == 'text':
- return np.array([v.decode() for v in value.tensor.string_val])
- return None
-
- def _get_audio(self, value):
- if value.HasField('audio'):
- dic = {
- 'sample_rate': value.audio.sample_rate,
- 'num_channels': value.audio.num_channels,
- 'length_frames': value.audio.length_frames,
- 'encoded_audio_string': value.audio.encoded_audio_string
- }
- return dic
- # if Tensorboard API, use tensor decoder
- elif value.HasField('metadata'):
- if value.metadata.plugin_data.plugin_name == 'audio':
- dic = {
- 'tensor_shape': tuple([i.size for i in value.tensor.tensor_shape.dim]),
- 'string_val': [v for v in value.tensor.string_val]
- }
- if value.tag not in self.audio_exit_tag:
- # record the tag
- self.audio_exit_tag.append(value.tag)
- return dic
- elif value.tag in self.audio_exit_tag:
- dic = {
- 'tensor_shape': tuple([i.size for i in value.tensor.tensor_shape.dim]),
- 'string_val': [v for v in value.tensor.string_val]
- }
- return dic
- return None
-
- def _get_hist(self, value):
- if value.HasField('histo'):
- dic = {
- 'min': value.histo.min,
- 'max': value.histo.max,
- 'num': value.histo.num,
- 'sum': value.histo.sum,
- 'sum_squares': value.histo.sum_squares,
- 'bucket_limit': np.array(value.histo.bucket_limit),
- 'bucket': np.array(value.histo.bucket),
- }
- return dic
- # if Tensorboard API, use tensor decoder
- elif value.HasField('metadata'):
- if value.metadata.plugin_data.plugin_name == 'histograms':
- tensor = value.tensor
- dtype = _data_type[tensor.dtype]
- tensor_shape = tuple([i.size for i in tensor.tensor_shape.dim])
- tensor_content = tensor.tensor_content
- tensor_content = np.frombuffer(tensor_content, dtype=dtype)
- if value.tag not in self.hist_exit_tag:
- # record the tag
- self.hist_exit_tag.append(value.tag)
- return tensor_content.reshape(tensor_shape)
- elif value.tag in self.hist_exit_tag:
- tensor = value.tensor
- dtype = _data_type[tensor.dtype]
- tensor_shape = tuple([i.size for i in tensor.tensor_shape.dim])
- tensor_content = tensor.tensor_content
- tensor_content = np.frombuffer(tensor_content, dtype=dtype)
- return tensor_content.reshape(tensor_shape)
- return None
-
- def _get_hparams(self, value):
- if value.HasField('metadata'):
- if value.metadata.plugin_data.plugin_name == 'hparams':
- metadata = value.metadata
- plugin_data = HParamsPluginData()
- plugin_data.ParseFromString(metadata.plugin_data.content)
- return plugin_data
-
- _DECODERS = {
- 'scalar': _get_scalar,
- 'image': _get_image,
- 'text': _get_text,
- 'audio': _get_audio,
- 'hist': _get_hist,
- 'hparams': _get_hparams,
- }
-
- def __init__(
- self,
- fileblock: BytesIO,
- tag_filter: Optional[Iterable] = None,
- types: Iterable = ('scalar',),
- stop_on_error: bool = False
- ):
- """
- Initalize new summary reader
- :param fileblock: Event file block of Tensorboard
- :param tag_filter: A list of tags to leave (`None` for all)
- :param types: A list of types to get.
- :param stop_on_error: Whether stop on a broken file
- """
- self._fileblock = fileblock
-
- self._tag_filter = set(tag_filter) if tag_filter is not None else None
- self._types = set(types)
- self._check_type_names()
- self._stop_on_error = stop_on_error
- # Record the tag, that has been read by the parser.
- # If the tag, in this list appears next,
- # the type is automatically identified.
- self.scalar_exit_tag = []
- self.image_exit_tag = []
- self.text_exit_tag = []
- self.audio_exit_tag = []
- self.hist_exit_tag = []
-
- def _check_type_names(self):
- if self._types is None:
- return
- if not all(
- type_name in self._DECODERS.keys() or type_name == "graph" for type_name in self._types
- ):
- raise ValueError('Invalid type name')
-
- # def _decode_events(self, events: Iterable) -> Optional[Union[SummaryItem, GraphDef]]:
- def _decode_events(self, events: Iterable) \
- -> Optional[Union[SummaryItem]]:
- """
- Convert events to `SummaryItem` instances
- :param events: An iterable with events objects
- :return: A generator with decoded events
- or `None`s if an event can't be decoded
- """
- for event in events:
- # yield None
- step = event.step
- wall_time = event.wall_time
- if event.HasField('summary'):
- for value in event.summary.value:
- tag = value.tag
- # if value.HasField('metadata'):
- # continue
- for value_type in self._types:
- if value_type == "graph":
- continue
- decoder = self._DECODERS[value_type]
- data = decoder(self, value)
- if data is not None:
- yield SummaryItem(
- tag=tag,
- step=step,
- wall_time=wall_time,
- value=data,
- type=value_type
- )
- else:
- yield None
- elif event.HasField('graph_def'):
- graph = GraphDef()
- graph.ParseFromString(event.graph_def)
- yield GraphItem(
- wall_time=wall_time,
- value=graph,
- type='graph'
- )
-
- def _check_tag(self, tag: str) -> bool:
- """
- Check if a tag matches the current tag filter
- :param tag: A string with tag
- :return: A boolean value.
- """
- return self._tag_filter is None or tag in self._tag_filter
-
- def __iter__(self) -> SummaryItem:
- """
- Iterate over events in all the files in the current logdir
- :return: A generator with `SummaryItem` objects
- """
- reader = EventsFileReader(self._fileblock)
- try:
- yield from (
- item for item in self._decode_events(reader)
- if item is not None and all([
- self._check_tag(None if type(item) == GraphItem else item.tag),
- item.type in self._types
- ])
- )
- except EventReadingError:
- if self._stop_on_error:
- raise
- else:
- yield
|