You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

federatedlearningjob.go 18 kB


  1. /*
  2. Copyright 2021 The KubeEdge Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package federatedlearning
  14. import (
  15. "context"
  16. "fmt"
  17. "strconv"
  18. "time"
  19. v1 "k8s.io/api/core/v1"
  20. "k8s.io/apimachinery/pkg/api/errors"
  21. metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
  22. utilrand "k8s.io/apimachinery/pkg/util/rand"
  23. utilruntime "k8s.io/apimachinery/pkg/util/runtime"
  24. "k8s.io/apimachinery/pkg/util/wait"
  25. "k8s.io/client-go/kubernetes"
  26. "k8s.io/client-go/kubernetes/scheme"
  27. v1core "k8s.io/client-go/kubernetes/typed/core/v1"
  28. corelisters "k8s.io/client-go/listers/core/v1"
  29. "k8s.io/client-go/tools/cache"
  30. "k8s.io/client-go/tools/record"
  31. "k8s.io/client-go/util/workqueue"
  32. "k8s.io/klog/v2"
  33. k8scontroller "k8s.io/kubernetes/pkg/controller"
  34. sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1"
  35. sednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/typed/sedna/v1alpha1"
  36. sednav1listers "github.com/kubeedge/sedna/pkg/client/listers/sedna/v1alpha1"
  37. "github.com/kubeedge/sedna/pkg/globalmanager/config"
  38. "github.com/kubeedge/sedna/pkg/globalmanager/runtime"
  39. )
  40. const (
  41. // KindName is the kind name of CR this controller controls
  42. KindName = "FederatedLearningJob"
  43. // Name is this controller name
  44. Name = "FederatedLearning"
  45. )
  46. const (
  47. FLJobStageAgg = "Aggregation"
  48. FLJobStageTrain = "Training"
  49. )
  50. // Kind contains the schema.GroupVersionKind for this controller type.
  51. var Kind = sednav1.SchemeGroupVersion.WithKind(KindName)
  52. // Controller ensures that all FLJob objects have corresponding pods to
  53. // run their configured workload.
  54. type Controller struct {
  55. kubeClient kubernetes.Interface
  56. client sednaclientset.SednaV1alpha1Interface
  57. // podStoreSynced returns true if the pod store has been synced at least once.
  58. // Added as a member to the struct to allow injection for testing.
  59. podStoreSynced cache.InformerSynced
  60. // jobStoreSynced returns true if the flJob store has been synced at least once.
  61. // Added as a member to the struct to allow injection for testing.
  62. jobStoreSynced cache.InformerSynced
  63. // A store of jobs
  64. jobLister sednav1listers.FederatedLearningJobLister
  65. // A store of pods, populated by the podController
  66. podStore corelisters.PodLister
  67. // FLJobs that need to be updated
  68. queue workqueue.RateLimitingInterface
  69. recorder record.EventRecorder
  70. cfg *config.ControllerConfig
  71. }
  72. // Run starts the main goroutine responsible for watching and syncing jobs.
  73. func (c *Controller) Run(stopCh <-chan struct{}) {
  74. workers := 1
  75. defer utilruntime.HandleCrash()
  76. defer c.queue.ShutDown()
  77. klog.Infof("Starting %s controller", Name)
  78. defer klog.Infof("Shutting down %s controller", Name)
  79. if !cache.WaitForNamedCacheSync(Name, stopCh, c.podStoreSynced, c.jobStoreSynced) {
  80. klog.Errorf("failed to wait for %s caches to sync", Name)
  81. return
  82. }
  83. klog.Infof("Starting %s workers", Name)
  84. for i := 0; i < workers; i++ {
  85. go wait.Until(c.worker, time.Second, stopCh)
  86. }
  87. <-stopCh
  88. }
  89. // enqueueByPod enqueues the FederatedLearningJob object of the specified pod.
  90. func (c *Controller) enqueueByPod(pod *v1.Pod, immediate bool) {
  91. controllerRef := metav1.GetControllerOf(pod)
  92. if controllerRef == nil {
  93. return
  94. }
  95. if controllerRef.Kind != Kind.Kind {
  96. return
  97. }
  98. job, err := c.jobLister.FederatedLearningJobs(pod.Namespace).Get(controllerRef.Name)
  99. if err != nil {
  100. return
  101. }
  102. if job.UID != controllerRef.UID {
  103. return
  104. }
  105. c.enqueueController(job, immediate)
  106. }
  107. // When a pod is created, enqueue the controller that manages it and update it's expectations.
  108. func (c *Controller) addPod(obj interface{}) {
  109. pod := obj.(*v1.Pod)
  110. if pod.DeletionTimestamp != nil {
  111. // on a restart of the controller, it's possible a new pod shows up in a state that
  112. // is already pending deletion. Prevent the pod from being a creation observation.
  113. c.deletePod(pod)
  114. return
  115. }
  116. // backoff to queue when PodFailed
  117. immediate := pod.Status.Phase != v1.PodFailed
  118. c.enqueueByPod(pod, immediate)
  119. }
  120. // When a pod is updated, figure out what federatedlearning job manage it and wake them up.
  121. func (c *Controller) updatePod(old, cur interface{}) {
  122. curPod := cur.(*v1.Pod)
  123. oldPod := old.(*v1.Pod)
  124. // no pod update, no queue
  125. if curPod.ResourceVersion == oldPod.ResourceVersion {
  126. return
  127. }
  128. c.addPod(curPod)
  129. }
  130. // deletePod enqueues the FederatedLearningJob obj When a pod is deleted
  131. func (c *Controller) deletePod(obj interface{}) {
  132. pod, ok := obj.(*v1.Pod)
  133. // comment from https://github.com/kubernetes/kubernetes/blob/master/pkg/controller/job/job_controller.go
  134. // When a delete is dropped, the relist will notice a pod in the store not
  135. // in the list, leading to the insertion of a tombstone object which contains
  136. // the deleted key/value. Note that this value might be stale. If the pod
  137. // changed labels the new FederatedLearningJob will not be woken up till the periodic resync.
  138. if !ok {
  139. tombstone, ok := obj.(cache.DeletedFinalStateUnknown)
  140. if !ok {
  141. klog.Warningf("couldn't get object from tombstone %+v", obj)
  142. return
  143. }
  144. pod, ok = tombstone.Obj.(*v1.Pod)
  145. if !ok {
  146. klog.Warningf("tombstone contained object that is not a pod %+v", obj)
  147. return
  148. }
  149. }
  150. c.enqueueByPod(pod, true)
  151. }
  152. // obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item,
  153. // immediate tells the controller to update the status right away, and should
  154. // happen ONLY when there was a successful pod run.
  155. func (c *Controller) enqueueController(obj interface{}, immediate bool) {
  156. key, err := k8scontroller.KeyFunc(obj)
  157. if err != nil {
  158. klog.Warningf("Couldn't get key for object %+v: %v", obj, err)
  159. return
  160. }
  161. backoff := time.Duration(0)
  162. if !immediate {
  163. backoff = runtime.GetBackoff(c.queue, key)
  164. }
  165. c.queue.AddAfter(key, backoff)
  166. }
  167. // worker runs a worker thread that just dequeues items, processes them, and marks them done.
  168. // It enforces that the syncHandler is never invoked concurrently with the same key.
  169. func (c *Controller) worker() {
  170. for c.processNextWorkItem() {
  171. }
  172. }
  173. func (c *Controller) processNextWorkItem() bool {
  174. key, quit := c.queue.Get()
  175. if quit {
  176. return false
  177. }
  178. defer c.queue.Done(key)
  179. forget, err := c.sync(key.(string))
  180. if err == nil {
  181. if forget {
  182. c.queue.Forget(key)
  183. }
  184. return true
  185. }
  186. klog.Warningf("Error syncing federatedlearning job: %v", err)
  187. c.queue.AddRateLimited(key)
  188. return true
  189. }
  190. // sync will sync the FederatedLearningJob with the given key if it has had its expectations fulfilled, meaning
  191. // it did not expect to see any more of its pods created or deleted. This function is not meant to be invoked
  192. // concurrently with the same key.
  193. func (c *Controller) sync(key string) (bool, error) {
  194. startTime := time.Now()
  195. defer func() {
  196. klog.V(4).Infof("Finished syncing federatedlearning job %q (%v)", key, time.Since(startTime))
  197. }()
  198. ns, name, err := cache.SplitMetaNamespaceKey(key)
  199. if err != nil {
  200. return false, err
  201. }
  202. if len(ns) == 0 || len(name) == 0 {
  203. return false, fmt.Errorf("invalid federatedlearning job key %q: either namespace or name is missing", key)
  204. }
  205. sharedJob, err := c.jobLister.FederatedLearningJobs(ns).Get(name)
  206. if err != nil {
  207. if errors.IsNotFound(err) {
  208. klog.V(4).Infof("FLJob has been deleted: %v", key)
  209. return true, nil
  210. }
  211. return false, err
  212. }
  213. flJob := *sharedJob
  214. // set kind for flJob in case that the kind is None
  215. flJob.SetGroupVersionKind(sednav1.SchemeGroupVersion.WithKind("FederatedLearningJob"))
  216. // if flJob was finished previously, we don't want to redo the termination
  217. if IsFLJobFinished(&flJob) {
  218. return true, nil
  219. }
  220. selector, _ := runtime.GenerateSelector(&flJob)
  221. pods, err := c.podStore.Pods(flJob.Namespace).List(selector)
  222. if err != nil {
  223. return false, err
  224. }
  225. activePods := k8scontroller.FilterActivePods(pods)
  226. active := int32(len(activePods))
  227. succeeded, failed := getStatus(pods)
  228. conditions := len(flJob.Status.Conditions)
  229. // flJob first start
  230. if flJob.Status.StartTime == nil {
  231. now := metav1.Now()
  232. flJob.Status.StartTime = &now
  233. }
  234. var manageJobErr error
  235. jobFailed := false
  236. var failureReason string
  237. var failureMessage string
  238. phase := flJob.Status.Phase
  239. if failed > 0 {
  240. jobFailed = true
  241. failureReason = "workerFailed"
  242. failureMessage = "the worker of FLJob failed"
  243. }
  244. if jobFailed {
  245. flJob.Status.Conditions = append(flJob.Status.Conditions, NewFLJobCondition(sednav1.FLJobCondFailed, failureReason, failureMessage))
  246. flJob.Status.Phase = sednav1.FLJobFailed
  247. c.recorder.Event(&flJob, v1.EventTypeWarning, failureReason, failureMessage)
  248. } else {
  249. // in the First time, we create the pods
  250. if len(pods) == 0 {
  251. active, manageJobErr = c.createPod(&flJob)
  252. }
  253. complete := false
  254. if succeeded > 0 && active == 0 {
  255. complete = true
  256. }
  257. if complete {
  258. flJob.Status.Conditions = append(flJob.Status.Conditions, NewFLJobCondition(sednav1.FLJobCondComplete, "", ""))
  259. now := metav1.Now()
  260. flJob.Status.CompletionTime = &now
  261. c.recorder.Event(&flJob, v1.EventTypeNormal, "Completed", "FLJob completed")
  262. flJob.Status.Phase = sednav1.FLJobSucceeded
  263. } else {
  264. flJob.Status.Phase = sednav1.FLJobRunning
  265. }
  266. }
  267. forget := false
  268. // Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true
  269. // This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to
  270. // improve the FLJob backoff policy when parallelism > 1 and few FLJobs failed but others succeed.
  271. // In this case, we should clear the backoff delay.
  272. if flJob.Status.Succeeded < succeeded {
  273. forget = true
  274. }
  275. // no need to update the flJob if the status hasn't changed since last time
  276. if flJob.Status.Active != active || flJob.Status.Succeeded != succeeded || flJob.Status.Failed != failed || len(flJob.Status.Conditions) != conditions || flJob.Status.Phase != phase {
  277. flJob.Status.Active = active
  278. flJob.Status.Succeeded = succeeded
  279. flJob.Status.Failed = failed
  280. if jobFailed && !IsFLJobFinished(&flJob) {
  281. // returning an error will re-enqueue FLJob after the backoff period
  282. return forget, fmt.Errorf("failed pod(s) detected for flJob key %q", key)
  283. }
  284. forget = true
  285. }
  286. return forget, manageJobErr
  287. }
  288. func NewFLJobCondition(conditionType sednav1.FLJobConditionType, reason, message string) sednav1.FLJobCondition {
  289. return sednav1.FLJobCondition{
  290. Type: conditionType,
  291. Status: v1.ConditionTrue,
  292. LastProbeTime: metav1.Now(),
  293. LastHeartbeatTime: metav1.Now(),
  294. Reason: reason,
  295. Message: message,
  296. }
  297. }
  298. // getStatus returns no of succeeded and failed pods running a flJob
  299. func getStatus(pods []*v1.Pod) (succeeded, failed int32) {
  300. succeeded = int32(filterPods(pods, v1.PodSucceeded))
  301. failed = int32(filterPods(pods, v1.PodFailed))
  302. return
  303. }
  304. func (c *Controller) updateFLJobStatus(flJob *sednav1.FederatedLearningJob) error {
  305. jobClient := c.client.FederatedLearningJobs(flJob.Namespace)
  306. var err error
  307. for i := 0; i <= runtime.ResourceUpdateRetries; i = i + 1 {
  308. var newFLJob *sednav1.FederatedLearningJob
  309. newFLJob, err = jobClient.Get(context.TODO(), flJob.Name, metav1.GetOptions{})
  310. if err != nil {
  311. break
  312. }
  313. newFLJob.Status = flJob.Status
  314. if _, err = jobClient.UpdateStatus(context.TODO(), newFLJob, metav1.UpdateOptions{}); err == nil {
  315. break
  316. }
  317. }
  318. return nil
  319. }
  320. // filterPods returns pods based on their phase.
  321. func filterPods(pods []*v1.Pod, phase v1.PodPhase) int {
  322. result := 0
  323. for i := range pods {
  324. if phase == pods[i].Status.Phase {
  325. result++
  326. }
  327. }
  328. return result
  329. }
  330. func IsFLJobFinished(j *sednav1.FederatedLearningJob) bool {
  331. for _, c := range j.Status.Conditions {
  332. if (c.Type == sednav1.FLJobCondComplete || c.Type == sednav1.FLJobCondFailed) && c.Status == v1.ConditionTrue {
  333. return true
  334. }
  335. }
  336. return false
  337. }
  338. func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
  339. active = 0
  340. ctx := context.Background()
  341. modelName := job.Spec.AggregationWorker.Model.Name
  342. model, err := c.client.Models(job.Namespace).Get(ctx, modelName, metav1.GetOptions{})
  343. if err != nil {
  344. return active, fmt.Errorf("failed to get model %s: %w",
  345. modelName, err)
  346. }
  347. secretName := model.Spec.CredentialName
  348. var modelSecret *v1.Secret
  349. if secretName != "" {
  350. modelSecret, _ = c.kubeClient.CoreV1().Secrets(job.Namespace).Get(context.TODO(), secretName, metav1.GetOptions{})
  351. }
  352. participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
  353. // deliver pod for aggregation worker
  354. aggWorker := job.Spec.AggregationWorker
  355. // Configure container mounting and Env information by initial runtime.WorkerParam
  356. var aggPort int32 = 7363
  357. aggWorkerParam := new(runtime.WorkerParam)
  358. aggWorkerParam.Env = map[string]string{
  359. "NAMESPACE": job.Namespace,
  360. "WORKER_NAME": "aggworker-" + utilrand.String(5),
  361. "JOB_NAME": job.Name,
  362. "AGG_BIND_PORT": strconv.Itoa(int(aggPort)),
  363. "PARTICIPANTS_COUNT": participantsCount,
  364. }
  365. aggWorkerParam.WorkerType = FLJobStageAgg
  366. aggWorkerParam.RestartPolicy = v1.RestartPolicyOnFailure
  367. aggWorkerParam.Mounts = append(aggWorkerParam.Mounts,
  368. runtime.WorkerMount{
  369. URL: &runtime.MountURL{
  370. URL: model.Spec.URL,
  371. Secret: modelSecret,
  372. DownloadByInitializer: false,
  373. },
  374. EnvName: "MODEL_URL",
  375. },
  376. )
  377. // create aggpod based on configured parameters
  378. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, aggWorkerParam)
  379. if err != nil {
  380. return active, err
  381. }
  382. active++
  383. var appIP string
  384. var aggServicePort int32
  385. // FIXME(llhuii): only the case that Spec.NodeName specified is support,
  386. // will support Spec.NodeSelector.
  387. appIP, err = runtime.GetNodeIPByName(c.kubeClient, job.Spec.AggregationWorker.Template.Spec.NodeName)
  388. aggServicePort, err = runtime.CreateKubernetesService(c.kubeClient, job, FLJobStageAgg, aggPort, appIP)
  389. if err != nil {
  390. return active, err
  391. }
  392. // deliver pod for training worker
  393. for _, trainingWorker := range job.Spec.TrainingWorkers {
  394. // get dataseturl through parsing crd of dataset
  395. datasetName := trainingWorker.Dataset.Name
  396. dataset, err := c.client.Datasets(job.Namespace).Get(ctx, datasetName, metav1.GetOptions{})
  397. if err != nil {
  398. return active, fmt.Errorf("failed to get dataset %s: %w",
  399. datasetName, err)
  400. }
  401. secretName := dataset.Spec.CredentialName
  402. var datasetSecret *v1.Secret
  403. if secretName != "" {
  404. datasetSecret, _ = c.kubeClient.CoreV1().Secrets(job.Namespace).Get(context.TODO(), secretName, metav1.GetOptions{})
  405. }
  406. // Configure container mounting and env information
  407. workerParam := new(runtime.WorkerParam)
  408. workerParam.Mounts = append(workerParam.Mounts,
  409. runtime.WorkerMount{
  410. URL: &runtime.MountURL{
  411. URL: model.Spec.URL,
  412. Secret: modelSecret,
  413. },
  414. EnvName: "MODEL_URL",
  415. },
  416. runtime.WorkerMount{
  417. URL: &runtime.MountURL{
  418. URL: dataset.Spec.URL,
  419. Secret: datasetSecret,
  420. },
  421. EnvName: "TRAIN_DATASET_URL",
  422. },
  423. )
  424. workerParam.Env = map[string]string{
  425. "AGG_PORT": strconv.Itoa(int(aggServicePort)),
  426. "AGG_IP": appIP,
  427. "WORKER_NAME": "trainworker-" + utilrand.String(5),
  428. "JOB_NAME": job.Name,
  429. "PARTICIPANTS_COUNT": participantsCount,
  430. "NAMESPACE": job.Namespace,
  431. "MODEL_NAME": modelName,
  432. "DATASET_NAME": datasetName,
  433. "LC_SERVER": c.cfg.LC.Server,
  434. }
  435. workerParam.WorkerType = runtime.TrainPodType
  436. workerParam.HostNetwork = true
  437. workerParam.RestartPolicy = v1.RestartPolicyOnFailure
  438. // create train pod based on configured parameters
  439. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, workerParam)
  440. if err != nil {
  441. return active, err
  442. }
  443. active++
  444. }
  445. return
  446. }
  447. // New creates a new federated learning job controller that keeps the relevant pods
  448. // in sync with their corresponding FederatedLearningJob objects.
  449. func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
  450. cfg := cc.Config
  451. podInformer := cc.KubeInformerFactory.Core().V1().Pods()
  452. jobInformer := cc.SednaInformerFactory.Sedna().V1alpha1().FederatedLearningJobs()
  453. eventBroadcaster := record.NewBroadcaster()
  454. eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: cc.KubeClient.CoreV1().Events("")})
  455. fc := &Controller{
  456. kubeClient: cc.KubeClient,
  457. client: cc.SednaClient.SednaV1alpha1(),
  458. queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "flJob"),
  459. recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "flJob-controller"}),
  460. cfg: cfg,
  461. }
  462. jobInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  463. AddFunc: func(obj interface{}) {
  464. fc.enqueueController(obj, true)
  465. },
  466. UpdateFunc: func(old, cur interface{}) {
  467. fc.enqueueController(cur, true)
  468. },
  469. DeleteFunc: func(obj interface{}) {
  470. fc.enqueueController(obj, true)
  471. },
  472. })
  473. fc.jobLister = jobInformer.Lister()
  474. fc.jobStoreSynced = jobInformer.Informer().HasSynced
  475. podInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  476. AddFunc: fc.addPod,
  477. UpdateFunc: fc.updatePod,
  478. DeleteFunc: fc.deletePod,
  479. })
  480. fc.podStore = podInformer.Lister()
  481. fc.podStoreSynced = podInformer.Informer().HasSynced
  482. fc.addUpstreamHandler(cc)
  483. return fc, nil
  484. }