You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

federatedlearningjob.go 25 kB


  1. /*
  2. Copyright 2021 The KubeEdge Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package federatedlearning
  14. import (
  15. "context"
  16. "fmt"
  17. "strconv"
  18. "sync"
  19. "time"
  20. v1 "k8s.io/api/core/v1"
  21. "k8s.io/apimachinery/pkg/api/errors"
  22. metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
  23. "k8s.io/apimachinery/pkg/labels"
  24. utilrand "k8s.io/apimachinery/pkg/util/rand"
  25. utilruntime "k8s.io/apimachinery/pkg/util/runtime"
  26. "k8s.io/apimachinery/pkg/util/wait"
  27. "k8s.io/apimachinery/pkg/watch"
  28. "k8s.io/client-go/kubernetes"
  29. "k8s.io/client-go/kubernetes/scheme"
  30. v1core "k8s.io/client-go/kubernetes/typed/core/v1"
  31. corelisters "k8s.io/client-go/listers/core/v1"
  32. "k8s.io/client-go/tools/cache"
  33. "k8s.io/client-go/tools/record"
  34. "k8s.io/client-go/util/workqueue"
  35. "k8s.io/klog/v2"
  36. sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1"
  37. sednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/typed/sedna/v1alpha1"
  38. sednav1listers "github.com/kubeedge/sedna/pkg/client/listers/sedna/v1alpha1"
  39. "github.com/kubeedge/sedna/pkg/globalmanager/config"
  40. "github.com/kubeedge/sedna/pkg/globalmanager/runtime"
  41. "github.com/kubeedge/sedna/pkg/globalmanager/utils"
  42. )
  43. const (
  44. // KindName is the kind name of CR this controller controls
  45. KindName = "FederatedLearningJob"
  46. // Name is this controller name
  47. Name = "FederatedLearning"
  48. )
  49. const (
  50. jobStageAgg = "Aggregation"
  51. jobStageTrain = "Training"
  52. )
  53. // Kind contains the schema.GroupVersionKind for this controller type.
  54. var Kind = sednav1.SchemeGroupVersion.WithKind(KindName)
  55. // Controller ensures that all FederatedLearningJob objects have corresponding pods to
  56. // run their configured workload.
  57. type Controller struct {
  58. kubeClient kubernetes.Interface
  59. client sednaclientset.SednaV1alpha1Interface
  60. // podStoreSynced returns true if the pod store has been synced at least once.
  61. // Added as a member to the struct to allow injection for testing.
  62. podStoreSynced cache.InformerSynced
  63. // jobStoreSynced returns true if the FederatedLearningJob store has been synced at least once.
  64. // Added as a member to the struct to allow injection for testing.
  65. jobStoreSynced cache.InformerSynced
  66. // A store of jobs
  67. jobLister sednav1listers.FederatedLearningJobLister
  68. // A store of pods, populated by the podController
  69. podStore corelisters.PodLister
  70. // FLJobs that need to be updated
  71. queue workqueue.RateLimitingInterface
  72. recorder record.EventRecorder
  73. cfg *config.ControllerConfig
  74. sendToEdgeFunc runtime.DownstreamSendFunc
  75. // map to record the pods that are recreated
  76. recreatedPods sync.Map
  77. flSelector labels.Selector
  78. aggServiceHost string
  79. preventRecreation bool
  80. }
  81. // Run starts the main goroutine responsible for watching and syncing jobs.
  82. func (c *Controller) Run(stopCh <-chan struct{}) {
  83. workers := 1
  84. defer utilruntime.HandleCrash()
  85. defer c.queue.ShutDown()
  86. klog.Infof("Starting %s controller", Name)
  87. defer klog.Infof("Shutting down %s controller", Name)
  88. if !cache.WaitForNamedCacheSync(Name, stopCh, c.podStoreSynced, c.jobStoreSynced) {
  89. klog.Errorf("failed to wait for %s caches to sync", Name)
  90. return
  91. }
  92. klog.Infof("Starting %s workers", Name)
  93. for i := 0; i < workers; i++ {
  94. go wait.Until(c.worker, time.Second, stopCh)
  95. }
  96. <-stopCh
  97. }
  98. // enqueueByPod enqueues the FederatedLearningJob object of the specified pod.
  99. func (c *Controller) enqueueByPod(pod *v1.Pod, immediate bool) {
  100. controllerRef := metav1.GetControllerOf(pod)
  101. if controllerRef == nil {
  102. return
  103. }
  104. if controllerRef.Kind != Kind.Kind {
  105. return
  106. }
  107. job, err := c.jobLister.FederatedLearningJobs(pod.Namespace).Get(controllerRef.Name)
  108. if err != nil {
  109. return
  110. }
  111. if job.UID != controllerRef.UID {
  112. return
  113. }
  114. c.enqueueController(job, immediate)
  115. }
  116. // When a pod is created, enqueue the controller that manages it and update it's expectations.
  117. func (c *Controller) addPod(obj interface{}) {
  118. pod := obj.(*v1.Pod)
  119. if pod.DeletionTimestamp != nil {
  120. // on a restart of the controller, it's possible a new pod shows up in a state that
  121. // is already pending deletion. Prevent the pod from being a creation observation.
  122. c.deletePod(pod)
  123. return
  124. }
  125. // backoff to queue when PodFailed
  126. immediate := pod.Status.Phase != v1.PodFailed
  127. c.enqueueByPod(pod, immediate)
  128. }
  129. // When a pod is updated, figure out what federatedlearning job manage it and wake them up.
  130. func (c *Controller) updatePod(old, cur interface{}) {
  131. curPod := cur.(*v1.Pod)
  132. oldPod := old.(*v1.Pod)
  133. // no pod update, no queue
  134. if curPod.ResourceVersion == oldPod.ResourceVersion {
  135. return
  136. }
  137. c.addPod(curPod)
  138. }
  139. // deletePod enqueues the FederatedLearningJob obj When a pod is deleted
  140. func (c *Controller) deletePod(obj interface{}) {
  141. pod, ok := obj.(*v1.Pod)
  142. // comment from https://github.com/kubernetes/kubernetes/blob/master/pkg/controller/job/job_controller.go
  143. // When a delete is dropped, the relist will notice a pod in the store not
  144. // in the list, leading to the insertion of a tombstone object which contains
  145. // the deleted key/value. Note that this value might be stale. If the pod
  146. // changed labels the new FederatedLearningJob will not be woken up till the periodic resync.
  147. if !ok {
  148. tombstone, ok := obj.(cache.DeletedFinalStateUnknown)
  149. if !ok {
  150. klog.Warningf("couldn't get object from tombstone %+v", obj)
  151. return
  152. }
  153. pod, ok = tombstone.Obj.(*v1.Pod)
  154. if !ok {
  155. klog.Warningf("tombstone contained object that is not a pod %+v", obj)
  156. return
  157. }
  158. }
  159. c.enqueueByPod(pod, true)
  160. // when the CRD is updated, do not recreate the pod
  161. // if c.preventRecreation is true, do not recreate the pod
  162. if c.preventRecreation {
  163. return
  164. }
  165. // if pod is manually deleted, recreate it
  166. // first check if the pod is owned by a FederatedLearningJob
  167. controllerRef := metav1.GetControllerOf(pod)
  168. if controllerRef == nil || controllerRef.Kind != Kind.Kind {
  169. return
  170. }
  171. _, err := c.jobLister.FederatedLearningJobs(pod.Namespace).Get(controllerRef.Name)
  172. if err != nil {
  173. if errors.IsNotFound(err) {
  174. // The FederatedLearningJob has been deleted, and the Pod should not be rebuilt.
  175. klog.Infof("FederatedLearningJob %s/%s not found, skipping pod recreation", pod.Namespace, controllerRef.Name)
  176. return
  177. }
  178. klog.Errorf("Error getting FederatedLearningJob %s/%s: %v", pod.Namespace, controllerRef.Name, err)
  179. return
  180. }
  181. // then check if the pod is already in the map
  182. if _, exists := c.recreatedPods.Load(pod.Name); exists {
  183. return
  184. }
  185. // if not, recreate it
  186. klog.Infof("Pod %s/%s deleted, recreating...", pod.Namespace, pod.Name)
  187. // Create a deep copy of the old pod
  188. newPod := pod.DeepCopy()
  189. // Reset the resource version and UID as they are unique to each object
  190. newPod.ResourceVersion = ""
  191. newPod.UID = ""
  192. // Clear the status
  193. newPod.Status = v1.PodStatus{}
  194. // Remove the deletion timestamp
  195. newPod.DeletionTimestamp = nil
  196. // Remove the deletion grace period seconds
  197. newPod.DeletionGracePeriodSeconds = nil
  198. _, err = c.kubeClient.CoreV1().Pods(pod.Namespace).Create(context.TODO(), newPod, metav1.CreateOptions{})
  199. if err != nil {
  200. return
  201. }
  202. klog.Infof("Successfully recreated pod %s/%s", newPod.Namespace, newPod.Name)
  203. // mark the pod as recreated
  204. c.recreatedPods.Store(newPod.Name, true)
  205. // set a timer to delete the record from the map after a while
  206. go func() {
  207. time.Sleep(5 * time.Second)
  208. c.recreatedPods.Delete(pod.Name)
  209. }()
  210. }
  211. // obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item,
  212. // immediate tells the controller to update the status right away, and should
  213. // happen ONLY when there was a successful pod run.
  214. func (c *Controller) enqueueController(obj interface{}, immediate bool) {
  215. key, err := cache.DeletionHandlingMetaNamespaceKeyFunc(obj)
  216. if err != nil {
  217. klog.Warningf("Couldn't get key for object %+v: %v", obj, err)
  218. return
  219. }
  220. backoff := time.Duration(0)
  221. if !immediate {
  222. backoff = runtime.GetBackoff(c.queue, key)
  223. }
  224. c.queue.AddAfter(key, backoff)
  225. }
  226. // worker runs a worker thread that just dequeues items, processes them, and marks them done.
  227. // It enforces that the syncHandler is never invoked concurrently with the same key.
  228. func (c *Controller) worker() {
  229. for c.processNextWorkItem() {
  230. }
  231. }
  232. func (c *Controller) processNextWorkItem() bool {
  233. key, quit := c.queue.Get()
  234. if quit {
  235. return false
  236. }
  237. defer c.queue.Done(key)
  238. forget, err := c.sync(key.(string))
  239. if err == nil {
  240. if forget {
  241. c.queue.Forget(key)
  242. }
  243. return true
  244. }
  245. klog.Warningf("Error syncing federatedlearning job: %v", err)
  246. c.queue.AddRateLimited(key)
  247. return true
  248. }
  249. // sync will sync the FederatedLearningJob with the given key if it has had its expectations fulfilled, meaning
  250. // it did not expect to see any more of its pods created or deleted. This function is not meant to be invoked
  251. // concurrently with the same key.
  252. func (c *Controller) sync(key string) (bool, error) {
  253. startTime := time.Now()
  254. defer func() {
  255. klog.V(4).Infof("Finished syncing federatedlearning job %q (%v)", key, time.Since(startTime))
  256. }()
  257. ns, name, err := cache.SplitMetaNamespaceKey(key)
  258. if err != nil {
  259. return false, err
  260. }
  261. if len(ns) == 0 || len(name) == 0 {
  262. return false, fmt.Errorf("invalid federatedlearning job key %q: either namespace or name is missing", key)
  263. }
  264. sharedJob, err := c.jobLister.FederatedLearningJobs(ns).Get(name)
  265. if err != nil {
  266. if errors.IsNotFound(err) {
  267. klog.V(4).Infof("%s %v has been deleted", Name, key)
  268. return true, nil
  269. }
  270. return false, err
  271. }
  272. job := *sharedJob
  273. // set kind for FederatedLearningJob in case that the kind is None
  274. job.SetGroupVersionKind(Kind)
  275. // if job was finished previously, we don't want to redo the termination
  276. if IsJobFinished(&job) {
  277. return true, nil
  278. }
  279. c.flSelector, _ = runtime.GenerateSelector(&job)
  280. pods, err := c.podStore.Pods(job.Namespace).List(c.flSelector)
  281. if err != nil {
  282. return false, err
  283. }
  284. activePods := utils.FilterActivePods(pods)
  285. active := int32(len(activePods))
  286. var activeAgg int32
  287. var activeTrain int32
  288. succeeded, failed := countPods(pods)
  289. conditions := len(job.Status.Conditions)
  290. // set StartTime when job is handled firstly
  291. if job.Status.StartTime == nil {
  292. now := metav1.Now()
  293. job.Status.StartTime = &now
  294. }
  295. var manageJobErr error
  296. var manageAggErr error
  297. var manageTrainErr error
  298. jobFailed := false
  299. var failureReason string
  300. var failureMessage string
  301. phase := job.Status.Phase
  302. if failed > 0 {
  303. jobFailed = true
  304. failureReason = "workerFailed"
  305. failureMessage = "the worker of FederatedLearningJob failed"
  306. }
  307. if jobFailed {
  308. job.Status.Conditions = append(job.Status.Conditions, NewJobCondition(sednav1.FLJobCondFailed, failureReason, failureMessage))
  309. job.Status.Phase = sednav1.FLJobFailed
  310. c.recorder.Event(&job, v1.EventTypeWarning, failureReason, failureMessage)
  311. } else {
  312. // in the First time, we create the pods
  313. if len(pods) == 0 {
  314. activeAgg, manageAggErr = c.createAggPod(&job)
  315. createServiceErr := c.createService(&job)
  316. if createServiceErr != nil {
  317. return false, createServiceErr
  318. }
  319. activeTrain, manageTrainErr = c.createTrainPod(&job)
  320. active = activeAgg + activeTrain
  321. }
  322. complete := false
  323. if succeeded > 0 && active == 0 {
  324. complete = true
  325. }
  326. if complete {
  327. job.Status.Conditions = append(job.Status.Conditions, NewJobCondition(sednav1.FLJobCondComplete, "", ""))
  328. now := metav1.Now()
  329. job.Status.CompletionTime = &now
  330. c.recorder.Event(&job, v1.EventTypeNormal, "Completed", "FederatedLearningJob completed")
  331. job.Status.Phase = sednav1.FLJobSucceeded
  332. } else {
  333. job.Status.Phase = sednav1.FLJobRunning
  334. }
  335. }
  336. // Combine manageAggErr and manageTrainErr into a single error
  337. if manageAggErr != nil || manageTrainErr != nil {
  338. manageJobErr = fmt.Errorf("aggregator error: %v, training error: %v", manageAggErr, manageTrainErr)
  339. }
  340. forget := false
  341. // Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true
  342. // This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to
  343. // improve the job backoff policy when parallelism > 1 and few FLJobs failed but others succeed.
  344. // In this case, we should clear the backoff delay.
  345. if job.Status.Succeeded < succeeded {
  346. forget = true
  347. }
  348. // no need to update the job if the status hasn't changed since last time
  349. if job.Status.Active != active || job.Status.Succeeded != succeeded || job.Status.Failed != failed || len(job.Status.Conditions) != conditions || job.Status.Phase != phase {
  350. job.Status.Active = active
  351. job.Status.Succeeded = succeeded
  352. job.Status.Failed = failed
  353. c.updateJobStatus(&job)
  354. if jobFailed && !IsJobFinished(&job) {
  355. // returning an error will re-enqueue FederatedLearningJob after the backoff period
  356. return forget, fmt.Errorf("failed pod(s) detected for FederatedLearningJob key %q", key)
  357. }
  358. forget = true
  359. }
  360. return forget, manageJobErr
  361. }
  362. func NewJobCondition(conditionType sednav1.FLJobConditionType, reason, message string) sednav1.FLJobCondition {
  363. return sednav1.FLJobCondition{
  364. Type: conditionType,
  365. Status: v1.ConditionTrue,
  366. LastProbeTime: metav1.Now(),
  367. LastHeartbeatTime: metav1.Now(),
  368. Reason: reason,
  369. Message: message,
  370. }
  371. }
  372. // countPods returns number of succeeded and failed pods
  373. func countPods(pods []*v1.Pod) (succeeded, failed int32) {
  374. succeeded = int32(filterPods(pods, v1.PodSucceeded))
  375. failed = int32(filterPods(pods, v1.PodFailed))
  376. return
  377. }
  378. func (c *Controller) updateJobStatus(job *sednav1.FederatedLearningJob) error {
  379. jobClient := c.client.FederatedLearningJobs(job.Namespace)
  380. return runtime.RetryUpdateStatus(job.Name, job.Namespace, func() error {
  381. newJob, err := jobClient.Get(context.TODO(), job.Name, metav1.GetOptions{})
  382. if err != nil {
  383. return err
  384. }
  385. newJob.Status = job.Status
  386. _, err = jobClient.UpdateStatus(context.TODO(), newJob, metav1.UpdateOptions{})
  387. return err
  388. })
  389. }
  390. // filterPods returns pods based on their phase.
  391. func filterPods(pods []*v1.Pod, phase v1.PodPhase) int {
  392. result := 0
  393. for i := range pods {
  394. if phase == pods[i].Status.Phase {
  395. result++
  396. }
  397. }
  398. return result
  399. }
  400. func IsJobFinished(j *sednav1.FederatedLearningJob) bool {
  401. for _, c := range j.Status.Conditions {
  402. if (c.Type == sednav1.FLJobCondComplete || c.Type == sednav1.FLJobCondFailed) && c.Status == v1.ConditionTrue {
  403. return true
  404. }
  405. }
  406. return false
  407. }
  408. func (c *Controller) getSecret(namespace, name, ownerStr string) (secret *v1.Secret, err error) {
  409. if name != "" {
  410. secret, err = c.kubeClient.CoreV1().Secrets(namespace).Get(context.TODO(), name, metav1.GetOptions{})
  411. if err != nil {
  412. err = fmt.Errorf("failed to get the secret %s for %s: %w",
  413. name,
  414. ownerStr, err)
  415. }
  416. }
  417. return
  418. }
  419. func (c *Controller) getModelAndItsSecret(ctx context.Context, namespace, name string) (model *sednav1.Model, secret *v1.Secret, err error) {
  420. if name != "" {
  421. model, err = c.client.Models(namespace).Get(ctx, name, metav1.GetOptions{})
  422. if err != nil {
  423. err = fmt.Errorf("failed to get the model %s: %w", name, err)
  424. }
  425. }
  426. if model != nil {
  427. secret, err = c.getSecret(
  428. namespace,
  429. model.Spec.CredentialName,
  430. fmt.Sprintf("model %s", name),
  431. )
  432. }
  433. return
  434. }
  435. func (c *Controller) getDatasetAndItsSecret(ctx context.Context, namespace, name string) (dataset *sednav1.Dataset, secret *v1.Secret, err error) {
  436. if name != "" {
  437. dataset, err = c.client.Datasets(namespace).Get(ctx, name, metav1.GetOptions{})
  438. if err != nil {
  439. err = fmt.Errorf("failed to get the dataset %s: %w", name, err)
  440. }
  441. }
  442. if dataset != nil {
  443. secret, err = c.getSecret(
  444. namespace,
  445. dataset.Spec.CredentialName,
  446. fmt.Sprintf("model %s", name),
  447. )
  448. }
  449. return
  450. }
  451. // addWorkerMount adds CR(e.g., model, dataset)'s url to worker mount.
  452. func (c *Controller) addWorkerMount(workerParam *runtime.WorkerParam, url string, envName string,
  453. secret *v1.Secret, downloadByInitializer bool) {
  454. if url != "" {
  455. workerParam.Mounts = append(workerParam.Mounts,
  456. runtime.WorkerMount{
  457. URL: &runtime.MountURL{
  458. URL: url,
  459. Secret: secret,
  460. DownloadByInitializer: downloadByInitializer,
  461. },
  462. EnvName: envName,
  463. },
  464. )
  465. }
  466. }
  467. // addTransmitterToWorkerParam adds transmitter to the WorkerParam
  468. func (c *Controller) addTransmitterToWorkerParam(param *runtime.WorkerParam, job *sednav1.FederatedLearningJob) error {
  469. transmitter := job.Spec.Transmitter
  470. if transmitter.S3 != nil {
  471. param.Env["TRANSMITTER"] = "s3"
  472. url := transmitter.S3.AggregationDataPath
  473. secret, err := c.getSecret(
  474. job.Namespace,
  475. transmitter.S3.CredentialName,
  476. fmt.Sprintf("for aggregationData: %s", url))
  477. if err != nil {
  478. return err
  479. }
  480. param.Mounts = append(param.Mounts,
  481. runtime.WorkerMount{
  482. URL: &runtime.MountURL{
  483. URL: url,
  484. Secret: secret,
  485. },
  486. EnvName: "AGG_DATA_PATH",
  487. },
  488. )
  489. } else {
  490. param.Env["TRANSMITTER"] = "ws"
  491. }
  492. return nil
  493. }
  494. func (c *Controller) createAggPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
  495. active = 0
  496. ctx := context.Background()
  497. pretrainedModelName := job.Spec.PretrainedModel.Name
  498. pretrainedModel, pretrainedModelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, pretrainedModelName)
  499. if err != nil {
  500. return active, err
  501. }
  502. modelName := job.Spec.AggregationWorker.Model.Name
  503. model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
  504. if err != nil {
  505. return active, fmt.Errorf("failed to get aggregation model: %w", err)
  506. }
  507. participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
  508. // deliver pod for aggregation worker
  509. aggWorker := job.Spec.AggregationWorker
  510. // Configure aggregation worker's mounts and envs
  511. var aggPort int32 = 7363
  512. var aggWorkerParam runtime.WorkerParam
  513. aggWorkerParam.Env = map[string]string{
  514. "NAMESPACE": job.Namespace,
  515. "WORKER_NAME": "aggworker-" + utilrand.String(5),
  516. "JOB_NAME": job.Name,
  517. "AGG_BIND_PORT": strconv.Itoa(int(aggPort)),
  518. "PARTICIPANTS_COUNT": participantsCount,
  519. }
  520. if err := c.addTransmitterToWorkerParam(&aggWorkerParam, job); err != nil {
  521. return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
  522. }
  523. aggWorkerParam.WorkerType = jobStageAgg
  524. aggWorkerParam.RestartPolicy = v1.RestartPolicyOnFailure
  525. c.addWorkerMount(&aggWorkerParam, model.Spec.URL, "MODEL_URL",
  526. modelSecret, true)
  527. if pretrainedModel != nil {
  528. c.addWorkerMount(&aggWorkerParam, pretrainedModel.Spec.URL, "PRETRAINED_MODEL_URL",
  529. pretrainedModelSecret, true)
  530. }
  531. aggWorker.Template.Name = fmt.Sprintf("%s-aggworker", job.Name)
  532. // create aggpod based on configured parameters
  533. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, &aggWorkerParam)
  534. if err != nil {
  535. return active, fmt.Errorf("failed to create aggregation worker: %w", err)
  536. }
  537. klog.Infof("create aggpod success")
  538. active++
  539. return
  540. }
  541. func (c *Controller) createTrainPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
  542. active = 0
  543. ctx := context.Background()
  544. pretrainedModelName := job.Spec.PretrainedModel.Name
  545. pretrainedModel, pretrainedModelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, pretrainedModelName)
  546. if err != nil {
  547. return active, fmt.Errorf("failed to get pretrained model: %w", err)
  548. }
  549. modelName := job.Spec.AggregationWorker.Model.Name
  550. model, modelSecret, err := c.getModelAndItsSecret(ctx, job.Namespace, modelName)
  551. if err != nil {
  552. return active, fmt.Errorf("failed to get aggregation model: %w", err)
  553. }
  554. var aggPort int32 = 7363
  555. participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
  556. // deliver pod for training worker
  557. for i, trainingWorker := range job.Spec.TrainingWorkers {
  558. // Configure training worker's mounts and envs
  559. var workerParam runtime.WorkerParam
  560. c.addWorkerMount(&workerParam, model.Spec.URL, "MODEL_URL", modelSecret, true)
  561. if pretrainedModel != nil {
  562. c.addWorkerMount(&workerParam, pretrainedModel.Spec.URL, "PRETRAINED_MODEL_URL",
  563. pretrainedModelSecret, true)
  564. }
  565. datasetName := trainingWorker.Dataset.Name
  566. dataset, datasetSecret, err := c.getDatasetAndItsSecret(ctx, job.Namespace, datasetName)
  567. if err != nil {
  568. return active, err
  569. }
  570. c.addWorkerMount(&workerParam, dataset.Spec.URL, "TRAIN_DATASET_URL",
  571. datasetSecret, true)
  572. workerParam.Env = map[string]string{
  573. "AGG_PORT": strconv.Itoa(int(aggPort)),
  574. "AGG_IP": c.aggServiceHost,
  575. "WORKER_NAME": "trainworker-" + utilrand.String(5),
  576. "JOB_NAME": job.Name,
  577. "PARTICIPANTS_COUNT": participantsCount,
  578. "NAMESPACE": job.Namespace,
  579. "MODEL_NAME": modelName,
  580. "DATASET_NAME": datasetName,
  581. "LC_SERVER": c.cfg.LC.Server,
  582. }
  583. workerParam.WorkerType = runtime.TrainPodType
  584. workerParam.HostNetwork = true
  585. workerParam.RestartPolicy = v1.RestartPolicyOnFailure
  586. if err := c.addTransmitterToWorkerParam(&workerParam, job); err != nil {
  587. return active, fmt.Errorf("failed to add transmitter to worker param: %w", err)
  588. }
  589. trainingWorker.Template.Name = fmt.Sprintf("%s-trainworker-%d", job.Name, i)
  590. // create training worker based on configured parameters
  591. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, &workerParam)
  592. if err != nil {
  593. return active, fmt.Errorf("failed to create %dth training worker: %w", i, err)
  594. }
  595. active++
  596. }
  597. return
  598. }
  599. // New creates a new federated learning job controller that keeps the relevant pods
  600. // in sync with their corresponding FederatedLearningJob objects.
  601. func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
  602. cfg := cc.Config
  603. podInformer := cc.KubeInformerFactory.Core().V1().Pods()
  604. jobInformer := cc.SednaInformerFactory.Sedna().V1alpha1().FederatedLearningJobs()
  605. eventBroadcaster := record.NewBroadcaster()
  606. eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: cc.KubeClient.CoreV1().Events("")})
  607. fc := &Controller{
  608. kubeClient: cc.KubeClient,
  609. client: cc.SednaClient.SednaV1alpha1(),
  610. queue: workqueue.NewRateLimitingQueueWithConfig(
  611. workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff),
  612. workqueue.RateLimitingQueueConfig{Name: Name},
  613. ),
  614. recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: Name + "-controller"}),
  615. cfg: cfg,
  616. }
  617. jobInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  618. AddFunc: func(obj interface{}) {
  619. fc.enqueueController(obj, true)
  620. // when a federated learning job is added,
  621. // send it to edge's LC.
  622. fc.syncToEdge(watch.Added, obj)
  623. },
  624. UpdateFunc: fc.updateJob,
  625. DeleteFunc: func(obj interface{}) {
  626. fc.enqueueController(obj, true)
  627. // when a federated learning job is deleted,
  628. // send it to edge's LC.
  629. fc.syncToEdge(watch.Deleted, obj)
  630. },
  631. })
  632. fc.jobLister = jobInformer.Lister()
  633. fc.jobStoreSynced = jobInformer.Informer().HasSynced
  634. podInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  635. AddFunc: fc.addPod,
  636. UpdateFunc: fc.updatePod,
  637. DeleteFunc: fc.deletePod,
  638. })
  639. fc.podStore = podInformer.Lister()
  640. fc.podStoreSynced = podInformer.Informer().HasSynced
  641. return fc, nil
  642. }
  643. func (c *Controller) updateJob(old, cur interface{}) {
  644. oldJob, ok := old.(*sednav1.FederatedLearningJob)
  645. if !ok {
  646. return
  647. }
  648. curJob, ok := cur.(*sednav1.FederatedLearningJob)
  649. if !ok {
  650. return
  651. }
  652. if oldJob.ResourceVersion == curJob.ResourceVersion {
  653. return
  654. }
  655. if oldJob.Generation != curJob.Generation {
  656. pods, err := c.podStore.Pods(curJob.Namespace).List(c.flSelector)
  657. if err != nil {
  658. klog.Errorf("Failed to list pods: %v", err)
  659. }
  660. c.preventRecreation = true
  661. for _, pod := range pods {
  662. // delete all pods
  663. c.kubeClient.CoreV1().Pods(pod.Namespace).Delete(context.TODO(), pod.Name, metav1.DeleteOptions{})
  664. klog.Infof("CRD modified, so we deleted pod %s/%s", pod.Namespace, pod.Name)
  665. }
  666. klog.Infof("CRD modified, so we deleted all pods, and will create new pods")
  667. curJob.SetGroupVersionKind(Kind)
  668. _, err = c.createAggPod(curJob)
  669. if err != nil {
  670. klog.Errorf("Failed to create aggregation worker: %v", err)
  671. }
  672. _, err = c.createTrainPod(curJob)
  673. if err != nil {
  674. klog.Errorf("Failed to create training workers: %v", err)
  675. }
  676. // update the job status
  677. c.client.FederatedLearningJobs(curJob.Namespace).Update(context.TODO(), curJob, metav1.UpdateOptions{})
  678. }
  679. c.preventRecreation = false
  680. c.enqueueController(curJob, true)
  681. // when a federated learning job is updated,
  682. // send it to edge's LC as Added event.
  683. c.syncToEdge(watch.Added, curJob)
  684. }
  685. // create edgemesh service for the job
  686. func (c *Controller) createService(job *sednav1.FederatedLearningJob) (err error) {
  687. var aggPort int32 = 7363
  688. c.aggServiceHost, err = runtime.CreateEdgeMeshService(c.kubeClient, job, jobStageAgg, aggPort)
  689. if err != nil {
  690. return err
  691. }
  692. return nil
  693. }