|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161 |
- # Copyright 2020 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Define the message handler."""
- import uuid
- from queue import Queue, Empty
- from threading import Lock
-
- from mindinsight.debugger.common.exceptions.exceptions import DebuggerParamValueError
- from mindinsight.debugger.common.log import LOGGER as log
- from mindinsight.debugger.stream_handler.base_handler import StreamHandlerBase
-
-
- class EventHandler(StreamHandlerBase):
- """Message Handler."""
-
- max_limit = 1000 # the max number of items in cache
-
- def __init__(self):
- self._prev_flag = str(uuid.uuid4())
- self._cur_flag = str(uuid.uuid4())
- self._next_idx = 0
- self._event_cache = [None] * self.max_limit
- self._pending_requests = {}
- self._lock = Lock()
-
- @property
- def next_pos(self):
- """The next pos to be updated in cache."""
- return ':'.join([self._cur_flag, str(self._next_idx)])
-
- def has_pos(self, pos):
- """Get the event according to pos."""
- cur_flag, cur_idx = self._parse_pos(pos)
- if cur_flag not in [self._cur_flag, self._prev_flag]:
- cur_flag, cur_idx = self._cur_flag, 0
- event = self._event_cache[cur_idx]
- if event is not None:
- if not cur_flag or (cur_flag == self._cur_flag and cur_idx < self._next_idx) or \
- (cur_flag == self._prev_flag and cur_idx >= self._next_idx):
- return event
-
- return None
-
- def clean(self):
- """Clean event cache."""
- with self._lock:
- self._prev_flag = str(uuid.uuid4())
- self._cur_flag = str(uuid.uuid4())
- self._next_idx = 0
- self._event_cache = [None] * self.max_limit
- value = {'metadata': {'pos': '0'}}
- self.clean_pending_requests(value)
- log.debug("Clean event cache. %d request is waiting.", len(self._pending_requests))
-
- def put(self, value):
- """
- Put value into event_cache.
-
- Args:
- value (dict): The event to be put into cache.
- """
- if not isinstance(value, dict):
- log.error("Dict type required when put event message.")
- raise DebuggerParamValueError("Dict type required when put event message.")
-
- with self._lock:
- log.debug("Put the %d-th message into queue. \n %d requests is waiting.",
- self._next_idx, len(self._pending_requests))
- cur_pos = self._next_idx
- # update next pos
- self._next_idx += 1
- if self._next_idx >= self.max_limit:
- self._next_idx = 0
- self._prev_flag = self._cur_flag
- self._cur_flag = str(uuid.uuid4())
- # set next pos
- if not value.get('metadata'):
- value['metadata'] = {}
- value['metadata']['pos'] = self.next_pos
- self._event_cache[cur_pos] = value
- # feed the value for pending requests
- self.clean_pending_requests(value)
-
- def clean_pending_requests(self, value):
- """Clean pending requests."""
- for _, request in self._pending_requests.items():
- request.put(value)
- self._pending_requests = {}
-
- def get(self, filter_condition=None):
- """
- Get the pos-th value from event_cache according to filter_condition.
-
- Args:
- filter_condition (str): The index of event in cache. Default: None.
-
- Returns:
- object, the pos-th event.
- """
- flag, idx = self._parse_pos(filter_condition)
- cur_id = str(uuid.uuid4())
- with self._lock:
- # reset the pos after the cache is re-initialized.
- if not flag or flag not in [self._cur_flag, self._prev_flag]:
- idx = 0
- # get event from cache immediately
- if idx != self._next_idx and self._event_cache[idx]:
- return self._event_cache[idx]
- # wait for the event
- cur_queue = Queue(maxsize=1)
- self._pending_requests[cur_id] = cur_queue
- # block until event has been received
- event = self._wait_for_event(cur_id, cur_queue, filter_condition)
-
- return event
-
- def _parse_pos(self, pos):
- """Get next pos according to input position."""
- elements = pos.split(':')
- try:
- idx = int(elements[-1])
- except ValueError:
- log.error("Invalid index. The index in pos should be digit but get pos:%s", pos)
- raise DebuggerParamValueError("Invalid pos.")
-
- if idx < 0 or idx >= self.max_limit:
- log.error("Invalid index. The index in pos should between [0, %d)", self.max_limit)
- raise DebuggerParamValueError(f"Invalid pos. {idx}")
- flag = elements[0] if len(elements) == 2 else ''
-
- return flag, idx
-
- def _wait_for_event(self, cur_id, cur_queue, pos):
- """Wait for the pos-th event."""
- try:
- # set the timeout to 25 seconds which is less the the timeout limit from UI
- event = cur_queue.get(timeout=25)
- except Empty:
- event = None
-
- if event is None:
- with self._lock:
- if self._pending_requests.get(cur_id):
- self._pending_requests.pop(cur_id)
- log.debug("Clean timeout request. Left pending requests: %d",
- len(self._pending_requests))
- event = {'metadata': {'pos': pos}}
-
- return event
|