diff --git a/pkg/globalmanager/common.go b/pkg/globalmanager/common.go index 157574b0..52565304 100644 --- a/pkg/globalmanager/common.go +++ b/pkg/globalmanager/common.go @@ -245,6 +245,10 @@ func injectWorkerPara(pod *v1.Pod, workerPara *WorkerPara, object CommonInterfac // force to set hostnetwork pod.Spec.HostNetwork = true } + + if pod.Spec.RestartPolicy == "" { + pod.Spec.RestartPolicy = workerPara.restartPolicy + } } // createPodWithTemplate creates and returns a pod object given a crd object, pod template, and workerPara diff --git a/pkg/globalmanager/federatedlearningjob.go b/pkg/globalmanager/federatedlearningjob.go index 476009b5..1a1a06a3 100644 --- a/pkg/globalmanager/federatedlearningjob.go +++ b/pkg/globalmanager/federatedlearningjob.go @@ -429,11 +429,11 @@ func (fc *FederatedController) createPod(job *sednav1.FederatedLearningJob) (act // Configure container mounting and Env information by initial WorkerPara var aggPort int32 = 7363 - var aggContainer *WorkerPara = new(WorkerPara) - aggContainer.volumeMountList = []string{aggModelConPath} - aggContainer.volumeList = []string{modelPath} - aggContainer.volumeMapName = []string{"model"} - aggContainer.env = map[string]string{ + var aggWorkerPara *WorkerPara = new(WorkerPara) + aggWorkerPara.volumeMountList = []string{aggModelConPath} + aggWorkerPara.volumeList = []string{modelPath} + aggWorkerPara.volumeMapName = []string{"model"} + aggWorkerPara.env = map[string]string{ "MODEL": modelstring, "WORKER_NAME": "aggworker-" + utilrand.String(5), "JOB_NAME": job.Name, @@ -443,10 +443,11 @@ func (fc *FederatedController) createPod(job *sednav1.FederatedLearningJob) (act "AGG_BIND_PORT": strconv.Itoa(int(aggPort)), } - aggContainer.workerType = FLJobStageAgg + aggWorkerPara.workerType = FLJobStageAgg + aggWorkerPara.restartPolicy = v1.RestartPolicyOnFailure // create aggpod based on configured parameters - _, err = createPodWithTemplate(fc.kubeClient, job, &aggWorker.Template, aggContainer) + _, err = createPodWithTemplate(fc.kubeClient, job, &aggWorker.Template, aggWorkerPara) if err != nil { return active, err } @@ -485,11 +486,11 @@ func (fc *FederatedController) createPod(job *sednav1.FederatedLearningJob) (act trainModelURL := trainModelConPath // Configure container mounting and Env information by initial WorkerPara - var trainContainer *WorkerPara = new(WorkerPara) - trainContainer.volumeMountList = []string{trainDataConPath, trainModelConPath} - trainContainer.volumeList = []string{datasetParent, modelPath} - trainContainer.volumeMapName = []string{"data", "model"} - trainContainer.env = map[string]string{ + var workerPara *WorkerPara = new(WorkerPara) + workerPara.volumeMountList = []string{trainDataConPath, trainModelConPath} + workerPara.volumeList = []string{datasetParent, modelPath} + workerPara.volumeMapName = []string{"data", "model"} + workerPara.env = map[string]string{ "DATASET": datasetstring, "AGG_PORT": strconv.Itoa(int(aggServicePort)), "AGG_IP": appIP, @@ -503,9 +504,11 @@ func (fc *FederatedController) createPod(job *sednav1.FederatedLearningJob) (act "DATASET_NAME": datasetName, "LC_SERVER": fc.cfg.LC.Server, } - // create trainpod based on configured parameters - trainContainer.workerType = "train" - _, err = createPodWithTemplate(fc.kubeClient, job, &trainingWorker.Template, trainContainer) + workerPara.workerType = "train" + workerPara.hostNetwork = true + workerPara.restartPolicy = v1.RestartPolicyOnFailure + // create train pod based on configured parameters + _, err = createPodWithTemplate(fc.kubeClient, job, &trainingWorker.Template, workerPara) if err != nil { return active, err } diff --git a/pkg/globalmanager/incrementallearningjob.go b/pkg/globalmanager/incrementallearningjob.go index c3c6377c..c2491875 100644 --- a/pkg/globalmanager/incrementallearningjob.go +++ b/pkg/globalmanager/incrementallearningjob.go @@ -612,6 +612,11 @@ func (jc *IncrementalJobController) createPod(job *sednav1.IncrementalLearningJo "LC_SERVER": jc.cfg.LC.Server, } } + + // set the default policy instead of Always policy + workerPara.restartPolicy = v1.RestartPolicyOnFailure + workerPara.hostNetwork = true + // create pod based on podtype _, err = createPodWithTemplate(jc.kubeClient, job, podTemplate, workerPara) if err != nil { @@ -639,11 +644,11 @@ func (jc *IncrementalJobController) createInferPod(job *sednav1.IncrementalLearn inferModelURL := dataPrefix + inferModelPath // Configure container mounting and Env information by initial WorkerPara - var inferContainer *WorkerPara = new(WorkerPara) - inferContainer.volumeMountList = []string{inferModelConPath} - inferContainer.volumeList = []string{inferModelParent} - inferContainer.volumeMapName = []string{"model"} - inferContainer.env = map[string]string{ + var workerParam *WorkerPara = new(WorkerPara) + workerParam.volumeMountList = []string{inferModelConPath} + workerParam.volumeList = []string{inferModelParent} + workerParam.volumeMapName = []string{"model"} + workerParam.env = map[string]string{ "WORKER_NAME": "inferworker-" + utilrand.String(5), "MODEL_URL": inferModelURL, "NAMESPACE": job.Namespace, @@ -651,11 +656,11 @@ func (jc *IncrementalJobController) createInferPod(job *sednav1.IncrementalLearn "LC_SERVER": jc.cfg.LC.Server, } - inferContainer.workerType = "inference" - inferContainer.hostNetwork = true + workerParam.workerType = "inference" + workerParam.hostNetwork = true // create edge pod - _, err = createPodWithTemplate(jc.kubeClient, job, &job.Spec.DeploySpec.Template, inferContainer) + _, err = createPodWithTemplate(jc.kubeClient, job, &job.Spec.DeploySpec.Template, workerParam) return err } diff --git a/pkg/globalmanager/types.go b/pkg/globalmanager/types.go index acffd39a..dc0e428c 100644 --- a/pkg/globalmanager/types.go +++ b/pkg/globalmanager/types.go @@ -19,6 +19,7 @@ package globalmanager import ( "encoding/json" + v1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/runtime/schema" @@ -33,6 +34,8 @@ type WorkerPara struct { workerType string // if true, force to use hostNetwork hostNetwork bool + + restartPolicy v1.RestartPolicy } // CommonInterface describes the commom interface of CRs