Browse Source

Merge pull request #186 from JoeyHwong-gk/fl_fix

fix env variable error in aggregation
tags/v0.4.0
KubeEdge Bot GitHub 4 years ago
parent
commit
938326a169
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions
  1. +1
    -4
      examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py
  2. +5
    -2
      lib/sedna/service/server/aggregation.py

+ 1
- 4
examples/federated_learning/surface_defect_detection/aggregation_worker/aggregate.py View File

@@ -26,12 +26,9 @@ def run_server():
participants_count = int(Context.get_parameters( participants_count = int(Context.get_parameters(
"participants_count", 1 "participants_count", 1
)) ))
agg_ip = Context.get_parameters("AGG_IP", "0.0.0.0")
agg_port = int(Context.get_parameters("AGG_PORT", "7363"))

server = AggregationServer( server = AggregationServer(
aggregation=aggregation_algorithm, aggregation=aggregation_algorithm,
host=agg_ip,
http_port=agg_port,
exit_round=exit_round, exit_round=exit_round,
ws_size=20 * 1024 * 1024, ws_size=20 * 1024 * 1024,
participants_count=participants_count participants_count=participants_count


+ 5
- 2
lib/sedna/service/server/aggregation.py View File

@@ -26,6 +26,7 @@ from starlette.endpoints import WebSocketEndpoint
from starlette.types import ASGIApp, Receive, Scope, Send from starlette.types import ASGIApp, Receive, Scope, Send


from sedna.common.log import LOGGER from sedna.common.log import LOGGER
from sedna.common.config import Context
from sedna.common.utils import get_host_ip from sedna.common.utils import get_host_ip
from sedna.common.class_factory import ClassFactory, ClassType from sedna.common.class_factory import ClassFactory, ClassType
from sedna.algorithms.aggregation import AggClient from sedna.algorithms.aggregation import AggClient
@@ -211,12 +212,14 @@ class AggregationServer(BaseServer):
self, self,
aggregation: str, aggregation: str,
host: str = None, host: str = None,
http_port: int = 7363,
http_port: int = None,
exit_round: int = 1, exit_round: int = 1,
participants_count: int = 1, participants_count: int = 1,
ws_size: int = 10 * 1024 * 1024): ws_size: int = 10 * 1024 * 1024):
if not host: if not host:
host = get_host_ip()
host = Context.get_parameters("AGG_BIND_IP", get_host_ip())
if not http_port:
http_port = int(Context.get_parameters("AGG_BIND_PORT", 7363))
super( super(
AggregationServer, AggregationServer,
self).__init__( self).__init__(


Loading…
Cancel
Save