You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

federatedlearningjob.go 18 kB


  1. /*
  2. Copyright 2021 The KubeEdge Authors.
  3. Licensed under the Apache License, Version 2.0 (the "License");
  4. you may not use this file except in compliance with the License.
  5. You may obtain a copy of the License at
  6. http://www.apache.org/licenses/LICENSE-2.0
  7. Unless required by applicable law or agreed to in writing, software
  8. distributed under the License is distributed on an "AS IS" BASIS,
  9. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. See the License for the specific language governing permissions and
  11. limitations under the License.
  12. */
  13. package federatedlearning
  14. import (
  15. "context"
  16. "fmt"
  17. "strconv"
  18. "time"
  19. v1 "k8s.io/api/core/v1"
  20. "k8s.io/apimachinery/pkg/api/errors"
  21. metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
  22. utilrand "k8s.io/apimachinery/pkg/util/rand"
  23. utilruntime "k8s.io/apimachinery/pkg/util/runtime"
  24. "k8s.io/apimachinery/pkg/util/wait"
  25. "k8s.io/apimachinery/pkg/watch"
  26. "k8s.io/client-go/kubernetes"
  27. "k8s.io/client-go/kubernetes/scheme"
  28. v1core "k8s.io/client-go/kubernetes/typed/core/v1"
  29. corelisters "k8s.io/client-go/listers/core/v1"
  30. "k8s.io/client-go/tools/cache"
  31. "k8s.io/client-go/tools/record"
  32. "k8s.io/client-go/util/workqueue"
  33. "k8s.io/klog/v2"
  34. k8scontroller "k8s.io/kubernetes/pkg/controller"
  35. sednav1 "github.com/kubeedge/sedna/pkg/apis/sedna/v1alpha1"
  36. sednaclientset "github.com/kubeedge/sedna/pkg/client/clientset/versioned/typed/sedna/v1alpha1"
  37. sednav1listers "github.com/kubeedge/sedna/pkg/client/listers/sedna/v1alpha1"
  38. "github.com/kubeedge/sedna/pkg/globalmanager/config"
  39. "github.com/kubeedge/sedna/pkg/globalmanager/runtime"
  40. )
  41. const (
  42. // KindName is the kind name of CR this controller controls
  43. KindName = "FederatedLearningJob"
  44. // Name is this controller name
  45. Name = "FederatedLearning"
  46. )
  47. const (
  48. FLJobStageAgg = "Aggregation"
  49. FLJobStageTrain = "Training"
  50. )
  51. // Kind contains the schema.GroupVersionKind for this controller type.
  52. var Kind = sednav1.SchemeGroupVersion.WithKind(KindName)
  53. // Controller ensures that all FLJob objects have corresponding pods to
  54. // run their configured workload.
  55. type Controller struct {
  56. kubeClient kubernetes.Interface
  57. client sednaclientset.SednaV1alpha1Interface
  58. // podStoreSynced returns true if the pod store has been synced at least once.
  59. // Added as a member to the struct to allow injection for testing.
  60. podStoreSynced cache.InformerSynced
  61. // jobStoreSynced returns true if the flJob store has been synced at least once.
  62. // Added as a member to the struct to allow injection for testing.
  63. jobStoreSynced cache.InformerSynced
  64. // A store of jobs
  65. jobLister sednav1listers.FederatedLearningJobLister
  66. // A store of pods, populated by the podController
  67. podStore corelisters.PodLister
  68. // FLJobs that need to be updated
  69. queue workqueue.RateLimitingInterface
  70. recorder record.EventRecorder
  71. cfg *config.ControllerConfig
  72. sendToEdgeFunc runtime.DownstreamSendFunc
  73. }
  74. // Run starts the main goroutine responsible for watching and syncing jobs.
  75. func (c *Controller) Run(stopCh <-chan struct{}) {
  76. workers := 1
  77. defer utilruntime.HandleCrash()
  78. defer c.queue.ShutDown()
  79. klog.Infof("Starting %s controller", Name)
  80. defer klog.Infof("Shutting down %s controller", Name)
  81. if !cache.WaitForNamedCacheSync(Name, stopCh, c.podStoreSynced, c.jobStoreSynced) {
  82. klog.Errorf("failed to wait for %s caches to sync", Name)
  83. return
  84. }
  85. klog.Infof("Starting %s workers", Name)
  86. for i := 0; i < workers; i++ {
  87. go wait.Until(c.worker, time.Second, stopCh)
  88. }
  89. <-stopCh
  90. }
  91. // enqueueByPod enqueues the FederatedLearningJob object of the specified pod.
  92. func (c *Controller) enqueueByPod(pod *v1.Pod, immediate bool) {
  93. controllerRef := metav1.GetControllerOf(pod)
  94. if controllerRef == nil {
  95. return
  96. }
  97. if controllerRef.Kind != Kind.Kind {
  98. return
  99. }
  100. job, err := c.jobLister.FederatedLearningJobs(pod.Namespace).Get(controllerRef.Name)
  101. if err != nil {
  102. return
  103. }
  104. if job.UID != controllerRef.UID {
  105. return
  106. }
  107. c.enqueueController(job, immediate)
  108. }
  109. // When a pod is created, enqueue the controller that manages it and update it's expectations.
  110. func (c *Controller) addPod(obj interface{}) {
  111. pod := obj.(*v1.Pod)
  112. if pod.DeletionTimestamp != nil {
  113. // on a restart of the controller, it's possible a new pod shows up in a state that
  114. // is already pending deletion. Prevent the pod from being a creation observation.
  115. c.deletePod(pod)
  116. return
  117. }
  118. // backoff to queue when PodFailed
  119. immediate := pod.Status.Phase != v1.PodFailed
  120. c.enqueueByPod(pod, immediate)
  121. }
  122. // When a pod is updated, figure out what federatedlearning job manage it and wake them up.
  123. func (c *Controller) updatePod(old, cur interface{}) {
  124. curPod := cur.(*v1.Pod)
  125. oldPod := old.(*v1.Pod)
  126. // no pod update, no queue
  127. if curPod.ResourceVersion == oldPod.ResourceVersion {
  128. return
  129. }
  130. c.addPod(curPod)
  131. }
  132. // deletePod enqueues the FederatedLearningJob obj When a pod is deleted
  133. func (c *Controller) deletePod(obj interface{}) {
  134. pod, ok := obj.(*v1.Pod)
  135. // comment from https://github.com/kubernetes/kubernetes/blob/master/pkg/controller/job/job_controller.go
  136. // When a delete is dropped, the relist will notice a pod in the store not
  137. // in the list, leading to the insertion of a tombstone object which contains
  138. // the deleted key/value. Note that this value might be stale. If the pod
  139. // changed labels the new FederatedLearningJob will not be woken up till the periodic resync.
  140. if !ok {
  141. tombstone, ok := obj.(cache.DeletedFinalStateUnknown)
  142. if !ok {
  143. klog.Warningf("couldn't get object from tombstone %+v", obj)
  144. return
  145. }
  146. pod, ok = tombstone.Obj.(*v1.Pod)
  147. if !ok {
  148. klog.Warningf("tombstone contained object that is not a pod %+v", obj)
  149. return
  150. }
  151. }
  152. c.enqueueByPod(pod, true)
  153. }
  154. // obj could be an *sednav1.FederatedLearningJob, or a DeletionFinalStateUnknown marker item,
  155. // immediate tells the controller to update the status right away, and should
  156. // happen ONLY when there was a successful pod run.
  157. func (c *Controller) enqueueController(obj interface{}, immediate bool) {
  158. key, err := k8scontroller.KeyFunc(obj)
  159. if err != nil {
  160. klog.Warningf("Couldn't get key for object %+v: %v", obj, err)
  161. return
  162. }
  163. backoff := time.Duration(0)
  164. if !immediate {
  165. backoff = runtime.GetBackoff(c.queue, key)
  166. }
  167. c.queue.AddAfter(key, backoff)
  168. }
  169. // worker runs a worker thread that just dequeues items, processes them, and marks them done.
  170. // It enforces that the syncHandler is never invoked concurrently with the same key.
  171. func (c *Controller) worker() {
  172. for c.processNextWorkItem() {
  173. }
  174. }
  175. func (c *Controller) processNextWorkItem() bool {
  176. key, quit := c.queue.Get()
  177. if quit {
  178. return false
  179. }
  180. defer c.queue.Done(key)
  181. forget, err := c.sync(key.(string))
  182. if err == nil {
  183. if forget {
  184. c.queue.Forget(key)
  185. }
  186. return true
  187. }
  188. klog.Warningf("Error syncing federatedlearning job: %v", err)
  189. c.queue.AddRateLimited(key)
  190. return true
  191. }
  192. // sync will sync the FederatedLearningJob with the given key if it has had its expectations fulfilled, meaning
  193. // it did not expect to see any more of its pods created or deleted. This function is not meant to be invoked
  194. // concurrently with the same key.
  195. func (c *Controller) sync(key string) (bool, error) {
  196. startTime := time.Now()
  197. defer func() {
  198. klog.V(4).Infof("Finished syncing federatedlearning job %q (%v)", key, time.Since(startTime))
  199. }()
  200. ns, name, err := cache.SplitMetaNamespaceKey(key)
  201. if err != nil {
  202. return false, err
  203. }
  204. if len(ns) == 0 || len(name) == 0 {
  205. return false, fmt.Errorf("invalid federatedlearning job key %q: either namespace or name is missing", key)
  206. }
  207. sharedJob, err := c.jobLister.FederatedLearningJobs(ns).Get(name)
  208. if err != nil {
  209. if errors.IsNotFound(err) {
  210. klog.V(4).Infof("FLJob has been deleted: %v", key)
  211. return true, nil
  212. }
  213. return false, err
  214. }
  215. flJob := *sharedJob
  216. // set kind for flJob in case that the kind is None
  217. flJob.SetGroupVersionKind(sednav1.SchemeGroupVersion.WithKind("FederatedLearningJob"))
  218. // if flJob was finished previously, we don't want to redo the termination
  219. if IsFLJobFinished(&flJob) {
  220. return true, nil
  221. }
  222. selector, _ := runtime.GenerateSelector(&flJob)
  223. pods, err := c.podStore.Pods(flJob.Namespace).List(selector)
  224. if err != nil {
  225. return false, err
  226. }
  227. activePods := k8scontroller.FilterActivePods(pods)
  228. active := int32(len(activePods))
  229. succeeded, failed := getStatus(pods)
  230. conditions := len(flJob.Status.Conditions)
  231. // flJob first start
  232. if flJob.Status.StartTime == nil {
  233. now := metav1.Now()
  234. flJob.Status.StartTime = &now
  235. }
  236. var manageJobErr error
  237. jobFailed := false
  238. var failureReason string
  239. var failureMessage string
  240. phase := flJob.Status.Phase
  241. if failed > 0 {
  242. jobFailed = true
  243. failureReason = "workerFailed"
  244. failureMessage = "the worker of FLJob failed"
  245. }
  246. if jobFailed {
  247. flJob.Status.Conditions = append(flJob.Status.Conditions, NewFLJobCondition(sednav1.FLJobCondFailed, failureReason, failureMessage))
  248. flJob.Status.Phase = sednav1.FLJobFailed
  249. c.recorder.Event(&flJob, v1.EventTypeWarning, failureReason, failureMessage)
  250. } else {
  251. // in the First time, we create the pods
  252. if len(pods) == 0 {
  253. active, manageJobErr = c.createPod(&flJob)
  254. }
  255. complete := false
  256. if succeeded > 0 && active == 0 {
  257. complete = true
  258. }
  259. if complete {
  260. flJob.Status.Conditions = append(flJob.Status.Conditions, NewFLJobCondition(sednav1.FLJobCondComplete, "", ""))
  261. now := metav1.Now()
  262. flJob.Status.CompletionTime = &now
  263. c.recorder.Event(&flJob, v1.EventTypeNormal, "Completed", "FLJob completed")
  264. flJob.Status.Phase = sednav1.FLJobSucceeded
  265. } else {
  266. flJob.Status.Phase = sednav1.FLJobRunning
  267. }
  268. }
  269. forget := false
  270. // Check if the number of jobs succeeded increased since the last check. If yes "forget" should be true
  271. // This logic is linked to the issue: https://github.com/kubernetes/kubernetes/issues/56853 that aims to
  272. // improve the FLJob backoff policy when parallelism > 1 and few FLJobs failed but others succeed.
  273. // In this case, we should clear the backoff delay.
  274. if flJob.Status.Succeeded < succeeded {
  275. forget = true
  276. }
  277. // no need to update the flJob if the status hasn't changed since last time
  278. if flJob.Status.Active != active || flJob.Status.Succeeded != succeeded || flJob.Status.Failed != failed || len(flJob.Status.Conditions) != conditions || flJob.Status.Phase != phase {
  279. flJob.Status.Active = active
  280. flJob.Status.Succeeded = succeeded
  281. flJob.Status.Failed = failed
  282. if jobFailed && !IsFLJobFinished(&flJob) {
  283. // returning an error will re-enqueue FLJob after the backoff period
  284. return forget, fmt.Errorf("failed pod(s) detected for flJob key %q", key)
  285. }
  286. forget = true
  287. }
  288. return forget, manageJobErr
  289. }
  290. func NewFLJobCondition(conditionType sednav1.FLJobConditionType, reason, message string) sednav1.FLJobCondition {
  291. return sednav1.FLJobCondition{
  292. Type: conditionType,
  293. Status: v1.ConditionTrue,
  294. LastProbeTime: metav1.Now(),
  295. LastHeartbeatTime: metav1.Now(),
  296. Reason: reason,
  297. Message: message,
  298. }
  299. }
  300. // getStatus returns no of succeeded and failed pods running a flJob
  301. func getStatus(pods []*v1.Pod) (succeeded, failed int32) {
  302. succeeded = int32(filterPods(pods, v1.PodSucceeded))
  303. failed = int32(filterPods(pods, v1.PodFailed))
  304. return
  305. }
  306. func (c *Controller) updateFLJobStatus(flJob *sednav1.FederatedLearningJob) error {
  307. jobClient := c.client.FederatedLearningJobs(flJob.Namespace)
  308. var err error
  309. for i := 0; i <= runtime.ResourceUpdateRetries; i = i + 1 {
  310. var newFLJob *sednav1.FederatedLearningJob
  311. newFLJob, err = jobClient.Get(context.TODO(), flJob.Name, metav1.GetOptions{})
  312. if err != nil {
  313. break
  314. }
  315. newFLJob.Status = flJob.Status
  316. if _, err = jobClient.UpdateStatus(context.TODO(), newFLJob, metav1.UpdateOptions{}); err == nil {
  317. break
  318. }
  319. }
  320. return nil
  321. }
  322. // filterPods returns pods based on their phase.
  323. func filterPods(pods []*v1.Pod, phase v1.PodPhase) int {
  324. result := 0
  325. for i := range pods {
  326. if phase == pods[i].Status.Phase {
  327. result++
  328. }
  329. }
  330. return result
  331. }
  332. func IsFLJobFinished(j *sednav1.FederatedLearningJob) bool {
  333. for _, c := range j.Status.Conditions {
  334. if (c.Type == sednav1.FLJobCondComplete || c.Type == sednav1.FLJobCondFailed) && c.Status == v1.ConditionTrue {
  335. return true
  336. }
  337. }
  338. return false
  339. }
  340. func (c *Controller) createPod(job *sednav1.FederatedLearningJob) (active int32, err error) {
  341. active = 0
  342. ctx := context.Background()
  343. modelName := job.Spec.AggregationWorker.Model.Name
  344. model, err := c.client.Models(job.Namespace).Get(ctx, modelName, metav1.GetOptions{})
  345. if err != nil {
  346. return active, fmt.Errorf("failed to get model %s: %w",
  347. modelName, err)
  348. }
  349. secretName := model.Spec.CredentialName
  350. var modelSecret *v1.Secret
  351. if secretName != "" {
  352. modelSecret, _ = c.kubeClient.CoreV1().Secrets(job.Namespace).Get(context.TODO(), secretName, metav1.GetOptions{})
  353. }
  354. participantsCount := strconv.Itoa(len(job.Spec.TrainingWorkers))
  355. // deliver pod for aggregation worker
  356. aggWorker := job.Spec.AggregationWorker
  357. // Configure container mounting and Env information by initial runtime.WorkerParam
  358. var aggPort int32 = 7363
  359. aggWorkerParam := new(runtime.WorkerParam)
  360. aggWorkerParam.Env = map[string]string{
  361. "NAMESPACE": job.Namespace,
  362. "WORKER_NAME": "aggworker-" + utilrand.String(5),
  363. "JOB_NAME": job.Name,
  364. "AGG_BIND_PORT": strconv.Itoa(int(aggPort)),
  365. "PARTICIPANTS_COUNT": participantsCount,
  366. }
  367. aggWorkerParam.WorkerType = FLJobStageAgg
  368. aggWorkerParam.RestartPolicy = v1.RestartPolicyOnFailure
  369. aggWorkerParam.Mounts = append(aggWorkerParam.Mounts,
  370. runtime.WorkerMount{
  371. URL: &runtime.MountURL{
  372. URL: model.Spec.URL,
  373. Secret: modelSecret,
  374. DownloadByInitializer: false,
  375. },
  376. EnvName: "MODEL_URL",
  377. },
  378. )
  379. // create aggpod based on configured parameters
  380. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &aggWorker.Template, aggWorkerParam)
  381. if err != nil {
  382. return active, err
  383. }
  384. active++
  385. var appIP string
  386. var aggServicePort int32
  387. // FIXME(llhuii): only the case that Spec.NodeName specified is support,
  388. // will support Spec.NodeSelector.
  389. appIP, err = runtime.GetNodeIPByName(c.kubeClient, job.Spec.AggregationWorker.Template.Spec.NodeName)
  390. aggServicePort, err = runtime.CreateKubernetesService(c.kubeClient, job, FLJobStageAgg, aggPort, appIP)
  391. if err != nil {
  392. return active, err
  393. }
  394. // deliver pod for training worker
  395. for _, trainingWorker := range job.Spec.TrainingWorkers {
  396. // get dataseturl through parsing crd of dataset
  397. datasetName := trainingWorker.Dataset.Name
  398. dataset, err := c.client.Datasets(job.Namespace).Get(ctx, datasetName, metav1.GetOptions{})
  399. if err != nil {
  400. return active, fmt.Errorf("failed to get dataset %s: %w",
  401. datasetName, err)
  402. }
  403. secretName := dataset.Spec.CredentialName
  404. var datasetSecret *v1.Secret
  405. if secretName != "" {
  406. datasetSecret, _ = c.kubeClient.CoreV1().Secrets(job.Namespace).Get(context.TODO(), secretName, metav1.GetOptions{})
  407. }
  408. // Configure container mounting and env information
  409. workerParam := new(runtime.WorkerParam)
  410. workerParam.Mounts = append(workerParam.Mounts,
  411. runtime.WorkerMount{
  412. URL: &runtime.MountURL{
  413. URL: model.Spec.URL,
  414. Secret: modelSecret,
  415. },
  416. EnvName: "MODEL_URL",
  417. },
  418. runtime.WorkerMount{
  419. URL: &runtime.MountURL{
  420. URL: dataset.Spec.URL,
  421. Secret: datasetSecret,
  422. },
  423. EnvName: "TRAIN_DATASET_URL",
  424. },
  425. )
  426. workerParam.Env = map[string]string{
  427. "AGG_PORT": strconv.Itoa(int(aggServicePort)),
  428. "AGG_IP": appIP,
  429. "WORKER_NAME": "trainworker-" + utilrand.String(5),
  430. "JOB_NAME": job.Name,
  431. "PARTICIPANTS_COUNT": participantsCount,
  432. "NAMESPACE": job.Namespace,
  433. "MODEL_NAME": modelName,
  434. "DATASET_NAME": datasetName,
  435. "LC_SERVER": c.cfg.LC.Server,
  436. }
  437. workerParam.WorkerType = runtime.TrainPodType
  438. workerParam.HostNetwork = true
  439. workerParam.RestartPolicy = v1.RestartPolicyOnFailure
  440. // create train pod based on configured parameters
  441. _, err = runtime.CreatePodWithTemplate(c.kubeClient, job, &trainingWorker.Template, workerParam)
  442. if err != nil {
  443. return active, err
  444. }
  445. active++
  446. }
  447. return
  448. }
  449. // New creates a new federated learning job controller that keeps the relevant pods
  450. // in sync with their corresponding FederatedLearningJob objects.
  451. func New(cc *runtime.ControllerContext) (runtime.FeatureControllerI, error) {
  452. cfg := cc.Config
  453. podInformer := cc.KubeInformerFactory.Core().V1().Pods()
  454. jobInformer := cc.SednaInformerFactory.Sedna().V1alpha1().FederatedLearningJobs()
  455. eventBroadcaster := record.NewBroadcaster()
  456. eventBroadcaster.StartRecordingToSink(&v1core.EventSinkImpl{Interface: cc.KubeClient.CoreV1().Events("")})
  457. fc := &Controller{
  458. kubeClient: cc.KubeClient,
  459. client: cc.SednaClient.SednaV1alpha1(),
  460. queue: workqueue.NewNamedRateLimitingQueue(workqueue.NewItemExponentialFailureRateLimiter(runtime.DefaultBackOff, runtime.MaxBackOff), "flJob"),
  461. recorder: eventBroadcaster.NewRecorder(scheme.Scheme, v1.EventSource{Component: "flJob-controller"}),
  462. cfg: cfg,
  463. }
  464. jobInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  465. AddFunc: func(obj interface{}) {
  466. fc.enqueueController(obj, true)
  467. fc.syncToEdge(watch.Added, obj)
  468. },
  469. UpdateFunc: func(old, cur interface{}) {
  470. fc.enqueueController(cur, true)
  471. fc.syncToEdge(watch.Added, cur)
  472. },
  473. DeleteFunc: func(obj interface{}) {
  474. fc.enqueueController(obj, true)
  475. fc.syncToEdge(watch.Deleted, obj)
  476. },
  477. })
  478. fc.jobLister = jobInformer.Lister()
  479. fc.jobStoreSynced = jobInformer.Informer().HasSynced
  480. podInformer.Informer().AddEventHandler(cache.ResourceEventHandlerFuncs{
  481. AddFunc: fc.addPod,
  482. UpdateFunc: fc.updatePod,
  483. DeleteFunc: fc.deletePod,
  484. })
  485. fc.podStore = podInformer.Lister()
  486. fc.podStoreSynced = podInformer.Informer().HasSynced
  487. fc.addUpstreamHandler(cc)
  488. return fc, nil
  489. }