Browse Source

Use multiple processes to calc events.

1. To accelerate summary file parsing, multiple processes are used. As the first step to mindinsight parsing performance optimization, we only made changes to _load_single_file function.

2. This PR will imporve summary parsing throughput dramatically (about cpu_count times)

3. Changes are mainly about _load_single_file function

In the future, a more global concurrent computing framework is needed for mindinsight. See the gitee wiki doc for details.
tags/v0.6.0-beta
wangshuide2020 5 years ago
parent
commit
7877f33b70
9 changed files with 148 additions and 81 deletions
  1. +5
    -6
      mindinsight/backend/run.py
  2. +2
    -0
      mindinsight/conf/constants.py
  3. +8
    -3
      mindinsight/datavisual/data_transform/data_loader.py
  4. +4
    -4
      mindinsight/datavisual/data_transform/data_manager.py
  5. +1
    -0
      mindinsight/datavisual/data_transform/events_data.py
  6. +101
    -48
      mindinsight/datavisual/data_transform/ms_data_loader.py
  7. +2
    -3
      mindinsight/datavisual/data_transform/tensor_container.py
  8. +3
    -3
      mindinsight/datavisual/processors/tensor_processor.py
  9. +22
    -14
      mindinsight/scripts/stop.py

+ 5
- 6
mindinsight/backend/run.py View File

@@ -236,9 +236,10 @@ def start():
process = subprocess.Popen( process = subprocess.Popen(
shlex.split(cmd), shlex.split(cmd),
shell=False, shell=False,
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
# Change stdout to DEVNULL to prevent broken pipe error when creating new processes.
stdin=subprocess.DEVNULL,
stdout=subprocess.DEVNULL,
stderr=subprocess.STDOUT
) )


# sleep 1 second for gunicorn appplication to load modules # sleep 1 second for gunicorn appplication to load modules
@@ -246,9 +247,7 @@ def start():


# check if gunicorn application is running # check if gunicorn application is running
if process.poll() is not None: if process.poll() is not None:
_, stderr = process.communicate()
for line in stderr.decode().split('\n'):
console.error(line)
console.error("Start MindInsight failed. See log for details.")
else: else:
state_result = _check_server_start_stat(errorlog_abspath, log_size) state_result = _check_server_start_stat(errorlog_abspath, log_size)
# print gunicorn start state to stdout # print gunicorn start state to stdout


+ 2
- 0
mindinsight/conf/constants.py View File

@@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
"""Constants module for mindinsight settings.""" """Constants module for mindinsight settings."""
import logging import logging
import os


#################################### ####################################
# Global default settings. # Global default settings.
@@ -48,6 +49,7 @@ API_PREFIX = '/v1/mindinsight'
# Datavisual default settings. # Datavisual default settings.
#################################### ####################################
MAX_THREADS_COUNT = 15 MAX_THREADS_COUNT = 15
MAX_PROCESSES_COUNT = max(os.cpu_count() or 0, 15)


MAX_TAG_SIZE_PER_EVENTS_DATA = 300 MAX_TAG_SIZE_PER_EVENTS_DATA = 300
DEFAULT_STEP_SIZES_PER_TAG = 500 DEFAULT_STEP_SIZES_PER_TAG = 500


+ 8
- 3
mindinsight/datavisual/data_transform/data_loader.py View File

@@ -34,8 +34,13 @@ class DataLoader:
self._summary_dir = summary_dir self._summary_dir = summary_dir
self._loader = None self._loader = None


def load(self):
"""Load the data when loader is exist."""
def load(self, workers_count=1):
"""Load the data when loader is exist.

Args:
workers_count (int): The count of workers. Default value is 1.
"""

if self._loader is None: if self._loader is None:
ms_dataloader = MSDataLoader(self._summary_dir) ms_dataloader = MSDataLoader(self._summary_dir)
loaders = [ms_dataloader] loaders = [ms_dataloader]
@@ -48,7 +53,7 @@ class DataLoader:
logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir) logger.warning("No valid files can be loaded, summary_dir: %s.", self._summary_dir)
raise exceptions.SummaryLogPathInvalid() raise exceptions.SummaryLogPathInvalid()


self._loader.load()
self._loader.load(workers_count)


def get_events_data(self): def get_events_data(self):
""" """


+ 4
- 4
mindinsight/datavisual/data_transform/data_manager.py View File

@@ -510,7 +510,7 @@ class _DetailCacheManager(_BaseCacheManager):
logger.debug("delete loader %s", loader_id) logger.debug("delete loader %s", loader_id)
self._loader_pool.pop(loader_id) self._loader_pool.pop(loader_id)


def _execute_loader(self, loader_id):
def _execute_loader(self, loader_id, workers_count):
""" """
Load data form data_loader. Load data form data_loader.


@@ -518,7 +518,7 @@ class _DetailCacheManager(_BaseCacheManager):


Args: Args:
loader_id (str): An ID for `Loader`. loader_id (str): An ID for `Loader`.
workers_count (int): The count of workers.
""" """
try: try:
with self._loader_pool_mutex: with self._loader_pool_mutex:
@@ -527,7 +527,7 @@ class _DetailCacheManager(_BaseCacheManager):
logger.debug("Loader %r has been deleted, will not load data.", loader_id) logger.debug("Loader %r has been deleted, will not load data.", loader_id)
return return


loader.data_loader.load()
loader.data_loader.load(workers_count)


# Update loader cache status to CACHED. # Update loader cache status to CACHED.
# Loader with cache status CACHED should remain the same cache status. # Loader with cache status CACHED should remain the same cache status.
@@ -584,7 +584,7 @@ class _DetailCacheManager(_BaseCacheManager):
futures = [] futures = []
loader_pool = self._get_snapshot_loader_pool() loader_pool = self._get_snapshot_loader_pool()
for loader_id in loader_pool: for loader_id in loader_pool:
future = executor.submit(self._execute_loader, loader_id)
future = executor.submit(self._execute_loader, loader_id, threads_count)
futures.append(future) futures.append(future)
wait(futures, return_when=ALL_COMPLETED) wait(futures, return_when=ALL_COMPLETED)




+ 1
- 0
mindinsight/datavisual/data_transform/events_data.py View File

@@ -85,6 +85,7 @@ class EventsData:
deleted_tag = self._check_tag_out_of_spec(plugin_name) deleted_tag = self._check_tag_out_of_spec(plugin_name)
if deleted_tag is not None: if deleted_tag is not None:
if tag in self._deleted_tags: if tag in self._deleted_tags:
logger.debug("Tag is in deleted tags: %s.", tag)
return return
self.delete_tensor_event(deleted_tag) self.delete_tensor_event(deleted_tag)




+ 101
- 48
mindinsight/datavisual/data_transform/ms_data_loader.py View File

@@ -19,12 +19,17 @@ This module is used to load the MindSpore training log file.
Each instance will read an entire run, a run can contain one or Each instance will read an entire run, a run can contain one or
more log file. more log file.
""" """
import concurrent.futures as futures
import math
import os
import re import re
import struct import struct
import threading


from google.protobuf.message import DecodeError from google.protobuf.message import DecodeError
from google.protobuf.text_format import ParseError from google.protobuf.text_format import ParseError


from mindinsight.conf import settings
from mindinsight.datavisual.common import exceptions from mindinsight.datavisual.common import exceptions
from mindinsight.datavisual.common.enums import PluginNameEnum from mindinsight.datavisual.common.enums import PluginNameEnum
from mindinsight.datavisual.common.log import logger from mindinsight.datavisual.common.log import logger
@@ -32,13 +37,13 @@ from mindinsight.datavisual.data_access.file_handler import FileHandler
from mindinsight.datavisual.data_transform.events_data import EventsData from mindinsight.datavisual.data_transform.events_data import EventsData
from mindinsight.datavisual.data_transform.events_data import TensorEvent from mindinsight.datavisual.data_transform.events_data import TensorEvent
from mindinsight.datavisual.data_transform.graph import MSGraph from mindinsight.datavisual.data_transform.graph import MSGraph
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import UnknownError
from mindinsight.datavisual.data_transform.histogram import Histogram from mindinsight.datavisual.data_transform.histogram import Histogram
from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer from mindinsight.datavisual.data_transform.histogram_container import HistogramContainer
from mindinsight.datavisual.data_transform.tensor_container import TensorContainer from mindinsight.datavisual.data_transform.tensor_container import TensorContainer
from mindinsight.datavisual.proto_files import mindinsight_anf_ir_pb2 as anf_ir_pb2
from mindinsight.datavisual.proto_files import mindinsight_summary_pb2 as summary_pb2
from mindinsight.datavisual.utils import crc32
from mindinsight.utils.exceptions import UnknownError


HEADER_SIZE = 8 HEADER_SIZE = 8
CRC_STR_SIZE = 4 CRC_STR_SIZE = 4
@@ -79,11 +84,14 @@ class MSDataLoader:
"we will reload all files in path %s.", self._summary_dir) "we will reload all files in path %s.", self._summary_dir)
self.__init__(self._summary_dir) self.__init__(self._summary_dir)


def load(self):
def load(self, workers_count=1):
""" """
Load all log valid files. Load all log valid files.


When the file is reloaded, it will continue to load from where it left off. When the file is reloaded, it will continue to load from where it left off.

Args:
workers_count (int): The count of workers. Default value is 1.
""" """
logger.debug("Start to load data in ms data loader.") logger.debug("Start to load data in ms data loader.")
filenames = self.filter_valid_files() filenames = self.filter_valid_files()
@@ -95,7 +103,7 @@ class MSDataLoader:
self._check_files_deleted(filenames, old_filenames) self._check_files_deleted(filenames, old_filenames)


for parser in self._parser_list: for parser in self._parser_list:
parser.parse_files(filenames, events_data=self._events_data)
parser.parse_files(workers_count, filenames, events_data=self._events_data)


def filter_valid_files(self): def filter_valid_files(self):
""" """
@@ -125,11 +133,12 @@ class _Parser:
self._latest_mtime = 0 self._latest_mtime = 0
self._summary_dir = summary_dir self._summary_dir = summary_dir


def parse_files(self, filenames, events_data):
def parse_files(self, workers_count, filenames, events_data):
""" """
Load files and parse files content. Load files and parse files content.


Args: Args:
workers_count (int): The count of workers.
filenames (list[str]): File name list. filenames (list[str]): File name list.
events_data (EventsData): The container of event data. events_data (EventsData): The container of event data.
""" """
@@ -177,7 +186,7 @@ class _Parser:
class _PbParser(_Parser): class _PbParser(_Parser):
"""This class is used to parse pb file.""" """This class is used to parse pb file."""


def parse_files(self, filenames, events_data):
def parse_files(self, workers_count, filenames, events_data):
pb_filenames = self.filter_files(filenames) pb_filenames = self.filter_files(filenames)
pb_filenames = self.sort_files(pb_filenames) pb_filenames = self.sort_files(pb_filenames)
for filename in pb_filenames: for filename in pb_filenames:
@@ -255,11 +264,12 @@ class _SummaryParser(_Parser):
self._summary_file_handler = None self._summary_file_handler = None
self._events_data = None self._events_data = None


def parse_files(self, filenames, events_data):
def parse_files(self, workers_count, filenames, events_data):
""" """
Load summary file and parse file content. Load summary file and parse file content.


Args: Args:
workers_count (int): The count of workers.
filenames (list[str]): File name list. filenames (list[str]): File name list.
events_data (EventsData): The container of event data. events_data (EventsData): The container of event data.
""" """
@@ -285,7 +295,7 @@ class _SummaryParser(_Parser):


self._latest_file_size = new_size self._latest_file_size = new_size
try: try:
self._load_single_file(self._summary_file_handler)
self._load_single_file(self._summary_file_handler, workers_count)
except UnknownError as ex: except UnknownError as ex:
logger.warning("Parse summary file failed, detail: %r," logger.warning("Parse summary file failed, detail: %r,"
"file path: %s.", str(ex), file_path) "file path: %s.", str(ex), file_path)
@@ -304,36 +314,75 @@ class _SummaryParser(_Parser):
lambda filename: (re.search(r'summary\.\d+', filename) lambda filename: (re.search(r'summary\.\d+', filename)
and not filename.endswith("_lineage")), filenames)) and not filename.endswith("_lineage")), filenames))


def _load_single_file(self, file_handler):
def _load_single_file(self, file_handler, workers_count):
""" """
Load a log file data. Load a log file data.


Args: Args:
file_handler (FileHandler): A file handler. file_handler (FileHandler): A file handler.
workers_count (int): The count of workers.
""" """
logger.debug("Load single summary file, file path: %s.", file_handler.file_path)
while True:
start_offset = file_handler.offset
try:
event_str = self._event_load(file_handler)
if event_str is None:

default_concurrency = 1
cpu_count = os.cpu_count()
if cpu_count is None:
concurrency = default_concurrency
else:
concurrency = min(math.floor(cpu_count / workers_count),
math.floor(settings.MAX_PROCESSES_COUNT / workers_count))
if concurrency <= 0:
concurrency = default_concurrency
logger.debug("Load single summary file, file path: %s, concurrency: %s.", file_handler.file_path, concurrency)

semaphore = threading.Semaphore(value=concurrency)
with futures.ProcessPoolExecutor(max_workers=concurrency) as executor:
while True:
start_offset = file_handler.offset
try:
event_str = self._event_load(file_handler)
if event_str is None:
file_handler.reset_offset(start_offset)
break

# Make sure we have at most concurrency tasks not finished to save memory.
semaphore.acquire()
future = executor.submit(self._event_parse, event_str, self._latest_filename)

def _add_tensor_event_callback(future_value):
try:
tensor_values = future_value.result()
for tensor_value in tensor_values:
if tensor_value.plugin_name == PluginNameEnum.GRAPH.value:
try:
graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
except KeyError:
graph_tags = []

summary_tags = self.filter_files(graph_tags)
for tag in summary_tags:
self._events_data.delete_tensor_event(tag)

self._events_data.add_tensor_event(tensor_value)
except Exception as exc:
# Log exception for debugging.
logger.exception(exc)
raise
finally:
semaphore.release()

future.add_done_callback(_add_tensor_event_callback)
except exceptions.CRCFailedError:
file_handler.reset_offset(start_offset) file_handler.reset_offset(start_offset)
logger.warning("Check crc faild and ignore this file, file_path=%s, "
"offset=%s.", file_handler.file_path, file_handler.offset)
break break

event = summary_pb2.Event.FromString(event_str)
self._event_parse(event)
except exceptions.CRCFailedError:
file_handler.reset_offset(start_offset)
logger.warning("Check crc faild and ignore this file, file_path=%s, "
"offset=%s.", file_handler.file_path, file_handler.offset)
break
except (OSError, DecodeError, exceptions.MindInsightException) as ex:
logger.warning("Parse log file fail, and ignore this file, detail: %r,"
"file path: %s.", str(ex), file_handler.file_path)
break
except Exception as ex:
logger.exception(ex)
raise UnknownError(str(ex))
except (OSError, DecodeError, exceptions.MindInsightException) as ex:
logger.warning("Parse log file fail, and ignore this file, detail: %r,"
"file path: %s.", str(ex), file_handler.file_path)
break
except Exception as ex:
logger.exception(ex)
raise UnknownError(str(ex))


def _event_load(self, file_handler): def _event_load(self, file_handler):
""" """
@@ -381,20 +430,29 @@ class _SummaryParser(_Parser):


return event_str return event_str


def _event_parse(self, event):
@staticmethod
def _event_parse(event_str, latest_file_name):
""" """
Transform `Event` data to tensor_event and update it to EventsData. Transform `Event` data to tensor_event and update it to EventsData.


This method is static to avoid sending unnecessary objects to other processes.

Args: Args:
event (Event): Message event in summary proto, data read from file handler.
event (str): Message event string in summary proto, data read from file handler.
latest_file_name (str): Latest file name.
""" """

plugins = { plugins = {
'scalar_value': PluginNameEnum.SCALAR, 'scalar_value': PluginNameEnum.SCALAR,
'image': PluginNameEnum.IMAGE, 'image': PluginNameEnum.IMAGE,
'histogram': PluginNameEnum.HISTOGRAM, 'histogram': PluginNameEnum.HISTOGRAM,
'tensor': PluginNameEnum.TENSOR 'tensor': PluginNameEnum.TENSOR
} }
logger.debug("Start to parse event string. Event string len: %s.", len(event_str))
event = summary_pb2.Event.FromString(event_str)
logger.debug("Deserialize event string completed.")


ret_tensor_events = []
if event.HasField('summary'): if event.HasField('summary'):
for value in event.summary.value: for value in event.summary.value:
for plugin in plugins: for plugin in plugins:
@@ -402,6 +460,7 @@ class _SummaryParser(_Parser):
continue continue
plugin_name_enum = plugins[plugin] plugin_name_enum = plugins[plugin]
tensor_event_value = getattr(value, plugin) tensor_event_value = getattr(value, plugin)
logger.debug("Processing plugin value: %s.", plugin_name_enum)


if plugin == 'histogram': if plugin == 'histogram':
tensor_event_value = HistogramContainer(tensor_event_value) tensor_event_value = HistogramContainer(tensor_event_value)
@@ -419,29 +478,23 @@ class _SummaryParser(_Parser):
tag='{}/{}'.format(value.tag, plugin_name_enum.value), tag='{}/{}'.format(value.tag, plugin_name_enum.value),
plugin_name=plugin_name_enum.value, plugin_name=plugin_name_enum.value,
value=tensor_event_value, value=tensor_event_value,
filename=self._latest_filename)
self._events_data.add_tensor_event(tensor_event)
filename=latest_file_name)
logger.debug("Tensor event generated, plugin is %s, tag is %s, step is %s.",
plugin_name_enum, value.tag, event.step)
ret_tensor_events.append(tensor_event)


elif event.HasField('graph_def'): elif event.HasField('graph_def'):
graph = MSGraph() graph = MSGraph()
graph.build_graph(event.graph_def) graph.build_graph(event.graph_def)
tensor_event = TensorEvent(wall_time=event.wall_time, tensor_event = TensorEvent(wall_time=event.wall_time,
step=event.step, step=event.step,
tag=self._latest_filename,
tag=latest_file_name,
plugin_name=PluginNameEnum.GRAPH.value, plugin_name=PluginNameEnum.GRAPH.value,
value=graph, value=graph,
filename=self._latest_filename)

try:
graph_tags = self._events_data.list_tags_by_plugin(PluginNameEnum.GRAPH.value)
except KeyError:
graph_tags = []

summary_tags = self.filter_files(graph_tags)
for tag in summary_tags:
self._events_data.delete_tensor_event(tag)
filename=latest_file_name)
ret_tensor_events.append(tensor_event)


self._events_data.add_tensor_event(tensor_event)
return ret_tensor_events


@staticmethod @staticmethod
def _compare_summary_file(current_file, dst_file): def _compare_summary_file(current_file, dst_file):


+ 2
- 3
mindinsight/datavisual/data_transform/tensor_container.py View File

@@ -199,8 +199,8 @@ class TensorContainer:


def __init__(self, tensor_message): def __init__(self, tensor_message):
self._lock = threading.Lock self._lock = threading.Lock
self._msg = tensor_message
self._dims = tensor_message.dims
# Original dims can not be pickled to transfer to other process, so tuple is used.
self._dims = tuple(tensor_message.dims)
self._data_type = tensor_message.data_type self._data_type = tensor_message.data_type
self._np_array = None self._np_array = None
self._data = _get_data_from_tensor(tensor_message) self._data = _get_data_from_tensor(tensor_message)
@@ -265,5 +265,4 @@ class TensorContainer:
logger.error("Reshape array fail, detail: %r", str(ex)) logger.error("Reshape array fail, detail: %r", str(ex))
return return


self._msg = None
self._np_array = ndarray self._np_array = ndarray

+ 3
- 3
mindinsight/datavisual/processors/tensor_processor.py View File

@@ -245,7 +245,7 @@ class TensorProcessor(BaseProcessor):
# This value is an instance of TensorContainer # This value is an instance of TensorContainer
value = tensor.value value = tensor.value
value_dict = { value_dict = {
"dims": tuple(value.dims),
"dims": value.dims,
"data_type": anf_ir_pb2.DataType.Name(value.data_type) "data_type": anf_ir_pb2.DataType.Name(value.data_type)
} }
if detail and detail == 'stats': if detail and detail == 'stats':
@@ -313,7 +313,7 @@ class TensorProcessor(BaseProcessor):
"wall_time": tensor.wall_time, "wall_time": tensor.wall_time,
"step": tensor.step, "step": tensor.step,
"value": { "value": {
"dims": tuple(value.dims),
"dims": value.dims,
"data_type": anf_ir_pb2.DataType.Name(value.data_type), "data_type": anf_ir_pb2.DataType.Name(value.data_type),
"data": res_data.tolist(), "data": res_data.tolist(),
"statistics": get_statistics_dict(value, flatten_data) "statistics": get_statistics_dict(value, flatten_data)
@@ -362,7 +362,7 @@ class TensorProcessor(BaseProcessor):
"wall_time": tensor.wall_time, "wall_time": tensor.wall_time,
"step": tensor.step, "step": tensor.step,
"value": { "value": {
"dims": tuple(value.dims),
"dims": value.dims,
"data_type": anf_ir_pb2.DataType.Name(value.data_type), "data_type": anf_ir_pb2.DataType.Name(value.data_type),
"histogram_buckets": buckets, "histogram_buckets": buckets,
"statistics": get_statistics_dict(value, None) "statistics": get_statistics_dict(value, None)


+ 22
- 14
mindinsight/scripts/stop.py View File

@@ -103,21 +103,17 @@ class Command(BaseCommand):
self.logfile.info('Stop mindinsight with port %s and pid %s.', port, pid) self.logfile.info('Stop mindinsight with port %s and pid %s.', port, pid)


process = psutil.Process(pid) process = psutil.Process(pid)
child_pids = [child.pid for child in process.children()]
processes_to_kill = [process]
# Set recursive to True to kill grand children processes.
for child in process.children(recursive=True):
processes_to_kill.append(child)


# kill gunicorn master process
try:
os.kill(pid, signal.SIGKILL)
except PermissionError:
self.console.info('kill pid %s failed due to permission error', pid)
sys.exit(1)

# cleanup gunicorn worker processes
for child_pid in child_pids:
for proc in processes_to_kill:
self.logfile.info('Stopping mindinsight process %s.', proc.pid)
try: try:
os.kill(child_pid, signal.SIGKILL)
except ProcessLookupError:
pass
proc.send_signal(signal.SIGKILL)
except psutil.Error as ex:
self.logfile.warning("Stop process %s failed. Detail: %s.", proc.pid, str(ex))


for hook in HookUtils.instance().hooks(): for hook in HookUtils.instance().hooks():
hook.on_shutdown(self.logfile) hook.on_shutdown(self.logfile)
@@ -154,7 +150,19 @@ class Command(BaseCommand):
if user != process.username(): if user != process.username():
continue continue


pid = process.pid if process.ppid() == 1 else process.ppid()
gunicorn_master_process = process

# The gunicorn master process might have grand children (eg forked by process pool).
while True:
parent_process = gunicorn_master_process.parent()
if parent_process is None or parent_process.pid == 1:
break
parent_cmd = parent_process.cmdline()
if ' '.join(parent_cmd).find(self.cmd_regex) == -1:
break
gunicorn_master_process = parent_process

pid = gunicorn_master_process.pid


for open_file in process.open_files(): for open_file in process.open_files():
if open_file.path.endswith(self.access_log_path): if open_file.path.endswith(self.access_log_path):


Loading…
Cancel
Save