Browse Source

fix bug: Cloud worker not exiting

Signed-off-by: JoeyHwong <joeyhwong@gknow.cn>
tags/v0.3.1
JoeyHwong 4 years ago
parent
commit
dd0dd5cc2d
11 changed files with 199 additions and 152 deletions
  1. +6
    -2
      examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py
  2. +2
    -2
      examples/federated_learning/surface_defect_detection/training_worker/inference.py
  3. +1
    -0
      lib/sedna/__init__.py
  4. +31
    -14
      lib/sedna/algorithms/aggregation/aggregation.py
  5. +3
    -1
      lib/sedna/backend/tensorflow/__init__.py
  6. +0
    -15
      lib/sedna/core/__init__.py
  7. +38
    -37
      lib/sedna/core/federated_learning/federated_learning.py
  8. +6
    -36
      lib/sedna/core/joint_inference/joint_inference.py
  9. +5
    -2
      lib/sedna/service/client.py
  10. +63
    -38
      lib/sedna/service/server/aggregation.py
  11. +44
    -5
      lib/sedna/service/server/base.py

+ 6
- 2
examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py View File

@@ -23,14 +23,18 @@ def run_server():
exit_round = int(Context.get_parameters(
"exit_round", 3
))
participants_count = int(Context.get_parameters(
"participants_count", 1
))
agg_ip = Context.get_parameters("AGG_IP", "0.0.0.0")
agg_port = int(Context.get_parameters("AGG_PORT", "7363"))
server = AggregationServer(
servername=aggregation_algorithm,
aggregation=aggregation_algorithm,
host=agg_ip,
http_port=agg_port,
exit_round=exit_round,
ws_size=20 * 1024 * 1024
ws_size=20 * 1024 * 1024,
participants_count=participants_count
)
server.start()



+ 2
- 2
examples/federated_learning/surface_defect_detection/training_worker/inference.py View File

@@ -45,8 +45,8 @@ def main():
test_data = TxtDataParse(data_type="test", func=image_process)
test_data.parse(fl_instance.config.test_dataset_url)

fl_instance.evaluate(test_data)
return fl_instance.inference(test_data.x)


if __name__ == '__main__':
main()
print(main())

+ 1
- 0
lib/sedna/__init__.py View File

@@ -13,3 +13,4 @@
# limitations under the License.

from .__version__ import __version__
import sedna.algorithms

+ 31
- 14
lib/sedna/algorithms/aggregation/aggregation.py View File

@@ -16,41 +16,58 @@

import abc
from copy import deepcopy
from typing import List

import numpy as np

from sedna.common.class_factory import ClassFactory, ClassType

__all__ = ('FedAvg',)
__all__ = ('AggClient', 'FedAvg',)


class AggClient:
"""Aggregation clients"""
num_samples: int
weights: List


class BaseAggregation(metaclass=abc.ABCMeta):
"""Abstract class of aggregator"""

def __init__(self):
self.total_size = 0
self.weights = None

@abc.abstractmethod
def aggregate(self, weights, size=0):
def aggregate(self, clients: List[AggClient]):
"""
Aggregation
:param weights: deep learning weight
:param size: numbers of sample in each loop
Some algorithms can be aggregated in sequence,
but some can be calculated only after all aggregated data is uploaded.
therefore, this abstractmethod should consider that all weights are
uploaded.
:param clients: All clients in federated learning job
:return: final weights
"""


@ClassFactory.register(ClassType.FL_AGG)
class FedAvg(BaseAggregation, abc.ABC):
"""Federated averaging algorithm"""
"""
Federated averaging algorithm : Calculate the average weight
according to the number of samples
"""

def aggregate(self, weights, size=0):
total_sample = self.total_size + size
if not total_sample:
def aggregate(self, clients: List[AggClient]):
if not len(clients):
return self.weights
self.total_size = sum([c.num_samples for c in clients])
old_weight = [np.zeros(np.array(c).shape) for c in
next(iter(clients)).weights]
updates = []
for inx, weight in enumerate(weights):
old_weight = self.weights[inx]
row_weight = ((np.array(weight) - old_weight) *
(size / total_sample) + old_weight)
updates.append(row_weight)
for inx, row in enumerate(old_weight):
for c in clients:
row += (np.array(c.weights[inx]) * c.num_samples
/ self.total_size)
updates.append(row.tolist())
self.weights = deepcopy(updates)
return updates

+ 3
- 1
lib/sedna/backend/tensorflow/__init__.py View File

@@ -14,6 +14,7 @@

import os

import numpy as np
import tensorflow as tf

from sedna.backend.base import BackendBase
@@ -74,7 +75,7 @@ class TFBackend(BackendBase):
if not self.has_load:
tf.reset_default_graph()
self.sess = self.load()
return self.estimator.predict(data=data, **kwargs)
return self.estimator.predict(data, **kwargs)

def evaluate(self, data, **kwargs):
if not self.has_load:
@@ -135,4 +136,5 @@ class KerasBackend(TFBackend):
return list(map(lambda x: x.tolist(), self.estimator.get_weights()))

def set_weights(self, weights):
weights = [np.array(x) for x in weights]
self.estimator.set_weights(weights)

+ 0
- 15
lib/sedna/core/__init__.py View File

@@ -1,15 +0,0 @@
# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from sedna.algorithms import *

+ 38
- 37
lib/sedna/core/federated_learning/federated_learning.py View File

@@ -80,57 +80,58 @@ class FederatedLearning(JobBase):
"""

callback_func = None
if post_process is not None:
if post_process:
callback_func = ClassFactory.get_cls(
ClassType.CALLBACK, post_process)

round_number = 0
num_samples = len(train_data)
self.aggregation.total_size += num_samples

_flag = True
start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
res = None
while 1:
round_number += 1
start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
self.log.info(
f"Federated learning start at {start},"
f" round_number={round_number}")
res = self.estimator.train(
train_data=train_data, valid_data=valid_data, **kwargs)

self.aggregation.weights = self.estimator.get_weights()
send_data = {"num_samples": num_samples,
"weights": self.aggregation.weights}
self.node.send(send_data,
msg_type="update_weight",
job_name=self.job_name)
received = self.node.recv()
exit_flag = False
if (received and received["type"] == "update_weight"
and received["job_name"] == self.job_name):
recv = received["data"]

rec_client = received["client"]
rec_sample = recv["num_samples"]

if _flag:
round_number += 1
start = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
self.log.info(
f"Federated learning get weight from "
f"[{rec_client}] : {rec_sample}")
n_weight = self.aggregation.aggregate(
recv["weights"], rec_sample)
self.estimator.set_weights(n_weight)
exit_flag = received.get("exit_flag", "") == "ok"

f"Federated learning start at {start},"
f" round_number={round_number}")
res = self.estimator.train(
train_data=train_data, valid_data=valid_data, **kwargs)

current_weights = self.estimator.get_weights()
send_data = {"num_samples": num_samples,
"weights": current_weights}
self.node.send(
send_data, msg_type="update_weight", job_name=self.job_name
)
received = self.node.recv(wait_data_type="recv_weight")
if not received:
_flag = False
continue
_flag = True

rec_data = received.get("data", {})
exit_flag = rec_data.get("exit_flag", "")
server_round = int(rec_data.get("round_number"))
total_size = int(rec_data.get("total_sample"))
self.log.info(
f"Federated learning recv weight, "
f"round: {server_round}, total_sample: {total_size}"
)
n_weight = rec_data.get("weights")
self.estimator.set_weights(n_weight)
task_info = {
'currentRound': round_number,
'sampleCount': self.aggregation.total_size,
'sampleCount': total_size,
'startTime': start,
'updateTime': time.strftime(
"%Y-%m-%d %H:%M:%S",
time.localtime())}
"%Y-%m-%d %H:%M:%S", time.localtime())
}
model_paths = self.estimator.save()
task_info_res = self.estimator.model_info(
model_paths, result=res, relpath=self.config.data_path_prefix)
if exit_flag:
if exit_flag == "ok":
self.report_task_info(
task_info,
K8sResourceKindStatus.COMPLETED.value,


+ 6
- 36
lib/sedna/core/joint_inference/joint_inference.py View File

@@ -32,22 +32,13 @@ class TSBigModelService(JobBase):
Provides RESTful interfaces for large-model inference.
"""

def __init__(self, estimator=None):
"""
Initial a big model service for JointInference
:param estimator: Customize estimator
"""

super(TSBigModelService, self).__init__(estimator=estimator)
def __init__(self, estimator=None, config=None):
super(TSBigModelService, self).__init__(
estimator=estimator, config=config)
self.local_ip = self.get_parameters("BIG_MODEL_BIND_IP", get_host_ip())
self.port = int(self.get_parameters("BIG_MODEL_BIND_PORT", "5000"))

def start(self):
"""
Start inference rest server
:return:
"""

if callable(self.estimator):
self.estimator = self.estimator()
if not os.path.exists(self.model_path):
@@ -65,14 +56,6 @@ class TSBigModelService(JobBase):
"""todo: no support yet"""

def inference(self, data=None, post_process=None, **kwargs):
"""
Inference task for IncrementalLearning
:param data: inference sample
:param post_process: post process
:param kwargs: params for inference of big model
:return: inference result
"""

callback_func = None
if callable(post_process):
callback_func = post_process
@@ -92,13 +75,9 @@ class JointInference(JobBase):
Joint inference
"""

def __init__(self, estimator=None):
"""
Initial a JointInference Job
:param estimator: Customize estimator
"""

super(JointInference, self).__init__(estimator=estimator)
def __init__(self, estimator=None, config=None):
super(JointInference, self).__init__(
estimator=estimator, config=config)
self.job_kind = K8sResourceKind.JOINT_INFERENCE_SERVICE.value
self.local_ip = get_host_ip()
self.remote_ip = self.get_parameters(
@@ -137,15 +116,6 @@ class JointInference(JobBase):
"""todo: no support yet"""

def inference(self, data=None, post_process=None, **kwargs):
"""
Inference task for IncrementalLearning
:param data: inference sample
:param post_process: post process
:param kwargs: params for inference of customize estimator
:return: if is hard sample, real result,
little model result, big model result
"""

callback_func = None
if callable(post_process):
callback_func = post_process


+ 5
- 2
lib/sedna/service/client.py View File

@@ -206,12 +206,15 @@ class AggregationClient:
})
loop.run_until_complete(self._send(j))

def recv(self):
def recv(self, wait_data_type=None):
loop = asyncio.get_event_loop()
data = loop.run_until_complete(self._recv())
try:
return json.loads(data)
data = json.loads(data)
except Exception:
pass
if not wait_data_type or (isinstance(data, dict) and
data.get("type", "") == wait_data_type):
return data




+ 63
- 38
lib/sedna/service/server/aggregation.py View File

@@ -13,7 +13,7 @@
# limitations under the License.

import time
from typing import List, Optional, Dict
from typing import List, Optional, Dict, Any

import uuid
from pydantic import BaseModel
@@ -27,6 +27,8 @@ from starlette.types import ASGIApp, Receive, Scope, Send

from sedna.common.log import LOGGER
from sedna.common.utils import get_host_ip
from sedna.common.class_factory import ClassFactory, ClassType
from sedna.algorithms.aggregation import AggClient

from .base import BaseServer

@@ -37,10 +39,9 @@ class WSClientInfo(BaseModel): # pylint: disable=too-few-public-methods
"""
client information
"""

client_id: str
connected_at: float
job_count: int
info: Any


class WSClientInfoList(BaseModel): # pylint: disable=too-few-public-methods
@@ -57,6 +58,9 @@ class WSEventMiddleware: # pylint: disable=too-few-public-methods
servername = scope["path"].lstrip("/")
scope[servername] = self._server
await self._app(scope, receive, send)
# exit agg server if job complete
scope["app"].shutdown = (self._server.exit_check()
and self._server.empty)


class WSServerBase:
@@ -84,7 +88,7 @@ class WSServerBase:
LOGGER.info(f"Adding client {client_id}")
self._clients[client_id] = websocket
self._client_meta[client_id] = WSClientInfo(
client_id=client_id, connected_at=time.time(), job_count=0
client_id=client_id, connected_at=time.time(), info=None
)

async def kick_client(self, client_id: str):
@@ -103,7 +107,6 @@ class WSServerBase:
return self._client_meta.get(client_id)

async def send_message(self, client_id: str, msg: Dict):
self._client_meta[client_id].job_count += 1
for to_client, websocket in self._clients.items():
if to_client == client_id:
continue
@@ -115,40 +118,54 @@ class WSServerBase:
await websocket.send_json({"type": "CLIENT_JOIN",
"data": client_id})

async def client_left(self, client_id: str):
for to_client, websocket in self._clients.items():
if to_client == client_id:
continue
await websocket.send_json({"type": "CLIENT_LEAVE",
"data": client_id})


class Aggregator(WSServerBase):
def __init__(self, **kwargs):
super(Aggregator, self).__init__()
self.exit_round = int(kwargs.get("exit_round", 3))
self.current_round = {}
aggregation = kwargs.get("aggregation", "FedAvg")
self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)
if callable(self.aggregation):
self.aggregation = self.aggregation()
self.participants_count = int(kwargs.get("participants_count", "1"))
self.current_round = 0

async def send_message(self, client_id: str, msg: Dict):
self._client_meta[client_id].job_count += 1
clients = list(self._clients.items())
for to_client, websocket in clients:

if msg.get("type", "") == "update_weight":
if to_client == client_id:
continue
self.current_round[to_client] = self.current_round.get(
to_client, 0) + 1
exit_flag = "ok" if self.exit_check(to_client) else "continue"
msg["exit_flag"] = exit_flag
data = msg.get("data")
if data and msg.get("type", "") == "update_weight":
info = AggClient()
info.num_samples = int(data["num_samples"])
info.weights = data["weights"]
self._client_meta[client_id].info = info
current_clinets = [
x.info for x in self._client_meta.values() if x.info
]
# exit while aggregation job is NOT start
if len(current_clinets) < self.participants_count:
return
self.current_round += 1
weights = self.aggregation.aggregate(current_clinets)
exit_flag = "ok" if self.exit_check() else "continue"

msg["type"] = "recv_weight"
msg["round_number"] = self.current_round
msg["data"] = {
"total_sample": self.aggregation.total_size,
"round_number": self.current_round,
"weights": weights,
"exit_flag": exit_flag
}
for to_client, websocket in self._clients.items():
try:
await websocket.send_json(msg)
except Exception as err:
LOGGER.error(err)
else:
if msg["type"] == "recv_weight":
self._client_meta[to_client].info = None

def exit_check(self, client_id):
current_round = self.current_round.get(client_id, 0)
return current_round >= self.exit_round
def exit_check(self):
return self.current_round >= self.exit_round


class BroadcastWs(WebSocketEndpoint):
@@ -176,7 +193,6 @@ class BroadcastWs(WebSocketEndpoint):
"on_disconnect() called without a valid client_id"
)
self.server.remove_client(self.client_id)
await self.server.client_left(self.client_id)

async def on_receive(self, _websocket: WebSocket, msg: Dict):
command = msg.get("type", "")
@@ -185,50 +201,59 @@ class BroadcastWs(WebSocketEndpoint):
await self.server.client_joined(self.client_id)
self.server.add_client(self.client_id, _websocket)
if self.client_id is None:
raise RuntimeError("on_receive() called without a valid client_id")
raise RuntimeError(
"on_receive() called without a valid client_id")
await self.server.send_message(self.client_id, msg)


class AggregationServer(BaseServer):
def __init__(
self,
servername: str,
aggregation: str,
host: str = None,
http_port: int = 7363,
exit_round: int = 1,
ws_size: int = 10 *
1024 *
1024):
participants_count: int = 1,
ws_size: int = 10 * 1024 * 1024):
if not host:
host = get_host_ip()
super(
AggregationServer,
self).__init__(
servername=servername,
servername=aggregation,
host=host,
http_port=http_port,
ws_size=ws_size)
self.server_name = servername
self.aggregation = aggregation
self.participants_count = participants_count
self.exit_round = max(int(exit_round), 1)
self.app = FastAPI(
routes=[
APIRoute(
f"/{servername}",
f"/{aggregation}",
self.client_info,
response_class=JSONResponse,
),
WebSocketRoute(
f"/{servername}",
f"/{aggregation}",
BroadcastWs
)
],
)
self.app.shutdown = False

def start(self):
"""
Start the server
"""
self.app.add_middleware(WSEventMiddleware, exit_round=self.exit_round)

self.app.add_middleware(
WSEventMiddleware,
exit_round=self.exit_round,
aggregation=self.aggregation,
participants_count=self.participants_count
) # define the aggregation method and exit condition

self.run(self.app, ws_max_size=self.ws_size)

async def client_info(self, request: Request):


+ 44
- 5
lib/sedna/service/server/base.py View File

@@ -12,6 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import time
import threading
import asyncio

import uvicorn
from fastapi.middleware.cors import CORSMiddleware

@@ -19,9 +24,27 @@ from sedna.common.log import LOGGER
from sedna.common.utils import get_host_ip


class Server(uvicorn.Server):
def install_signal_handlers(self):
pass

@contextlib.contextmanager
def run_in_thread(self):
thread = threading.Thread(target=self.run)
thread.start()
try:
while not self.started:
time.sleep(1e-3)
yield
finally:
self.should_exit = True
thread.join()


class BaseServer:
# pylint: disable=too-many-instance-attributes,too-many-arguments
DEBUG = True
WAIT_TIME = 15

def __init__(
self,
@@ -50,14 +73,15 @@ class BaseServer:
self.url = f"{protocal}://{self.host}:{self.http_port}"

def run(self, app, **kwargs):
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)
if hasattr(app, "add_middleware"):
app.add_middleware(
CORSMiddleware, allow_origins=["*"], allow_credentials=True,
allow_methods=["*"], allow_headers=["*"],
)

LOGGER.info(f"Start {self.server_name} server over {self.url}")

uvicorn.run(
config = uvicorn.Config(
app,
host=self.host,
port=self.http_port,
@@ -65,7 +89,22 @@ class BaseServer:
ssl_certfile=self.certfile,
workers=self.workers,
timeout_keep_alive=self.timeout,
log_level="info",
**kwargs)
server = Server(config=config)
with server.run_in_thread():
loop = asyncio.get_event_loop()
stop = loop.create_task(self.wait_stop(loop.create_future()))
loop.run_until_complete(stop)
return

async def wait_stop(self, stop):
"""wait the stop flag to shutdown the server"""
while 1:
await asyncio.sleep(self.WAIT_TIME)
if getattr(self.app, "shutdown", False):
stop.set_result(1)
return

def get_all_urls(self):
url_list = [{"path": route.path, "name": route.name}


Loading…
Cancel
Save