1. add transmitter, client_choose, aggregation interface to Lib. 2. add example of how to use new added interface. Signed-off-by: Jie Pu <pujie2@huawei.com> Signed-off-by: XinYao1994 <xyao@cs.hku.hk> updatetags/v0.4.0^2
| @@ -0,0 +1,54 @@ | |||
| apiVersion: sedna.io/v1alpha1 | |||
| kind: FederatedLearningJob | |||
| metadata: | |||
| name: yolo-v5 | |||
| spec: | |||
| pretrainedModel: # option | |||
| name: "yolo-v5-pretrained-model" | |||
| transimitter: # option | |||
| ws: { } # option, by default | |||
| s3: # option, but at least one | |||
| aggDataPath: "s3://sedna/fl/aggregation_data" | |||
| credentialName: mysecret | |||
| aggregationWorker: | |||
| model: | |||
| name: "yolo-v5-model" | |||
| template: | |||
| spec: | |||
| nodeName: "sedna-control-plane" | |||
| containers: | |||
| - image: kubeedge/sedna-fl-aggregation:mistnetyolo | |||
| name: agg-worker | |||
| imagePullPolicy: IfNotPresent | |||
| env: # user defined environments | |||
| - name: "cut_layer" | |||
| value: "4" | |||
| - name: "epsilon" | |||
| value: "100" | |||
| - name: "aggregation_algorithm" | |||
| value: "mistnet" | |||
| - name: "batch_size" | |||
| resources: # user defined resources | |||
| limits: | |||
| memory: 8Gi | |||
| trainingWorkers: | |||
| - dataset: | |||
| name: "coco-dataset" | |||
| template: | |||
| spec: | |||
| nodeName: "edge-node" | |||
| containers: | |||
| - image: kubeedge/sedna-fl-train:mistnetyolo | |||
| name: train-worker | |||
| imagePullPolicy: IfNotPresent | |||
| args: [ "-i", "1" ] | |||
| env: # user defined environments | |||
| - name: "batch_size" | |||
| value: "32" | |||
| - name: "learning_rate" | |||
| value: "0.001" | |||
| - name: "epochs" | |||
| value: "1" | |||
| resources: # user defined resources | |||
| limits: | |||
| memory: 2Gi | |||
| @@ -0,0 +1,238 @@ | |||
| # Collaboratively Train Yolo-v5 Using MistNet on COCO128 Dataset | |||
| This case introduces how to train a federated learning job with an aggregation algorithm named MistNet in MNIST | |||
| handwritten digit classification scenario. Data is scattered in different places (such as edge nodes, cameras, and | |||
| others) and cannot be aggregated at the server due to data privacy and bandwidth. As a result, we cannot use all the | |||
| data for training. In some cases, edge nodes have limited computing resources and even have no training capability. The | |||
| edge cannot gain the updated weights from the training process. Therefore, traditional algorithms (e.g., federated | |||
| average), which usually aggregate the updated weights trained by different edge clients, cannot work in this scenario. | |||
| MistNet is proposed to address this issue. | |||
| MistNet partitions a DNN model into two parts, a lightweight feature extractor at the edge side to generate meaningful | |||
| features from the raw data, and a classifier including the most model layers at the cloud to be iteratively trained for | |||
| specific tasks. MistNet achieves acceptable model utility while greatly reducing privacy leakage from the released | |||
| intermediate features. | |||
| ## Object Detection Experiment | |||
| > Assume that there are two edge nodes and a cloud node. Data on the edge nodes cannot be migrated to the cloud due to privacy issues. | |||
| > Base on this scenario, we will demonstrate the mnist example. | |||
| ### Prepare Nodes | |||
| ``` | |||
| CLOUD_NODE="cloud-node-name" | |||
| EDGE1_NODE="edge1-node-name" | |||
| EDGE2_NODE="edge2-node-name" | |||
| ``` | |||
| ### Install Sedna | |||
| Follow the [Sedna installation document](/docs/setup/install.md) to install Sedna. | |||
| ### Prepare Dataset | |||
| Download [dataset](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip) and do data partition | |||
| ``` | |||
| wget https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip | |||
| unzip coco128.zip -d data | |||
| rm coco128.zip | |||
| python partition.py ./data 2 | |||
| ``` | |||
| move ```./data/1``` to `/data` of ```EDGE1_NODE```. | |||
| ``` | |||
| mkdir -p /data | |||
| cd /data | |||
| mv ./data/1 ./ | |||
| ``` | |||
| move ```./data/2``` to `/data` of ```EDGE2_NODE```. | |||
| ``` | |||
| mkdir -p /data | |||
| cd /data | |||
| mv ./data/2 ./ | |||
| ``` | |||
| ### Prepare Images | |||
| This example uses these images: | |||
| 1. aggregation worker: ```kubeedge/sedna-example-federated-learning-mistnet:v0.3.0``` | |||
| 2. train worker: ```kubeedge/sedna-example-federated-learning-mistnet-client:v0.3.0``` | |||
| These images are generated by the script [build_images.sh](/examples/build_image.sh). | |||
| ### Create Federated Learning Job | |||
| #### Create Dataset | |||
| create dataset for `$EDGE1_NODE` | |||
| ```n | |||
| kubectl create -f - <<EOF | |||
| apiVersion: sedna.io/v1alpha1 | |||
| kind: Dataset | |||
| metadata: | |||
| name: "coco-dataset" | |||
| spec: | |||
| url: "/data/test.txt" | |||
| format: "txt" | |||
| nodeName: edge-node | |||
| EOF | |||
| ``` | |||
| create dataset for `$EDGE2_NODE` | |||
| ``` | |||
| kubectl create -f - <<EOF | |||
| apiVersion: sedna.io/v1alpha1 | |||
| kind: Dataset | |||
| metadata: | |||
| name: "coco-dataset" | |||
| spec: | |||
| url: "/data/test.txt" | |||
| format: "txt" | |||
| nodeName: edge-node | |||
| EOF | |||
| ``` | |||
| #### Create Model | |||
| create the directory `/model` in the host of `$EDGE1_NODE` | |||
| ``` | |||
| mkdir /model | |||
| ``` | |||
| create the directory `/model` in the host of `$EDGE2_NODE` | |||
| ``` | |||
| mkdir /model | |||
| ``` | |||
| ``` | |||
| TODO: put pretrained model on nodes. | |||
| ``` | |||
| create model | |||
| ``` | |||
| kubectl create -f - <<EOF | |||
| apiVersion: sedna.io/v1alpha1 | |||
| kind: Model | |||
| metadata: | |||
| name: "yolo-v5-model" | |||
| spec: | |||
| url: "/model/yolo.pb" | |||
| format: "pb" | |||
| EOF | |||
| ``` | |||
| #### Start Federated Learning Job | |||
| ``` | |||
| kubectl create -f - <<EOF | |||
| apiVersion: sedna.io/v1alpha1 | |||
| kind: FederatedLearningJob | |||
| metadata: | |||
| name: mistnet-on-mnist-dataset | |||
| spec: | |||
| stopCondition: | |||
| operator: "or" # and | |||
| conditions: | |||
| - operator: ">" | |||
| threshold: 100 | |||
| metric: rounds | |||
| - operator: ">" | |||
| threshold: 0.95 | |||
| metric: targetAccuracy | |||
| - operator: "<" | |||
| threshold: 0.03 | |||
| metric: deltaLoss | |||
| aggregationTrigger: | |||
| condition: | |||
| operator: ">" | |||
| threshold: 5 | |||
| metric: num_of_ready_clients | |||
| aggregationWorker: | |||
| model: | |||
| name: "mistnet-on-mnist-model" | |||
| template: | |||
| spec: | |||
| nodeName: $CLOUD_NODE | |||
| containers: | |||
| - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-aggregation:v0.4.0 | |||
| name: agg-worker | |||
| imagePullPolicy: IfNotPresent | |||
| env: # user defined environments | |||
| - name: "cut_layer" | |||
| value: "4" | |||
| - name: "epsilon" | |||
| value: "100" | |||
| - name: "aggregation_algorithm" | |||
| value: "mistnet" | |||
| - name: "batch_size" | |||
| value: "10" | |||
| resources: # user defined resources | |||
| limits: | |||
| memory: 2Gi | |||
| trainingWorkers: | |||
| - dataset: | |||
| name: "edge1-surface-defect-detection-dataset" | |||
| template: | |||
| spec: | |||
| nodeName: $EDGE1_NODE | |||
| containers: | |||
| - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0 | |||
| name: train-worker | |||
| imagePullPolicy: IfNotPresent | |||
| env: # user defined environments | |||
| - name: "batch_size" | |||
| value: "32" | |||
| - name: "learning_rate" | |||
| value: "0.001" | |||
| - name: "epochs" | |||
| value: "2" | |||
| resources: # user defined resources | |||
| limits: | |||
| memory: 2Gi | |||
| - dataset: | |||
| name: "edge2-surface-defect-detection-dataset" | |||
| template: | |||
| spec: | |||
| nodeName: $EDGE2_NODE | |||
| containers: | |||
| - image: kubeedge/sedna-example-federated-learning-mistnet-on-mnist-dataset-train:v0.4.0 | |||
| name: train-worker | |||
| imagePullPolicy: IfNotPresent | |||
| env: # user defined environments | |||
| - name: "batch_size" | |||
| value: "32" | |||
| - name: "learning_rate" | |||
| value: "0.001" | |||
| - name: "epochs" | |||
| value: "2" | |||
| resources: # user defined resources | |||
| limits: | |||
| memory: 2Gi | |||
| EOF | |||
| ``` | |||
| ``` | |||
| TODO: show the benifit of mistnet. for example, the compared results of fedavg & mistnet. | |||
| ``` | |||
| ### Check Federated Learning Status | |||
| ``` | |||
| kubectl get federatedlearningjob surface-defect-detection | |||
| ``` | |||
| ### Check Federated Learning Train Result | |||
| After the job completed, you will find the model generated on the directory `/model` in `$EDGE1_NODE` and `$EDGE2_NODE`. | |||
| @@ -0,0 +1,35 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from interface import mistnet, s3_transmitter, simple_chooser | |||
| from interface import Dataset, Estimator | |||
| from sedna.service.server import AggregationServer | |||
| def run_server(): | |||
| data = Dataset() | |||
| estimator = Estimator() | |||
| server = AggregationServer( | |||
| data=data, | |||
| estimator=estimator, | |||
| aggregation=mistnet, | |||
| transmitter=s3_transmitter, | |||
| chooser=simple_chooser) | |||
| server.start() | |||
| if __name__ == '__main__': | |||
| run_server() | |||
| @@ -0,0 +1,149 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from sedna.algorithms.aggregation import MistNet | |||
| from sedna.algorithms.client_choose import SimpleClientChoose | |||
| from sedna.common.config import Context | |||
| from sedna.core.federated_learning import FederatedLearning | |||
| simple_chooser = SimpleClientChoose(per_round=1) | |||
| # It has been determined that mistnet is required here. | |||
| mistnet = MistNet(cut_layer=Context.get_parameters("cut_layer"), | |||
| epsilon=Context.get_parameters("epsilon")) | |||
| # The function `get_transmitter_from_config()` returns an object instance. | |||
| s3_transmitter = FederatedLearning.get_transmitter_from_config() | |||
| class Dataset: | |||
| def __init__(self) -> None: | |||
| self.parameters = { | |||
| "datasource": "YOLO", | |||
| "data_params": "./coco128.yaml", | |||
| # Where the dataset is located | |||
| "data_path": "./data/COCO", | |||
| "train_path": "./data/COCO/coco128/images/train2017/", | |||
| "test_path": "./data/COCO/coco128/images/train2017/", | |||
| # number of training examples | |||
| "num_train_examples": 128, | |||
| # number of testing examples | |||
| "num_test_examples": 128, | |||
| # number of classes | |||
| "num_classes": 80, | |||
| # image size | |||
| "image_size": 640, | |||
| "classes": | |||
| [ | |||
| "person", | |||
| "bicycle", | |||
| "car", | |||
| "motorcycle", | |||
| "airplane", | |||
| "bus", | |||
| "train", | |||
| "truck", | |||
| "boat", | |||
| "traffic light", | |||
| "fire hydrant", | |||
| "stop sign", | |||
| "parking meter", | |||
| "bench", | |||
| "bird", | |||
| "cat", | |||
| "dog", | |||
| "horse", | |||
| "sheep", | |||
| "cow", | |||
| "elephant", | |||
| "bear", | |||
| "zebra", | |||
| "giraffe", | |||
| "backpack", | |||
| "umbrella", | |||
| "handbag", | |||
| "tie", | |||
| "suitcase", | |||
| "frisbee", | |||
| "skis", | |||
| "snowboard", | |||
| "sports ball", | |||
| "kite", | |||
| "baseball bat", | |||
| "baseball glove", | |||
| "skateboard", | |||
| "surfboard", | |||
| "tennis racket", | |||
| "bottle", | |||
| "wine glass", | |||
| "cup", | |||
| "fork", | |||
| "knife", | |||
| "spoon", | |||
| "bowl", | |||
| "banana", | |||
| "apple", | |||
| "sandwich", | |||
| "orange", | |||
| "broccoli", | |||
| "carrot", | |||
| "hot dog", | |||
| "pizza", | |||
| "donut", | |||
| "cake", | |||
| "chair", | |||
| "couch", | |||
| "potted plant", | |||
| "bed", | |||
| "dining table", | |||
| "toilet", | |||
| "tv", | |||
| "laptop", | |||
| "mouse", | |||
| "remote", | |||
| "keyboard", | |||
| "cell phone", | |||
| "microwave", | |||
| "oven", | |||
| "toaster", | |||
| "sink", | |||
| "refrigerator", | |||
| "book", | |||
| "clock", | |||
| "vase", | |||
| "scissors", | |||
| "teddy bear", | |||
| "hair drier", | |||
| "toothbrush", | |||
| ], | |||
| "partition_size": 128, | |||
| } | |||
| class Estimator: | |||
| def __init__(self) -> None: | |||
| self.model = None | |||
| self.hyperparameters = { | |||
| "type": "yolov5", | |||
| "rounds": 1, | |||
| "target_accuracy": 0.99, | |||
| "epochs": 500, | |||
| "batch_size": 16, | |||
| "optimizer": "SGD", | |||
| "linear_lr": False, | |||
| # The machine learning model | |||
| "model_name": "yolov5", | |||
| "model_config": "./yolov5s.yaml", | |||
| "train_params": "./hyp.scratch.yaml" | |||
| } | |||
| @@ -0,0 +1,33 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from interface import mistnet, s3_transmitter | |||
| from interface import Dataset, Estimator | |||
| from sedna.core.federated_learning import FederatedLearning | |||
| def main(): | |||
| data = Dataset() | |||
| estimator = Estimator() | |||
| fl_model = FederatedLearning( | |||
| estimator=estimator, | |||
| aggregation=mistnet, | |||
| transmitter=s3_transmitter) | |||
| fl_model.train(data) | |||
| if __name__ == '__main__': | |||
| main() | |||
| @@ -13,3 +13,4 @@ | |||
| # limitations under the License. | |||
| from . import aggregation | |||
| from .aggregation import FedAvg, MistNet | |||
| @@ -104,3 +104,14 @@ class FedAvg(BaseAggregation, abc.ABC): | |||
| updates.append(row.tolist()) | |||
| self.weights = deepcopy(updates) | |||
| return updates | |||
| @ClassFactory.register(ClassType.FL_AGG) | |||
| class MistNet(BaseAggregation, abc.ABC): | |||
| def __init__(self, cut_layer, epsilon=100): | |||
| super().__init__() | |||
| self.cut_layer = cut_layer | |||
| self.epsilon = epsilon | |||
| def aggregate(self, clients: List[AggClient]): | |||
| pass | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .client_choose import SimpleClientChoose | |||
| @@ -0,0 +1,36 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| import abc | |||
| class AbstractClientChoose(metaclass=abc.ABCMeta): | |||
| """ | |||
| Abstract class of ClientChoose, which provides base client choose | |||
| algorithm interfaces in federated learning. | |||
| """ | |||
| def __init__(self): | |||
| pass | |||
| class SimpleClientChoose(AbstractClientChoose): | |||
| """ | |||
| A Simple Implementation of Client Choose. | |||
| """ | |||
| def __init__(self, per_round=1): | |||
| super().__init__() | |||
| self.per_round = per_round | |||
| @@ -0,0 +1,15 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from .transmitter import S3Transmitter, WSTransmitter | |||
| @@ -0,0 +1,64 @@ | |||
| # Copyright 2021 The KubeEdge Authors. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| # http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| from abc import ABC, abstractmethod | |||
| class AbstractTransmitter(ABC): | |||
| """ | |||
| Abstract class of Transmitter, which provides base transmission | |||
| interfaces between edge and cloud. | |||
| """ | |||
| @abstractmethod | |||
| def recv(self): | |||
| pass | |||
| @abstractmethod | |||
| def send(self, data): | |||
| pass | |||
| class WSTransmitter(AbstractTransmitter, ABC): | |||
| """ | |||
| An implementation of Transmitter based on WebSocket. | |||
| """ | |||
| def recv(self): | |||
| pass | |||
| def send(self, data): | |||
| pass | |||
| class S3Transmitter(AbstractTransmitter, ABC): | |||
| """ | |||
| An implementation of Transmitter based on S3 protocol. | |||
| """ | |||
| def __init__(self, | |||
| s3_endpoint_url, | |||
| access_key, | |||
| secret_key, | |||
| transmitter_url): | |||
| self.s3_endpoint_url = s3_endpoint_url | |||
| self.access_key = access_key | |||
| self.secret_key = secret_key | |||
| self.transmitter_url = transmitter_url | |||
| def recv(self): | |||
| pass | |||
| def send(self, data): | |||
| pass | |||
| @@ -269,9 +269,16 @@ class BaseConfig(ConfigSerializable): | |||
| # the name of FederatedLearningJob and others Job | |||
| job_name = os.getenv("JOB_NAME", "sedna") | |||
| pretrained_model_url = os.getenv("PRETRAINED_MODEL_URL", "./") | |||
| model_url = os.getenv("MODEL_URL") | |||
| model_name = os.getenv("MODEL_NAME") | |||
| transmitter = os.getenv("TRANSMITTER", "ws") | |||
| agg_data_path = os.getenv("AGG_DATA_PATH", "./") | |||
| s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "") | |||
| access_key_id = os.getenv("ACCESS_KEY_ID", "") | |||
| secret_access_key = os.getenv("SECRET_ACCESS_KEY", "") | |||
| # user parameter | |||
| parameters = os.getenv("PARAMETERS") | |||
| @@ -13,17 +13,23 @@ | |||
| # limitations under the License. | |||
| import asyncio | |||
| import sys | |||
| import time | |||
| from sedna.core.base import JobBase | |||
| from sedna.common.config import Context | |||
| from sedna.common.file_ops import FileOps | |||
| from plato.clients import registry as client_registry | |||
| from plato.config import Config | |||
| from sedna.algorithms.transmitter import S3Transmitter, WSTransmitter | |||
| from sedna.common.class_factory import ClassFactory, ClassType | |||
| from sedna.service.client import AggregationClient | |||
| from sedna.common.config import BaseConfig, Context | |||
| from sedna.common.constant import K8sResourceKindStatus | |||
| from sedna.common.file_ops import FileOps | |||
| from sedna.core.base import JobBase | |||
| from sedna.service.client import AggregationClient | |||
| class FederatedLearning(JobBase): | |||
| class FederatedLearningV0(JobBase): | |||
| """ | |||
| Federated learning enables multiple actors to build a common, robust | |||
| machine learning model without sharing data, thus allowing to address | |||
| @@ -50,6 +56,7 @@ class FederatedLearning(JobBase): | |||
| aggregation="FedAvg" | |||
| ) | |||
| """ | |||
| def __init__(self, estimator, aggregation="FedAvg"): | |||
| protocol = Context.get_parameters("AGG_PROTOCOL", "ws") | |||
| @@ -178,3 +185,65 @@ class FederatedLearning(JobBase): | |||
| task_info, | |||
| K8sResourceKindStatus.RUNNING.value, | |||
| task_info_res) | |||
| class FederatedLearning: | |||
| def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None) -> None: | |||
| # set parameters | |||
| server = Config.server._asdict() | |||
| clients = Config.clients._asdict() | |||
| datastore = Config.data._asdict() | |||
| train = Config.trainer._asdict() | |||
| if data is not None: | |||
| for xkey in data.parameters: | |||
| datastore[xkey] = data.parameters[xkey] | |||
| Config.data = Config.namedtuple_from_dict(datastore) | |||
| self.model = None | |||
| if estimator is not None: | |||
| self.model = estimator.model | |||
| for xkey in estimator.hyperparameters: | |||
| train[xkey] = estimator.hyperparameters[xkey] | |||
| Config.trainer = Config.namedtuple_from_dict(train) | |||
| if aggregation is not None: | |||
| Config.algorithm = Config.namedtuple_from_dict(aggregation.parameters) | |||
| if aggregation.parameters["type"] == "mistnet": | |||
| clients["type"] = "mistnet" | |||
| server["type"] = "mistnet" | |||
| if isinstance(transmitter, S3Transmitter): | |||
| server["address"] = Context.get_parameters("AGG_IP") | |||
| server["port"] = Context.get_parameters("AGG_PORT") | |||
| server["s3_endpoint_url"] = transmitter.s3_endpoint_url | |||
| server["s3_bucket"] = transmitter.s3_bucket | |||
| server["access_key"] = transmitter.access_key | |||
| server["secret_key"] = transmitter.secret_key | |||
| elif isinstance(transmitter, WSTransmitter): | |||
| pass | |||
| Config.server = Config.namedtuple_from_dict(server) | |||
| Config.clients = Config.namedtuple_from_dict(clients) | |||
| # Config.store() | |||
| # create a client | |||
| self.client = client_registry.get(model=self.model) | |||
| self.client.configure() | |||
| @classmethod | |||
| def get_transmitter_from_config(cls): | |||
| if BaseConfig.transmitter == "ws": | |||
| return WSTransmitter() | |||
| elif BaseConfig.transmitter == "s3": | |||
| return S3Transmitter(s3_endpoint_url=BaseConfig.s3_endpoint_url, | |||
| access_key=BaseConfig.access_key_id, | |||
| secret_key=BaseConfig.secret_access_key, | |||
| transmitter_url=BaseConfig.agg_data_path) | |||
| def train(self): | |||
| if int(sys.version[2]) <= 6: | |||
| loop = asyncio.get_event_loop() | |||
| loop.run_until_complete(self.client.start_client()) | |||
| else: | |||
| asyncio.run(self.client.start_client()) | |||