@@ -19,12 +19,17 @@ This module is used to load the MindSpore training log file.
Each instance will read an entire run, a run can contain one or
more log file.
"""
import concurrent.futures as futures
import math
import os
import re
import struct
import threading
from google.protobuf.message import DecodeError
from google.protobuf.text_format import ParseError
from mindinsight.conf import settings
from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.log import logger
@@ -32,13 +37,13 @@ from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.events_data import TensorEvent
from mindinsight.datavisual.data_transform.graph import MSGraph
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import UnknownError
from mindinsight.datavisual.data_transform.histogram import Histogram
from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer
from mindinsight.datavisual.data_transform.tensor_container import TensorContainer
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import UnknownError
HEADER_SIZE = 8
CRC_STR_SIZE = 4
@@ -79,11 +84,14 @@ class MSDataLoader:
"we will reload all files in path %s.", self._summary_dir)
self.__init__(self._summary_dir)
def load(self):
def load(self, workers_count=1 ):
"""
Load all log valid files.
When the file is reloaded, it will continue to load from where it left off.
Args:
workers_count (int): The count of workers. Default value is 1.
"""
logger.debug("Start to load data in ms data loader.")
filenames = self.filter_valid_files()
@@ -95,7 +103,7 @@ class MSDataLoader:
self._check_files_deleted(filenames, old_filenames)
for parser in self._parser_list:
parser.parse_files(filenames, events_data=self._events_data)
parser.parse_files(workers_count, filenames, events_data=self._events_data)
def filter_valid_files(self):
"""
@@ -125,11 +133,12 @@ class _Parser:
self._latest_mtime = 0
self._summary_dir = summary_dir
def parse_files(self, filenames, events_data):
def parse_files(self, workers_count, filenames, events_data):
"""
Load files and parse files content.
Args:
workers_count (int): The count of workers.
filenames (list[str]): File name list.
events_data (EventsData): The container of event data.
"""
@@ -177,7 +186,7 @@ class _Parser:
class _PbParser(_Parser):
"""This class is used to parse pb file."""
def parse_files(self, filenames, events_data):
def parse_files(self, workers_count, filenames, events_data):
pb_filenames = self.filter_files(filenames)
pb_filenames = self.sort_files(pb_filenames)
for filename in pb_filenames:
@@ -255,11 +264,12 @@ class _SummaryParser(_Parser):
self._summary_file_handler = None
self._events_data = None
def parse_files(self, filenames, events_data):
def parse_files(self, workers_count, filenames, events_data):
"""
Load summary file and parse file content.
Args:
workers_count (int): The count of workers.
filenames (list[str]): File name list.
events_data (EventsData): The container of event data.
"""
@@ -285,7 +295,7 @@ class _SummaryParser(_Parser):
self._latest_file_size = new_size
try:
self._load_single_file(self._summary_file_handler)
self._load_single_file(self._summary_file_handler, workers_count )
except UnknownError as ex:
logger.warning("Parse summary file failed, detail: %r,"
"file path: %s.", str(ex), file_path)
@@ -304,36 +314,75 @@ class _SummaryParser(_Parser):
lambda filename: (re.search(r'summary\.\d+', filename)
and not filename.endswith("_lineage")), filenames))
def _load_single_file(self, file_handler):
def _load_single_file(self, file_handler, workers_count ):
"""
Load a log file data.
Args:
file_handler (FileHandler): A file handler.
workers_count (int): The count of workers.
"""
logger.debug("Load single summary file, file path: %s.", file_handler.file_path)
while True:
start_offset = file_handler.offset
try:
event_str = self._event_load(file_handler)
if event_str is None:
default_concurrency = 1
cpu_count = os.cpu_count()
if cpu_count is None:
concurrency = default_concurrency
else:
concurrency = min(math.floor(cpu_count / workers_count),
math.floor(settings.MAX_PROCESSES_COUNT / workers_count))
if concurrency <= 0:
concurrency = default_concurrency
logger.debug("Load single summary file, file path: %s, concurrency: %s.", file_handler.file_path, concurrency)
semaphore = threading.Semaphore(value=concurrency)
with futures.ProcessPoolExecutor(max_workers=concurrency) as executor:
while True:
start_offset = file_handler.offset
try:
event_str = self._event_load(file_handler)
if event_str is None:
file_handler.reset_offset(start_offset)
break
# Make sure we have at most concurrency tasks not finished to save memory.
semaphore.acquire()
future = executor.submit(self._event_parse, event_str, self._latest_filename)
def _add_tensor_event_callback(future_value):
try:
tensor_values = future_value.result()
for tensor_value in tensor_values:
if tensor_value.plugin_name == PluginNameEnum.GRAPH.value:
try:
graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
except KeyError:
graph_tags = []
summary_tags = self.filter_files(graph_tags)
for tag in summary_tags:
self._events_data.delete_tensor_event(tag)
self._events_data.add_tensor_event(tensor_value)
except Exception as exc:
# Log exception for debugging.
logger.exception(exc)
raise
finally:
semaphore.release()
future.add_done_callback(_add_tensor_event_callback)
except exceptions.CRCFailedError:
file_handler.reset_offset(start_offset)
logger.warning("Check crc faild and ignore this file, file_path=%s, "
"offset=%s.", file_handler.file_path, file_handler.offset)
break
event = summary_pb2.Event.FromString(event_str)
self._event_parse(event)
except exceptions.CRCFailedError:
file_handler.reset_offset(start_offset)
logger.warning("Check crc faild and ignore this file, file_path=%s, "
"offset=%s.", file_handler.file_path, file_handler.offset)
break
except (OSError, DecodeError, exceptions.MindInsightException) as ex:
logger.warning("Parse log file fail, and ignore this file, detail: %r,"
"file path: %s.", str(ex), file_handler.file_path)
break
except Exception as ex:
logger.exception(ex)
raise UnknownError(str(ex))
except (OSError, DecodeError, exceptions.MindInsightException) as ex:
logger.warning("Parse log file fail, and ignore this file, detail: %r,"
"file path: %s.", str(ex), file_handler.file_path)
break
except Exception as ex:
logger.exception(ex)
raise UnknownError(str(ex))
def _event_load(self, file_handler):
"""
@@ -381,20 +430,29 @@ class _SummaryParser(_Parser):
return event_str
def _event_parse(self, event):
@staticmethod
def _event_parse(event_str, latest_file_name):
"""
Transform `Event` data to tensor_event and update it to EventsData.
This method is static to avoid sending unnecessary objects to other processes.
Args:
event (Event): Message event in summary proto, data read from file handler.
event (str): Message event string in summary proto, data read from file handler.
latest_file_name (str): Latest file name.
"""
plugins = {
'scalar_value': PluginNameEnum.SCALAR,
'image': PluginNameEnum.IMAGE,
'histogram': PluginNameEnum.HISTOGRAM,
'tensor': PluginNameEnum.TENSOR
}
logger.debug("Start to parse event string. Event string len: %s.", len(event_str))
event = summary_pb2.Event.FromString(event_str)
logger.debug("Deserialize event string completed.")
ret_tensor_events = []
if event.HasField('summary'):
for value in event.summary.value:
for plugin in plugins:
@@ -402,6 +460,7 @@ class _SummaryParser(_Parser):
continue
plugin_name_enum = plugins[plugin]
tensor_event_value = getattr(value, plugin)
logger.debug("Processing plugin value: %s.", plugin_name_enum)
if plugin == 'histogram':
tensor_event_value = HistogramContainer(tensor_event_value)
@@ -419,29 +478,23 @@ class _SummaryParser(_Parser):
tag='{}/{}'.format(value.tag, plugin_name_enum.value),
plugin_name=plugin_name_enum.value,
value=tensor_event_value,
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
filename=latest_file_name)
logger.debug("Tensor event generated, plugin is %s, tag is %s, step is %s.",
plugin_name_enum, value.tag, event.step)
ret_tensor_events.append(tensor_event)
elif event.HasField('graph_def'):
graph = MSGraph()
graph.build_graph(event.graph_def)
tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step,
tag=self._ latest_filename,
tag=latest_file_ name,
plugin_name=PluginNameEnum.GRAPH.value,
value=graph,
filename=self._latest_filename)
try:
graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
except KeyError:
graph_tags = []
summary_tags = self.filter_files(graph_tags)
for tag in summary_tags:
self._events_data.delete_tensor_event(tag)
filename=latest_file_name)
ret_tensor_events.append(tensor_event)
self._events_data.add_tensor_event(tensor_event)
return ret_tensor_events
@staticmethod
def _compare_summary_file(current_file, dst_file):