fix federated learning bugstags/v0.3.1
| @@ -23,14 +23,18 @@ def run_server(): | |||
| exit_round = int(Context.get_parameters( | |||
| "exit_round", 3 | |||
| )) | |||
| participants_count = int(Context.get_parameters( | |||
| "participants_count", 1 | |||
| )) | |||
| agg_ip = Context.get_parameters("AGG_IP", "0.0.0.0") | |||
| agg_port = int(Context.get_parameters("AGG_PORT", "7363")) | |||
| server = AggregationServer( | |||
| servername=aggregation_algorithm, | |||
| aggregation=aggregation_algorithm, | |||
| host=agg_ip, | |||
| http_port=agg_port, | |||
| exit_round=exit_round, | |||
| ws_size=20 * 1024 * 1024 | |||
| ws_size=20 * 1024 * 1024, | |||
| participants_count=participants_count | |||
| ) | |||
| server.start() | |||
| @@ -45,8 +45,8 @@ def main(): | |||
| test_data = TxtDataParse(data_type="test", func=image_process) | |||
| test_data.parse(fl_instance.config.test_dataset_url) | |||
| fl_instance.evaluate(test_data) | |||
| return fl_instance.inference(test_data.x) | |||
| if __name__ == '__main__': | |||
| main() | |||
| print(main()) | |||
| @@ -65,7 +65,7 @@ def main(): | |||
| fl_model = FederatedLearning( | |||
| estimator=Estimator, | |||
| aggregation=aggregation_algorithm) | |||
| fl_model.register() | |||
| train_jobs = fl_model.train( | |||
| train_data=train_data, | |||
| valid_data=valid_data, | |||
| @@ -13,3 +13,4 @@ | |||
| # limitations under the License. | |||
| from .__version__ import __version__ | |||
| import sedna.algorithms | |||
| @@ -12,45 +12,4 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Aggregation algorithms""" | |||
| import abc | |||
| from copy import deepcopy | |||
| import numpy as np | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('FedAvg',) | |||
| class BaseAggregation(metaclass=abc.ABCMeta): | |||
| def __init__(self): | |||
| self.total_size = 0 | |||
| self.weights = None | |||
| @abc.abstractmethod | |||
| def aggregate(self, weights, size=0): | |||
| """ | |||
| Aggregation | |||
| :param weights: deep learning weight | |||
| :param size: numbers of sample in each loop | |||
| """ | |||
| @ClassFactory.register(ClassType.FL_AGG) | |||
| class FedAvg(BaseAggregation, abc.ABC): | |||
| """Federated averaging algorithm""" | |||
| def aggregate(self, weights, size=0): | |||
| total_sample = self.total_size + size | |||
| if not total_sample: | |||
| return self.weights | |||
| updates = [] | |||
| for inx, weight in enumerate(weights): | |||
| old_weight = self.weights[inx] | |||
| row_weight = ((np.array(weight) - old_weight) * | |||
| (size / total_sample) + old_weight) | |||
| updates.append(row_weight) | |||
| self.weights = deepcopy(updates) | |||
| return updates | |||
| from .aggregation import * | |||
| @@ -0,0 +1,73 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| """Aggregation algorithms""" | |||
| import abc | |||
| from copy import deepcopy | |||
| from typing import List | |||
| import numpy as np | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| __all__ = ('AggClient', 'FedAvg',) | |||
| class AggClient: | |||
| """Aggregation clients""" | |||
| num_samples: int | |||
| weights: List | |||
| class BaseAggregation(metaclass=abc.ABCMeta): | |||
| """Abstract class of aggregator""" | |||
| def __init__(self): | |||
| self.total_size = 0 | |||
| self.weights = None | |||
| @abc.abstractmethod | |||
| def aggregate(self, clients: List[AggClient]): | |||
| """ | |||
| Some algorithms can be aggregated in sequence, | |||
| but some can be calculated only after all aggregated data is uploaded. | |||
| therefore, this abstractmethod should consider that all weights are | |||
| uploaded. | |||
| :param clients: All clients in federated learning job | |||
| :return: final weights | |||
| """ | |||
| @ClassFactory.register(ClassType.FL_AGG) | |||
| class FedAvg(BaseAggregation, abc.ABC): | |||
| """ | |||
| Federated averaging algorithm : Calculate the average weight | |||
| according to the number of samples | |||
| """ | |||
| def aggregate(self, clients: List[AggClient]): | |||
| if not len(clients): | |||
| return self.weights | |||
| self.total_size = sum([c.num_samples for c in clients]) | |||
| old_weight = [np.zeros(np.array(c).shape) for c in | |||
| next(iter(clients)).weights] | |||
| updates = [] | |||
| for inx, row in enumerate(old_weight): | |||
| for c in clients: | |||
| row += (np.array(c.weights[inx]) * c.num_samples | |||
| / self.total_size) | |||
| updates.append(row.tolist()) | |||
| self.weights = deepcopy(updates) | |||
| return updates | |||
| @@ -14,6 +14,7 @@ | |||
| import os | |||
| import numpy as np | |||
| import tensorflow as tf | |||
| from sedna.backend.base import BackendBase | |||
| @@ -74,7 +75,7 @@ class TFBackend(BackendBase): | |||
| if not self.has_load: | |||
| tf.reset_default_graph() | |||
| self.sess = self.load() | |||
| return self.estimator.predict(data=data, **kwargs) | |||
| return self.estimator.predict(data, **kwargs) | |||
| def evaluate(self, data, **kwargs): | |||
| if not self.has_load: | |||
| @@ -135,4 +136,5 @@ class KerasBackend(TFBackend): | |||
| return list(map(lambda x: x.tolist(), self.estimator.get_weights())) | |||
| def set_weights(self, weights): | |||
| weights = [np.array(x) for x in weights] | |||
| self.estimator.set_weights(weights) | |||
| @@ -1,15 +0,0 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from sedna.algorithms import * | |||
| @@ -14,8 +14,6 @@ | |||
| import time | |||
| import asyncio | |||
| from sedna.core.base import JobBase | |||
| from sedna.common.config import Context | |||
| from sedna.common.file_ops import FileOps | |||
| @@ -30,6 +28,12 @@ class FederatedLearning(JobBase): | |||
| """ | |||
| def __init__(self, estimator, aggregation="FedAvg"): | |||
| """ | |||
| Initial a FederatedLearning job | |||
| :param estimator: Customize estimator | |||
| :param aggregation: aggregation algorithm for FederatedLearning | |||
| """ | |||
| protocol = Context.get_parameters("AGG_PROTOCOL", "ws") | |||
| agg_ip = Context.get_parameters("AGG_IP", "127.0.0.1") | |||
| agg_port = int(Context.get_parameters("AGG_PORT", "7363")) | |||
| @@ -43,20 +47,23 @@ class FederatedLearning(JobBase): | |||
| super(FederatedLearning, self).__init__( | |||
| estimator=estimator, config=config) | |||
| self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation) | |||
| connect_timeout = int(Context.get_parameters("CONNECT_TIMEOUT", "300")) | |||
| self.node = None | |||
| self.register(timeout=connect_timeout) | |||
| def register(self): | |||
| def register(self, timeout=300): | |||
| self.log.info( | |||
| f"Node {self.worker_name} connect to : {self.config.agg_uri}") | |||
| self.node = AggregationClient( | |||
| url=self.config.agg_uri, client_id=self.worker_name) | |||
| loop = asyncio.get_event_loop() | |||
| res = loop.run_until_complete( | |||
| asyncio.wait_for(self.node.connect(), timeout=300)) | |||
| url=self.config.agg_uri, | |||
| client_id=self.worker_name, | |||
| ping_timeout=timeout | |||
| ) | |||
| FileOps.clean_folder([self.config.model_url], clean=False) | |||
| self.aggregation = self.aggregation() | |||
| self.log.info(f"Federated learning Jobs model prepared -- {res}") | |||
| self.log.info(f"{self.worker_name} model prepared") | |||
| if callable(self.estimator): | |||
| self.estimator = self.estimator() | |||
| @@ -64,55 +71,66 @@ class FederatedLearning(JobBase): | |||
| valid_data=None, | |||
| post_process=None, | |||
| **kwargs): | |||
| """ | |||
| Training task for FederatedLearning | |||
| :param train_data: datasource use for train | |||
| :param valid_data: datasource use for evaluation | |||
| :param post_process: post process | |||
| :param kwargs: params for training of customize estimator | |||
| """ | |||
| callback_func = None | |||
| if post_process is not None: | |||
| if post_process: | |||
| callback_func = ClassFactory.get_cls( | |||
| ClassType.CALLBACK, post_process) | |||
| round_number = 0 | |||
| num_samples = len(train_data) | |||
| self.aggregation.total_size += num_samples | |||
| _flag = True | |||
| start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||
| res = None | |||
| while 1: | |||
| round_number += 1 | |||
| start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||
| self.log.info( | |||
| f"Federated learning start at {start}," | |||
| f" round_number={round_number}") | |||
| res = self.estimator.train( | |||
| train_data=train_data, valid_data=valid_data, **kwargs) | |||
| self.aggregation.weights = self.estimator.get_weights() | |||
| send_data = {"num_samples": num_samples, | |||
| "weights": self.aggregation.weights} | |||
| received = self.node.send( | |||
| send_data, msg_type="update_weight", job_name=self.job_name) | |||
| exit_flag = False | |||
| if (received and received["type"] == "update_weight" | |||
| and received["job_name"] == self.job_name): | |||
| recv = received["data"] | |||
| if _flag: | |||
| round_number += 1 | |||
| start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) | |||
| self.log.info( | |||
| f"Federated learning start, round_number={round_number}") | |||
| res = self.estimator.train( | |||
| train_data=train_data, valid_data=valid_data, **kwargs) | |||
| rec_client = received["client"] | |||
| rec_sample = recv["num_samples"] | |||
| current_weights = self.estimator.get_weights() | |||
| send_data = {"num_samples": num_samples, | |||
| "weights": current_weights} | |||
| self.node.send( | |||
| send_data, msg_type="update_weight", job_name=self.job_name | |||
| ) | |||
| received = self.node.recv(wait_data_type="recv_weight") | |||
| if not received: | |||
| _flag = False | |||
| continue | |||
| _flag = True | |||
| self.log.info( | |||
| f"Federated learning get weight from " | |||
| f"[{rec_client}] : {rec_sample}") | |||
| n_weight = self.aggregation.aggregate( | |||
| recv["weights"], rec_sample) | |||
| self.estimator.set_weights(n_weight) | |||
| exit_flag = received.get("exit_flag", "") == "ok" | |||
| rec_data = received.get("data", {}) | |||
| exit_flag = rec_data.get("exit_flag", "") | |||
| server_round = int(rec_data.get("round_number")) | |||
| total_size = int(rec_data.get("total_sample")) | |||
| self.log.info( | |||
| f"Federated learning recv weight, " | |||
| f"round: {server_round}, total_sample: {total_size}" | |||
| ) | |||
| n_weight = rec_data.get("weights") | |||
| self.estimator.set_weights(n_weight) | |||
| task_info = { | |||
| 'currentRound': round_number, | |||
| 'sampleCount': self.aggregation.total_size, | |||
| 'sampleCount': total_size, | |||
| 'startTime': start, | |||
| 'updateTime': time.strftime( | |||
| "%Y-%m-%d %H:%M:%S", | |||
| time.localtime())} | |||
| "%Y-%m-%d %H:%M:%S", time.localtime()) | |||
| } | |||
| model_paths = self.estimator.save() | |||
| task_info_res = self.estimator.model_info( | |||
| model_paths, result=res, relpath=self.config.data_path_prefix) | |||
| if exit_flag: | |||
| if exit_flag == "ok": | |||
| self.report_task_info( | |||
| task_info, | |||
| K8sResourceKindStatus.COMPLETED.value, | |||
| @@ -26,6 +26,7 @@ from websockets.exceptions import InvalidStatusCode, WebSocketException | |||
| from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK | |||
| from sedna.common.log import LOGGER | |||
| from sedna.common.file_ops import FileOps | |||
| @retry(stop_max_attempt_number=3, | |||
| @@ -98,8 +99,9 @@ class LCReporter(threading.Thread): | |||
| time.localtime()), | |||
| "inferenceNumber": self.inference_number, | |||
| "hardExampleNumber": self.hard_example_number, | |||
| "uploadCloudRatio": self.hard_example_number / | |||
| self.inference_number} | |||
| "uploadCloudRatio": (self.hard_example_number / | |||
| self.inference_number) | |||
| } | |||
| self.message["ownerInfo"] = info | |||
| LCClient.send(self.lc_server, | |||
| self.message["name"], | |||
| @@ -141,15 +143,18 @@ class AggregationClient: | |||
| "ping_interval": interval, | |||
| "max_size": min(max_size, 16 * 1024 * 1024) | |||
| }) | |||
| loop = asyncio.get_event_loop() | |||
| loop.run_until_complete( | |||
| asyncio.wait_for(self.connect(), timeout=timeout) | |||
| ) | |||
| async def connect(self): | |||
| LOGGER.info(f"{self.uri} connection by {self.client_id}") | |||
| try: | |||
| conn = websockets.connect( | |||
| self.ws = await asyncio.wait_for(websockets.connect( | |||
| self.uri, **self.kwargs | |||
| ) | |||
| self.ws = await conn.__aenter__() | |||
| ), self._ws_timeout) | |||
| await self.ws.send(json.dumps({'type': 'subscribe', | |||
| 'client_id': self.client_id})) | |||
| @@ -162,15 +167,15 @@ class AggregationClient: | |||
| LOGGER.info(f"{self.uri} connection lost") | |||
| raise | |||
| except ConnectionClosedOK: | |||
| LOGGER.info(f"{self.uri } connection closed") | |||
| LOGGER.info(f"{self.uri} connection closed") | |||
| raise | |||
| except InvalidStatusCode as err: | |||
| LOGGER.info( | |||
| f"{self.uri } websocket failed - " | |||
| f"{self.uri} websocket failed - " | |||
| f"with invalid status code {err.status_code}") | |||
| raise | |||
| except WebSocketException as err: | |||
| LOGGER.info(f"{self.uri } websocket failed - with {err}") | |||
| LOGGER.info(f"{self.uri} websocket failed - with {err}") | |||
| raise | |||
| except OSError as err: | |||
| LOGGER.info(f"{self.uri} connection failed - with {err}") | |||
| @@ -182,12 +187,16 @@ class AggregationClient: | |||
| async def _send(self, data): | |||
| for _ in range(self._retry): | |||
| try: | |||
| await self.ws.send(data) | |||
| result = await self.ws.recv() | |||
| return result | |||
| except Exception: | |||
| await asyncio.wait_for(self.ws.send(data), self._ws_timeout) | |||
| return | |||
| except Exception as err: | |||
| LOGGER.info(f"{self.uri} send data failed - with {err}") | |||
| time.sleep(self._retry_interval_seconds) | |||
| return None | |||
| return | |||
| async def _recv(self): | |||
| result = await self.ws.recv() | |||
| return result | |||
| def send(self, data, msg_type="message", job_name=""): | |||
| loop = asyncio.get_event_loop() | |||
| @@ -195,11 +204,18 @@ class AggregationClient: | |||
| "type": msg_type, "client": self.client_id, | |||
| "data": data, "job_name": job_name | |||
| }) | |||
| data_json = loop.run_until_complete(self._send(j)) | |||
| if data_json is None: | |||
| return | |||
| res = json.loads(data_json) | |||
| return res | |||
| loop.run_until_complete(self._send(j)) | |||
| def recv(self, wait_data_type=None): | |||
| loop = asyncio.get_event_loop() | |||
| data = loop.run_until_complete(self._recv()) | |||
| try: | |||
| data = json.loads(data) | |||
| except Exception: | |||
| pass | |||
| if not wait_data_type or (isinstance(data, dict) and | |||
| data.get("type", "") == wait_data_type): | |||
| return data | |||
| class ModelClient: | |||
| @@ -237,10 +253,11 @@ class KBClient: | |||
| with open(files, "rb") as fin: | |||
| files = {"file": fin} | |||
| outurl = http_request(url=_url, method="POST", files=files) | |||
| if outurl: | |||
| outurl = outurl.lstrip("/") | |||
| return f"{self.kbserver}/{outurl}" | |||
| return files | |||
| if FileOps.is_remote(outurl): | |||
| return outurl | |||
| outurl = outurl.lstrip("/") | |||
| FileOps.delete(files) | |||
| return f"{self.kbserver}/{outurl}" | |||
| def update_db(self, task_info_file): | |||
| @@ -250,13 +267,16 @@ class KBClient: | |||
| with open(task_info_file, "rb") as fin: | |||
| files = {"task": fin} | |||
| outurl = http_request(url=_url, method="POST", files=files) | |||
| outurl = outurl.lstrip("/") | |||
| _id = f"{self.kbserver}/{outurl}" | |||
| LOGGER.info(f"Update kb success: {_id}") | |||
| except Exception as err: | |||
| LOGGER.error(f"Update kb error: {err}") | |||
| _id = None | |||
| return _id | |||
| outurl = None | |||
| else: | |||
| if not FileOps.is_remote(outurl): | |||
| outurl = outurl.lstrip("/") | |||
| outurl = f"{self.kbserver}/{outurl}" | |||
| FileOps.delete(task_info_file) | |||
| return outurl | |||
| def update_task_status(self, tasks: str, new_status=1): | |||
| data = { | |||
| @@ -266,11 +286,12 @@ class KBClient: | |||
| _url = f"{self.kbserver}/update/status" | |||
| try: | |||
| outurl = http_request(url=_url, method="POST", json=data) | |||
| outurl = outurl.lstrip("/") | |||
| return f"{self.kbserver}/{outurl}" | |||
| except Exception as err: | |||
| LOGGER.error(f"Update kb error: {err}") | |||
| return None | |||
| def query_db(self, sample): | |||
| pass | |||
| outurl = None | |||
| if not outurl: | |||
| return None | |||
| if not FileOps.is_remote(outurl): | |||
| outurl = outurl.lstrip("/") | |||
| outurl = f"{self.kbserver}/{outurl}" | |||
| return outurl | |||
| @@ -13,7 +13,7 @@ | |||
| # limitations under the License. | |||
| import time | |||
| from typing import List, Optional, Dict | |||
| from typing import List, Optional, Dict, Any | |||
| import uuid | |||
| from pydantic import BaseModel | |||
| @@ -27,6 +27,8 @@ from starlette.types import ASGIApp, Receive, Scope, Send | |||
| from sedna.common.log import LOGGER | |||
| from sedna.common.utils import get_host_ip | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| from sedna.algorithms.aggregation import AggClient | |||
| from .base import BaseServer | |||
| @@ -37,10 +39,9 @@ class WSClientInfo(BaseModel): # pylint: disable=too-few-public-methods | |||
| """ | |||
| client information | |||
| """ | |||
| client_id: str | |||
| connected_at: float | |||
| job_count: int | |||
| info: Any | |||
| class WSClientInfoList(BaseModel): # pylint: disable=too-few-public-methods | |||
| @@ -57,6 +58,9 @@ class WSEventMiddleware: # pylint: disable=too-few-public-methods | |||
| servername = scope["path"].lstrip("/") | |||
| scope[servername] = self._server | |||
| await self._app(scope, receive, send) | |||
| # exit agg server if job complete | |||
| scope["app"].shutdown = (self._server.exit_check() | |||
| and self._server.empty) | |||
| class WSServerBase: | |||
| @@ -84,7 +88,7 @@ class WSServerBase: | |||
| LOGGER.info(f"Adding client {client_id}") | |||
| self._clients[client_id] = websocket | |||
| self._client_meta[client_id] = WSClientInfo( | |||
| client_id=client_id, connected_at=time.time(), job_count=0 | |||
| client_id=client_id, connected_at=time.time(), info=None | |||
| ) | |||
| async def kick_client(self, client_id: str): | |||
| @@ -103,7 +107,6 @@ class WSServerBase: | |||
| return self._client_meta.get(client_id) | |||
| async def send_message(self, client_id: str, msg: Dict): | |||
| self._client_meta[client_id].job_count += 1 | |||
| for to_client, websocket in self._clients.items(): | |||
| if to_client == client_id: | |||
| continue | |||
| @@ -115,40 +118,54 @@ class WSServerBase: | |||
| await websocket.send_json({"type": "CLIENT_JOIN", | |||
| "data": client_id}) | |||
| async def client_left(self, client_id: str): | |||
| for to_client, websocket in self._clients.items(): | |||
| if to_client == client_id: | |||
| continue | |||
| await websocket.send_json({"type": "CLIENT_LEAVE", | |||
| "data": client_id}) | |||
| class Aggregator(WSServerBase): | |||
| def __init__(self, **kwargs): | |||
| super(Aggregator, self).__init__() | |||
| self.exit_round = int(kwargs.get("exit_round", 3)) | |||
| self.current_round = {} | |||
| aggregation = kwargs.get("aggregation", "FedAvg") | |||
| self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation) | |||
| if callable(self.aggregation): | |||
| self.aggregation = self.aggregation() | |||
| self.participants_count = int(kwargs.get("participants_count", "1")) | |||
| self.current_round = 0 | |||
| async def send_message(self, client_id: str, msg: Dict): | |||
| self._client_meta[client_id].job_count += 1 | |||
| clients = list(self._clients.items()) | |||
| for to_client, websocket in clients: | |||
| if msg.get("type", "") == "update_weight": | |||
| if to_client == client_id: | |||
| continue | |||
| self.current_round[to_client] = self.current_round.get( | |||
| to_client, 0) + 1 | |||
| exit_flag = "ok" if self.exit_check(to_client) else "continue" | |||
| msg["exit_flag"] = exit_flag | |||
| data = msg.get("data") | |||
| if data and msg.get("type", "") == "update_weight": | |||
| info = AggClient() | |||
| info.num_samples = int(data["num_samples"]) | |||
| info.weights = data["weights"] | |||
| self._client_meta[client_id].info = info | |||
| current_clinets = [ | |||
| x.info for x in self._client_meta.values() if x.info | |||
| ] | |||
| # exit while aggregation job is NOT start | |||
| if len(current_clinets) < self.participants_count: | |||
| return | |||
| self.current_round += 1 | |||
| weights = self.aggregation.aggregate(current_clinets) | |||
| exit_flag = "ok" if self.exit_check() else "continue" | |||
| msg["type"] = "recv_weight" | |||
| msg["round_number"] = self.current_round | |||
| msg["data"] = { | |||
| "total_sample": self.aggregation.total_size, | |||
| "round_number": self.current_round, | |||
| "weights": weights, | |||
| "exit_flag": exit_flag | |||
| } | |||
| for to_client, websocket in self._clients.items(): | |||
| try: | |||
| await websocket.send_json(msg) | |||
| except Exception as err: | |||
| LOGGER.error(err) | |||
| else: | |||
| if msg["type"] == "recv_weight": | |||
| self._client_meta[to_client].info = None | |||
| def exit_check(self, client_id): | |||
| current_round = self.current_round.get(client_id, 0) | |||
| return current_round >= self.exit_round | |||
| def exit_check(self): | |||
| return self.current_round >= self.exit_round | |||
| class BroadcastWs(WebSocketEndpoint): | |||
| @@ -176,7 +193,6 @@ class BroadcastWs(WebSocketEndpoint): | |||
| "on_disconnect() called without a valid client_id" | |||
| ) | |||
| self.server.remove_client(self.client_id) | |||
| await self.server.client_left(self.client_id) | |||
| async def on_receive(self, _websocket: WebSocket, msg: Dict): | |||
| command = msg.get("type", "") | |||
| @@ -185,50 +201,59 @@ class BroadcastWs(WebSocketEndpoint): | |||
| await self.server.client_joined(self.client_id) | |||
| self.server.add_client(self.client_id, _websocket) | |||
| if self.client_id is None: | |||
| raise RuntimeError("on_receive() called without a valid client_id") | |||
| raise RuntimeError( | |||
| "on_receive() called without a valid client_id") | |||
| await self.server.send_message(self.client_id, msg) | |||
| class AggregationServer(BaseServer): | |||
| def __init__( | |||
| self, | |||
| servername: str, | |||
| aggregation: str, | |||
| host: str = None, | |||
| http_port: int = 7363, | |||
| exit_round: int = 1, | |||
| ws_size: int = 10 * | |||
| 1024 * | |||
| 1024): | |||
| participants_count: int = 1, | |||
| ws_size: int = 10 * 1024 * 1024): | |||
| if not host: | |||
| host = get_host_ip() | |||
| super( | |||
| AggregationServer, | |||
| self).__init__( | |||
| servername=servername, | |||
| servername=aggregation, | |||
| host=host, | |||
| http_port=http_port, | |||
| ws_size=ws_size) | |||
| self.server_name = servername | |||
| self.aggregation = aggregation | |||
| self.participants_count = participants_count | |||
| self.exit_round = max(int(exit_round), 1) | |||
| self.app = FastAPI( | |||
| routes=[ | |||
| APIRoute( | |||
| f"/{servername}", | |||
| f"/{aggregation}", | |||
| self.client_info, | |||
| response_class=JSONResponse, | |||
| ), | |||
| WebSocketRoute( | |||
| f"/{servername}", | |||
| f"/{aggregation}", | |||
| BroadcastWs | |||
| ) | |||
| ], | |||
| ) | |||
| self.app.shutdown = False | |||
| def start(self): | |||
| """ | |||
| Start the server | |||
| """ | |||
| self.app.add_middleware(WSEventMiddleware, exit_round=self.exit_round) | |||
| self.app.add_middleware( | |||
| WSEventMiddleware, | |||
| exit_round=self.exit_round, | |||
| aggregation=self.aggregation, | |||
| participants_count=self.participants_count | |||
| ) # define the aggregation method and exit condition | |||
| self.run(self.app, ws_max_size=self.ws_size) | |||
| async def client_info(self, request: Request): | |||
| @@ -12,6 +12,11 @@ | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import contextlib | |||
| import time | |||
| import threading | |||
| import asyncio | |||
| import uvicorn | |||
| from fastapi.middleware.cors import CORSMiddleware | |||
| @@ -19,9 +24,25 @@ from sedna.common.log import LOGGER | |||
| from sedna.common.utils import get_host_ip | |||
| class Server(uvicorn.Server): | |||
| def install_signal_handlers(self): | |||
| pass | |||
| @contextlib.contextmanager | |||
| def run_in_thread(self): | |||
| thread = threading.Thread(target=self.run, daemon=True) | |||
| thread.start() | |||
| try: | |||
| yield thread | |||
| finally: | |||
| self.should_exit = True | |||
| thread.join() | |||
| class BaseServer: | |||
| # pylint: disable=too-many-instance-attributes,too-many-arguments | |||
| DEBUG = True | |||
| WAIT_TIME = 15 | |||
| def __init__( | |||
| self, | |||
| @@ -50,14 +71,15 @@ class BaseServer: | |||
| self.url = f"{protocal}://{self.host}:{self.http_port}" | |||
| def run(self, app, **kwargs): | |||
| app.add_middleware( | |||
| CORSMiddleware, allow_origins=["*"], allow_credentials=True, | |||
| allow_methods=["*"], allow_headers=["*"], | |||
| ) | |||
| if hasattr(app, "add_middleware"): | |||
| app.add_middleware( | |||
| CORSMiddleware, allow_origins=["*"], allow_credentials=True, | |||
| allow_methods=["*"], allow_headers=["*"], | |||
| ) | |||
| LOGGER.info(f"Start {self.server_name} server over {self.url}") | |||
| uvicorn.run( | |||
| config = uvicorn.Config( | |||
| app, | |||
| host=self.host, | |||
| port=self.http_port, | |||
| @@ -65,7 +87,20 @@ class BaseServer: | |||
| ssl_certfile=self.certfile, | |||
| workers=self.workers, | |||
| timeout_keep_alive=self.timeout, | |||
| log_level="info", | |||
| **kwargs) | |||
| server = Server(config=config) | |||
| with server.run_in_thread() as current_thread: | |||
| return self.wait_stop(current=current_thread) | |||
| def wait_stop(self, current): | |||
| """wait the stop flag to shutdown the server""" | |||
| while 1: | |||
| time.sleep(self.WAIT_TIME) | |||
| if not current.isAlive(): | |||
| return | |||
| if getattr(self.app, "shutdown", False): | |||
| return | |||
| def get_all_urls(self): | |||
| url_list = [{"path": route.path, "name": route.name} | |||