You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

federatedlearningjob.go 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594
  1. /*
  2. Copyright 2021 The KubeEdge Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package federatedlearning
  14. import (
  15. "context"
  16. "fmt"
  17. "strconv"
  18. "time"
  19. v1 "k8s.io/api/core/v1"
  20. "k8s.io/apimachinery/pkg/api/errors"
  21. metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
  22. utilrand "k8s.io/apimachinery/pkg/util/rand"
  23. utilruntime "k8s.io/apimachinery/pkg/util/runtime"
  24. "k8s.io/apimachinery/pkg/util/wait"
  25. "k8s.io/apimachinery/pkg/watch"
  26. "k8s.io/client-go/kubernetes"
  27. "k8s.io/client-go/kubernetes/scheme"
  28. v1core "k8s.io/client-go/kubernetes/typed/core/v1"
  29. corelisters "k8s.io/client-go/listers/core/v1"
  30. "k8s.io/client-go/tools/cache"
  31. "k8s.io/client-go/tools/record"
  32. "k8s.io/client-go/util/workqueue"
  33. "k8s.io/klog/v2"
  34. k8scontroller "k8s.io/kubernetes/pkg/controller"
  35. sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1"
  36. sednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/typed/sedna/v1alpha1"
  37. sednav1listers "github.com/kubeedge/sedna/pkg/client/listers/sedna/v1alpha1"
  38. "github.com/kubeedge/sedna/pkg/globalmanager/config"
  39. "github.com/kubeedge/sedna/pkg/globalmanager/runtime"
  40. )
  41. const (
  42. // KindName is the kind name of CR this controller controls
  43. KindName = "FederatedLearningJob"
  44. // Name is this controller name
  45. Name = "FederatedLearning"
  46. )
  47. const (
  48. jobStageAgg = "Aggregation"
  49. jobStageTrain = "Training"
  50. )
  51. // Kind contains the schema.GroupVersionKind for this controller type.
  52. var Kind = sednav1.SchemeGroupVersion.WithKind(KindName)
  53. // Controller ensures that all FederatedLearningJob objects have corresponding pods to
  54. // run their configured workload.
  55. type Controller struct {
  56. kubeClient kubernetes.Interface
  57. client sednaclientset.SednaV1alpha1Interface
  58. // podStoreSynced returns true if the pod store has been synced at least once.
  59. // Added as a member to the struct to allow injection for testing.
  60. podStoreSynced cache.InformerSynced
  61. // jobStoreSynced returns true if the FederatedLearningJob store has been synced at least once.
  62. // Added as a member to the struct to allow injection for testing.
  63. jobStoreSynced cache.InformerSynced
  64. // A store of jobs
  65. jobLister sednav1listers.FederatedLearningJobLister
  66. // A store of pods, populated by the podController
  67. podStore corelisters.PodLister
  68. // FLJobs that need to be updated
  69. queue workqueue.RateLimitingInterface
  70. recorder record.EventRecorder
  71. cfg *config.ControllerConfig
  72. sendToEdgeFunc runtime.DownstreamSendFunc
  73. }
  74. // Run starts the main goroutine responsible for watching and syncing jobs.
  75. func (c *Controller) Run(stopCh <-chan struct{}) {
  76. workers := 1
  77. defer utilruntime.HandleCrash()
  78. defer c.queue.ShutDown()
  79. klog.Infof("Starting %s controller", Name)
  80. defer klog.Infof("Shutting down %s controller", Name)
  81. if !cache.WaitForNamedCacheSync(Name, stopCh, c.podStoreSynced, c.jobStoreSynced) {
  82. klog.Errorf("failed to wait for %s caches to sync", Name)
  83. return
  84. }
  85. klog.Infof("Starting %s workers", Name)
  86. for i := 0; i < workers; i++ {
  87. go wait.Until(c.worker, time.Second, stopCh)
  88. }
  89. <-stopCh
  90. }
  91. // enqueueByPod enqueues the FederatedLearningJob object of the specified pod.
  92. func (c *Controller) enqueueByPod(pod *v1.Pod, immediate bool) {
  93. controllerRef := metav1.GetControllerOf(pod)
  94. if controllerRef == nil {
  95. return
  96. }
  97. if controllerRef.Kind != Kind.Kind {
  98. return
  99. }
  100. job, err := c.jobLister.FederatedLearningJobs(pod.Namespace).Get(controllerRef.Name)
  101. if err != nil {
  102. return
  103. }
  104. if job.UID != controllerRef.UID {
  105. return
  106. }
  107. c.enqueueController(job, immediate)
  108. }
  109. // When a pod is created, enqueue the controller that manages it and update it's expectations.
  110. func (c *Controller) addPod(obj interface{}) {
  111. pod := obj.(*v1.Pod)
  112. if pod.DeletionTimestamp != nil {
  113. // on a restart of the controller, it's possible a new pod shows up in a state that
  114. // is already pending deletion. Prevent the pod from being a creation observation.
  115. c.deletePod(pod)
  116. return
  117. }
  118. // backoff to queue when PodFailed
  119. immediate := pod.Status.Phase != v1.PodFailed
  120. c.enqueueByPod(pod, immediate)
  121. }
  122. // When a pod is updated, figure out what federatedlearning job manage it and wake them up.
  123. func (c *Controller) updatePod(old, cur interface{}) {
  124. curPod := cur.(*v1.Pod)
  125. oldPod := old.(*v1.Pod)
  126. // no pod update, no queue
  127. if curPod.ResourceVersion == oldPod.ResourceVersion {
  128. return
  129. }
  130. c.addPod(curPod)
  131. }
  132. // deletePod enqueues the FederatedLearningJob obj When a pod is deleted
  133. func (c *Controller) deletePod(obj interface{}) {
  134. pod, ok := obj.(*v1.Pod)
  135. // comment from https://github.com/kubernetes/kubernetes/blob/master/pkg/controller/job/job_controller.go
  136. // When a delete is dropped, the relist will notice a pod in the store not
  137. // in the list, leading to the insertion of a tombstone object which contains
  138. // the deleted key/value. Note that this value might be stale. If the pod
  139. // changed labels the new FederatedLearningJob will not be woken up till the periodic resync.
  140. if !ok {
  141. tombstone, ok := obj.(cache.DeletedFinalStateUnknown)
  142. if !ok {
  143. klog.Warningf("couldn't get object from tombstone %+v", obj)
  144. return
  145. }
  146. pod, ok = tombstone.Obj.(*v1.Pod)
  147. if !ok {
  148. klog.Warningf("tombstone contained object that is not a pod %+v", obj)
  149. return
  150. }
  151. }
  152. c.enqueueByPod(pod, true)
  153. }
  154. // obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item,
  155. // immediate tells the controller to update the status right away, and should
  156. // happen ONLY when there was a successful pod run.
  157. func (c *Controller) enqueueController(obj interface{}, immediate bool) {
  158. key, err := k8scontroller.KeyFunc(obj)
  159. if err != nil {
  160. klog.Warningf("Couldn't get key for object %+v: %v", obj, err)
  161. return
  162. }
  163. backoff := time.Duration(0)
  164. if !immediate {
  165. backoff = runtime.GetBackoff(c.queue, key)
  166. }
  167. c.queue.AddAfter(key, backoff)
  168. }
  169. // worker runs a worker thread that just dequeues items, processes them, and marks them done.
  170. // It enforces that the syncHandler is never invoked concurrently with the same key.
  171. func (c *Controller) worker() {
  172. for c.processNextWorkItem() {
  173. }
  174. }
  175. func (c *Controller) processNextWorkItem() bool {
  176. key, quit := c.queue.Get()
  177. if quit {
  178. return false
  179. }
  180. defer c.queue.Done(key)
  181. forget, err := c.sync(key.(string))
  182. if err == nil {
  183. if forget {
  184. c.queue.Forget(key)
  185. }
  186. return true
  187. }
  188. klog.Warningf("Error syncing federatedlearning job: %v", err)
  189. c.queue.AddRateLimited(key)
  190. return true
  191. }
  192. // sync will sync the FederatedLearningJob with the given key if it has had its expectations fulfilled, meaning
  193. // it did not expect to see any more of its pods created or deleted. This function is not meant to be invoked
  194. // concurrently with the same key.
  195. func (c *Controller) sync(key string) (bool, error) {
  196. startTime := time.Now()
  197. defer func() {
  198. klog.V(4).Infof("Finished syncing federatedlearning job %q (%v)", key, time.Since(startTime))
  199. }()
  200. ns, name, err := cache.SplitMetaNamespaceKey(key)
  201. if err != nil {
  202. return false, err
  203. }
  204. if len(ns) == 0 || len(name) == 0 {
  205. return false, fmt.Errorf("invalid federatedlearning job key %q: either namespace or name is missing", key)
  206. }
  207. sharedJob, err := c.jobLister.FederatedLearningJobs(ns).Get(name)
  208. if err != nil {
  209. if errors.IsNotFound(err) {
  210. klog.V(4).Infof("%s %v has been deleted", Name, key)
  211. return true, nil
  212. }
  213. return false, err
  214. }
  215. job := *sharedJob
  216. // set kind for FederatedLearningJob in case that the kind is None
  217. job.SetGroupVersionKind(Kind)
  218. // if job was finished previously, we don't want to redo the termination
  219. if IsJobFinished(&job) {
  220. return true, nil
  221. }
  222. selector, _ := runtime.GenerateSelector(&job)
  223. pods, err := c.podStore.Pods(job.Namespace).List(selector)
  224. if err != nil {
  225. return false, err
  226. }
  227. activePods := k8scontroller.FilterActivePods(pods)
  228. active := int32(len(activePods))
  229. succeeded, failed := countPods(pods)
  230. conditions := len(job.Status.Conditions)
  231. // set StartTime when job is handled firstly
  232. if job.Status.StartTime == nil {
  233. now := metav1.Now()
  234. job.Status.StartTime = &now
  235. }
  236. var manageJobErr error
  237. jobFailed := false
  238. var failureReason string
  239. var failureMessage string
  240. phase := job.Status.Phase
  241. if failed > 0 {
  242. jobFailed = true
  243. failureReason = "workerFailed"
  244. failureMessage = "the worker of FederatedLearningJob failed"
  245. }
  246. if jobFailed {
  247. job.Status.Conditions = append(job.Status.Conditions, NewJobCondition(sednav1.FLJobCondFailed, failureReason, failureMessage))
  248. job.Status.Phase = sednav1.FLJobFailed
  249. c.recorder.Event(&job, v1.EventTypeWarning, failureReason, failureMessage)
  250. } else {
  251. // in the First time, we create the pods
  252. if len(pods) == 0 {
  253. active, manageJobErr = c.createPod(&job)
  254. }
  255. complete := false
  256. if succeeded > 0 && active == 0 {
  257. complete = true
  258. }
  259. if complete {
  260. job.Status.Conditions = append(job.Status.Conditions, NewJobCondition(sednav1.FLJobCondComplete, "", ""))
  261. now := metav1.Now()
  262. job.Status.CompletionTime = &now
  263. c.recorder.Event(&job, v1.EventTypeNormal, "Completed", "FederatedLearningJob completed")
  264. job.Status.Phase = sednav1.FLJobSucceeded
  265. } else {
  266. job.Status.Phase = sednav1.FLJobRunning
  267. }
  268. }
  269. forget := false
  270. // Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true
  271. // This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to
  272. // improve the job backoff policy when parallelism > 1 and few FLJobs failed but others succeed.
  273. // In this case, we should clear the backoff delay.
  274. if job.Status.Succeeded < succeeded {
  275. forget = true
  276. }
  277. // no need to update the job if the status hasn't changed since last time
  278. if job.Status.Active != active || job.Status.Succeeded != succeeded || job.Status.Failed != failed || len(job.Status.Conditions) != conditions || job.Status.Phase != phase {
  279. job.Status.Active = active
  280. job.Status.Succeeded = succeeded
  281. job.Status.Failed = failed
  282. c.updateJobStatus(&job)
  283. if jobFailed && !IsJobFinished(&job) {
  284. // returning an error will re-enqueue FederatedLearningJob after the backoff period
  285. return forget, fmt.Errorf("failed pod(s) detected for FederatedLearningJob key %q", key)
  286. }
  287. forget = true
  288. }
  289. return forget, manageJobErr
  290. }
  291. func NewJobCondition(conditionType sednav1.FLJobConditionType, reason, message string) sednav1.FLJobCondition {
  292. return sednav1.FLJobCondition{
  293. Type: conditionType,
  294. Status: v1.ConditionTrue,
  295. LastProbeTime: metav1.Now(),
  296. LastHeartbeatTime: metav1.Now(),
  297. Reason: reason,
  298. Message: message,
  299. }
  300. }
  301. // countPods returns number of succeeded and failed pods
  302. func countPods(pods []*v1.Pod) (succeeded, failed int32) {
  303. succeeded = int32(filterPods(pods, v1.PodSucceeded))
  304. failed = int32(filterPods(pods, v1.PodFailed))
  305. return
  306. }
  307. func (c *Controller) updateJobStatus(job *sednav1.FederatedLearningJob) error {
  308. jobClient := c.client.FederatedLearningJobs(job.Namespace)
  309. return runtime.RetryUpdateStatus(job.Name, job.Namespace, func() error {
  310. newJob, err := jobClient.Get(context.TODO(), job.Name, metav1.GetOptions{})
  311. if err != nil {
  312. return err
  313. }
  314. newJob.Status = job.Status
  315. _, err = jobClient.UpdateStatus(context.TODO(), newJob, metav1.UpdateOptions{})
  316. return err
  317. })
  318. }
  319. // filterPods returns pods based on their phase.
  320. func filterPods(pods []*v1.Pod, phase v1.PodPhase) int {
  321. result := 0
  322. for i := range pods {
  323. if phase == pods[i].Status.Phase {
  324. result++
  325. }
  326. }
  327. return result
  328. }
  329. func IsJobFinished(j *sednav1.FederatedLearningJob) bool {
  330. for _, c := range j.Status.Conditions {
  331. if (c.Type == sednav1.FLJobCondComplete || c.Type == sednav1.FLJobCondFailed) && c.Status == v1.ConditionTrue {
  332. return true
  333. }
  334. }
  335. return false
  336. }
  337. func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
  338. active = 0
  339. ctx := context.Background()
  340. modelName := job.Spec.AggregationWorker.Model.Name
  341. model, err := c.client.Models(job.Namespace).Get(ctx, modelName, metav1.GetOptions{})
  342. if err != nil {
  343. return active, fmt.Errorf("failed to get model %s: %w",
  344. modelName, err)
  345. }
  346. secretName := model.Spec.CredentialName
  347. var modelSecret *v1.Secret
  348. if secretName != "" {
  349. modelSecret, _ = c.kubeClient.CoreV1().Secrets(job.Namespace).Get(context.TODO(), secretName, metav1.GetOptions{})
  350. }
  351. participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
  352. // deliver pod for aggregation worker
  353. aggWorker := job.Spec.AggregationWorker
  354. // Configure aggregation worker's mounts and envs
  355. var aggPort int32 = 7363
  356. var aggWorkerParam runtime.WorkerParam
  357. aggWorkerParam.Env = map[string]string{
  358. "NAMESPACE": job.Namespace,
  359. "WORKER_NAME": "aggworker-" + utilrand.String(5),
  360. "JOB_NAME": job.Name,
  361. "AGG_BIND_PORT": strconv.Itoa(int(aggPort)),
  362. "PARTICIPANTS_COUNT": participantsCount,
  363. }
  364. aggWorkerParam.WorkerType = jobStageAgg
  365. aggWorkerParam.RestartPolicy = v1.RestartPolicyOnFailure
  366. aggWorkerParam.Mounts = append(aggWorkerParam.Mounts,
  367. runtime.WorkerMount{
  368. URL: &runtime.MountURL{
  369. URL: model.Spec.URL,
  370. Secret: modelSecret,
  371. DownloadByInitializer: false,
  372. },
  373. EnvName: "MODEL_URL",
  374. },
  375. )
  376. // create aggpod based on configured parameters
  377. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, &aggWorkerParam)
  378. if err != nil {
  379. return active, fmt.Errorf("failed to create aggregation worker: %w", err)
  380. }
  381. active++
  382. var appIP string
  383. var aggServicePort int32
  384. // FIXME(llhuii): only the case that Spec.NodeName specified is support,
  385. // will support Spec.NodeSelector.
  386. appIP, err = runtime.GetNodeIPByName(c.kubeClient, job.Spec.AggregationWorker.Template.Spec.NodeName)
  387. if err != nil {
  388. return active, err
  389. }
  390. aggServicePort, err = runtime.CreateKubernetesService(c.kubeClient, job, jobStageAgg, aggPort, appIP)
  391. if err != nil {
  392. return active, err
  393. }
  394. // deliver pod for training worker
  395. for i, trainingWorker := range job.Spec.TrainingWorkers {
  396. // get dataseturl through parsing crd of dataset
  397. datasetName := trainingWorker.Dataset.Name
  398. dataset, err := c.client.Datasets(job.Namespace).Get(ctx, datasetName, metav1.GetOptions{})
  399. if err != nil {
  400. return active, fmt.Errorf("failed to get dataset %s: %w",
  401. datasetName, err)
  402. }
  403. secretName := dataset.Spec.CredentialName
  404. var datasetSecret *v1.Secret
  405. if secretName != "" {
  406. datasetSecret, _ = c.kubeClient.CoreV1().Secrets(job.Namespace).Get(context.TODO(), secretName, metav1.GetOptions{})
  407. }
  408. // Configure training worker's mounts and envs
  409. var workerParam runtime.WorkerParam
  410. workerParam.Mounts = append(workerParam.Mounts,
  411. runtime.WorkerMount{
  412. URL: &runtime.MountURL{
  413. URL: model.Spec.URL,
  414. Secret: modelSecret,
  415. },
  416. EnvName: "MODEL_URL",
  417. },
  418. runtime.WorkerMount{
  419. URL: &runtime.MountURL{
  420. URL: dataset.Spec.URL,
  421. Secret: datasetSecret,
  422. },
  423. EnvName: "TRAIN_DATASET_URL",
  424. },
  425. )
  426. workerParam.Env = map[string]string{
  427. "AGG_PORT": strconv.Itoa(int(aggServicePort)),
  428. "AGG_IP": appIP,
  429. "WORKER_NAME": "trainworker-" + utilrand.String(5),
  430. "JOB_NAME": job.Name,
  431. "PARTICIPANTS_COUNT": participantsCount,
  432. "NAMESPACE": job.Namespace,
  433. "MODEL_NAME": modelName,
  434. "DATASET_NAME": datasetName,
  435. "LC_SERVER": c.cfg.LC.Server,
  436. }
  437. workerParam.WorkerType = runtime.TrainPodType
  438. workerParam.HostNetwork = true
  439. workerParam.RestartPolicy = v1.RestartPolicyOnFailure
  440. // create training worker based on configured parameters
  441. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, &workerParam)
  442. if err != nil {
  443. return active, fmt.Errorf("failed to create %dth training worker: %w", i, err)
  444. }
  445. active++
  446. }
  447. return
  448. }
  449. // New creates a new federated learning job controller that keeps the relevant pods
  450. // in sync with their corresponding FederatedLearningJob objects.
  451. func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
  452. cfg := cc.Config
  453. podInformer := cc.KubeInformerFactory.Core().V1().Pods()
  454. jobInformer := cc.SednaInformerFactory.Sedna().V1alpha1().FederatedLearningJobs()
  455. eventBroadcaster := record.NewBroadcaster()
  456. eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: cc.KubeClient.CoreV1().Events("")})
  457. fc := &Controller{
  458. kubeClient: cc.KubeClient,
  459. client: cc.SednaClient.SednaV1alpha1(),
  460. queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), Name),
  461. recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: Name + "-controller"}),
  462. cfg: cfg,
  463. }
  464. jobInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  465. AddFunc: func(obj interface{}) {
  466. fc.enqueueController(obj, true)
  467. // when a federated learning job is added,
  468. // send it to edge's LC.
  469. fc.syncToEdge(watch.Added, obj)
  470. },
  471. UpdateFunc: func(old, cur interface{}) {
  472. fc.enqueueController(cur, true)
  473. // when a federated learning job is updated,
  474. // send it to edge's LC as Added event.
  475. fc.syncToEdge(watch.Added, cur)
  476. },
  477. DeleteFunc: func(obj interface{}) {
  478. fc.enqueueController(obj, true)
  479. // when a federated learning job is deleted,
  480. // send it to edge's LC.
  481. fc.syncToEdge(watch.Deleted, obj)
  482. },
  483. })
  484. fc.jobLister = jobInformer.Lister()
  485. fc.jobStoreSynced = jobInformer.Informer().HasSynced
  486. podInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  487. AddFunc: fc.addPod,
  488. UpdateFunc: fc.updatePod,
  489. DeleteFunc: fc.deletePod,
  490. })
  491. fc.podStore = podInformer.Lister()
  492. fc.podStoreSynced = podInformer.Informer().HasSynced
  493. return fc, nil
  494. }