1. add transmitter, client_choose, aggregation interface to Lib. 2. add example of how to use new added interface. Signed-off-by: Jie Pu <pujie2@huawei.com> Signed-off-by: XinYao1994 <xyao@cs.hku.hk> updatetags/v0.4.0^2
| @@ -0,0 +1,54 @@ | |||||
| apiVersion: sedna.io/v1alpha1 | |||||
| kind: FederatedLearningJob | |||||
| metadata: | |||||
| name: yolo-v5 | |||||
| spec: | |||||
| pretrainedModel: # option | |||||
| name: "yolo-v5-pretrained-model" | |||||
| transimitter: # option | |||||
| ws: { } # option, by default | |||||
| s3: # option, but at least one | |||||
| aggDataPath: "s3://sedna/fl/aggregation_data" | |||||
| credentialName: mysecret | |||||
| aggregationWorker: | |||||
| model: | |||||
| name: "yolo-v5-model" | |||||
| template: | |||||
| spec: | |||||
| nodeName: "sedna-control-plane" | |||||
| containers: | |||||
| - image: kubeedge/sedna-fl-aggregation:mistnetyolo | |||||
| name: agg-worker | |||||
| imagePullPolicy: IfNotPresent | |||||
| env: # user defined environments | |||||
| - name: "cut_layer" | |||||
| value: "4" | |||||
| - name: "epsilon" | |||||
| value: "100" | |||||
| - name: "aggregation_algorithm" | |||||
| value: "mistnet" | |||||
| - name: "batch_size" | |||||
| resources: # user defined resources | |||||
| limits: | |||||
| memory: 8Gi | |||||
| trainingWorkers: | |||||
| - dataset: | |||||
| name: "coco-dataset" | |||||
| template: | |||||
| spec: | |||||
| nodeName: "edge-node" | |||||
| containers: | |||||
| - image: kubeedge/sedna-fl-train:mistnetyolo | |||||
| name: train-worker | |||||
| imagePullPolicy: IfNotPresent | |||||
| args: [ "-i", "1" ] | |||||
| env: # user defined environments | |||||
| - name: "batch_size" | |||||
| value: "32" | |||||
| - name: "learning_rate" | |||||
| value: "0.001" | |||||
| - name: "epochs" | |||||
| value: "1" | |||||
| resources: # user defined resources | |||||
| limits: | |||||
| memory: 2Gi | |||||
| @@ -0,0 +1,238 @@ | |||||
| # Collaboratively Train Yolo-v5 Using MistNet on COCO128 Dataset | |||||
| This case introduces how to train a federated learning job with an aggregation algorithm named MistNet in MNIST | |||||
| handwritten digit classification scenario. Data is scattered in different places (such as edge nodes, cameras, and | |||||
| others) and cannot be aggregated at the server due to data privacy and bandwidth. As a result, we cannot use all the | |||||
| data for training. In some cases, edge nodes have limited computing resources and even have no training capability. The | |||||
| edge cannot gain the updated weights from the training process. Therefore, traditional algorithms (e.g., federated | |||||
| average), which usually aggregate the updated weights trained by different edge clients, cannot work in this scenario. | |||||
| MistNet is proposed to address this issue. | |||||
| MistNet partitions a DNN model into two parts, a lightweight feature extractor at the edge side to generate meaningful | |||||
| features from the raw data, and a classifier including the most model layers at the cloud to be iteratively trained for | |||||
| specific tasks. MistNet achieves acceptable model utility while greatly reducing privacy leakage from the released | |||||
| intermediate features. | |||||
| ## Object Detection Experiment | |||||
| > Assume that there are two edge nodes and a cloud node. Data on the edge nodes cannot be migrated to the cloud due to privacy issues. | |||||
| > Base on this scenario, we will demonstrate the mnist example. | |||||
| ### Prepare Nodes | |||||
| ``` | |||||
| CLOUD_NODE="cloud-node-name" | |||||
| EDGE1_NODE="edge1-node-name" | |||||
| EDGE2_NODE="edge2-node-name" | |||||
| ``` | |||||
| ### Install Sedna | |||||
| Follow the [Sedna installation document](/docs/setup/install.md) to install Sedna. | |||||
| ### Prepare Dataset | |||||
| Download [dataset](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) and do data partition | |||||
| ``` | |||||
| wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip | |||||
| unzip coco128.zip -d data | |||||
| rm coco128.zip | |||||
| python partition.py ./data 2 | |||||
| ``` | |||||
| move ```./data/1``` to `/data` of ```EDGE1_NODE```. | |||||
| ``` | |||||
| mkdir -p /data | |||||
| cd /data | |||||
| mv ./data/1 ./ | |||||
| ``` | |||||
| move ```./data/2``` to `/data` of ```EDGE2_NODE```. | |||||
| ``` | |||||
| mkdir -p /data | |||||
| cd /data | |||||
| mv ./data/2 ./ | |||||
| ``` | |||||
| ### Prepare Images | |||||
| This example uses these images: | |||||
| 1. aggregation worker: ```kubeedge/sedna-example-federated-learning-mistnet:v0.3.0``` | |||||
| 2. train worker: ```kubeedge/sedna-example-federated-learning-mistnet-client:v0.3.0``` | |||||
| These images are generated by the script [build_images.sh](/examples/build_image.sh). | |||||
| ### Create Federated Learning Job | |||||
| #### Create Dataset | |||||
| create dataset for `$EDGE1_NODE` | |||||
| ```n | |||||
| kubectl create -f - <<EOF | |||||
| apiVersion: sedna.io/v1alpha1 | |||||
| kind: Dataset | |||||
| metadata: | |||||
| name: "coco-dataset" | |||||
| spec: | |||||
| url: "/data/test.txt" | |||||
| format: "txt" | |||||
| nodeName: edge-node | |||||
| EOF | |||||
| ``` | |||||
| create dataset for `$EDGE2_NODE` | |||||
| ``` | |||||
| kubectl create -f - <<EOF | |||||
| apiVersion: sedna.io/v1alpha1 | |||||
| kind: Dataset | |||||
| metadata: | |||||
| name: "coco-dataset" | |||||
| spec: | |||||
| url: "/data/test.txt" | |||||
| format: "txt" | |||||
| nodeName: edge-node | |||||
| EOF | |||||
| ``` | |||||
| #### Create Model | |||||
| create the directory `/model` in the host of `$EDGE1_NODE` | |||||
| ``` | |||||
| mkdir /model | |||||
| ``` | |||||
| create the directory `/model` in the host of `$EDGE2_NODE` | |||||
| ``` | |||||
| mkdir /model | |||||
| ``` | |||||
| ``` | |||||
| TODO: put pretrained model on nodes. | |||||
| ``` | |||||
| create model | |||||
| ``` | |||||
| kubectl create -f - <<EOF | |||||
| apiVersion: sedna.io/v1alpha1 | |||||
| kind: Model | |||||
| metadata: | |||||
| name: "yolo-v5-model" | |||||
| spec: | |||||
| url: "/model/yolo.pb" | |||||
| format: "pb" | |||||
| EOF | |||||
| ``` | |||||
| #### Start Federated Learning Job | |||||
| ``` | |||||
| kubectl create -f - <<EOF | |||||
| apiVersion: sedna.io/v1alpha1 | |||||
| kind: FederatedLearningJob | |||||
| metadata: | |||||
| name: mistnet-on-mnist-dataset | |||||
| spec: | |||||
| stopCondition: | |||||
| operator: "or" # and | |||||
| conditions: | |||||
| - operator: ">" | |||||
| threshold: 100 | |||||
| metric: rounds | |||||
| - operator: ">" | |||||
| threshold: 0.95 | |||||
| metric: targetAccuracy | |||||
| - operator: "<" | |||||
| threshold: 0.03 | |||||
| metric: deltaLoss | |||||
| aggregationTrigger: | |||||
| condition: | |||||
| operator: ">" | |||||
| threshold: 5 | |||||
| metric: num_of_ready_clients | |||||
| aggregationWorker: | |||||
| model: | |||||
| name: "mistnet-on-mnist-model" | |||||
| template: | |||||
| spec: | |||||
| nodeName: $CLOUD_NODE | |||||
| containers: | |||||
| - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-aggregation:v0.4.0 | |||||
| name: agg-worker | |||||
| imagePullPolicy: IfNotPresent | |||||
| env: # user defined environments | |||||
| - name: "cut_layer" | |||||
| value: "4" | |||||
| - name: "epsilon" | |||||
| value: "100" | |||||
| - name: "aggregation_algorithm" | |||||
| value: "mistnet" | |||||
| - name: "batch_size" | |||||
| value: "10" | |||||
| resources: # user defined resources | |||||
| limits: | |||||
| memory: 2Gi | |||||
| trainingWorkers: | |||||
| - dataset: | |||||
| name: "edge1-surface-defect-detection-dataset" | |||||
| template: | |||||
| spec: | |||||
| nodeName: $EDGE1_NODE | |||||
| containers: | |||||
| - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0 | |||||
| name: train-worker | |||||
| imagePullPolicy: IfNotPresent | |||||
| env: # user defined environments | |||||
| - name: "batch_size" | |||||
| value: "32" | |||||
| - name: "learning_rate" | |||||
| value: "0.001" | |||||
| - name: "epochs" | |||||
| value: "2" | |||||
| resources: # user defined resources | |||||
| limits: | |||||
| memory: 2Gi | |||||
| - dataset: | |||||
| name: "edge2-surface-defect-detection-dataset" | |||||
| template: | |||||
| spec: | |||||
| nodeName: $EDGE2_NODE | |||||
| containers: | |||||
| - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0 | |||||
| name: train-worker | |||||
| imagePullPolicy: IfNotPresent | |||||
| env: # user defined environments | |||||
| - name: "batch_size" | |||||
| value: "32" | |||||
| - name: "learning_rate" | |||||
| value: "0.001" | |||||
| - name: "epochs" | |||||
| value: "2" | |||||
| resources: # user defined resources | |||||
| limits: | |||||
| memory: 2Gi | |||||
| EOF | |||||
| ``` | |||||
| ``` | |||||
| TODO: show the benifit of mistnet. for example, the compared results of fedavg & mistnet. | |||||
| ``` | |||||
| ### Check Federated Learning Status | |||||
| ``` | |||||
| kubectl get federatedlearningjob surface-defect-detection | |||||
| ``` | |||||
| ### Check Federated Learning Train Result | |||||
| After the job completed, you will find the model generated on the directory `/model` in `$EDGE1_NODE` and `$EDGE2_NODE`. | |||||
| @@ -0,0 +1,35 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from interface import mistnet, s3_transmitter, simple_chooser | |||||
| from interface import Dataset, Estimator | |||||
| from sedna.service.server import AggregationServer | |||||
| def run_server(): | |||||
| data = Dataset() | |||||
| estimator = Estimator() | |||||
| server = AggregationServer( | |||||
| data=data, | |||||
| estimator=estimator, | |||||
| aggregation=mistnet, | |||||
| transmitter=s3_transmitter, | |||||
| chooser=simple_chooser) | |||||
| server.start() | |||||
| if __name__ == '__main__': | |||||
| run_server() | |||||
| @@ -0,0 +1,149 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from sedna.algorithms.aggregation import MistNet | |||||
| from sedna.algorithms.client_choose import SimpleClientChoose | |||||
| from sedna.common.config import Context | |||||
| from sedna.core.federated_learning import FederatedLearning | |||||
| simple_chooser = SimpleClientChoose(per_round=1) | |||||
| # It has been determined that mistnet is required here. | |||||
| mistnet = MistNet(cut_layer=Context.get_parameters("cut_layer"), | |||||
| epsilon=Context.get_parameters("epsilon")) | |||||
| # The function `get_transmitter_from_config()` returns an object instance. | |||||
| s3_transmitter = FederatedLearning.get_transmitter_from_config() | |||||
| class Dataset: | |||||
| def __init__(self) -> None: | |||||
| self.parameters = { | |||||
| "datasource": "YOLO", | |||||
| "data_params": "./coco128.yaml", | |||||
| # Where the dataset is located | |||||
| "data_path": "./data/COCO", | |||||
| "train_path": "./data/COCO/coco128/images/train2017/", | |||||
| "test_path": "./data/COCO/coco128/images/train2017/", | |||||
| # number of training examples | |||||
| "num_train_examples": 128, | |||||
| # number of testing examples | |||||
| "num_test_examples": 128, | |||||
| # number of classes | |||||
| "num_classes": 80, | |||||
| # image size | |||||
| "image_size": 640, | |||||
| "classes": | |||||
| [ | |||||
| "person", | |||||
| "bicycle", | |||||
| "car", | |||||
| "motorcycle", | |||||
| "airplane", | |||||
| "bus", | |||||
| "train", | |||||
| "truck", | |||||
| "boat", | |||||
| "traffic light", | |||||
| "fire hydrant", | |||||
| "stop sign", | |||||
| "parking meter", | |||||
| "bench", | |||||
| "bird", | |||||
| "cat", | |||||
| "dog", | |||||
| "horse", | |||||
| "sheep", | |||||
| "cow", | |||||
| "elephant", | |||||
| "bear", | |||||
| "zebra", | |||||
| "giraffe", | |||||
| "backpack", | |||||
| "umbrella", | |||||
| "handbag", | |||||
| "tie", | |||||
| "suitcase", | |||||
| "frisbee", | |||||
| "skis", | |||||
| "snowboard", | |||||
| "sports ball", | |||||
| "kite", | |||||
| "baseball bat", | |||||
| "baseball glove", | |||||
| "skateboard", | |||||
| "surfboard", | |||||
| "tennis racket", | |||||
| "bottle", | |||||
| "wine glass", | |||||
| "cup", | |||||
| "fork", | |||||
| "knife", | |||||
| "spoon", | |||||
| "bowl", | |||||
| "banana", | |||||
| "apple", | |||||
| "sandwich", | |||||
| "orange", | |||||
| "broccoli", | |||||
| "carrot", | |||||
| "hot dog", | |||||
| "pizza", | |||||
| "donut", | |||||
| "cake", | |||||
| "chair", | |||||
| "couch", | |||||
| "potted plant", | |||||
| "bed", | |||||
| "dining table", | |||||
| "toilet", | |||||
| "tv", | |||||
| "laptop", | |||||
| "mouse", | |||||
| "remote", | |||||
| "keyboard", | |||||
| "cell phone", | |||||
| "microwave", | |||||
| "oven", | |||||
| "toaster", | |||||
| "sink", | |||||
| "refrigerator", | |||||
| "book", | |||||
| "clock", | |||||
| "vase", | |||||
| "scissors", | |||||
| "teddy bear", | |||||
| "hair drier", | |||||
| "toothbrush", | |||||
| ], | |||||
| "partition_size": 128, | |||||
| } | |||||
| class Estimator: | |||||
| def __init__(self) -> None: | |||||
| self.model = None | |||||
| self.hyperparameters = { | |||||
| "type": "yolov5", | |||||
| "rounds": 1, | |||||
| "target_accuracy": 0.99, | |||||
| "epochs": 500, | |||||
| "batch_size": 16, | |||||
| "optimizer": "SGD", | |||||
| "linear_lr": False, | |||||
| # The machine learning model | |||||
| "model_name": "yolov5", | |||||
| "model_config": "./yolov5s.yaml", | |||||
| "train_params": "./hyp.scratch.yaml" | |||||
| } | |||||
| @@ -0,0 +1,33 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from interface import mistnet, s3_transmitter | |||||
| from interface import Dataset, Estimator | |||||
| from sedna.core.federated_learning import FederatedLearning | |||||
| def main(): | |||||
| data = Dataset() | |||||
| estimator = Estimator() | |||||
| fl_model = FederatedLearning( | |||||
| estimator=estimator, | |||||
| aggregation=mistnet, | |||||
| transmitter=s3_transmitter) | |||||
| fl_model.train(data) | |||||
| if __name__ == '__main__': | |||||
| main() | |||||
| @@ -13,3 +13,4 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| from . import aggregation | from . import aggregation | ||||
| from .aggregation import FedAvg, MistNet | |||||
| @@ -104,3 +104,14 @@ class FedAvg(BaseAggregation, abc.ABC): | |||||
| updates.append(row.tolist()) | updates.append(row.tolist()) | ||||
| self.weights = deepcopy(updates) | self.weights = deepcopy(updates) | ||||
| return updates | return updates | ||||
| @ClassFactory.register(ClassType.FL_AGG) | |||||
| class MistNet(BaseAggregation, abc.ABC): | |||||
| def __init__(self, cut_layer, epsilon=100): | |||||
| super().__init__() | |||||
| self.cut_layer = cut_layer | |||||
| self.epsilon = epsilon | |||||
| def aggregate(self, clients: List[AggClient]): | |||||
| pass | |||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from .client_choose import SimpleClientChoose | |||||
| @@ -0,0 +1,36 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| import abc | |||||
| class AbstractClientChoose(metaclass=abc.ABCMeta): | |||||
| """ | |||||
| Abstract class of ClientChoose, which provides base client choose | |||||
| algorithm interfaces in federated learning. | |||||
| """ | |||||
| def __init__(self): | |||||
| pass | |||||
| class SimpleClientChoose(AbstractClientChoose): | |||||
| """ | |||||
| A Simple Implementation of Client Choose. | |||||
| """ | |||||
| def __init__(self, per_round=1): | |||||
| super().__init__() | |||||
| self.per_round = per_round | |||||
| @@ -0,0 +1,15 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from .transmitter import S3Transmitter, WSTransmitter | |||||
| @@ -0,0 +1,64 @@ | |||||
| # Copyright 2021 The KubeEdge Authors. | |||||
| # | |||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| # you may not use this file except in compliance with the License. | |||||
| # You may obtain a copy of the License at | |||||
| # | |||||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||||
| # | |||||
| # Unless required by applicable law or agreed to in writing, software | |||||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| # See the License for the specific language governing permissions and | |||||
| # limitations under the License. | |||||
| from abc import ABC, abstractmethod | |||||
| class AbstractTransmitter(ABC): | |||||
| """ | |||||
| Abstract class of Transmitter, which provides base transmission | |||||
| interfaces between edge and cloud. | |||||
| """ | |||||
| @abstractmethod | |||||
| def recv(self): | |||||
| pass | |||||
| @abstractmethod | |||||
| def send(self, data): | |||||
| pass | |||||
| class WSTransmitter(AbstractTransmitter, ABC): | |||||
| """ | |||||
| An implementation of Transmitter based on WebSocket. | |||||
| """ | |||||
| def recv(self): | |||||
| pass | |||||
| def send(self, data): | |||||
| pass | |||||
| class S3Transmitter(AbstractTransmitter, ABC): | |||||
| """ | |||||
| An implementation of Transmitter based on S3 protocol. | |||||
| """ | |||||
| def __init__(self, | |||||
| s3_endpoint_url, | |||||
| access_key, | |||||
| secret_key, | |||||
| transmitter_url): | |||||
| self.s3_endpoint_url = s3_endpoint_url | |||||
| self.access_key = access_key | |||||
| self.secret_key = secret_key | |||||
| self.transmitter_url = transmitter_url | |||||
| def recv(self): | |||||
| pass | |||||
| def send(self, data): | |||||
| pass | |||||
| @@ -269,9 +269,16 @@ class BaseConfig(ConfigSerializable): | |||||
| # the name of FederatedLearningJob and others Job | # the name of FederatedLearningJob and others Job | ||||
| job_name = os.getenv("JOB_NAME", "sedna") | job_name = os.getenv("JOB_NAME", "sedna") | ||||
| pretrained_model_url = os.getenv("PRETRAINED_MODEL_URL", "./") | |||||
| model_url = os.getenv("MODEL_URL") | model_url = os.getenv("MODEL_URL") | ||||
| model_name = os.getenv("MODEL_NAME") | model_name = os.getenv("MODEL_NAME") | ||||
| transmitter = os.getenv("TRANSMITTER", "ws") | |||||
| agg_data_path = os.getenv("AGG_DATA_PATH", "./") | |||||
| s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "") | |||||
| access_key_id = os.getenv("ACCESS_KEY_ID", "") | |||||
| secret_access_key = os.getenv("SECRET_ACCESS_KEY", "") | |||||
| # user parameter | # user parameter | ||||
| parameters = os.getenv("PARAMETERS") | parameters = os.getenv("PARAMETERS") | ||||
| @@ -13,17 +13,23 @@ | |||||
| # limitations under the License. | # limitations under the License. | ||||
| import asyncio | |||||
| import sys | |||||
| import time | import time | ||||
| from sedna.core.base import JobBase | |||||
| from sedna.common.config import Context | |||||
| from sedna.common.file_ops import FileOps | |||||
| from plato.clients import registry as client_registry | |||||
| from plato.config import Config | |||||
| from sedna.algorithms.transmitter import S3Transmitter, WSTransmitter | |||||
| from sedna.common.class_factory import ClassFactory, ClassType | from sedna.common.class_factory import ClassFactory, ClassType | ||||
| from sedna.service.client import AggregationClient | |||||
| from sedna.common.config import BaseConfig, Context | |||||
| from sedna.common.constant import K8sResourceKindStatus | from sedna.common.constant import K8sResourceKindStatus | ||||
| from sedna.common.file_ops import FileOps | |||||
| from sedna.core.base import JobBase | |||||
| from sedna.service.client import AggregationClient | |||||
| class FederatedLearning(JobBase): | |||||
| class FederatedLearningV0(JobBase): | |||||
| """ | """ | ||||
| Federated learning enables multiple actors to build a common, robust | Federated learning enables multiple actors to build a common, robust | ||||
| machine learning model without sharing data, thus allowing to address | machine learning model without sharing data, thus allowing to address | ||||
| @@ -50,6 +56,7 @@ class FederatedLearning(JobBase): | |||||
| aggregation="FedAvg" | aggregation="FedAvg" | ||||
| ) | ) | ||||
| """ | """ | ||||
| def __init__(self, estimator, aggregation="FedAvg"): | def __init__(self, estimator, aggregation="FedAvg"): | ||||
| protocol = Context.get_parameters("AGG_PROTOCOL", "ws") | protocol = Context.get_parameters("AGG_PROTOCOL", "ws") | ||||
| @@ -178,3 +185,65 @@ class FederatedLearning(JobBase): | |||||
| task_info, | task_info, | ||||
| K8sResourceKindStatus.RUNNING.value, | K8sResourceKindStatus.RUNNING.value, | ||||
| task_info_res) | task_info_res) | ||||
| class FederatedLearning: | |||||
| def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None) -> None: | |||||
| # set parameters | |||||
| server = Config.server._asdict() | |||||
| clients = Config.clients._asdict() | |||||
| datastore = Config.data._asdict() | |||||
| train = Config.trainer._asdict() | |||||
| if data is not None: | |||||
| for xkey in data.parameters: | |||||
| datastore[xkey] = data.parameters[xkey] | |||||
| Config.data = Config.namedtuple_from_dict(datastore) | |||||
| self.model = None | |||||
| if estimator is not None: | |||||
| self.model = estimator.model | |||||
| for xkey in estimator.hyperparameters: | |||||
| train[xkey] = estimator.hyperparameters[xkey] | |||||
| Config.trainer = Config.namedtuple_from_dict(train) | |||||
| if aggregation is not None: | |||||
| Config.algorithm = Config.namedtuple_from_dict(aggregation.parameters) | |||||
| if aggregation.parameters["type"] == "mistnet": | |||||
| clients["type"] = "mistnet" | |||||
| server["type"] = "mistnet" | |||||
| if isinstance(transmitter, S3Transmitter): | |||||
| server["address"] = Context.get_parameters("AGG_IP") | |||||
| server["port"] = Context.get_parameters("AGG_PORT") | |||||
| server["s3_endpoint_url"] = transmitter.s3_endpoint_url | |||||
| server["s3_bucket"] = transmitter.s3_bucket | |||||
| server["access_key"] = transmitter.access_key | |||||
| server["secret_key"] = transmitter.secret_key | |||||
| elif isinstance(transmitter, WSTransmitter): | |||||
| pass | |||||
| Config.server = Config.namedtuple_from_dict(server) | |||||
| Config.clients = Config.namedtuple_from_dict(clients) | |||||
| # Config.store() | |||||
| # create a client | |||||
| self.client = client_registry.get(model=self.model) | |||||
| self.client.configure() | |||||
| @classmethod | |||||
| def get_transmitter_from_config(cls): | |||||
| if BaseConfig.transmitter == "ws": | |||||
| return WSTransmitter() | |||||
| elif BaseConfig.transmitter == "s3": | |||||
| return S3Transmitter(s3_endpoint_url=BaseConfig.s3_endpoint_url, | |||||
| access_key=BaseConfig.access_key_id, | |||||
| secret_key=BaseConfig.secret_access_key, | |||||
| transmitter_url=BaseConfig.agg_data_path) | |||||
| def train(self): | |||||
| if int(sys.version[2]) <= 6: | |||||
| loop = asyncio.get_event_loop() | |||||
| loop.run_until_complete(self.client.start_client()) | |||||
| else: | |||||
| asyncio.run(self.client.start_client()) | |||||