diff --git a/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py b/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py index 5e32fc36..84ac6f68 100644 --- a/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py +++ b/examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py @@ -23,14 +23,18 @@ def run_server(): exit_round = int(Context.get_parameters( "exit_round", 3 )) + participants_count = int(Context.get_parameters( + "participants_count", 1 + )) agg_ip = Context.get_parameters("AGG_IP", "0.0.0.0") agg_port = int(Context.get_parameters("AGG_PORT", "7363")) server = AggregationServer( - servername=aggregation_algorithm, + aggregation=aggregation_algorithm, host=agg_ip, http_port=agg_port, exit_round=exit_round, - ws_size=20 * 1024 * 1024 + ws_size=20 * 1024 * 1024, + participants_count=participants_count ) server.start() diff --git a/examples/federated_learning/surface_defect_detection/training_worker/inference.py b/examples/federated_learning/surface_defect_detection/training_worker/inference.py index 9ed01864..ee35b918 100644 --- a/examples/federated_learning/surface_defect_detection/training_worker/inference.py +++ b/examples/federated_learning/surface_defect_detection/training_worker/inference.py @@ -45,8 +45,8 @@ def main(): test_data = TxtDataParse(data_type="test", func=image_process) test_data.parse(fl_instance.config.test_dataset_url) - fl_instance.evaluate(test_data) + return fl_instance.inference(test_data.x) if __name__ == '__main__': - main() + print(main()) diff --git a/examples/federated_learning/surface_defect_detection/training_worker/train.py b/examples/federated_learning/surface_defect_detection/training_worker/train.py index 74ffddff..4fd9a112 100644 --- a/examples/federated_learning/surface_defect_detection/training_worker/train.py +++ b/examples/federated_learning/surface_defect_detection/training_worker/train.py @@ -65,7 +65,7 @@ def main(): fl_model = FederatedLearning( estimator=Estimator, aggregation=aggregation_algorithm) - fl_model.register() + train_jobs = fl_model.train( train_data=train_data, valid_data=valid_data, diff --git a/lib/sedna/__init__.py b/lib/sedna/__init__.py index 09032824..f9084e46 100644 --- a/lib/sedna/__init__.py +++ b/lib/sedna/__init__.py @@ -13,3 +13,4 @@ # limitations under the License. from .__version__ import __version__ +import sedna.algorithms diff --git a/lib/sedna/algorithms/aggregation/__init__.py b/lib/sedna/algorithms/aggregation/__init__.py index d5af3b2e..bcaac502 100644 --- a/lib/sedna/algorithms/aggregation/__init__.py +++ b/lib/sedna/algorithms/aggregation/__init__.py @@ -12,45 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Aggregation algorithms""" - -import abc -from copy import deepcopy - -import numpy as np - -from sedna.common.class_factory import ClassFactory, ClassType - -__all__ = ('FedAvg',) - - -class BaseAggregation(metaclass=abc.ABCMeta): - def __init__(self): - self.total_size = 0 - self.weights = None - - @abc.abstractmethod - def aggregate(self, weights, size=0): - """ - Aggregation - :param weights: deep learning weight - :param size: numbers of sample in each loop - """ - - -@ClassFactory.register(ClassType.FL_AGG) -class FedAvg(BaseAggregation, abc.ABC): - """Federated averaging algorithm""" - - def aggregate(self, weights, size=0): - total_sample = self.total_size + size - if not total_sample: - return self.weights - updates = [] - for inx, weight in enumerate(weights): - old_weight = self.weights[inx] - row_weight = ((np.array(weight) - old_weight) * - (size / total_sample) + old_weight) - updates.append(row_weight) - self.weights = deepcopy(updates) - return updates +from .aggregation import * diff --git a/lib/sedna/algorithms/aggregation/aggregation.py b/lib/sedna/algorithms/aggregation/aggregation.py new file mode 100644 index 00000000..de56b24b --- /dev/null +++ b/lib/sedna/algorithms/aggregation/aggregation.py @@ -0,0 +1,73 @@ +# Copyright 2021 The KubeEdge Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Aggregation algorithms""" + +import abc +from copy import deepcopy +from typing import List + +import numpy as np + +from sedna.common.class_factory import ClassFactory, ClassType + +__all__ = ('AggClient', 'FedAvg',) + + +class AggClient: + """Aggregation clients""" + num_samples: int + weights: List + + +class BaseAggregation(metaclass=abc.ABCMeta): + """Abstract class of aggregator""" + + def __init__(self): + self.total_size = 0 + self.weights = None + + @abc.abstractmethod + def aggregate(self, clients: List[AggClient]): + """ + Some algorithms can be aggregated in sequence, + but some can be calculated only after all aggregated data is uploaded. + therefore, this abstractmethod should consider that all weights are + uploaded. + :param clients: All clients in federated learning job + :return: final weights + """ + + +@ClassFactory.register(ClassType.FL_AGG) +class FedAvg(BaseAggregation, abc.ABC): + """ + Federated averaging algorithm : Calculate the average weight + according to the number of samples + """ + + def aggregate(self, clients: List[AggClient]): + if not len(clients): + return self.weights + self.total_size = sum([c.num_samples for c in clients]) + old_weight = [np.zeros(np.array(c).shape) for c in + next(iter(clients)).weights] + updates = [] + for inx, row in enumerate(old_weight): + for c in clients: + row += (np.array(c.weights[inx]) * c.num_samples + / self.total_size) + updates.append(row.tolist()) + self.weights = deepcopy(updates) + return updates diff --git a/lib/sedna/backend/tensorflow/__init__.py b/lib/sedna/backend/tensorflow/__init__.py index cf15fc4f..c029b2e5 100644 --- a/lib/sedna/backend/tensorflow/__init__.py +++ b/lib/sedna/backend/tensorflow/__init__.py @@ -14,6 +14,7 @@ import os +import numpy as np import tensorflow as tf from sedna.backend.base import BackendBase @@ -74,7 +75,7 @@ class TFBackend(BackendBase): if not self.has_load: tf.reset_default_graph() self.sess = self.load() - return self.estimator.predict(data=data, **kwargs) + return self.estimator.predict(data, **kwargs) def evaluate(self, data, **kwargs): if not self.has_load: @@ -135,4 +136,5 @@ class KerasBackend(TFBackend): return list(map(lambda x: x.tolist(), self.estimator.get_weights())) def set_weights(self, weights): + weights = [np.array(x) for x in weights] self.estimator.set_weights(weights) diff --git a/lib/sedna/core/__init__.py b/lib/sedna/core/__init__.py index 958382e3..e69de29b 100644 --- a/lib/sedna/core/__init__.py +++ b/lib/sedna/core/__init__.py @@ -1,15 +0,0 @@ -# Copyright 2021 The KubeEdge Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from sedna.algorithms import * diff --git a/lib/sedna/core/federated_learning/federated_learning.py b/lib/sedna/core/federated_learning/federated_learning.py index 5030fdd5..ac2a4a37 100644 --- a/lib/sedna/core/federated_learning/federated_learning.py +++ b/lib/sedna/core/federated_learning/federated_learning.py @@ -14,8 +14,6 @@ import time -import asyncio - from sedna.core.base import JobBase from sedna.common.config import Context from sedna.common.file_ops import FileOps @@ -30,6 +28,12 @@ class FederatedLearning(JobBase): """ def __init__(self, estimator, aggregation="FedAvg"): + """ + Initial a FederatedLearning job + :param estimator: Customize estimator + :param aggregation: aggregation algorithm for FederatedLearning + """ + protocol = Context.get_parameters("AGG_PROTOCOL", "ws") agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1") agg_port = int(Context.get_parameters("AGG_PORT", "7363")) @@ -43,20 +47,23 @@ class FederatedLearning(JobBase): super(FederatedLearning, self).__init__( estimator=estimator, config=config) self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation) + + connect_timeout = int(Context.get_parameters("CONNECT_TIMEOUT", "300")) self.node = None + self.register(timeout=connect_timeout) - def register(self): + def register(self, timeout=300): self.log.info( f"Node {self.worker_name} connect to : {self.config.agg_uri}") self.node = AggregationClient( - url=self.config.agg_uri, client_id=self.worker_name) - loop = asyncio.get_event_loop() - res = loop.run_until_complete( - asyncio.wait_for(self.node.connect(), timeout=300)) + url=self.config.agg_uri, + client_id=self.worker_name, + ping_timeout=timeout + ) FileOps.clean_folder([self.config.model_url], clean=False) self.aggregation = self.aggregation() - self.log.info(f"Federated learning Jobs model prepared -- {res}") + self.log.info(f"{self.worker_name} model prepared") if callable(self.estimator): self.estimator = self.estimator() @@ -64,55 +71,66 @@ class FederatedLearning(JobBase): valid_data=None, post_process=None, **kwargs): + """ + Training task for FederatedLearning + :param train_data: datasource use for train + :param valid_data: datasource use for evaluation + :param post_process: post process + :param kwargs: params for training of customize estimator + """ + callback_func = None - if post_process is not None: + if post_process: callback_func = ClassFactory.get_cls( ClassType.CALLBACK, post_process) round_number = 0 num_samples = len(train_data) - self.aggregation.total_size += num_samples - + _flag = True + start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + res = None while 1: - round_number += 1 - start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - self.log.info( - f"Federated learning start at {start}," - f" round_number={round_number}") - res = self.estimator.train( - train_data=train_data, valid_data=valid_data, **kwargs) - - self.aggregation.weights = self.estimator.get_weights() - send_data = {"num_samples": num_samples, - "weights": self.aggregation.weights} - received = self.node.send( - send_data, msg_type="update_weight", job_name=self.job_name) - exit_flag = False - if (received and received["type"] == "update_weight" - and received["job_name"] == self.job_name): - recv = received["data"] + if _flag: + round_number += 1 + start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + self.log.info( + f"Federated learning start, round_number={round_number}") + res = self.estimator.train( + train_data=train_data, valid_data=valid_data, **kwargs) - rec_client = received["client"] - rec_sample = recv["num_samples"] + current_weights = self.estimator.get_weights() + send_data = {"num_samples": num_samples, + "weights": current_weights} + self.node.send( + send_data, msg_type="update_weight", job_name=self.job_name + ) + received = self.node.recv(wait_data_type="recv_weight") + if not received: + _flag = False + continue + _flag = True - self.log.info( - f"Federated learning get weight from " - f"[{rec_client}] : {rec_sample}") - n_weight = self.aggregation.aggregate( - recv["weights"], rec_sample) - self.estimator.set_weights(n_weight) - exit_flag = received.get("exit_flag", "") == "ok" + rec_data = received.get("data", {}) + exit_flag = rec_data.get("exit_flag", "") + server_round = int(rec_data.get("round_number")) + total_size = int(rec_data.get("total_sample")) + self.log.info( + f"Federated learning recv weight, " + f"round: {server_round}, total_sample: {total_size}" + ) + n_weight = rec_data.get("weights") + self.estimator.set_weights(n_weight) task_info = { 'currentRound': round_number, - 'sampleCount': self.aggregation.total_size, + 'sampleCount': total_size, 'startTime': start, 'updateTime': time.strftime( - "%Y-%m-%d %H:%M:%S", - time.localtime())} + "%Y-%m-%d %H:%M:%S", time.localtime()) + } model_paths = self.estimator.save() task_info_res = self.estimator.model_info( model_paths, result=res, relpath=self.config.data_path_prefix) - if exit_flag: + if exit_flag == "ok": self.report_task_info( task_info, K8sResourceKindStatus.COMPLETED.value, diff --git a/lib/sedna/service/client.py b/lib/sedna/service/client.py index 0ed29307..d342fbc9 100644 --- a/lib/sedna/service/client.py +++ b/lib/sedna/service/client.py @@ -26,6 +26,7 @@ from websockets.exceptions import InvalidStatusCode, WebSocketException from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK from sedna.common.log import LOGGER +from sedna.common.file_ops import FileOps @retry(stop_max_attempt_number=3, @@ -98,8 +99,9 @@ class LCReporter(threading.Thread): time.localtime()), "inferenceNumber": self.inference_number, "hardExampleNumber": self.hard_example_number, - "uploadCloudRatio": self.hard_example_number / - self.inference_number} + "uploadCloudRatio": (self.hard_example_number / + self.inference_number) + } self.message["ownerInfo"] = info LCClient.send(self.lc_server, self.message["name"], @@ -141,15 +143,18 @@ class AggregationClient: "ping_interval": interval, "max_size": min(max_size, 16 * 1024 * 1024) }) + loop = asyncio.get_event_loop() + loop.run_until_complete( + asyncio.wait_for(self.connect(), timeout=timeout) + ) async def connect(self): LOGGER.info(f"{self.uri} connection by {self.client_id}") try: - conn = websockets.connect( + self.ws = await asyncio.wait_for(websockets.connect( self.uri, **self.kwargs - ) - self.ws = await conn.__aenter__() + ), self._ws_timeout) await self.ws.send(json.dumps({'type': 'subscribe', 'client_id': self.client_id})) @@ -162,15 +167,15 @@ class AggregationClient: LOGGER.info(f"{self.uri} connection lost") raise except ConnectionClosedOK: - LOGGER.info(f"{self.uri } connection closed") + LOGGER.info(f"{self.uri} connection closed") raise except InvalidStatusCode as err: LOGGER.info( - f"{self.uri } websocket failed - " + f"{self.uri} websocket failed - " f"with invalid status code {err.status_code}") raise except WebSocketException as err: - LOGGER.info(f"{self.uri } websocket failed - with {err}") + LOGGER.info(f"{self.uri} websocket failed - with {err}") raise except OSError as err: LOGGER.info(f"{self.uri} connection failed - with {err}") @@ -182,12 +187,16 @@ class AggregationClient: async def _send(self, data): for _ in range(self._retry): try: - await self.ws.send(data) - result = await self.ws.recv() - return result - except Exception: + await asyncio.wait_for(self.ws.send(data), self._ws_timeout) + return + except Exception as err: + LOGGER.info(f"{self.uri} send data failed - with {err}") time.sleep(self._retry_interval_seconds) - return None + return + + async def _recv(self): + result = await self.ws.recv() + return result def send(self, data, msg_type="message", job_name=""): loop = asyncio.get_event_loop() @@ -195,11 +204,18 @@ class AggregationClient: "type": msg_type, "client": self.client_id, "data": data, "job_name": job_name }) - data_json = loop.run_until_complete(self._send(j)) - if data_json is None: - return - res = json.loads(data_json) - return res + loop.run_until_complete(self._send(j)) + + def recv(self, wait_data_type=None): + loop = asyncio.get_event_loop() + data = loop.run_until_complete(self._recv()) + try: + data = json.loads(data) + except Exception: + pass + if not wait_data_type or (isinstance(data, dict) and + data.get("type", "") == wait_data_type): + return data class ModelClient: @@ -237,10 +253,11 @@ class KBClient: with open(files, "rb") as fin: files = {"file": fin} outurl = http_request(url=_url, method="POST", files=files) - if outurl: - outurl = outurl.lstrip("/") - return f"{self.kbserver}/{outurl}" - return files + if FileOps.is_remote(outurl): + return outurl + outurl = outurl.lstrip("/") + FileOps.delete(files) + return f"{self.kbserver}/{outurl}" def update_db(self, task_info_file): @@ -250,13 +267,16 @@ class KBClient: with open(task_info_file, "rb") as fin: files = {"task": fin} outurl = http_request(url=_url, method="POST", files=files) - outurl = outurl.lstrip("/") - _id = f"{self.kbserver}/{outurl}" - LOGGER.info(f"Update kb success: {_id}") + except Exception as err: LOGGER.error(f"Update kb error: {err}") - _id = None - return _id + outurl = None + else: + if not FileOps.is_remote(outurl): + outurl = outurl.lstrip("/") + outurl = f"{self.kbserver}/{outurl}" + FileOps.delete(task_info_file) + return outurl def update_task_status(self, tasks: str, new_status=1): data = { @@ -266,11 +286,12 @@ class KBClient: _url = f"{self.kbserver}/update/status" try: outurl = http_request(url=_url, method="POST", json=data) - outurl = outurl.lstrip("/") - return f"{self.kbserver}/{outurl}" except Exception as err: LOGGER.error(f"Update kb error: {err}") - return None - - def query_db(self, sample): - pass + outurl = None + if not outurl: + return None + if not FileOps.is_remote(outurl): + outurl = outurl.lstrip("/") + outurl = f"{self.kbserver}/{outurl}" + return outurl diff --git a/lib/sedna/service/server/aggregation.py b/lib/sedna/service/server/aggregation.py index 72cf6468..3529d6b4 100644 --- a/lib/sedna/service/server/aggregation.py +++ b/lib/sedna/service/server/aggregation.py @@ -13,7 +13,7 @@ # limitations under the License. import time -from typing import List, Optional, Dict +from typing import List, Optional, Dict, Any import uuid from pydantic import BaseModel @@ -27,6 +27,8 @@ from starlette.types import ASGIApp, Receive, Scope, Send from sedna.common.log import LOGGER from sedna.common.utils import get_host_ip +from sedna.common.class_factory import ClassFactory, ClassType +from sedna.algorithms.aggregation import AggClient from .base import BaseServer @@ -37,10 +39,9 @@ class WSClientInfo(BaseModel): # pylint: disable=too-few-public-methods """ client information """ - client_id: str connected_at: float - job_count: int + info: Any class WSClientInfoList(BaseModel): # pylint: disable=too-few-public-methods @@ -57,6 +58,9 @@ class WSEventMiddleware: # pylint: disable=too-few-public-methods servername = scope["path"].lstrip("/") scope[servername] = self._server await self._app(scope, receive, send) + # exit agg server if job complete + scope["app"].shutdown = (self._server.exit_check() + and self._server.empty) class WSServerBase: @@ -84,7 +88,7 @@ class WSServerBase: LOGGER.info(f"Adding client {client_id}") self._clients[client_id] = websocket self._client_meta[client_id] = WSClientInfo( - client_id=client_id, connected_at=time.time(), job_count=0 + client_id=client_id, connected_at=time.time(), info=None ) async def kick_client(self, client_id: str): @@ -103,7 +107,6 @@ class WSServerBase: return self._client_meta.get(client_id) async def send_message(self, client_id: str, msg: Dict): - self._client_meta[client_id].job_count += 1 for to_client, websocket in self._clients.items(): if to_client == client_id: continue @@ -115,40 +118,54 @@ class WSServerBase: await websocket.send_json({"type": "CLIENT_JOIN", "data": client_id}) - async def client_left(self, client_id: str): - for to_client, websocket in self._clients.items(): - if to_client == client_id: - continue - await websocket.send_json({"type": "CLIENT_LEAVE", - "data": client_id}) - class Aggregator(WSServerBase): def __init__(self, **kwargs): super(Aggregator, self).__init__() self.exit_round = int(kwargs.get("exit_round", 3)) - self.current_round = {} + aggregation = kwargs.get("aggregation", "FedAvg") + self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation) + if callable(self.aggregation): + self.aggregation = self.aggregation() + self.participants_count = int(kwargs.get("participants_count", "1")) + self.current_round = 0 async def send_message(self, client_id: str, msg: Dict): - self._client_meta[client_id].job_count += 1 - clients = list(self._clients.items()) - for to_client, websocket in clients: - - if msg.get("type", "") == "update_weight": - if to_client == client_id: - continue - self.current_round[to_client] = self.current_round.get( - to_client, 0) + 1 - exit_flag = "ok" if self.exit_check(to_client) else "continue" - msg["exit_flag"] = exit_flag + data = msg.get("data") + if data and msg.get("type", "") == "update_weight": + info = AggClient() + info.num_samples = int(data["num_samples"]) + info.weights = data["weights"] + self._client_meta[client_id].info = info + current_clinets = [ + x.info for x in self._client_meta.values() if x.info + ] + # exit while aggregation job is NOT start + if len(current_clinets) < self.participants_count: + return + self.current_round += 1 + weights = self.aggregation.aggregate(current_clinets) + exit_flag = "ok" if self.exit_check() else "continue" + + msg["type"] = "recv_weight" + msg["round_number"] = self.current_round + msg["data"] = { + "total_sample": self.aggregation.total_size, + "round_number": self.current_round, + "weights": weights, + "exit_flag": exit_flag + } + for to_client, websocket in self._clients.items(): try: await websocket.send_json(msg) except Exception as err: LOGGER.error(err) + else: + if msg["type"] == "recv_weight": + self._client_meta[to_client].info = None - def exit_check(self, client_id): - current_round = self.current_round.get(client_id, 0) - return current_round >= self.exit_round + def exit_check(self): + return self.current_round >= self.exit_round class BroadcastWs(WebSocketEndpoint): @@ -176,7 +193,6 @@ class BroadcastWs(WebSocketEndpoint): "on_disconnect() called without a valid client_id" ) self.server.remove_client(self.client_id) - await self.server.client_left(self.client_id) async def on_receive(self, _websocket: WebSocket, msg: Dict): command = msg.get("type", "") @@ -185,50 +201,59 @@ class BroadcastWs(WebSocketEndpoint): await self.server.client_joined(self.client_id) self.server.add_client(self.client_id, _websocket) if self.client_id is None: - raise RuntimeError("on_receive() called without a valid client_id") + raise RuntimeError( + "on_receive() called without a valid client_id") await self.server.send_message(self.client_id, msg) class AggregationServer(BaseServer): def __init__( self, - servername: str, + aggregation: str, host: str = None, http_port: int = 7363, exit_round: int = 1, - ws_size: int = 10 * - 1024 * - 1024): + participants_count: int = 1, + ws_size: int = 10 * 1024 * 1024): if not host: host = get_host_ip() super( AggregationServer, self).__init__( - servername=servername, + servername=aggregation, host=host, http_port=http_port, ws_size=ws_size) - self.server_name = servername + self.aggregation = aggregation + self.participants_count = participants_count self.exit_round = max(int(exit_round), 1) self.app = FastAPI( routes=[ APIRoute( - f"/{servername}", + f"/{aggregation}", self.client_info, response_class=JSONResponse, ), WebSocketRoute( - f"/{servername}", + f"/{aggregation}", BroadcastWs ) ], ) + self.app.shutdown = False def start(self): """ Start the server """ - self.app.add_middleware(WSEventMiddleware, exit_round=self.exit_round) + + self.app.add_middleware( + WSEventMiddleware, + exit_round=self.exit_round, + aggregation=self.aggregation, + participants_count=self.participants_count + ) # define the aggregation method and exit condition + self.run(self.app, ws_max_size=self.ws_size) async def client_info(self, request: Request): diff --git a/lib/sedna/service/server/base.py b/lib/sedna/service/server/base.py index c6e787dd..d5302e31 100644 --- a/lib/sedna/service/server/base.py +++ b/lib/sedna/service/server/base.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import time +import threading +import asyncio + import uvicorn from fastapi.middleware.cors import CORSMiddleware @@ -19,9 +24,25 @@ from sedna.common.log import LOGGER from sedna.common.utils import get_host_ip +class Server(uvicorn.Server): + def install_signal_handlers(self): + pass + + @contextlib.contextmanager + def run_in_thread(self): + thread = threading.Thread(target=self.run, daemon=True) + thread.start() + try: + yield thread + finally: + self.should_exit = True + thread.join() + + class BaseServer: # pylint: disable=too-many-instance-attributes,too-many-arguments DEBUG = True + WAIT_TIME = 15 def __init__( self, @@ -50,14 +71,15 @@ class BaseServer: self.url = f"{protocal}://{self.host}:{self.http_port}" def run(self, app, **kwargs): - app.add_middleware( - CORSMiddleware, allow_origins=["*"], allow_credentials=True, - allow_methods=["*"], allow_headers=["*"], - ) + if hasattr(app, "add_middleware"): + app.add_middleware( + CORSMiddleware, allow_origins=["*"], allow_credentials=True, + allow_methods=["*"], allow_headers=["*"], + ) LOGGER.info(f"Start {self.server_name} server over {self.url}") - uvicorn.run( + config = uvicorn.Config( app, host=self.host, port=self.http_port, @@ -65,7 +87,20 @@ class BaseServer: ssl_certfile=self.certfile, workers=self.workers, timeout_keep_alive=self.timeout, + log_level="info", **kwargs) + server = Server(config=config) + with server.run_in_thread() as current_thread: + return self.wait_stop(current=current_thread) + + def wait_stop(self, current): + """wait the stop flag to shutdown the server""" + while 1: + time.sleep(self.WAIT_TIME) + if not current.isAlive(): + return + if getattr(self.app, "shutdown", False): + return def get_all_urls(self): url_list = [{"path": route.path, "name": route.name}