Browse Source

Merge pull request #106 from JoeyHwong-gk/federated

fix federated learning bugs
tags/v0.3.1
KubeEdge Bot GitHub 4 years ago
parent
commit
a42a87abb3
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 303 additions and 180 deletions
  1. +6
    -2
      examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py
  2. +2
    -2
      examples/federated_learning/surface_defect_detection/training_worker/inference.py
  3. +1
    -1
      examples/federated_learning/surface_defect_detection/training_worker/train.py
  4. +1
    -0
      lib/sedna/__init__.py
  5. +1
    -42
      lib/sedna/algorithms/aggregation/__init__.py
  6. +73
    -0
      lib/sedna/algorithms/aggregation/aggregation.py
  7. +3
    -1
      lib/sedna/backend/tensorflow/__init__.py
  8. +0
    -15
      lib/sedna/core/__init__.py
  9. +59
    -41
      lib/sedna/core/federated_learning/federated_learning.py
  10. +54
    -33
      lib/sedna/service/client.py
  11. +63
    -38
      lib/sedna/service/server/aggregation.py
  12. +40
    -5
      lib/sedna/service/server/base.py

+ 6
- 2
examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py View File

@@ -23,14 +23,18 @@ def run_server():
exit_round = int(Context.get_parameters(
"exit_round", 3
))
participants_count = int(Context.get_parameters(
"participants_count", 1
))
agg_ip = Context.get_parameters("AGG_IP", "0.0.0.0")
agg_port = int(Context.get_parameters("AGG_PORT", "7363"))
server = AggregationServer(
servername=aggregation_algorithm,
aggregation=aggregation_algorithm,
host=agg_ip,
http_port=agg_port,
exit_round=exit_round,
ws_size=20 * 1024 * 1024
ws_size=20 * 1024 * 1024,
participants_count=participants_count
)
server.start()



+ 2
- 2
examples/federated_learning/surface_defect_detection/training_worker/inference.py View File

@@ -45,8 +45,8 @@ def main():
test_data = TxtDataParse(data_type="test", func=image_process)
test_data.parse(fl_instance.config.test_dataset_url)

fl_instance.evaluate(test_data)
return fl_instance.inference(test_data.x)


if __name__ == '__main__':
main()
print(main())

+ 1
- 1
examples/federated_learning/surface_defect_detection/training_worker/train.py View File

@@ -65,7 +65,7 @@ def main():
fl_model = FederatedLearning(
estimator=Estimator,
aggregation=aggregation_algorithm)
fl_model.register()
train_jobs = fl_model.train(
train_data=train_data,
valid_data=valid_data,


+ 1
- 0
lib/sedna/__init__.py View File

@@ -13,3 +13,4 @@
# limitations under the License.

from .__version__ import __version__
import sedna.algorithms

+ 1
- 42
lib/sedna/algorithms/aggregation/__init__.py View File

@@ -12,45 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Aggregation algorithms"""

import abc
from copy import deepcopy

import numpy as np

from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('FedAvg',)


class BaseAggregation(metaclass=abc.ABCMeta):
def __init__(self):
self.total_size = 0
self.weights = None

@abc.abstractmethod
def aggregate(self, weights, size=0):
"""
Aggregation
:param weights: deep learning weight
:param size: numbers of sample in each loop
"""


@ClassFactory.register(ClassType.FL_AGG)
class FedAvg(BaseAggregation, abc.ABC):
"""Federated averaging algorithm"""

def aggregate(self, weights, size=0):
total_sample = self.total_size + size
if not total_sample:
return self.weights
updates = []
for inx, weight in enumerate(weights):
old_weight = self.weights[inx]
row_weight = ((np.array(weight) - old_weight) *
(size / total_sample) + old_weight)
updates.append(row_weight)
self.weights = deepcopy(updates)
return updates
from .aggregation import *

+ 73
- 0
lib/sedna/algorithms/aggregation/aggregation.py View File

@@ -0,0 +1,73 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Aggregation algorithms"""

import abc
from copy import deepcopy
from typing import List

import numpy as np

from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('AggClient', 'FedAvg',)


class AggClient:
"""Aggregation clients"""
num_samples: int
weights: List


class BaseAggregation(metaclass=abc.ABCMeta):
"""Abstract class of aggregator"""

def __init__(self):
self.total_size = 0
self.weights = None

@abc.abstractmethod
def aggregate(self, clients: List[AggClient]):
"""
Some algorithms can be aggregated in sequence,
but some can be calculated only after all aggregated data is uploaded.
therefore, this abstractmethod should consider that all weights are
uploaded.
:param clients: All clients in federated learning job
:return: final weights
"""


@ClassFactory.register(ClassType.FL_AGG)
class FedAvg(BaseAggregation, abc.ABC):
"""
Federated averaging algorithm : Calculate the average weight
according to the number of samples
"""

def aggregate(self, clients: List[AggClient]):
if not len(clients):
return self.weights
self.total_size = sum([c.num_samples for c in clients])
old_weight = [np.zeros(np.array(c).shape) for c in
next(iter(clients)).weights]
updates = []
for inx, row in enumerate(old_weight):
for c in clients:
row += (np.array(c.weights[inx]) * c.num_samples
/ self.total_size)
updates.append(row.tolist())
self.weights = deepcopy(updates)
return updates

+ 3
- 1
lib/sedna/backend/tensorflow/__init__.py View File

@@ -14,6 +14,7 @@

import os

import numpy as np
import tensorflow as tf

from sedna.backend.base import BackendBase
@@ -74,7 +75,7 @@ class TFBackend(BackendBase):
if not self.has_load:
tf.reset_default_graph()
self.sess = self.load()
return self.estimator.predict(data=data, **kwargs)
return self.estimator.predict(data, **kwargs)

def evaluate(self, data, **kwargs):
if not self.has_load:
@@ -135,4 +136,5 @@ class KerasBackend(TFBackend):
return list(map(lambda x: x.tolist(), self.estimator.get_weights()))

def set_weights(self, weights):
weights = [np.array(x) for x in weights]
self.estimator.set_weights(weights)

+ 0
- 15
lib/sedna/core/__init__.py View File

@@ -1,15 +0,0 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sedna.algorithms import *

+ 59
- 41
lib/sedna/core/federated_learning/federated_learning.py View File

@@ -14,8 +14,6 @@

import time

import asyncio

from sedna.core.base import JobBase
from sedna.common.config import Context
from sedna.common.file_ops import FileOps
@@ -30,6 +28,12 @@ class FederatedLearning(JobBase):
"""

def __init__(self, estimator, aggregation="FedAvg"):
"""
Initial a FederatedLearning job
:param estimator: Customize estimator
:param aggregation: aggregation algorithm for FederatedLearning
"""

protocol = Context.get_parameters("AGG_PROTOCOL", "ws")
agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1")
agg_port = int(Context.get_parameters("AGG_PORT", "7363"))
@@ -43,20 +47,23 @@ class FederatedLearning(JobBase):
super(FederatedLearning, self).__init__(
estimator=estimator, config=config)
self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)

connect_timeout = int(Context.get_parameters("CONNECT_TIMEOUT", "300"))
self.node = None
self.register(timeout=connect_timeout)

def register(self):
def register(self, timeout=300):
self.log.info(
f"Node {self.worker_name} connect to : {self.config.agg_uri}")
self.node = AggregationClient(
url=self.config.agg_uri, client_id=self.worker_name)
loop = asyncio.get_event_loop()
res = loop.run_until_complete(
asyncio.wait_for(self.node.connect(), timeout=300))
url=self.config.agg_uri,
client_id=self.worker_name,
ping_timeout=timeout
)

FileOps.clean_folder([self.config.model_url], clean=False)
self.aggregation = self.aggregation()
self.log.info(f"Federated learning Jobs model prepared -- {res}")
self.log.info(f"{self.worker_name} model prepared")
if callable(self.estimator):
self.estimator = self.estimator()

@@ -64,55 +71,66 @@ class FederatedLearning(JobBase):
valid_data=None,
post_process=None,
**kwargs):
"""
Training task for FederatedLearning
:param train_data: datasource use for train
:param valid_data: datasource use for evaluation
:param post_process: post process
:param kwargs: params for training of customize estimator
"""

callback_func = None
if post_process is not None:
if post_process:
callback_func = ClassFactory.get_cls(
ClassType.CALLBACK, post_process)

round_number = 0
num_samples = len(train_data)
self.aggregation.total_size += num_samples

_flag = True
start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
res = None
while 1:
round_number += 1
start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
self.log.info(
f"Federated learning start at {start},"
f" round_number={round_number}")
res = self.estimator.train(
train_data=train_data, valid_data=valid_data, **kwargs)

self.aggregation.weights = self.estimator.get_weights()
send_data = {"num_samples": num_samples,
"weights": self.aggregation.weights}
received = self.node.send(
send_data, msg_type="update_weight", job_name=self.job_name)
exit_flag = False
if (received and received["type"] == "update_weight"
and received["job_name"] == self.job_name):
recv = received["data"]
if _flag:
round_number += 1
start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
self.log.info(
f"Federated learning start, round_number={round_number}")
res = self.estimator.train(
train_data=train_data, valid_data=valid_data, **kwargs)

rec_client = received["client"]
rec_sample = recv["num_samples"]
current_weights = self.estimator.get_weights()
send_data = {"num_samples": num_samples,
"weights": current_weights}
self.node.send(
send_data, msg_type="update_weight", job_name=self.job_name
)
received = self.node.recv(wait_data_type="recv_weight")
if not received:
_flag = False
continue
_flag = True

self.log.info(
f"Federated learning get weight from "
f"[{rec_client}] : {rec_sample}")
n_weight = self.aggregation.aggregate(
recv["weights"], rec_sample)
self.estimator.set_weights(n_weight)
exit_flag = received.get("exit_flag", "") == "ok"
rec_data = received.get("data", {})
exit_flag = rec_data.get("exit_flag", "")
server_round = int(rec_data.get("round_number"))
total_size = int(rec_data.get("total_sample"))
self.log.info(
f"Federated learning recv weight, "
f"round: {server_round}, total_sample: {total_size}"
)
n_weight = rec_data.get("weights")
self.estimator.set_weights(n_weight)
task_info = {
'currentRound': round_number,
'sampleCount': self.aggregation.total_size,
'sampleCount': total_size,
'startTime': start,
'updateTime': time.strftime(
"%Y-%m-%d %H:%M:%S",
time.localtime())}
"%Y-%m-%d %H:%M:%S", time.localtime())
}
model_paths = self.estimator.save()
task_info_res = self.estimator.model_info(
model_paths, result=res, relpath=self.config.data_path_prefix)
if exit_flag:
if exit_flag == "ok":
self.report_task_info(
task_info,
K8sResourceKindStatus.COMPLETED.value,


+ 54
- 33
lib/sedna/service/client.py View File

@@ -26,6 +26,7 @@ from websockets.exceptions import InvalidStatusCode, WebSocketException
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK

from sedna.common.log import LOGGER
from sedna.common.file_ops import FileOps


@retry(stop_max_attempt_number=3,
@@ -98,8 +99,9 @@ class LCReporter(threading.Thread):
time.localtime()),
"inferenceNumber": self.inference_number,
"hardExampleNumber": self.hard_example_number,
"uploadCloudRatio": self.hard_example_number /
self.inference_number}
"uploadCloudRatio": (self.hard_example_number /
self.inference_number)
}
self.message["ownerInfo"] = info
LCClient.send(self.lc_server,
self.message["name"],
@@ -141,15 +143,18 @@ class AggregationClient:
"ping_interval": interval,
"max_size": min(max_size, 16 * 1024 * 1024)
})
loop = asyncio.get_event_loop()
loop.run_until_complete(
asyncio.wait_for(self.connect(), timeout=timeout)
)

async def connect(self):
LOGGER.info(f"{self.uri} connection by {self.client_id}")

try:
conn = websockets.connect(
self.ws = await asyncio.wait_for(websockets.connect(
self.uri, **self.kwargs
)
self.ws = await conn.__aenter__()
), self._ws_timeout)
await self.ws.send(json.dumps({'type': 'subscribe',
'client_id': self.client_id}))

@@ -162,15 +167,15 @@ class AggregationClient:
LOGGER.info(f"{self.uri} connection lost")
raise
except ConnectionClosedOK:
LOGGER.info(f"{self.uri } connection closed")
LOGGER.info(f"{self.uri} connection closed")
raise
except InvalidStatusCode as err:
LOGGER.info(
f"{self.uri } websocket failed - "
f"{self.uri} websocket failed - "
f"with invalid status code {err.status_code}")
raise
except WebSocketException as err:
LOGGER.info(f"{self.uri } websocket failed - with {err}")
LOGGER.info(f"{self.uri} websocket failed - with {err}")
raise
except OSError as err:
LOGGER.info(f"{self.uri} connection failed - with {err}")
@@ -182,12 +187,16 @@ class AggregationClient:
async def _send(self, data):
for _ in range(self._retry):
try:
await self.ws.send(data)
result = await self.ws.recv()
return result
except Exception:
await asyncio.wait_for(self.ws.send(data), self._ws_timeout)
return
except Exception as err:
LOGGER.info(f"{self.uri} send data failed - with {err}")
time.sleep(self._retry_interval_seconds)
return None
return

async def _recv(self):
result = await self.ws.recv()
return result

def send(self, data, msg_type="message", job_name=""):
loop = asyncio.get_event_loop()
@@ -195,11 +204,18 @@ class AggregationClient:
"type": msg_type, "client": self.client_id,
"data": data, "job_name": job_name
})
data_json = loop.run_until_complete(self._send(j))
if data_json is None:
return
res = json.loads(data_json)
return res
loop.run_until_complete(self._send(j))

def recv(self, wait_data_type=None):
loop = asyncio.get_event_loop()
data = loop.run_until_complete(self._recv())
try:
data = json.loads(data)
except Exception:
pass
if not wait_data_type or (isinstance(data, dict) and
data.get("type", "") == wait_data_type):
return data


class ModelClient:
@@ -237,10 +253,11 @@ class KBClient:
with open(files, "rb") as fin:
files = {"file": fin}
outurl = http_request(url=_url, method="POST", files=files)
if outurl:
outurl = outurl.lstrip("/")
return f"{self.kbserver}/{outurl}"
return files
if FileOps.is_remote(outurl):
return outurl
outurl = outurl.lstrip("/")
FileOps.delete(files)
return f"{self.kbserver}/{outurl}"

def update_db(self, task_info_file):

@@ -250,13 +267,16 @@ class KBClient:
with open(task_info_file, "rb") as fin:
files = {"task": fin}
outurl = http_request(url=_url, method="POST", files=files)
outurl = outurl.lstrip("/")
_id = f"{self.kbserver}/{outurl}"
LOGGER.info(f"Update kb success: {_id}")

except Exception as err:
LOGGER.error(f"Update kb error: {err}")
_id = None
return _id
outurl = None
else:
if not FileOps.is_remote(outurl):
outurl = outurl.lstrip("/")
outurl = f"{self.kbserver}/{outurl}"
FileOps.delete(task_info_file)
return outurl

def update_task_status(self, tasks: str, new_status=1):
data = {
@@ -266,11 +286,12 @@ class KBClient:
_url = f"{self.kbserver}/update/status"
try:
outurl = http_request(url=_url, method="POST", json=data)
outurl = outurl.lstrip("/")
return f"{self.kbserver}/{outurl}"
except Exception as err:
LOGGER.error(f"Update kb error: {err}")
return None

def query_db(self, sample):
pass
outurl = None
if not outurl:
return None
if not FileOps.is_remote(outurl):
outurl = outurl.lstrip("/")
outurl = f"{self.kbserver}/{outurl}"
return outurl

+ 63
- 38
lib/sedna/service/server/aggregation.py View File

@@ -13,7 +13,7 @@
# limitations under the License.

import time
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any

import uuid
from pydantic import BaseModel
@@ -27,6 +27,8 @@ from starlette.types import ASGIApp, Receive, Scope, Send

from sedna.common.log import LOGGER
from sedna.common.utils import get_host_ip
from sedna.common.class_factory import ClassFactory, ClassType
from sedna.algorithms.aggregation import AggClient

from .base import BaseServer

@@ -37,10 +39,9 @@ class WSClientInfo(BaseModel): # pylint: disable=too-few-public-methods
"""
client information
"""

client_id: str
connected_at: float
job_count: int
info: Any


class WSClientInfoList(BaseModel): # pylint: disable=too-few-public-methods
@@ -57,6 +58,9 @@ class WSEventMiddleware: # pylint: disable=too-few-public-methods
servername = scope["path"].lstrip("/")
scope[servername] = self._server
await self._app(scope, receive, send)
# exit agg server if job complete
scope["app"].shutdown = (self._server.exit_check()
and self._server.empty)


class WSServerBase:
@@ -84,7 +88,7 @@ class WSServerBase:
LOGGER.info(f"Adding client {client_id}")
self._clients[client_id] = websocket
self._client_meta[client_id] = WSClientInfo(
client_id=client_id, connected_at=time.time(), job_count=0
client_id=client_id, connected_at=time.time(), info=None
)

async def kick_client(self, client_id: str):
@@ -103,7 +107,6 @@ class WSServerBase:
return self._client_meta.get(client_id)

async def send_message(self, client_id: str, msg: Dict):
self._client_meta[client_id].job_count += 1
for to_client, websocket in self._clients.items():
if to_client == client_id:
continue
@@ -115,40 +118,54 @@ class WSServerBase:
await websocket.send_json({"type": "CLIENT_JOIN",
"data": client_id})

async def client_left(self, client_id: str):
for to_client, websocket in self._clients.items():
if to_client == client_id:
continue
await websocket.send_json({"type": "CLIENT_LEAVE",
"data": client_id})


class Aggregator(WSServerBase):
def __init__(self, **kwargs):
super(Aggregator, self).__init__()
self.exit_round = int(kwargs.get("exit_round", 3))
self.current_round = {}
aggregation = kwargs.get("aggregation", "FedAvg")
self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)
if callable(self.aggregation):
self.aggregation = self.aggregation()
self.participants_count = int(kwargs.get("participants_count", "1"))
self.current_round = 0

async def send_message(self, client_id: str, msg: Dict):
self._client_meta[client_id].job_count += 1
clients = list(self._clients.items())
for to_client, websocket in clients:

if msg.get("type", "") == "update_weight":
if to_client == client_id:
continue
self.current_round[to_client] = self.current_round.get(
to_client, 0) + 1
exit_flag = "ok" if self.exit_check(to_client) else "continue"
msg["exit_flag"] = exit_flag
data = msg.get("data")
if data and msg.get("type", "") == "update_weight":
info = AggClient()
info.num_samples = int(data["num_samples"])
info.weights = data["weights"]
self._client_meta[client_id].info = info
current_clinets = [
x.info for x in self._client_meta.values() if x.info
]
# exit while aggregation job is NOT start
if len(current_clinets) < self.participants_count:
return
self.current_round += 1
weights = self.aggregation.aggregate(current_clinets)
exit_flag = "ok" if self.exit_check() else "continue"

msg["type"] = "recv_weight"
msg["round_number"] = self.current_round
msg["data"] = {
"total_sample": self.aggregation.total_size,
"round_number": self.current_round,
"weights": weights,
"exit_flag": exit_flag
}
for to_client, websocket in self._clients.items():
try:
await websocket.send_json(msg)
except Exception as err:
LOGGER.error(err)
else:
if msg["type"] == "recv_weight":
self._client_meta[to_client].info = None

def exit_check(self, client_id):
current_round = self.current_round.get(client_id, 0)
return current_round >= self.exit_round
def exit_check(self):
return self.current_round >= self.exit_round


class BroadcastWs(WebSocketEndpoint):
@@ -176,7 +193,6 @@ class BroadcastWs(WebSocketEndpoint):
"on_disconnect() called without a valid client_id"
)
self.server.remove_client(self.client_id)
await self.server.client_left(self.client_id)

async def on_receive(self, _websocket: WebSocket, msg: Dict):
command = msg.get("type", "")
@@ -185,50 +201,59 @@ class BroadcastWs(WebSocketEndpoint):
await self.server.client_joined(self.client_id)
self.server.add_client(self.client_id, _websocket)
if self.client_id is None:
raise RuntimeError("on_receive() called without a valid client_id")
raise RuntimeError(
"on_receive() called without a valid client_id")
await self.server.send_message(self.client_id, msg)


class AggregationServer(BaseServer):
def __init__(
self,
servername: str,
aggregation: str,
host: str = None,
http_port: int = 7363,
exit_round: int = 1,
ws_size: int = 10 *
1024 *
1024):
participants_count: int = 1,
ws_size: int = 10 * 1024 * 1024):
if not host:
host = get_host_ip()
super(
AggregationServer,
self).__init__(
servername=servername,
servername=aggregation,
host=host,
http_port=http_port,
ws_size=ws_size)
self.server_name = servername
self.aggregation = aggregation
self.participants_count = participants_count
self.exit_round = max(int(exit_round), 1)
self.app = FastAPI(
routes=[
APIRoute(
f"/{servername}",
f"/{aggregation}",
self.client_info,
response_class=JSONResponse,
),
WebSocketRoute(
f"/{servername}",
f"/{aggregation}",
BroadcastWs
)
],
)
self.app.shutdown = False

def start(self):
"""
Start the server
"""
self.app.add_middleware(WSEventMiddleware, exit_round=self.exit_round)

self.app.add_middleware(
WSEventMiddleware,
exit_round=self.exit_round,
aggregation=self.aggregation,
participants_count=self.participants_count
) # define the aggregation method and exit condition

self.run(self.app, ws_max_size=self.ws_size)

async def client_info(self, request: Request):


+ 40
- 5
lib/sedna/service/server/base.py View File

@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import time
import threading
import asyncio

import uvicorn
from fastapi.middleware.cors import CORSMiddleware

@@ -19,9 +24,25 @@ from sedna.common.log import LOGGER
from sedna.common.utils import get_host_ip


class Server(uvicorn.Server):
def install_signal_handlers(self):
pass

@contextlib.contextmanager
def run_in_thread(self):
thread = threading.Thread(target=self.run, daemon=True)
thread.start()
try:
yield thread
finally:
self.should_exit = True
thread.join()


class BaseServer:
# pylint: disable=too-many-instance-attributes,too-many-arguments
DEBUG = True
WAIT_TIME = 15

def __init__(
self,
@@ -50,14 +71,15 @@ class BaseServer:
self.url = f"{protocal}://{self.host}:{self.http_port}"

def run(self, app, **kwargs):
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
if hasattr(app, "add_middleware"):
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)

LOGGER.info(f"Start {self.server_name} server over {self.url}")

uvicorn.run(
config = uvicorn.Config(
app,
host=self.host,
port=self.http_port,
@@ -65,7 +87,20 @@ class BaseServer:
ssl_certfile=self.certfile,
workers=self.workers,
timeout_keep_alive=self.timeout,
log_level="info",
**kwargs)
server = Server(config=config)
with server.run_in_thread() as current_thread:
return self.wait_stop(current=current_thread)

def wait_stop(self, current):
"""wait the stop flag to shutdown the server"""
while 1:
time.sleep(self.WAIT_TIME)
if not current.isAlive():
return
if getattr(self.app, "shutdown", False):
return

def get_all_urls(self):
url_list = [{"path": route.path, "name": route.name}


Loading…
Cancel
Save