From 2e80c376f6a2797c8ec5e66f96b1914604e182d7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=8B=E6=B1=9F=E5=AE=9E=E9=AA=8C=E5=AE=A4?= Date: Mon, 26 Oct 2020 16:13:21 +0800 Subject: [PATCH] update distribute-train-operator --- distribute-train-operator/README.md | 26 + .../docs/crds/distribute-train-cr.yaml | 65 ++ .../docs/crds/distribute-train-crd.yaml | 61 ++ .../distribute-train-operator_deploy.yaml | 47 ++ distribute-train-operator/pom.xml | 150 +++++ .../DistributeTrainOperatorApplication.java | 35 + .../DistributeTrainOperatorManager.java | 199 ++++++ .../operator/action/OperatorRunner.java | 58 ++ .../org/onebrain/operator/action/PodInfo.java | 44 ++ .../deployer/AbstractResourceCreateInfo.java | 41 ++ .../deployer/ChildResourceCreateInfo.java | 227 +++++++ .../operator/action/deployer/JobDeployer.java | 35 + .../action/deployer/ServiceDeployer.java | 33 + .../action/deployer/StatefulSetDeployer.java | 33 + .../action/deployer/impl/BaseJobDeployer.java | 246 +++++++ .../deployer/impl/BaseServiceDeployer.java | 73 +++ .../impl/BaseStatefulSetDeployer.java | 246 +++++++ .../action/handler/AddActionHandler.java | 614 ++++++++++++++++++ .../action/handler/DeleteActionHandler.java | 88 +++ .../handler/DistributeTrainActionHandler.java | 33 + .../api/pod/DefaultPodExecListener.java | 85 +++ .../org/onebrain/operator/api/pod/PodApi.java | 177 +++++ .../operator/api/pod/StdPodExecListener.java | 83 +++ .../onebrain/operator/config/KubeConfig.java | 66 ++ .../operator/constants/CrdConstants.java | 34 + .../operator/constants/KubeConstants.java | 40 ++ .../operator/constants/NumberConstant.java | 43 ++ .../operator/context/KubeContext.java | 117 ++++ .../controller/DistributeTrainController.java | 131 ++++ .../operator/crd/DistributeTrain.java | 47 ++ .../operator/crd/DistributeTrainList.java | 27 + .../operator/crd/DistributeTrainSpec.java | 108 +++ .../operator/crd/DistributeTrainStatus.java | 55 ++ .../operator/crd/DoneableDistributeTrain.java | 31 + .../operator/enums/AccessModeEnum.java | 56 ++ .../operator/exception/OperatorException.java | 49 ++ .../operator/properties/KubeProperties.java | 34 + .../operator/redis/AbstractKeyPrefix.java | 65 ++ .../onebrain/operator/redis/RedisService.java | 290 +++++++++ .../operator/redis/key/OperatorKey.java | 45 ++ .../utils/DistributeTrainClientHolder.java | 41 ++ .../operator/utils/FastjsonUtils.java | 188 ++++++ .../org/onebrain/operator/utils/IOUtils.java | 56 ++ .../onebrain/operator/utils/RedisUtils.java | 289 +++++++++ .../operator/utils/SpringContextHolder.java | 99 +++ .../onebrain/operator/watcher/JobHandler.java | 111 ++++ .../onebrain/operator/watcher/JobWatcher.java | 71 ++ .../operator/watcher/KubeWatcherManager.java | 120 ++++ .../src/main/resources/key/id_rsa | 27 + .../src/main/resources/key/id_rsa.pub | 1 + .../src/main/resources/kubeconfig | 19 + .../src/main/resources/shell/pretreatment | 46 ++ ...stributeTrainOperatorApplicationTests.java | 43 ++ 53 files changed, 5048 insertions(+) create mode 100644 distribute-train-operator/README.md create mode 100644 distribute-train-operator/docs/crds/distribute-train-cr.yaml create mode 100644 distribute-train-operator/docs/crds/distribute-train-crd.yaml create mode 100644 distribute-train-operator/docs/deploy/distribute-train-operator_deploy.yaml create mode 100644 distribute-train-operator/pom.xml create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/DistributeTrainOperatorApplication.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/DistributeTrainOperatorManager.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/OperatorRunner.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/PodInfo.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/AbstractResourceCreateInfo.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ChildResourceCreateInfo.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/JobDeployer.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ServiceDeployer.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/StatefulSetDeployer.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseJobDeployer.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseServiceDeployer.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseStatefulSetDeployer.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/AddActionHandler.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DeleteActionHandler.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DistributeTrainActionHandler.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/DefaultPodExecListener.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/PodApi.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/StdPodExecListener.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/config/KubeConfig.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/constants/CrdConstants.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/constants/KubeConstants.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/constants/NumberConstant.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/context/KubeContext.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/controller/DistributeTrainController.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrain.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainList.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainSpec.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainStatus.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/crd/DoneableDistributeTrain.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/enums/AccessModeEnum.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/exception/OperatorException.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/properties/KubeProperties.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/redis/AbstractKeyPrefix.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/redis/RedisService.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/redis/key/OperatorKey.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/utils/DistributeTrainClientHolder.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/utils/FastjsonUtils.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/utils/IOUtils.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/utils/RedisUtils.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/utils/SpringContextHolder.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobHandler.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobWatcher.java create mode 100644 distribute-train-operator/src/main/java/org/onebrain/operator/watcher/KubeWatcherManager.java create mode 100644 distribute-train-operator/src/main/resources/key/id_rsa create mode 100644 distribute-train-operator/src/main/resources/key/id_rsa.pub create mode 100644 distribute-train-operator/src/main/resources/kubeconfig create mode 100644 distribute-train-operator/src/main/resources/shell/pretreatment create mode 100644 distribute-train-operator/src/test/java/org/onebrain/operator/DistributeTrainOperatorApplicationTests.java diff --git a/distribute-train-operator/README.md b/distribute-train-operator/README.md new file mode 100644 index 0000000..926fb8b --- /dev/null +++ b/distribute-train-operator/README.md @@ -0,0 +1,26 @@ +# 之江天枢-分布式训练 operator +该模块是分布式训练CRD的控制器,管理分布式训练容器生命周期,为分布式训练容器注入其他容器ip。 + +## 源码部署 + +### 准备环境 +安装如下软件环境。 +- OpenJDK:1.8+ +- Redis: 3.0+ +- Maven: 3.0+ + +### 下载源码 +``` bash +git clone https://codeup.teambition.com/zhejianglab/distribute-train-operator.git +# 进入项目根目录 +cd distribute-train-operator +``` + +### 构建 +``` bash +# 构建,生成的 jar 包位于 ./target/distribute-train-operator-1.0.jar +mvn clean compile package +``` + +### 部署 +部署过程参看文档:[部署 分布式训练operator](http://tianshu.org.cn/?/course/1.html) diff --git a/distribute-train-operator/docs/crds/distribute-train-cr.yaml b/distribute-train-operator/docs/crds/distribute-train-cr.yaml new file mode 100644 index 0000000..40b2f72 --- /dev/null +++ b/distribute-train-operator/docs/crds/distribute-train-cr.yaml @@ -0,0 +1,65 @@ +apiVersion: onebrain.oneflow.org/v1alpha1 +kind: DistributeTrain +metadata: + name: dt-resnet50 + namespace: resnet50 + labels: + key: value +spec: + size: 3 + image: {{IMAGE}} + imagePullPolicy: IfNotPresent + masterCmd: export NODE_IPS=`cat /home/hostfile.json |jq -r '.[]|.ip'|paste -d "," -s` && cd /workspace/Classification/cnns && rm -rf core.* && rm -rf ./output/snapshots/* && python3 of_cnn_train_val.py --train_data_dir=$DATA_ROOT/train --train_data_part_num=$TRAIN_DATA_PART_NUM --val_data_dir=$DATA_ROOT/validation --val_data_part_num=$VAL_DATA_PART_NUM --num_nodes=$NODE_NUM --node_ips="$NODE_IPS" --gpu_num_per_node=$GPU_NUM_PER_NODE --model_update="momentum" --learning_rate=0.256 --loss_print_every_n_iter=1 --batch_size_per_device=64 --val_batch_size_per_device=64 --num_epoch=1 --model="resnet50" --model_save_dir=/model + masterResources: + requests: + nvidia.com/gpu: 2 + memory: "16Gi" + cpu: "2" + limits: + nvidia.com/gpu: 2 + memory: "16Gi" + cpu: "2" + slaveCmd: export NODE_IPS=`cat /home/hostfile.json |jq -r '.[]|.ip'|paste -d "," -s` && cd /workspace/Classification/cnns && rm -rf core.* && rm -rf ./output/snapshots/* && python3 of_cnn_train_val.py --train_data_dir=$DATA_ROOT/train --train_data_part_num=$TRAIN_DATA_PART_NUM --val_data_dir=$DATA_ROOT/validation --val_data_part_num=$VAL_DATA_PART_NUM --num_nodes=$NODE_NUM --node_ips="$NODE_IPS" --gpu_num_per_node=$GPU_NUM_PER_NODE --model_update="momentum" --learning_rate=0.256 --loss_print_every_n_iter=1 --batch_size_per_device=64 --val_batch_size_per_device=64 --num_epoch=1 --model="resnet50" --model_save_dir=/model + slaveResources: + requests: + nvidia.com/gpu: 2 + memory: "16Gi" + cpu: "2" + limits: + nvidia.com/gpu: 2 + memory: "16Gi" + cpu: "2" + nodeSelector: + kubernetes.io/hostname: node02 + env: + - name: ENABLE_USER_OP + value: 'True' + - name: DATA_ROOT + value: '/dataset' + - name: NODE_NUM + value: 3 + - name: GPU_NUM_PER_NODE + value: 2 + - name: ONEFLOW_DEBUG_MODE + value: "" + - name: TRAIN_DATA_PART_NUM + value: 6 + - name: VAL_DATA_PART_NUM + value: 6 + - name: NCCL_DEBUG + value: INFO + datasetStorage: + name: pvc-dataset + nfs: + path: {{DATASET}} + server: {{NFS}} + workspaceStorage: + name: pvc-workspace + nfs: + path: /nfs/resnet50/workspace + server: {{WORKSPACE}} + modelStorage: + name: pvc-model + nfs: + path: /nfs/resnet50/model + server: {{MODEL}} \ No newline at end of file diff --git a/distribute-train-operator/docs/crds/distribute-train-crd.yaml b/distribute-train-operator/docs/crds/distribute-train-crd.yaml new file mode 100644 index 0000000..07bea5a --- /dev/null +++ b/distribute-train-operator/docs/crds/distribute-train-crd.yaml @@ -0,0 +1,61 @@ +--- +apiVersion: apiextensions.k8s.io/v1beta1 +kind: CustomResourceDefinition +metadata: + name: distributetrains.onebrain.oneflow.org +spec: + group: onebrain.oneflow.org + names: + kind: DistributeTrain + singular: distributetrain + plural: distributetrains + shortNames: + - dt + scope: Namespaced + subresources: + status: {} + version: v1alpha1 + validation: + openAPIV3Schema: + properties: + apiVersion: + type: string + kind: + type: string + metadata: + type: object + spec: + properties: + image: + type: string + imagePullPolicy: + type: string + size: + format: int32 + type: integer + masterCmd: + type: string + slaveCmd: + type: string + masterResources: + type: object + slaveResources: + type: object + nodeSelector: + type: object + initContainer: + type: object + datasetStorage: + type: object + workspaceStorage: + type: object + modelStorage: + type: object + required: + - image + - imagePullPolicy + - size + - masterCmd + - slaveCmd + - workspaceStorage + type: object \ No newline at end of file diff --git a/distribute-train-operator/docs/deploy/distribute-train-operator_deploy.yaml b/distribute-train-operator/docs/deploy/distribute-train-operator_deploy.yaml new file mode 100644 index 0000000..fe87131 --- /dev/null +++ b/distribute-train-operator/docs/deploy/distribute-train-operator_deploy.yaml @@ -0,0 +1,47 @@ +kind: Deployment +apiVersion: apps/v1 +metadata: + name: distribute-train-operator + namespace: test-ns + labels: + name: distribute-train-operator +spec: + replicas: 1 + selector: + matchLabels: + name: distribute-train-operator + template: + metadata: + labels: + name: distribute-train-operator + spec: + containers: + - name: distribute-train-operator + image: {{IMAGE}} + ports: + - containerPort: 8080 + protocol: TCP + volumeMounts:d + - mountPath: /root/config + name: config-volume + env: + - name: JAR_BALL + value: "distribute-train-operator-1.0.jar --k8s.kubeconfig=/root/config --spring.redis.host=192.168.1.104" + imagePullPolicy: IfNotPresent + volumes: + - name: config-volume + hostPath: + path: /root/.kube/config + restartPolicy: Always + terminationGracePeriodSeconds: 30 + securityContext: + runAsUser: 0 + schedulerName: default-scheduler + strategy: + type: RollingUpdate + rollingUpdate: + maxUnavailable: 1 + maxSurge: 1 + revisionHistoryLimit: 7 + progressDeadlineSeconds: 600 + diff --git a/distribute-train-operator/pom.xml b/distribute-train-operator/pom.xml new file mode 100644 index 0000000..d0ac31b --- /dev/null +++ b/distribute-train-operator/pom.xml @@ -0,0 +1,150 @@ + + + 4.0.0 + + org.springframework.boot + spring-boot-starter-parent + 2.2.5.RELEASE + + + org.onebrain + distribute-train-operator + 1.0 + distribute-train-operator + distribute-train operatior + + + UTF-8 + UTF-8 + 1.8 + 4.9.0 + + + + + + org.springframework.boot + spring-boot-starter-web + + + + + io.fabric8 + kubernetes-client + ${fabric.io.version} + + + io.fabric8 + kubernetes-assertions + 4.0.0 + test + + + + + org.springframework.boot + spring-boot-configuration-processor + + + + + org.springframework.boot + spring-boot-starter-data-redis + + + redis.clients + jedis + + + + + commons-io + commons-io + 2.6 + + + org.apache.commons + commons-compress + 1.19 + + + commons-codec + commons-codec + + + + + cn.hutool + hutool-all + 5.1.1 + + + + com.google.guava + guava + 27.0.1-jre + + + + com.alibaba + fastjson + 1.2.54 + + + + org.projectlombok + lombok + true + + + + org.springframework.boot + spring-boot-starter-test + test + + + + + + + org.springframework.boot + spring-boot-maven-plugin + + + + org.apache.maven.plugins + maven-surefire-plugin + + true + + + + + + + + public + aliyun nexus + http://maven.aliyun.com/nexus/content/groups/public/ + + true + + + + + + + public + aliyun nexus + http://maven.aliyun.com/nexus/content/groups/public/ + + true + + + false + + + + + diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/DistributeTrainOperatorApplication.java b/distribute-train-operator/src/main/java/org/onebrain/operator/DistributeTrainOperatorApplication.java new file mode 100644 index 0000000..1479e00 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/DistributeTrainOperatorApplication.java @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.scheduling.annotation.EnableAsync; +/** + * @description Operator启动类 + * @date 2020-09-03 + */ +@SpringBootApplication +@EnableAsync +public class DistributeTrainOperatorApplication { + + public static void main(String[] args) { + SpringApplication.run(DistributeTrainOperatorApplication.class, args); + } + +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/DistributeTrainOperatorManager.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/DistributeTrainOperatorManager.java new file mode 100644 index 0000000..7756545 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/DistributeTrainOperatorManager.java @@ -0,0 +1,199 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action; + +import cn.hutool.core.util.StrUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.google.common.collect.Maps; +import io.fabric8.kubernetes.api.model.apiextensions.*; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.dsl.MixedOperation; +import io.fabric8.kubernetes.client.dsl.Resource; +import io.fabric8.kubernetes.client.dsl.base.CustomResourceDefinitionContext; +import io.fabric8.kubernetes.client.informers.SharedIndexInformer; +import io.fabric8.kubernetes.client.informers.SharedInformerFactory; +import io.fabric8.kubernetes.client.internal.SerializationUtils; +import io.fabric8.kubernetes.internal.KubernetesDeserializer; +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.controller.DistributeTrainController; +import org.onebrain.operator.crd.DistributeTrain; +import org.onebrain.operator.crd.DistributeTrainList; +import org.onebrain.operator.crd.DoneableDistributeTrain; +import org.onebrain.operator.utils.DistributeTrainClientHolder; +import org.onebrain.operator.utils.SpringContextHolder; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.stereotype.Component; + +import java.util.Map; + +import static org.onebrain.operator.constants.CrdConstants.*; + +/** + * @description operator 主控制器 + * @date 2020-09-23 + */ +@Component +@Slf4j +public class DistributeTrainOperatorManager { + + public static final String NAMESPACE_DEFAULT = "default"; + public static final String TYPE_STRING = "string"; + public static final String TYPE_INTEGER = "integer"; + public static final String TYPE_OBJECT = "object"; + public static final String TYPE_ARRAY = "array"; + public static final String FORMAT_INT_32 = "int32"; + @Autowired + private KubernetesClient client; + + private CustomResourceDefinition crd; + + private String namespace; + + /** + * 检查crd是否存在,若不存在则创建 + * @throws JsonProcessingException + */ + public void createCrdIfNotExists() throws JsonProcessingException { + String namespace = client.getNamespace(); + if (namespace == null) { + log.info("No namespace found via config, assuming default."); + namespace = NAMESPACE_DEFAULT; + } + this.namespace = namespace; + log.info("Using namespace : {}", namespace); + //检查crd是否已存在 + CustomResourceDefinition crd = client.customResourceDefinitions().withName(CRD_NAME).get(); + if(crd == null){ + Map crdPropsMap = buildCrdProperties(); + log.info("crd props map is : 【{}】",crdPropsMap); + //如不存在,则创建 + CustomResourceDefinition distributeTrainCustomResourceDefinition = new CustomResourceDefinitionBuilder() + .withApiVersion(CRD_API_VERSION) + .withNewMetadata() + .withName(CRD_NAME) + .endMetadata() + .withNewSpec() + .withGroup(CRD_GROUP) + .withVersion(CRD_VERSION) + .withScope(CRD_SCOPE) + .withNewNames() + .withKind(CRD_KIND) + .withSingular(CRD_SINGULAR_NAME) + .withPlural(CRD_PLURAL_NAME) + .withShortNames(CRD_SHORT_NAME) + .endNames() + .withNewValidation() + .withNewOpenAPIV3Schema() + .addToProperties(crdPropsMap) + .endOpenAPIV3Schema() + .endValidation() + .endSpec() + .build(); + distributeTrainCustomResourceDefinition = client.customResourceDefinitions().create(distributeTrainCustomResourceDefinition); + log.info("create crd successfully : \n{}", SerializationUtils.dumpAsYaml(distributeTrainCustomResourceDefinition)); + crd = distributeTrainCustomResourceDefinition; + } + //注册到k8s反序列化解析器 + KubernetesDeserializer.registerCustomKind(CRD_GROUP + StrUtil.SLASH + CRD_VERSION, CRD_KIND, DistributeTrain.class); + this.crd = crd; + } + + /** + * 初始化informer + */ + public void initInformer(){ + CustomResourceDefinitionContext distributeTrainCustomResourceDefinitionContext = new CustomResourceDefinitionContext.Builder() + .withVersion(CRD_VERSION) + .withScope(CRD_SCOPE) + .withGroup(CRD_GROUP) + .withPlural(CRD_PLURAL_NAME) + .build(); + + SharedInformerFactory informerFactory = client.informers(); + + MixedOperation> distributeTrainClient = client.customResources(this.crd, DistributeTrain.class, DistributeTrainList.class, DoneableDistributeTrain.class); + SharedIndexInformer distributeTrainSharedIndexInformer = informerFactory.sharedIndexInformerForCustomResource(distributeTrainCustomResourceDefinitionContext, DistributeTrain.class, DistributeTrainList.class, 10 * 60 * 1000); + //使用静态变量维持 + DistributeTrainClientHolder.setDistributeTrainClient(distributeTrainClient); + //手动注册controller到ioc容器 + BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(DistributeTrainController.class); + DefaultListableBeanFactory beanFactory = (DefaultListableBeanFactory)((ConfigurableApplicationContext) SpringContextHolder.applicationContext).getBeanFactory(); + beanDefinitionBuilder.addConstructorArgValue(distributeTrainClient); + beanDefinitionBuilder.addConstructorArgValue(distributeTrainSharedIndexInformer); + beanDefinitionBuilder.addConstructorArgValue(namespace); + beanFactory.registerBeanDefinition("org.onebrain.operator.controller.DistributeTrainController", beanDefinitionBuilder.getRawBeanDefinition()); + + //取得托管的controller + DistributeTrainController controller = SpringContextHolder.getBean(DistributeTrainController.class); + //注册informer监听 + controller.create(); + informerFactory.startAllRegisteredInformers(); + //等待就绪 + controller.run(); + } + + /** + * 生成crd属性 + * @return crd属性集合 + */ + private Map buildCrdProperties(){ + Map properties = Maps.newHashMap(); + JSONSchemaProps stringType = new JSONSchemaPropsBuilder() + .withType(TYPE_STRING) + .build(); + JSONSchemaProps intType = new JSONSchemaPropsBuilder() + .withType(TYPE_INTEGER) + .withFormat(FORMAT_INT_32) + .build(); + JSONSchemaProps objectType = new JSONSchemaPropsBuilder() + .withType(TYPE_OBJECT) + .build(); + JSONSchemaProps arrayType = new JSONSchemaPropsBuilder() + .withType(TYPE_ARRAY) + .withNewItems() + .endItems() + .build(); + + //添加属性校验规则 + JSONSchemaProps specObjectType = new JSONSchemaPropsBuilder() + .addToProperties("image", stringType) + .addToProperties("imagePullPolicy", stringType) + .addToProperties("size", intType) + .addToProperties("env", arrayType) + .addToProperties("masterCmd", stringType) + .addToProperties("slaveCmd", stringType) + .addToProperties("masterResources", objectType) + .addToProperties("slaveResources", objectType) + .addToProperties("nodeSelector", objectType) + .addToProperties("initContainer", objectType) + .addToProperties("datasetStorage", objectType) + .addToProperties("workspaceStorage", objectType) + .addToProperties("modelStorage", objectType) + .withType("object") + .addToRequired("image", "imagePullPolicy", "size", "masterCmd", "slaveCmd", "workspaceStorage") + .build(); + properties.put("apiVersion", stringType); + properties.put("kind", stringType); + properties.put("metadata", objectType); + properties.put("spec", specObjectType); + return properties; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/OperatorRunner.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/OperatorRunner.java new file mode 100644 index 0000000..e6f5002 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/OperatorRunner.java @@ -0,0 +1,58 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action; + +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.watcher.KubeWatcherManager; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.ApplicationArguments; +import org.springframework.boot.ApplicationRunner; +import org.springframework.stereotype.Component; + +/** + * @description Operator运行入口 + * @date 2020-09-23 + */ +@Component +@Slf4j +public class OperatorRunner implements ApplicationRunner { + + @Autowired + private DistributeTrainOperatorManager operatorManager; + + @Autowired + private KubeWatcherManager watcherManager; + + /** + * spring 容器完全启动后 注册operator运行逻辑 + * @param args + * @throws Exception + */ + @Override + public void run(ApplicationArguments args) throws Exception { + //检查crd是否已存在,如果不存在则创建 + operatorManager.createCrdIfNotExists(); + + //job监控者启动 + watcherManager.startWatching(); + log.info("job watcher is running"); + + //初始化informer + operatorManager.initInformer(); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/PodInfo.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/PodInfo.java new file mode 100644 index 0000000..b863ab3 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/PodInfo.java @@ -0,0 +1,44 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * @description pod信息类 + * @date 2020-09-23 + */ +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class PodInfo { + + /** + * ip地址 + */ + private String ip; + + /** + * 角色 + */ + private String role; +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/AbstractResourceCreateInfo.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/AbstractResourceCreateInfo.java new file mode 100644 index 0000000..a8f031e --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/AbstractResourceCreateInfo.java @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer; + +import cn.hutool.core.util.RandomUtil; +import lombok.Data; +import lombok.experimental.Accessors; + +/** + * @description 创建资源的信息的抽象类 + * @date 2020-04-30 + */ +@Data +@Accessors(chain = true) +public abstract class AbstractResourceCreateInfo { + + + /** + * 生成随机字符串 + * @param digits 位数 + * @return + */ + protected static String getRandomStr(Integer digits){ + return RandomUtil.randomString(digits); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ChildResourceCreateInfo.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ChildResourceCreateInfo.java new file mode 100644 index 0000000..94d27eb --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ChildResourceCreateInfo.java @@ -0,0 +1,227 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer; + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.core.util.StrUtil; +import io.fabric8.kubernetes.api.model.*; +import lombok.Data; +import lombok.experimental.Accessors; +import org.onebrain.operator.constants.KubeConstants; +import org.onebrain.operator.constants.NumberConstant; +import org.onebrain.operator.crd.DistributeTrain; + +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * @description 暂存创建子资源所需的信息 + * @date 2020-06-16 + */ +@Data +@Accessors(chain = true) +public class ChildResourceCreateInfo extends AbstractResourceCreateInfo { + + public static final String SLAVE_TEMPLATE = "{}-slave-{}"; + public static final String MASTER_TEMPLATE = "{}-master-{}"; + public static final String SVC_TEMPLATE = "{}-svc"; + /** + * 父级名称(分布式训练名称) + */ + private String parentName; + + /** + * job名称 + */ + private String jobName; + + /** + * statefullSet名称 + */ + private String statefulSetName; + + /** + * 服务名称 + */ + private String svcName; + + /** + * 命名空间 + */ + private String namespace; + + /** + * 镜像 + */ + private String image; + + /** + * 镜像拉取策略 + */ + private String imagePullPolicy; + + /** + * 标签 + */ + private Map labels; + + /** + * master副本数 + */ + private Integer masterReplicas; + + /** + * slave副本数 + */ + private Integer slaveReplicas; + + /** + * master命令 + */ + private String masterCmd; + + /** + * slave命令 + */ + private String slaveCmd; + + /** + * master 资源节点限制 + */ + private ResourceRequirements masterResources; + + /** + * slave 资源节点限制 + */ + private ResourceRequirements slaveResources; + + /** + * 节点调度选择器 + */ + private Map nodeSelector; + + /** + * 初始化容器 + */ + private Container initContainer; + + /** + * 工作目录挂载 + */ + private Volume workspaceVolume; + + /** + * 数据集目录挂载 + */ + private Volume datasetVolume; + + /** + * 模型目录挂载 + */ + private Volume modelVolume; + + /** + * 环境变量 + */ + private List env; + + /** + * 拥有者信息 + */ + private OwnerReference ownerReference; + + /** + * 将分布式训练转换为K8S的资源信息 + * @param distributeTrain 分布式训练 + * @return ChildResourceCreateInfo + */ + public static ChildResourceCreateInfo fromCr(DistributeTrain distributeTrain){ + ChildResourceCreateInfo info = new ChildResourceCreateInfo(); + //ownerReferece信息 + info.generateOwnerReference(distributeTrain); + //各种资源的名称 + info.setNamespace(distributeTrain.getMetadata().getNamespace()); + info.setParentName(distributeTrain.getMetadata().getName()); + info.generateResoureName(); + //标签 + info.setLabels(distributeTrain.getMetadata().getLabels()); + //镜像 + info.setImage(distributeTrain.getSpec().getImage()) + .setImagePullPolicy(distributeTrain.getSpec().getImagePullPolicy()); + //副本数 + Integer size = distributeTrain.getSpec().getSize(); + info.setMasterReplicas(NumberConstant.NUMBER_1); + info.setSlaveReplicas(size - NumberConstant.NUMBER_1); + //命令行 + info.setMasterCmd(distributeTrain.getSpec().getMasterCmd()) + .setSlaveCmd(distributeTrain.getSpec().getSlaveCmd()); + //挂载 + Optional.ofNullable(distributeTrain.getSpec().getWorkspaceStorage()) + .ifPresent(v -> info.setWorkspaceVolume(v)); + Optional.ofNullable(distributeTrain.getSpec().getDatasetStorage()) + .ifPresent(v -> info.setDatasetVolume(v)); + Optional.ofNullable(distributeTrain.getSpec().getModelStorage()) + .ifPresent(v -> info.setModelVolume(v)); + + //主从两组资源限制 + Optional.ofNullable(distributeTrain.getSpec().getMasterResources()) + .ifPresent(v -> info.setMasterResources(v)); + Optional.ofNullable(distributeTrain.getSpec().getSlaveResources()) + .ifPresent(v -> info.setSlaveResources(v)); + + //环境变量 + List env = distributeTrain.getSpec().getEnv(); + if(CollectionUtil.isNotEmpty(env)){ + env = env.stream().filter(e -> !KubeConstants.ENV_NODE_NUM.equals(e.getName())).collect(Collectors.toList()); + info.setEnv(env); + } + + //node调度 + info.setNodeSelector(distributeTrain.getSpec().getNodeSelector()); + + //init-container + info.setInitContainer(distributeTrain.getSpec().getInitContainer()); + + return info; + } + + /** + * 生成资源名称 + */ + private void generateResoureName(){ + String suffix = getRandomStr(NumberConstant.NUMBER_5); + this.statefulSetName = StrUtil.format(SLAVE_TEMPLATE, this.parentName, suffix); + this.jobName = StrUtil.format(MASTER_TEMPLATE, this.parentName, suffix); + this.svcName = StrUtil.format(SVC_TEMPLATE, this.parentName); + } + + /** + * 生成所有者信息 + * @param distributeTrain 分布式训练 + */ + private void generateOwnerReference(DistributeTrain distributeTrain){ + this.ownerReference = new OwnerReferenceBuilder() + .withApiVersion(distributeTrain.getApiVersion()) + .withKind(distributeTrain.getKind()) + .withName(distributeTrain.getMetadata().getName()) + .withNewUid(distributeTrain.getMetadata().getUid()) + .build(); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/JobDeployer.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/JobDeployer.java new file mode 100644 index 0000000..68f4161 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/JobDeployer.java @@ -0,0 +1,35 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer; + +import io.fabric8.kubernetes.api.model.batch.JobBuilder; + +/** + * @description Job部署接口 规范部署方法 + * T 必须是AbstractResourceCreateInfo 的子类型 + * @date 2020-09-23 + */ +public interface JobDeployer { + + /** + * 构建 Job信息 + * @param info 资源信息 + * @return Job构建者 + */ + JobBuilder deploy(T info); +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ServiceDeployer.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ServiceDeployer.java new file mode 100644 index 0000000..4221c04 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ServiceDeployer.java @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer; + +import io.fabric8.kubernetes.api.model.ServiceBuilder; + +/** + * @description service部署器接口 + * @date 2020-09-23 + */ +public interface ServiceDeployer { + /** + * 构建service信息 + * @param info 资源信息 + * @return 服务构建者 + */ + ServiceBuilder deploy(T info); +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/StatefulSetDeployer.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/StatefulSetDeployer.java new file mode 100644 index 0000000..3be8d96 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/StatefulSetDeployer.java @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer; + +import io.fabric8.kubernetes.api.model.apps.StatefulSetBuilder; + +/** + * @description statefulset部署器接口 + * @date 2020-09-23 + */ +public interface StatefulSetDeployer { + /** + * 构建service信息 + * @param info 资源信息 + * @return StatefulSet构建者 + */ + StatefulSetBuilder deploy(T info); +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseJobDeployer.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseJobDeployer.java new file mode 100644 index 0000000..68bf788 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseJobDeployer.java @@ -0,0 +1,246 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer.impl; + +import cn.hutool.core.collection.CollectionUtil; +import com.google.common.collect.Lists; +import io.fabric8.kubernetes.api.model.CapabilitiesBuilder; +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.ContainerPortBuilder; +import io.fabric8.kubernetes.api.model.EnvVar; +import io.fabric8.kubernetes.api.model.EnvVarBuilder; +import io.fabric8.kubernetes.api.model.SecurityContextBuilder; +import io.fabric8.kubernetes.api.model.Volume; +import io.fabric8.kubernetes.api.model.VolumeBuilder; +import io.fabric8.kubernetes.api.model.VolumeMount; +import io.fabric8.kubernetes.api.model.VolumeMountBuilder; +import io.fabric8.kubernetes.api.model.batch.JobBuilder; +import org.onebrain.operator.action.deployer.ChildResourceCreateInfo; +import org.onebrain.operator.action.deployer.JobDeployer; +import org.onebrain.operator.constants.KubeConstants; + +import java.util.*; + +import static org.onebrain.operator.constants.NumberConstant.LONG_NUMBER_0; +import static org.onebrain.operator.constants.NumberConstant.NUMBER_1; +import static org.onebrain.operator.constants.NumberConstant.NUMBER_22; + +/** + * @description Job部署器 + * @date 2020-09-23 + */ +public class BaseJobDeployer implements JobDeployer { + + public static final String PVC_WORKSPACE = "pvc-workspace"; + public static final String SSH = "ssh"; + public static final String WORKSPACE = "/workspace"; + public static final String PVC_DATASET = "pvc-dataset"; + public static final String DATASET = "/dataset"; + public static final String PVC_MODEL = "pvc-model"; + public static final String MODEL = "/model"; + public static final String MEMORY = "Memory"; + public static final String DEV_SHM = "/dev/shm"; + public static final String BIN_BASH = "/bin/bash"; + public static final String IPC_LOCK = "IPC_LOCK"; + public static final String RESTART_POLICY_NEVER = "Never"; + + /** + * 部署Job + * @param info 资源信息 + * @return + */ + @Override + public JobBuilder deploy(ChildResourceCreateInfo info) { + + //容器 + Container container = buildContainer(info); + //存储卷 + List volumes = buildVolumes(info); + //挂载 + List volumeMounts = buildVolumeMounts(volumes); + + container.setVolumeMounts(volumeMounts); + + //启动命令 + container.setCommand(Collections.singletonList(BIN_BASH)); + //训练等待命令 + //一个是等待 pretreatment 文件 通过 podApi 拷贝 到pod上 + //另一个是等待 服务(svc)创建成功 + List cmdLines = Arrays.asList("while [ ! -f /home/pretreatment ]; do echo pretreatment not exist >> pretreatment.log; sleep 1;done && chmod a+x /home/pretreatment && bash /home/pretreatment ", "until nslookup " + info.getSvcName() + "; do sleep 5; done", info.getMasterCmd()); + container.setArgs(Arrays.asList("-c", CollectionUtil.join(cmdLines, " && "))); + + //权限 + container.setSecurityContext(new SecurityContextBuilder() + .withAllowPrivilegeEscalation(true) + .withCapabilities(new CapabilitiesBuilder() + .withAdd(Collections.singletonList(IPC_LOCK)) + .build()) + .build()); + + //用户自定义的标签 + Map customizeLabels = CollectionUtil.isNotEmpty(info.getLabels())? info.getLabels(): new HashMap<>(); + + JobBuilder builder = new JobBuilder(); + builder.withNewMetadata() + .withName(info.getJobName()) + .withNamespace(info.getNamespace()) + .addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName()) + .addToLabels(customizeLabels) + .addToOwnerReferences(info.getOwnerReference()) + .endMetadata() + .withNewSpec() + //并行1个 + .withParallelism(NUMBER_1) + //共计运行1次 + .withCompletions(NUMBER_1) + //失败重试次数 + .withBackoffLimit(KubeConstants.BACKOFFLIMIT) + .withNewTemplate() + .withNewMetadata() + .withName(info.getJobName()) + .addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName()) + .addToLabels(KubeConstants.JOB_LABEL, info.getJobName()) + .addToLabels(customizeLabels) + .endMetadata() + .withNewSpec() + //关闭指令发出时 立即执行 + .withTerminationGracePeriodSeconds(LONG_NUMBER_0) + .addToContainers(container) + .addToVolumes(volumes.toArray(new Volume[volumes.size()])) + .withRestartPolicy(RESTART_POLICY_NEVER) + .endSpec() + .endTemplate() + .endSpec(); + + //init-container + JobBuilder finalBuilder = builder; + Optional.ofNullable(info.getInitContainer()) + .ifPresent(initContainer -> { + finalBuilder.editSpec() + .editTemplate() + .editSpec() + .addToInitContainers(initContainer) + .endSpec() + .endTemplate() + .endSpec(); + }); + + //固定节点调度 + if(CollectionUtil.isNotEmpty(info.getNodeSelector())){ + builder = builder.editSpec() + .editTemplate().editSpec() + .addToNodeSelector(info.getNodeSelector()) + .endSpec().endTemplate() + .endSpec(); + } + + return builder; + } + + + + + /** + * 构建容器 + * @param info 资源信息 + * @return 容器信息 + */ + private Container buildContainer(ChildResourceCreateInfo info){ + //容器 + Container container = new Container(); + //镜像 + container.setName(KubeConstants.MASTER_CONTAINER_NAME); + container.setImage(info.getImage()); + container.setImagePullPolicy(info.getImagePullPolicy()); + //端口映射 + container.setPorts(Arrays.asList(new ContainerPortBuilder() + .withContainerPort(NUMBER_22) + .withName(SSH).build())); + //环境变量 + List envVars = Lists.newArrayList(new EnvVarBuilder() + .withName(KubeConstants.ENV_NODE_NUM) + .withValue(String.valueOf(info.getSlaveReplicas() + info.getMasterReplicas())) + .build()); + Optional.ofNullable(info.getEnv()).ifPresent(v -> envVars.addAll(v)); + container.setEnv(envVars); + + //资源限制 + Optional.ofNullable(info.getMasterResources()).ifPresent(v->container.setResources(v)); + return container; + } + + /** + * 构建存储卷集合 + * @param info 资源信息 + * @return 存储卷集合 + */ + private List buildVolumes(ChildResourceCreateInfo info){ + //存储卷 + List volumes = new LinkedList<>(); + Optional.ofNullable(info.getWorkspaceVolume()).ifPresent(v-> volumes.add(v)); + Optional.ofNullable(info.getDatasetVolume()).ifPresent(v-> volumes.add(v)); + Optional.ofNullable(info.getModelVolume()).ifPresent(v-> volumes.add(v)); + //shm默认就有 + volumes.add(new VolumeBuilder() + .withName(KubeConstants.VOLUME_SHM) + .withNewEmptyDir() + .withMedium(MEMORY) + .endEmptyDir() + .build()); + + return volumes; + } + + /** + * 构建挂载存储卷集合 + * @param volumes 存储卷集合 + * @return 构建挂载存储卷集合 + */ + private List buildVolumeMounts(List volumes) { + List volumeMounts = new LinkedList<>(); + for (Volume volume : volumes) { + if(PVC_WORKSPACE.equals(volume.getName())){ + volumeMounts.add(new VolumeMountBuilder() + .withName(volume.getName()) + .withMountPath(WORKSPACE) + .build()); + continue; + } + if(PVC_DATASET.equals(volume.getName())){ + volumeMounts.add(new VolumeMountBuilder() + .withName(volume.getName()) + .withMountPath(DATASET) + .build()); + continue; + } + if(PVC_MODEL.equals(volume.getName())){ + volumeMounts.add(new VolumeMountBuilder() + .withName(volume.getName()) + .withMountPath(MODEL) + .build()); + continue; + } + } + + volumeMounts.add(new VolumeMountBuilder() + .withName(KubeConstants.VOLUME_SHM) + .withMountPath(DEV_SHM) + .build()); + return volumeMounts; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseServiceDeployer.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseServiceDeployer.java new file mode 100644 index 0000000..e18a625 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseServiceDeployer.java @@ -0,0 +1,73 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer.impl; + +import cn.hutool.core.collection.CollectionUtil; +import io.fabric8.kubernetes.api.model.IntOrString; +import io.fabric8.kubernetes.api.model.ServiceBuilder; +import org.onebrain.operator.action.deployer.ChildResourceCreateInfo; +import org.onebrain.operator.action.deployer.ServiceDeployer; +import org.onebrain.operator.constants.KubeConstants; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.onebrain.operator.constants.NumberConstant.NUMBER_22; +import static org.onebrain.operator.constants.NumberConstant.NUMBER_30000; + +/** + * @description Service部署器 + * @date 2020-09-23 + */ +public class BaseServiceDeployer implements ServiceDeployer { + + public static final String WEB_SSH = "web-ssh"; + public static final String NONE = "None"; + + /** + * 构建service信息 + * @param info 资源信息 + * @return + */ + @Override + public ServiceBuilder deploy(ChildResourceCreateInfo info) { + + //用户自定义的标签 + Map customizeLabels = CollectionUtil.isNotEmpty(info.getLabels())? info.getLabels(): new HashMap<>(); + + return new ServiceBuilder() + .withNewMetadata() + .withName(info.getSvcName()) + .addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName()) + .addToLabels(customizeLabels) + .withNamespace(info.getNamespace()) + .addToOwnerReferences(info.getOwnerReference()) + .endMetadata() + .withNewSpec() + .addNewPort() + .withPort(NUMBER_30000) + .withTargetPort(new IntOrString(NUMBER_22)) + .withName(WEB_SSH) + .endPort() + .withClusterIP(NONE) + //选择带有分布式训练的节点 + .withSelector(Collections.singletonMap(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName())) + .endSpec(); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseStatefulSetDeployer.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseStatefulSetDeployer.java new file mode 100644 index 0000000..1d11251 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseStatefulSetDeployer.java @@ -0,0 +1,246 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.deployer.impl; + +import cn.hutool.core.collection.CollectionUtil; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import io.fabric8.kubernetes.api.model.CapabilitiesBuilder; +import io.fabric8.kubernetes.api.model.Container; +import io.fabric8.kubernetes.api.model.ContainerPortBuilder; +import io.fabric8.kubernetes.api.model.EnvVar; +import io.fabric8.kubernetes.api.model.EnvVarBuilder; +import io.fabric8.kubernetes.api.model.LabelSelector; +import io.fabric8.kubernetes.api.model.SecurityContextBuilder; +import io.fabric8.kubernetes.api.model.Volume; +import io.fabric8.kubernetes.api.model.VolumeBuilder; +import io.fabric8.kubernetes.api.model.VolumeMount; +import io.fabric8.kubernetes.api.model.VolumeMountBuilder; +import io.fabric8.kubernetes.api.model.apps.StatefulSetBuilder; +import org.onebrain.operator.action.deployer.ChildResourceCreateInfo; +import org.onebrain.operator.action.deployer.StatefulSetDeployer; +import org.onebrain.operator.constants.KubeConstants; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static org.onebrain.operator.constants.NumberConstant.LONG_NUMBER_0; +import static org.onebrain.operator.constants.NumberConstant.LONG_NUMBER_60; +import static org.onebrain.operator.constants.NumberConstant.NUMBER_22; + +/** + * @description StatefullSet部署器 + * @date 2020-09-23 + */ +public class BaseStatefulSetDeployer implements StatefulSetDeployer { + + public static final String SSH = "ssh"; + public static final String PVC_WORKSPACE = "pvc-workspace"; + public static final String WORKSPACE = "/workspace"; + public static final String PVC_DATASET = "pvc-dataset"; + public static final String DATASET = "/dataset"; + public static final String PVC_MODEL = "pvc-model"; + public static final String MODEL = "/model"; + public static final String MEMORY = "Memory"; + public static final String DEV_SHM = "/dev/shm"; + public static final String BIN_BASH = "/bin/bash"; + public static final String IPC_LOCK = "IPC_LOCK"; + + /** + * 生成 StatefullSet 信息 + * @param info 资源信息 + * @return + */ + @Override + public StatefulSetBuilder deploy(ChildResourceCreateInfo info) { + //标签筛选 + LabelSelector labelSelector = new LabelSelector(); + labelSelector.setMatchLabels(ImmutableMap.of(KubeConstants.STATEFULSET_LABEL, info.getStatefulSetName())); + //存储卷 + List volumes = buildVolumes(info); + //容器 + Container container = buildContainer(info); + //挂载 + List volumeMounts = buildVolumeMounts(volumes); + + container.setVolumeMounts(volumeMounts); + + //启动命令 + List cmdLines = Arrays.asList("while [ ! -f /home/pretreatment ]; do echo pretreatment not exist >> pretreatment.log; sleep 1;done && chmod a+x /home/pretreatment && bash /home/pretreatment ", "until nslookup " + info.getSvcName() + "; do sleep 5; done", info.getSlaveCmd()); + container.setCommand(Collections.singletonList(BIN_BASH)); + container.setArgs(Arrays.asList("-c", CollectionUtil.join(cmdLines, " && "))); + + //权限 + container.setSecurityContext(new SecurityContextBuilder() + .withAllowPrivilegeEscalation(true) +// .withPrivileged(true) + .withCapabilities(new CapabilitiesBuilder() + .withAdd(Collections.singletonList(IPC_LOCK)) + .build()) + .build()); + + //用户自定义的标签 + Map customizeLabels = CollectionUtil.isNotEmpty(info.getLabels())? info.getLabels(): new HashMap<>(); + + + StatefulSetBuilder builder = new StatefulSetBuilder(); + builder.withNewMetadata() + .withName(info.getStatefulSetName()) + .withNamespace(info.getNamespace()) + .addToOwnerReferences(info.getOwnerReference()) + .addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName()) + .endMetadata() + .withNewSpec() + .withSelector(labelSelector) + .withServiceName(info.getStatefulSetName()) + .withReplicas(info.getSlaveReplicas()) + .withNewTemplate() + .withNewMetadata() + .withName(info.getStatefulSetName()) + .addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName()) + .addToLabels(KubeConstants.STATEFULSET_LABEL, info.getStatefulSetName()) + .addToLabels(customizeLabels) + .endMetadata() + .withNewSpec() + .withTerminationGracePeriodSeconds(LONG_NUMBER_0) + .withTerminationGracePeriodSeconds(LONG_NUMBER_60) + .addToContainers(container) + .addToVolumes(volumes.toArray(new Volume[0])) + .endSpec() + .endTemplate() + .endSpec(); + + //init-container + StatefulSetBuilder finalBuilder = builder; + Optional.ofNullable(info.getInitContainer()) + .ifPresent(initContainer -> { + finalBuilder.editSpec() + .editTemplate() + .editSpec() + .addToInitContainers(initContainer) + .endSpec() + .endTemplate() + .endSpec(); + }); + + //固定节点调度 + if(CollectionUtil.isNotEmpty(info.getNodeSelector())){ + builder = builder.editSpec() + .editTemplate().editSpec() + .addToNodeSelector(info.getNodeSelector()) + .endSpec().endTemplate() + .endSpec(); + } + + return builder; + } + + /** + * 构建容器 + * @param info 资源信息 + * @return 容器信息 + */ + private Container buildContainer(ChildResourceCreateInfo info) { + Container container = new Container(); + //镜像 + container.setName(KubeConstants.SLAVE_CONTAINER_NAME); + container.setImage(info.getImage()); + container.setImagePullPolicy(info.getImagePullPolicy()); + //端口映射 + container.setPorts(Arrays.asList(new ContainerPortBuilder() + .withContainerPort(NUMBER_22) + .withName(SSH).build())); + //环境变量 + List envVars = Lists.newArrayList(new EnvVarBuilder() + .withName(KubeConstants.ENV_NODE_NUM) + .withValue(String.valueOf(info.getSlaveReplicas() + info.getMasterReplicas())) + .build()); + Optional.ofNullable(info.getEnv()).ifPresent(v -> envVars.addAll(v)); + container.setEnv(envVars); + + //资源限制 + Optional.ofNullable(info.getSlaveResources()).ifPresent(v -> container.setResources(v)); + + return container; + } + + /** + * 构建存储卷集合 + * @param info 资源信息 + * @return 存储卷集合 + */ + private List buildVolumes(ChildResourceCreateInfo info) { + List volumes = buildVolumes(info); + Optional.ofNullable(info.getWorkspaceVolume()).ifPresent(v-> volumes.add(v)); + Optional.ofNullable(info.getDatasetVolume()).ifPresent(v-> volumes.add(v)); + Optional.ofNullable(info.getModelVolume()).ifPresent(v-> volumes.add(v)); + + //shm默认就有 + volumes.add(new VolumeBuilder() + .withName(KubeConstants.VOLUME_SHM) + .withNewEmptyDir() + .withMedium(MEMORY) + .endEmptyDir() + .build()); + + return volumes; + } + + /** + * 构建挂载存储卷集合 + * @param volumes 存储卷集合 + * @return 构建挂载存储卷集合 + */ + private List buildVolumeMounts(List volumes) { + List volumeMounts=new LinkedList<>(); + for (Volume volume : volumes) { + if(PVC_WORKSPACE.equals(volume.getName())){ + volumeMounts.add(new VolumeMountBuilder() + .withName(volume.getName()) + .withMountPath(WORKSPACE) + .build()); + continue; + } + if(PVC_DATASET.equals(volume.getName())){ + volumeMounts.add(new VolumeMountBuilder() + .withName(volume.getName()) + .withMountPath(DATASET) + .build()); + continue; + } + if(PVC_MODEL.equals(volume.getName())){ + volumeMounts.add(new VolumeMountBuilder() + .withName(volume.getName()) + .withMountPath(MODEL) + .build()); + continue; + } + } + + volumeMounts.add(new VolumeMountBuilder() + .withName(KubeConstants.VOLUME_SHM) + .withMountPath(DEV_SHM) + .build()); + return volumeMounts; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/AddActionHandler.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/AddActionHandler.java new file mode 100644 index 0000000..99f6f52 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/AddActionHandler.java @@ -0,0 +1,614 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.handler; + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.core.io.FileUtil; +import cn.hutool.core.util.ObjectUtil; +import cn.hutool.core.util.StrUtil; +import com.alibaba.fastjson.JSONArray; +import com.alibaba.fastjson.JSONObject; +import com.google.common.collect.Lists; +import com.google.common.io.Files; +import io.fabric8.kubernetes.api.model.ObjectMeta; +import io.fabric8.kubernetes.api.model.Pod; +import io.fabric8.kubernetes.api.model.Service; +import io.fabric8.kubernetes.api.model.ServiceBuilder; +import io.fabric8.kubernetes.api.model.apps.StatefulSet; +import io.fabric8.kubernetes.api.model.apps.StatefulSetBuilder; +import io.fabric8.kubernetes.api.model.batch.Job; +import io.fabric8.kubernetes.api.model.batch.JobBuilder; +import io.fabric8.kubernetes.client.KubernetesClient; +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.action.PodInfo; +import org.onebrain.operator.action.deployer.ChildResourceCreateInfo; +import org.onebrain.operator.action.deployer.JobDeployer; +import org.onebrain.operator.action.deployer.ServiceDeployer; +import org.onebrain.operator.action.deployer.StatefulSetDeployer; +import org.onebrain.operator.action.deployer.impl.BaseJobDeployer; +import org.onebrain.operator.action.deployer.impl.BaseServiceDeployer; +import org.onebrain.operator.action.deployer.impl.BaseStatefulSetDeployer; +import org.onebrain.operator.api.pod.PodApi; +import org.onebrain.operator.constants.KubeConstants; +import org.onebrain.operator.crd.DistributeTrain; +import org.onebrain.operator.crd.DistributeTrainSpec; +import org.onebrain.operator.crd.DistributeTrainStatus; +import org.onebrain.operator.exception.OperatorException; +import org.onebrain.operator.redis.RedisService; +import org.onebrain.operator.redis.key.OperatorKey; +import org.onebrain.operator.utils.DistributeTrainClientHolder; +import org.onebrain.operator.utils.IOUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.core.io.ClassPathResource; +import org.springframework.stereotype.Component; + +import java.io.File; +import java.io.InputStream; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.onebrain.operator.constants.KubeConstants.CHARSET; +import static org.onebrain.operator.constants.KubeConstants.JOB_LABEL; +import static org.onebrain.operator.constants.KubeConstants.MASTER_CONTAINER_NAME; +import static org.onebrain.operator.constants.KubeConstants.SLAVE_CONTAINER_NAME; +import static org.onebrain.operator.constants.KubeConstants.STATEFULSET_LABEL; +import static org.onebrain.operator.constants.NumberConstant.NUMBER_2; + +/** + * @description 分布式训练添加事件的处理器 + * @date 2020-09-23 + */ +@Component("addActionHandler") +@Slf4j +public class AddActionHandler implements DistributeTrainActionHandler { + + public static final String JOB_WATCHER = "job-watcher-"; + public static final String PRETREATMENT = "pretreatment"; + public static final String JOB_NAME = "job-name"; + public static final String RUNNING = "Running"; + public static final String MASTER = "master"; + public static final String SLAVE = "slave"; + public static final String PRETREATMENT_TARGET_DIR = "/home/pretreatment"; + public static final String IP = "ip"; + public static final String ROLE = "role"; + public static final String HOSTFILE_TARGET_DIR = "/home/hostfile.json"; + @Autowired + private KubernetesClient client; + + @Autowired + private PodApi podApi; + + /** + * String 训练uid List pod信息 + */ + private Map> dtMap = new ConcurrentHashMap(); + + @Autowired + private RedisService redis; + + /** + * 线程池 + */ + private ThreadPoolExecutor pool = new ThreadPoolExecutor(5, 10, 10, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1), new ThreadFactory() { + private final AtomicInteger mThreadNum = new AtomicInteger(1); + + @Override + public Thread newThread(Runnable r) { + return new Thread(r, JOB_WATCHER + mThreadNum.getAndIncrement()); + } + }, new ThreadPoolExecutor.DiscardOldestPolicy()); + + /** + * 处理事件的任务 + */ + class HandlerActionTask implements Runnable { + + private DistributeTrain distributeTrain; + + public HandlerActionTask(DistributeTrain distributeTrain) { + this.distributeTrain = distributeTrain; + } + + @Override + public void run() { + doAction(distributeTrain); + } + } + + /** + * 执行任务动作 + * @param distributeTrain + */ + public void doAction(DistributeTrain distributeTrain) { + log.info("doAction=>distributeTrain : 【{}】", distributeTrain); + ChildResourceCreateInfo info = null; + try { + //redis重复检查 + //根据k8s 创建DistributionTrain 的uid去重 + if (null != redis.get(OperatorKey.CR, distributeTrain.getMetadata().getUid())) { + log.info("distribute train 【{}】 in namespace 【{}】 already exists", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace()); + return; + } else { + //录入redis做消费记录 + redis.set(OperatorKey.CR, distributeTrain.getMetadata().getUid(), System.currentTimeMillis()); + } + + //参数检查,提取并生成所需参数 + validateParams(distributeTrain); + info = ChildResourceCreateInfo.fromCr(distributeTrain); + //按照size,创建副本数为size-1的statefulSet + createStatefulSet(info); + //等待statefulset全部ready + waitUntilStatefulSetReady(info); + //创建job,job此时在死循环 + createJob(info); + //等待job ready + waitUntilJobReady(info); + //复制 /home/pretreatment 到 pod + copyPretreatmentShell(info); + //收集statefulSet和job的ip + validateAndCollectPods(info); + //本地生成公私钥、认证文件,并拷贝到所有节点的~/.ssh目录下 + sshAuthWithoutPass(info); + //本地生成hostfile,并拷贝到所有节点的指定目录下 + generateAndUploadHostFile(info); + //解锁job的死循环 + releaseInterLock(info); + //改状态 + //updateStatus(info, distributeTrain); + //为job注册监听器 + registerJobListener(info); + + log.info("all parts of【{}】 are ready", info.getParentName()); + } catch (Exception e) { + log.error("doAction error:【{}】", e); + //移除缓存 + redis.del(OperatorKey.CR, distributeTrain.getMetadata().getUid()); + //回收创建的资源 + if (info != null) { + recycleCr(info); + } + } + } + + /** + * 处理分布式训练 + * @param distributeTrain 分布式训练信息 + */ + @Override + public void handlerAction(DistributeTrain distributeTrain) { + log.info("handlerAction=>distributeTrain : 【{}】", distributeTrain); + HandlerActionTask handlerActionTask = new HandlerActionTask(distributeTrain); + pool.getActiveCount(); + pool.execute(handlerActionTask); + } + + /** + * 校验参数合法性 + * @param distributeTrain 分布式训练 + */ + private void validateParams(DistributeTrain distributeTrain) { + log.info("validateParams=>distributeTrain : 【{}】", distributeTrain); + Integer size = distributeTrain.getSpec().getSize(); + if (size < NUMBER_2) { + throw new OperatorException("size must be greater than 1"); + } + String masterCmd = distributeTrain.getSpec().getMasterCmd(); + String slaveCmd = distributeTrain.getSpec().getSlaveCmd(); + if (StrUtil.isEmpty(slaveCmd) || StrUtil.isEmpty(masterCmd)) { + throw new OperatorException("cmd lines must not be empty"); + } + } + + /** + * 拷贝文件pretreatment到pod + * @param info 资源信息 + */ + private void copyPretreatmentShell(ChildResourceCreateInfo info) { + log.info("start to copy pretreatment for 【{}】 ", info.getParentName()); + try { + String path = System.getProperty(KubeConstants.USER_DIR_SYSTEM_PROPERTY) + File.separator + PRETREATMENT; + if (!FileUtil.exist(path)) { + FileUtil.writeFromStream(new ClassPathResource("/shell/pretreatment").getInputStream(), path); + } + File pretreatment = new File(path); + //上传到pod指定目录 + List pods = getPods(info); + for (int i = 0; i < pods.size(); i++) { + Pod pod = pods.get(i); + //默认第一个为master + String containerName = i < 1 ? MASTER_CONTAINER_NAME : SLAVE_CONTAINER_NAME; + podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, pretreatment, PRETREATMENT_TARGET_DIR); + } + } catch (Exception e) { + log.error("copy pretreatment shell error: 【{}】",e); + throw new OperatorException("exception is thrown when copy pretreatment for 【" + info.getParentName() + "】 : \n" + e.getMessage()); + } + } + + /** + * 创建statefulSet + * @param info 资源信息 + */ + private void createStatefulSet(ChildResourceCreateInfo info) { + log.info("createStatefulSet=>childResourceCreateInfo : 【{}】", info); + StatefulSet statefulSet = client.apps().statefulSets() + .inNamespace(info.getNamespace()) + .withName(info.getStatefulSetName()).get(); + //已存在 + if (statefulSet != null) { + log.info("statefulSet 【{}】 already exists", statefulSet.getMetadata().getName()); + return; + } + //不存在,新建 + StatefulSetDeployer deployer = new BaseStatefulSetDeployer(); + StatefulSetBuilder builder = deployer.deploy(info); + statefulSet = builder.build(); + client.apps().statefulSets().create(statefulSet); + log.info("create statefulSet【{}】 successfully", statefulSet.getMetadata().getName()); + } + + /** + * 等待statefulSet全部ready + * @param info 资源信息 + */ + private void waitUntilStatefulSetReady(ChildResourceCreateInfo info) { + log.info("wait for statefulSet 【{}】 in namespace 【{}】 ready", info.getStatefulSetName(), info.getNamespace()); + try { + client.apps().statefulSets() + .inNamespace(info.getNamespace()) + .withName(info.getStatefulSetName()) + //阻塞 直到全部pod Ready 最长阻塞时间2小时 + .waitUntilCondition(c -> + c.getStatus().getReplicas() != null + && ObjectUtil.equal(c.getStatus().getReplicas(), c.getStatus().getReadyReplicas()), + NUMBER_2, TimeUnit.HOURS); + log.info("statefulSet 【{}】 in namespace 【{}】 is ready", info.getStatefulSetName(), info.getNamespace()); + } catch (Exception e) { + log.error("wait until statefulSet ready error:【{}】", e); + throw new OperatorException("exception is thrown when waiting for statefulSet 【" + info.getStatefulSetName() + "】 ready : \n" + e.getMessage()); + } + } + + /** + * 创建job + * @param info Job信息 + */ + private void createJob(ChildResourceCreateInfo info) { + log.info("createJob=>childResourceCreateInfo : 【{}】", info); + Job job = client.batch().jobs() + .inNamespace(info.getNamespace()) + .withName(info.getJobName()).get(); + //已存在 + if (job != null) { + log.info("job 【{}】 already exists", job.getMetadata().getName()); + return; + } + //不存在,新建 + JobDeployer deployer = new BaseJobDeployer(); + JobBuilder builder = deployer.deploy(info); + job = builder.build(); + log.info("job is : 【{}】", job); + client.batch().jobs().create(job); + log.info("create job【{}】 successfully", job.getMetadata().getName()); + } + + /** + * 等待job全部ready + * @param info 资源信息 + */ + private void waitUntilJobReady(ChildResourceCreateInfo info) { + log.info("wait for job 【{}】 in namespace 【{}】 ready", info.getStatefulSetName(), info.getNamespace()); + try { + List podList = client.pods().inNamespace(info.getNamespace()) + .withLabel(JOB_NAME, info.getJobName()) + .list().getItems(); + while (CollectionUtil.isEmpty(podList)) { + TimeUnit.SECONDS.sleep(2); + podList = client.pods().inNamespace(info.getNamespace()) + .withLabel(JOB_NAME, info.getJobName()) + .list().getItems(); + } + Pod pod = podList.get(0); + client.pods().inNamespace(info.getNamespace()) + .withName(pod.getMetadata().getName()) + //等待直到Ready状态 最长2小时 + .waitUntilReady(2, TimeUnit.HOURS); + log.info("job 【{}】 in namespace 【{}】 is ready", info.getJobName(), info.getNamespace()); + } catch (Exception e) { + log.info(e.getMessage(), e); + throw new OperatorException("exception is thrown when waiting for job 【" + info.getJobName() + "】 ready : \n" + e.getMessage()); + } + } + + /** + * 收集资源的podInfo + * @param info 资源信息 + */ + private void validateAndCollectPods(ChildResourceCreateInfo info) { + //检查是否都在正常运行 + log.info("validate pods status for 【{}】", info.getParentName()); + boolean isAllSlaveRunning = true; + boolean isMasterRunning = true; + Pod masterPod = null; + List slavePods = null; + + do { + //取得主的pod + masterPod = getMasterPod(info); + + //取得从的所有pod + slavePods = getSlavePods(info); + + if (masterPod == null) { + log.info("can not find pod belongs to job 【{}】", info.getJobName()); + return; + } + if (CollectionUtil.isEmpty(slavePods)) { + log.info("can not find pod belongs to statefulSet 【{}】", info.getStatefulSetName()); + return; + } + + isMasterRunning = RUNNING.equals(masterPod.getStatus().getPhase()); + isAllSlaveRunning = true; + for (Pod slavePod : slavePods) { + boolean isSlaveRunning = RUNNING.equals(slavePod.getStatus().getPhase()); + if (!isSlaveRunning) { + isAllSlaveRunning = false; + break; + } + } + } while (!(isMasterRunning && isAllSlaveRunning)); + + log.info("status checked 【{}】 all right", info.getParentName()); + collectChildPodInfo(info, masterPod, slavePods); + } + + /** + * 收集pod基本信息 + * @param info 资源信息 + * @param masterPod + * @param slavePods + */ + private void collectChildPodInfo(ChildResourceCreateInfo info, Pod masterPod, List slavePods) { + log.info("collectChildPodInfo=>childResourceCreateInfo : 【{}】, masterPod : 【{}】, slavePods : 【{}】", info, masterPod, slavePods); + String key = info.getOwnerReference().getUid(); + if (dtMap.containsKey(key)) { + dtMap.remove(key); + } + List podInfos = Lists.newArrayList(); + PodInfo masterPodInfo = PodInfo.builder() + .ip(masterPod.getStatus().getPodIP()) + .role(MASTER) + .build(); + podInfos.add(masterPodInfo); + for (Pod slavePod : slavePods) { + PodInfo slavePodInfo = PodInfo.builder() + .ip(slavePod.getStatus().getPodIP()) + .role(SLAVE) + .build(); + podInfos.add(slavePodInfo); + } + dtMap.put(key, podInfos); + } + + /** + * ssh免密互通相关配置 + * @param info 资源信息 + */ + private void sshAuthWithoutPass(ChildResourceCreateInfo info) { + log.info("start to configure ssh no password environment for 【{}】 ", info.getParentName()); + File tempDir = Files.createTempDir(); + try ( + InputStream isRsa = getClass().getClassLoader().getResourceAsStream("key/id_rsa"); + InputStream isRsaPub = getClass().getClassLoader().getResourceAsStream("key/id_rsa.pub") + ) { + //id_rsa + File tempIdRsa = FileUtil.createTempFile(tempDir); + IOUtils.copy(isRsa, tempIdRsa); + //id_rsa.pub + File tempIdRsaPub = FileUtil.createTempFile(tempDir); + IOUtils.copy(isRsaPub, tempIdRsaPub); + List pubLines = FileUtil.readLines(tempIdRsaPub, CHARSET); + String pubKeyContent = pubLines.get(0); + //按机器修改id_rsa.pub, 并组装一个大而全的authorized_keys + List idRsaPubFiles = Lists.newArrayList(); + File tempAuthorizedKeys = FileUtil.createTempFile(tempDir); + List pubKeys = Lists.newArrayList(); + for (PodInfo podInfo : dtMap.get(info.getOwnerReference().getUid())) { + String podPubKeyContent = pubKeyContent.replace("{{ip}}", podInfo.getIp()); + File tempIdRsaPubOnPod = FileUtil.createTempFile(tempDir); + FileUtil.writeLines(Collections.singletonList(podPubKeyContent), tempIdRsaPubOnPod, CHARSET); + idRsaPubFiles.add(tempIdRsaPubOnPod); + pubKeys.add(podPubKeyContent); + } + FileUtil.writeLines(pubKeys, tempAuthorizedKeys, CHARSET); + + //获得所有pod, 上传三个文件 + List pods = getPods(info); + for (int i = 0; i < pods.size(); i++) { + Pod pod = pods.get(i); + String containerName = i < 1 ? MASTER_CONTAINER_NAME : SLAVE_CONTAINER_NAME; + //上传id_rsa + podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempIdRsa, "/root/.ssh/id_rsa"); + //上传id_rsa.pub + File tempIdRsaPubOnPod = idRsaPubFiles.get(i); + podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempIdRsaPubOnPod, "/root/.ssh/id_rsa.pub"); + //上传authorized_keys + podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempAuthorizedKeys, "/root/.ssh/authorized_keys"); + //修改权限 + String chmodCmd = StrUtil.format("chmod 644 /root/.ssh/authorized_keys && chmod 600 /root/.ssh/id_rsa && chmod 644 /root/.ssh/id_rsa.pub"); + podApi.exec(info.getNamespace(), pod.getMetadata().getName(), containerName, chmodCmd); + } + log.info("configure ssh no password environment for 【{}】 successfully ", info.getParentName()); + } catch (Exception e) { + log.error("sshAuthWithoutPass error:【{}】", e); + throw new OperatorException("exception is thrown when configure ssh no password environment for 【" + info.getParentName() + "】 : \n" + e.getMessage()); + } finally { + //清理临时文件 + FileUtil.del(tempDir); + } + } + + /** + * 生成并上传hostfile + * @param info 资源信息 + */ + private void generateAndUploadHostFile(ChildResourceCreateInfo info) { + log.info("start to configure hostfile for 【{}】 ", info.getParentName()); + File tempDir = Files.createTempDir(); + try { + //生成hostfile + JSONArray jsonArray = new JSONArray(); + List podInfos = dtMap.get(info.getOwnerReference().getUid()); + for (PodInfo podInfo : podInfos) { + JSONObject podJson = new JSONObject(); + podJson.put(IP, podInfo.getIp()); + podJson.put(ROLE, podInfo.getRole()); + jsonArray.add(podJson); + } + File tempHostFile = FileUtil.createTempFile(tempDir); + FileUtil.writeLines(Collections.singletonList(jsonArray.toJSONString()), tempHostFile, CHARSET); + //上传到pod指定目录 + List pods = getPods(info); + for (int i = 0; i < pods.size(); i++) { + Pod pod = pods.get(i); + String containerName = i < 1 ? MASTER_CONTAINER_NAME : SLAVE_CONTAINER_NAME; + podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempHostFile, HOSTFILE_TARGET_DIR); + } + + } catch (Exception e) { + log.error("generateAndUploadHostFile error:【{}】", e); + throw new OperatorException("exception is thrown when generate and upload hostfile for 【" + info.getParentName() + "】 : \n" + e.getMessage()); + } finally { + //清理临时文件 + FileUtil.del(tempDir); + } + } + + /** + * 创建service 解除闭锁 + * @param info + */ + private void releaseInterLock(ChildResourceCreateInfo info) { + log.info("release lock for 【{}】", info.getParentName()); + ServiceDeployer deployer = new BaseServiceDeployer(); + ServiceBuilder builder = deployer.deploy(info); + Service svc = builder.build(); + client.services().create(svc); + log.info("lock for 【{}】 released", info.getParentName()); + } + + /** + * 回收cr + * @param info + */ + private void recycleCr(ChildResourceCreateInfo info) { + log.info("recycleCr=>childResourceCreateInfo : 【{}】", info); + Optional.ofNullable(DistributeTrainClientHolder.getClient()) + .ifPresent(distributeTrainClient -> { + ObjectMeta metadata = new ObjectMeta(); + metadata.setName(info.getParentName()); + metadata.setNamespace(info.getNamespace()); + DistributeTrain dt = new DistributeTrain(metadata, DistributeTrainSpec.builder() + .build()); + distributeTrainClient.delete(dt); + log.info("recycle distribute train 【{}】", info.getParentName()); + }); + } + + /**更新状态*/ + private void updateStatus(ChildResourceCreateInfo info, DistributeTrain distributeTrain) { + log.info("updateStatus=>childResourceCreateInfo : 【{}】, distributeTrain : 【{}】", info, distributeTrain); + if (distributeTrain.getStatus() == null) { + distributeTrain.setStatus(new DistributeTrainStatus()); + } + Integer size = distributeTrain.getSpec().getSize(); + distributeTrain.getStatus().setReplicas(size); + distributeTrain.getStatus().setReadyReplicas(size); + } + + /** + * 为job注册监听器 + * @param info + */ + private void registerJobListener(ChildResourceCreateInfo info) { + log.info("register listener for distribute train 【{}】", info.getParentName()); +// client.batch().jobs() +// .inNamespace(info.getNamespace()) +// .withName(info.getJobName()).watch(null); + } + + /** + * 获取所有分布式训练相关的pod + * @param info + * @return List 分布式相关Pod集合 + */ + private List getPods(ChildResourceCreateInfo info) { + log.info("getPods=>childResourceCreateInfo : 【{}】", info); + List pods = Lists.newArrayList(); + pods.add(getMasterPod(info)); + pods.addAll(getSlavePods(info)); + if (CollectionUtil.hasNull(pods) || pods.size() != info.getSlaveReplicas() + 1) { + throw new OperatorException("can not get pods in correct numbers"); + } + return pods; + } + + /** + * 获取master信息 + * @param info 资源信息 + * @return Pod Master节点对应的Pod + */ + private Pod getMasterPod(ChildResourceCreateInfo info) { + log.info("getMasterPod=>childResourceCreateInfo : 【{}】", info); + List masterPods = client.pods().inNamespace(info.getNamespace()) + .withLabel(JOB_LABEL, info.getJobName()) + .list().getItems(); + if (CollectionUtil.isEmpty(masterPods)) { + return null; + } + return masterPods.get(0); + } + + /** + * 取得从的所有pod + * @param info 资源信息 + * @return List Slave节点对应的Pod集合 + */ + private List getSlavePods(ChildResourceCreateInfo info) { + log.info("getSlavePods=>childResourceCreateInfo : 【{}】", info); + //取得从的所有pod + List slavePods = client.pods().inNamespace(info.getNamespace()) + .withLabel(STATEFULSET_LABEL, info.getStatefulSetName()) + .list().getItems(); + if (CollectionUtil.isEmpty(slavePods)) { + return null; + } + return slavePods; + } + +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DeleteActionHandler.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DeleteActionHandler.java new file mode 100644 index 0000000..f0b1465 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DeleteActionHandler.java @@ -0,0 +1,88 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.handler; + +import cn.hutool.core.collection.CollectionUtil; +import io.fabric8.kubernetes.api.model.Service; +import io.fabric8.kubernetes.api.model.ServiceList; +import io.fabric8.kubernetes.api.model.apps.StatefulSet; +import io.fabric8.kubernetes.api.model.apps.StatefulSetList; +import io.fabric8.kubernetes.api.model.batch.Job; +import io.fabric8.kubernetes.api.model.batch.JobList; +import io.fabric8.kubernetes.client.KubernetesClient; +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.constants.KubeConstants; +import org.onebrain.operator.crd.DistributeTrain; +import org.onebrain.operator.redis.RedisService; +import org.onebrain.operator.redis.key.OperatorKey; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +/** + * @description 删除事件的处理器 + * @date 2020-09-23 + */ +@Component("deleteActionHandler") +@Slf4j +public class DeleteActionHandler implements DistributeTrainActionHandler { + + @Autowired + private KubernetesClient client; + + @Autowired + private RedisService redis; + + /** + * 处理删除事件 + * @param distributeTrain 分布式训练信息 + */ + @Override + public void handlerAction(DistributeTrain distributeTrain) { + log.info("handlerAction=>distributeTrain : 【{}】", distributeTrain); + String namespace = distributeTrain.getMetadata().getNamespace(); + String parentName = distributeTrain.getMetadata().getName(); + // namespace+parentName(分布式训练名称) 确定相应的资源 + //删除job + JobList jobList = client.batch().jobs().inNamespace(namespace).withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, parentName).list(); + if(CollectionUtil.isNotEmpty(jobList.getItems())){ + for (Job item : jobList.getItems()) { + client.batch().jobs().delete(item); + } + log.info("delete job in distributeTrain 【{}】", parentName); + } + //删除statefullSete + StatefulSetList statefulSetList = client.apps().statefulSets().inNamespace(namespace).withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, parentName).list(); + if(CollectionUtil.isNotEmpty(statefulSetList.getItems())){ + for (StatefulSet item : statefulSetList.getItems()) { + client.apps().statefulSets().delete(item); + } + log.info("delete statefulSet in distributeTrain 【{}】", parentName); + } + //删除service + ServiceList svcList = client.services().inNamespace(namespace).withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, parentName).list(); + if(CollectionUtil.isNotEmpty(svcList.getItems())){ + for (Service item : svcList.getItems()) { + client.services().delete(item); + } + log.info("delete svc in distributeTrain 【{}】", parentName); + } + //删除redis里记录的分布式训练信息 + redis.del(OperatorKey.CR, distributeTrain.getMetadata().getUid()); + log.info("delete distributeTrain 【{}】 successfully", parentName); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DistributeTrainActionHandler.java b/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DistributeTrainActionHandler.java new file mode 100644 index 0000000..70931e8 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DistributeTrainActionHandler.java @@ -0,0 +1,33 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.action.handler; + +import org.onebrain.operator.crd.DistributeTrain; + +/** + * @description 分布式训练的事件处理器 + * @date 2020-09-23 + */ +public interface DistributeTrainActionHandler { + + /** + * 处理相应的事件 + * @param distributeTrain 分布式训练信息 + */ + void handlerAction(DistributeTrain distributeTrain); +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/DefaultPodExecListener.java b/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/DefaultPodExecListener.java new file mode 100644 index 0000000..c163c21 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/DefaultPodExecListener.java @@ -0,0 +1,85 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.api.pod; + +import io.fabric8.kubernetes.client.dsl.ExecListener; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; + +import java.util.concurrent.CountDownLatch; + +/** + * @description 默认命令执行监听器 + * @date 2020-09-23 + */ +@Slf4j +@Getter +public class DefaultPodExecListener implements ExecListener { + + /** + * pod名称 + */ + private String podName; + + /** + * 命名空间 + */ + private String namespace; + + /** + * 容器名称 + */ + private String containerName; + + /** + * 执行门栓 线程通信用 + */ + private CountDownLatch execLatch; + + public DefaultPodExecListener(String podName, String namespace, String containerName, CountDownLatch execLatch) { + this.podName = podName; + this.namespace = namespace; + this.containerName = containerName; + this.execLatch = execLatch; + } + + @Override + public void onOpen(Response response) { + log.debug("shell environment in pod '{}', namespace '{}' is opened", podName, namespace); + log.debug("onOpen: {}", response); + } + + @Override + public void onFailure(Throwable t, Response response) { + log.error("shell environment in pod '{}', namespace '{}' barfed", podName, namespace); + log.error("onFailure: {} {}", t.getMessage(), response); + if (execLatch != null) { + execLatch.countDown(); + } + } + + @Override + public void onClose(int code, String reason) { + log.debug("shell environment in pod '{}', namespace '{}' closed", podName, namespace); + log.debug("onClose: {} {}", code, reason); + if (execLatch != null) { + execLatch.countDown(); + } + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/PodApi.java b/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/PodApi.java new file mode 100644 index 0000000..a3a11f1 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/PodApi.java @@ -0,0 +1,177 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.api.pod; + +import cn.hutool.core.util.StrUtil; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.dsl.ExecWatch; +import lombok.extern.slf4j.Slf4j; +import org.apache.commons.io.FileUtils; +import org.onebrain.operator.context.KubeContext; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.io.File; +import java.io.IOException; +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * + * @description PodApi 操作pod 里的容器用于上传文件等操作吧 + * @date 2020-09-23 + */ +@Component +@Slf4j +public class PodApi { + + private static final Integer DEFAULT_LOG_LINES = 50; + + @Autowired + private KubeContext kubeContext; + + @Autowired + private KubernetesClient client; + /** + * 从Pod下载单个文件 + * @return File 临时文件,用后需要及时清理 + * **/ + public File copyFileFromPod(String namespace, String podName, String containerName, String filePath){ + try { + File tmpFile = File.createTempFile("copy-from-pod-", ""); + client.pods().inNamespace(namespace).withName(podName) + .inContainer(containerName) + .file(filePath) + .copy(tmpFile.toPath()); + + if(tmpFile.length() == 0){ + return null; + } + + return tmpFile; + } catch (IOException e) { + + log.error(" File copy error : 【{}】",e); + } + return null; + } + + /** + * 从Pod下载目录 + * @return File 临时文件,用后需要及时清理 + * **/ + public File copyFolderFromPod(String namespace, String podName, String containerName, String folderPath){ + final PipedInputStream stdoutInput = new PipedInputStream(); + final PipedOutputStream stdoutOutput = new PipedOutputStream(); + final PipedInputStream stderrInput = new PipedInputStream(); + final PipedOutputStream stderrOutput = new PipedOutputStream(); + final AtomicBoolean failed = new AtomicBoolean(false); + try { + stdoutInput.connect(stdoutOutput); + stderrInput.connect(stderrOutput); + + //去除路径上的/前缀 + if(folderPath.startsWith(StrUtil.SLASH)){ + folderPath = StrUtil.removePrefix(folderPath, StrUtil.SLASH); + } + + //监听器异步执行 + DefaultPodExecListener defaultPodExecListener = new DefaultPodExecListener(podName, namespace, containerName, null); + + StdPodExecListener stdPodExecListener = new StdPodExecListener(defaultPodExecListener, stdoutOutput, stderrOutput, failed); + + ExecWatch watch = client.pods().inNamespace(namespace) + .withName(podName).inContainer(containerName) + .writingOutput(stdoutOutput).writingError(stderrOutput) + .usingListener(stdPodExecListener) + .exec("tar", "cf", "-", "-C", folderPath, "."); + // execLatch.await(); + + } catch (IOException e) { + log.error("copyFolderFromPod:【{}】",e); + } + + File tmpFile = null; + + try { + tmpFile = File.createTempFile("copy-from-pod-", ".tar"); + + int length; + byte[] buffer = new byte[1024]; + while (!Thread.currentThread().isInterrupted() + && (length = stdoutInput.read(buffer)) != -1) { + + byte[] content = new byte[length]; + System.arraycopy(buffer, 0, content, 0, length); + + FileUtils.writeByteArrayToFile(tmpFile, content, true); + } + + while (!Thread.currentThread().isInterrupted() + && (length = stderrInput.read(buffer)) != -1) { + log.error(new String(buffer, 0, length)); + } + } catch (IOException e) { + if (!Thread.currentThread().isInterrupted()) { + log.error("Error while pumping stream. 【{}】", e); + } else { + log.error("Interrupted while pumping stream. 【{}】", e); + } + } + + return tmpFile; + } + + /** + * 拷贝文件到pod + * @param namespace 命名空间 + * @param podName pod名称 + * @param containerName 容器名称 + * @param file 文件 + * @param targetDir 目标路径 + */ + public void copyToPod(String namespace, String podName, String containerName, File file, String targetDir){ + client.pods().inNamespace(namespace).withName(podName) + .inContainer(containerName) + .file(targetDir) + .upload(file.toPath()); + } + + /** + * 同步执行 + * @param namespace 命名空间 + * @param podName pod名称 + * @param containerName 容器名称 + * @param cmd 命令 + */ + public void exec(String namespace, String podName, String containerName, String cmd){ + try { + final CountDownLatch execLatch = new CountDownLatch(1); + ExecWatch execWatch = client.pods().inNamespace(namespace).withName(podName).inContainer(containerName) + .redirectingOutput() + .withTTY() //不展示输出 + .usingListener(new DefaultPodExecListener(namespace, podName, containerName, execLatch)) + .exec("sh", "-c", cmd); + execLatch.await(); + } catch (InterruptedException e) { + log.error(" PodApi execute cmd error : 【{}】",e); + } + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/StdPodExecListener.java b/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/StdPodExecListener.java new file mode 100644 index 0000000..bd0aa79 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/StdPodExecListener.java @@ -0,0 +1,83 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.api.pod; + +import io.fabric8.kubernetes.client.dsl.ExecListener; +import lombok.extern.slf4j.Slf4j; +import okhttp3.Response; + +import java.io.IOException; +import java.io.PipedOutputStream; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * @description 标准pod执行监听器 + * @date 2020-09-23 + */ +@Slf4j +public class StdPodExecListener implements ExecListener { + + private ExecListener defaultExecListener; + + private PipedOutputStream stdoutOutput; + + private PipedOutputStream stderrOutput; + + private AtomicBoolean failed; + + public StdPodExecListener(ExecListener defaultExecListener, PipedOutputStream stdoutOutput, PipedOutputStream stderrOutput, AtomicBoolean failed) { + this.defaultExecListener = defaultExecListener; + this.stdoutOutput = stdoutOutput; + this.stderrOutput = stderrOutput; + this.failed = failed; + } + + @Override + public void onOpen(Response response) { + log.info("onOpen=>response : 【{}】",response); + defaultExecListener.onOpen(response); + } + + @Override + public void onFailure(Throwable t, Response response) { + log.info("onFailure=> t :【{}】,response : 【{}】",t,response); + try { + failed.set(true); + stdoutOutput.close(); + stderrOutput.close(); + } catch (IOException e) { + log.error("Failed to close stdout and stderr pipes. 【{}】", e); + } finally { + defaultExecListener.onFailure(t, response); + } + } + + @Override + public void onClose(int code, String reason) { + log.info("onClose=>code : 【{}】,reason : 【{}】",code,reason); + try { + stdoutOutput.close(); + stderrOutput.close(); + } catch (IOException e) { + log.error("Failed to close stdout and stderr pipes. 【{}】", e); + } finally { + defaultExecListener.onClose(code, reason); + } + } + +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/config/KubeConfig.java b/distribute-train-operator/src/main/java/org/onebrain/operator/config/KubeConfig.java new file mode 100644 index 0000000..f586440 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/config/KubeConfig.java @@ -0,0 +1,66 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.config; + +import cn.hutool.core.util.StrUtil; +import io.fabric8.kubernetes.client.KubernetesClient; +import org.onebrain.operator.context.KubeContext; +import org.onebrain.operator.properties.KubeProperties; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +/** + * @description k8s配置类 + * @date 2020-09-23 + */ +@Configuration +@EnableConfigurationProperties(KubeProperties.class) +public class KubeConfig { + + @Autowired + private KubeProperties kubeProperties; + + /** + * 注册k8s配置 + * @return + */ + @Bean + public KubeContext kubeContext() { + if (kubeProperties == null) { + return null; + } + + final String configSource = kubeProperties.getKubeconfig(); + if(StrUtil.isEmpty(configSource)){ + return null; + } + return new KubeContext(kubeProperties); + } + + /** + * 注册k8s客户端 + * @param kubeContext k8s配置 + * @return + */ + @Bean + public KubernetesClient kubernetesClient(KubeContext kubeContext){ + return kubeContext.getClient(); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/constants/CrdConstants.java b/distribute-train-operator/src/main/java/org/onebrain/operator/constants/CrdConstants.java new file mode 100644 index 0000000..e945ed0 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/constants/CrdConstants.java @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.constants; + +/** + * @description crd 常量信息 + * @date 2020-09-23 + */ +public class CrdConstants { + public static final String CRD_GROUP = "onebrain.oneflow.org"; + public static final String CRD_SINGULAR_NAME = "distributetrain"; + public static final String CRD_PLURAL_NAME = "distributetrains"; + public static final String CRD_NAME = CRD_PLURAL_NAME + "." + CRD_GROUP; + public static final String CRD_KIND = "DistributeTrain"; + public static final String CRD_SCOPE = "Namespaced"; + public static final String CRD_SHORT_NAME = "dt"; + public static final String CRD_VERSION = "v1alpha1"; + public static final String CRD_API_VERSION = "apiextensions.k8s.io/v1beta1"; +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/constants/KubeConstants.java b/distribute-train-operator/src/main/java/org/onebrain/operator/constants/KubeConstants.java new file mode 100644 index 0000000..f2de52d --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/constants/KubeConstants.java @@ -0,0 +1,40 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.constants; + +/** + * @description k8s常量 + * @date 2020-09-23 + */ +public class KubeConstants { + + public static final String DISTRIBUTE_TRAIN_LABEL = "dt-name"; + public static final String STATEFULSET_LABEL = "dt-ss-name"; + public static final String JOB_LABEL = "dt-job-name"; + public static final String MASTER_CONTAINER_NAME = "distribute-train-master"; + public static final String SLAVE_CONTAINER_NAME = "distribute-train-slave"; + public final static String USER_DIR_SYSTEM_PROPERTY = "user.dir"; + //不许重试 + public static final Integer BACKOFFLIMIT = 0; + + public static final String CHARSET = "utf-8"; + + public static final String ENV_NODE_NUM = "NODE_NUM"; + + public static final String VOLUME_SHM = "dshm"; +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/constants/NumberConstant.java b/distribute-train-operator/src/main/java/org/onebrain/operator/constants/NumberConstant.java new file mode 100644 index 0000000..da5821f --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/constants/NumberConstant.java @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.constants; + +/** + * @Description 数字常量 + * @Date 2020-6-9 + */ +public class NumberConstant { + + public final static int NUMBER_0 = 0; + public final static long LONG_NUMBER_0 = 0L; + public final static int NUMBER_1 = 1; + public final static int NUMBER_2 = 2; + public final static int NUMBER_3 = 3; + public final static int NUMBER_5 = 5; + public final static int NUMBER_10 = 10; + public final static int NUMBER_22 = 22; + public final static int NUMBER_30 = 30; + public final static int NUMBER_50 = 50; + public final static int NUMBER_60 = 60; + public final static long LONG_NUMBER_60 = 60L; + public final static int HOUR_SECOND = 60 * 60; + public final static int DAY_SECOND = 60 * 60 * 24; + public final static int WEEK_SECOND = 60 * 60 * 24 * 7; + public final static int MAX_PAGE_SIZE = 2000; + public final static int NUMBER_30000 = 30000; +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/context/KubeContext.java b/distribute-train-operator/src/main/java/org/onebrain/operator/context/KubeContext.java new file mode 100644 index 0000000..37f8393 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/context/KubeContext.java @@ -0,0 +1,117 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.context; + +import cn.hutool.core.util.StrUtil; +import com.fasterxml.jackson.core.JsonProcessingException; +import io.fabric8.kubernetes.api.model.HasMetadata; +import io.fabric8.kubernetes.client.Config; +import io.fabric8.kubernetes.client.DefaultKubernetesClient; +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.VersionInfo; +import io.fabric8.kubernetes.client.internal.SerializationUtils; +import io.fabric8.kubernetes.client.utils.Utils; +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.properties.KubeProperties; +import org.springframework.beans.BeansException; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; + + +/** + * @description k8s上下文 + * @date 2020-09-23 + */ +@Slf4j +@Getter +public class KubeContext implements ApplicationContextAware { + + private static final String AUTO = "auto"; + + private ApplicationContext applicationContext; + + private KubernetesClient client; + + private Config config; + + + public KubeContext(KubeProperties kubeProperties) { + String configSource = kubeProperties.getKubeconfig(); + try { + if(AUTO.equals(configSource)){ + //在集群内部可自动侦测 + log.info("kubernetes client is in cluster mode"); + client = new DefaultKubernetesClient(); + config = client.getConfiguration(); + }else{ + if(configSource.startsWith(StrUtil.SLASH)){ + log.info("read kubeconfig from file system:{}", configSource); + System.setProperty(Config.KUBERNETES_KUBECONFIG_FILE, configSource); + }else{ + log.info("read kubeconfig from classpath:{}", configSource); + final String testKubeconfigFile = Utils.filePath(getClass().getResource(StrUtil.SLASH + configSource)); + //修改环境变量,重新指定kubeconfig读取位置 + System.setProperty(Config.KUBERNETES_KUBECONFIG_FILE, testKubeconfigFile); + } + client = new DefaultKubernetesClient(); + config = client.getConfiguration(); + } + + //打印集群信息 + log.info("ApiVersion : {}", client.getApiVersion()); + log.info("MasterUrl : {}", client.getMasterUrl()); + if(log.isDebugEnabled()){ + VersionInfo versionInfo = client.getVersion(); + log.debug("Version details of this Kubernetes cluster :-"); + log.debug("Major : {}", versionInfo.getMajor()); + log.debug("Minor : {}", versionInfo.getMinor()); + log.debug("GitVersion : {}", versionInfo.getGitVersion()); + log.debug("GitCommit : {}", versionInfo.getGitCommit()); + log.debug("BuildDate : {}", versionInfo.getBuildDate()); + log.debug("GitTreeState : {}", versionInfo.getGitTreeState()); + log.debug("Platform : {}", versionInfo.getPlatform()); + log.debug("GoVersion : {}", versionInfo.getGoVersion()); + } + }catch (Exception e){ + client = null; + log.error("初始化 K8sUtils 失败!", e); + e.printStackTrace(); + } + } + + /** + * 导出成yaml字符串 + * @param resource k8s元数据 + * @return + */ + public String convertToYaml(HasMetadata resource) { + try { + return SerializationUtils.dumpAsYaml(resource); + } catch (JsonProcessingException e) { + e.printStackTrace(); + throw new RuntimeException("can not transform resource to yaml"); + } + } + + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + this.applicationContext = applicationContext; + } + +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/controller/DistributeTrainController.java b/distribute-train-operator/src/main/java/org/onebrain/operator/controller/DistributeTrainController.java new file mode 100644 index 0000000..e34e4bc --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/controller/DistributeTrainController.java @@ -0,0 +1,131 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.controller; + +import io.fabric8.kubernetes.client.KubernetesClient; +import io.fabric8.kubernetes.client.dsl.MixedOperation; +import io.fabric8.kubernetes.client.dsl.Resource; +import io.fabric8.kubernetes.client.informers.ResourceEventHandler; +import io.fabric8.kubernetes.client.informers.SharedIndexInformer; +import io.fabric8.kubernetes.client.informers.cache.Lister; +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.action.handler.DistributeTrainActionHandler; +import org.onebrain.operator.crd.DistributeTrain; +import org.onebrain.operator.crd.DistributeTrainList; +import org.onebrain.operator.crd.DoneableDistributeTrain; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.scheduling.annotation.Async; + +import java.util.concurrent.TimeUnit; + +/** + * @description 分布式训练控制器 + * @date 2020-06-16 + */ +@Slf4j +public class DistributeTrainController { + + @Autowired + private KubernetesClient client; + + /** + * 分布式训练informer + */ + private SharedIndexInformer distributeTrainSharedIndexInformer; + + /** + * 分布式训练k8s访问客户端 + */ + private MixedOperation> distributeTrainClient; + + /** + * 分布式训练lister + */ + private Lister distributeTrainLister; + + @Autowired + @Qualifier("addActionHandler") + private DistributeTrainActionHandler addActionHandler; + + @Autowired + @Qualifier("deleteActionHandler") + private DistributeTrainActionHandler deleteActionHandler; + + public DistributeTrainController(MixedOperation> distributeTrainClient, SharedIndexInformer distributeTrainSharedIndexInformer, String namespace) { + this.distributeTrainSharedIndexInformer = distributeTrainSharedIndexInformer; + this.distributeTrainClient = distributeTrainClient; + this.distributeTrainLister = new Lister<>(distributeTrainSharedIndexInformer.getIndexer()); + } + + /** + * 添加事件监听器 + */ + public void create() { + distributeTrainSharedIndexInformer.addEventHandler(new ResourceEventHandler() { + /** + * 处理添加事件 + * @param distributeTrain 分布式训练信息 + */ + @Override + public void onAdd(DistributeTrain distributeTrain) { + log.info("add distributeTrain named 【{}】 in namespace 【{}】", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace()); + addActionHandler.handlerAction(distributeTrain); + } + + /** + * 处理更内心事件 + * @param distributeTrain 旧的 分布式训练信息 + * @param newDistributeTrain 新的 分布式训练信息 + */ + @Override + public void onUpdate(DistributeTrain distributeTrain, DistributeTrain newDistributeTrain) { + log.info("update distributeTrain named 【{}】 in namespace 【{}】", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace()); + } + + /** + * 处理删除事件 + * @param distributeTrain 分布式训练信息 + * @param b 是否为未知事件 + */ + @Override + public void onDelete(DistributeTrain distributeTrain, boolean b) { + log.info("delete distributeTrain named 【{}】 in namespace 【{}】", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace()); + deleteActionHandler.handlerAction(distributeTrain); + } + }); + } + + /** + * 运行 + */ + @Async + public void run() { + log.info("Starting DistributeTrain controller"); + try { + //分布式训练信息同步 + while (!distributeTrainSharedIndexInformer.hasSynced()){ + TimeUnit.SECONDS.sleep(1); + } + } catch (InterruptedException e) { + e.printStackTrace(); + log.error("run error:【{}】",e); + } + log.info("DistributeTrain controller is Running"); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrain.java b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrain.java new file mode 100644 index 0000000..6128263 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrain.java @@ -0,0 +1,47 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.crd; + +import io.fabric8.kubernetes.api.model.ObjectMeta; +import io.fabric8.kubernetes.client.CustomResource; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * @description 分布式训练 + * @date 2020-09-24 + */ +@Data +@NoArgsConstructor +public class DistributeTrain extends CustomResource { + + /** + * 分布式训练详细规格 + */ + private DistributeTrainSpec spec; + + /** + * 分布式训练状态 + */ + private DistributeTrainStatus status; + + public DistributeTrain(ObjectMeta objectMeta, DistributeTrainSpec spec) { + this.setMetadata(objectMeta); + this.spec = spec; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainList.java b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainList.java new file mode 100644 index 0000000..3550d53 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainList.java @@ -0,0 +1,27 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.crd; + +import io.fabric8.kubernetes.client.CustomResourceList; + +/** + * @description CRD资源列表(分布式训练) + * @date 2020-09-24 + */ +public class DistributeTrainList extends CustomResourceList { +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainSpec.java b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainSpec.java new file mode 100644 index 0000000..fbeaa48 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainSpec.java @@ -0,0 +1,108 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.crd; + +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import io.fabric8.kubernetes.api.model.*; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +import java.util.List; +import java.util.Map; + +/** + * @description 分布式训练详细规格 + * @date 2020-09-23 + */ +@JsonDeserialize( + using = JsonDeserializer.None.class +) +@Data +@NoArgsConstructor +@AllArgsConstructor +@Builder +public class DistributeTrainSpec implements KubernetesResource { + + /** + * 镜像 + */ + private String image; + + /** + * 镜像拉取策略 + */ + private String imagePullPolicy; + /** + * 机器数 + */ + private Integer size; + + /** + * 环境参数 + */ + private List env; + + /** + * master 命令 + */ + private String masterCmd; + + /** + * slave命令 + */ + private String slaveCmd; + + /** + * master 资源节点限制 + */ + private ResourceRequirements masterResources; + + /** + * slave 资源节点限制 + */ + private ResourceRequirements slaveResources; + + /** + * 节点调度选择器 + */ + private Map nodeSelector; + + /** + * 初始化容器 + */ + private Container initContainer; + + /** + * 工作目录挂载 + */ + private Volume workspaceStorage; + + /** + * 数据集目录挂载 + */ + private Volume datasetStorage; + + /** + * 模型目录挂载 + */ + private Volume modelStorage; + +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainStatus.java b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainStatus.java new file mode 100644 index 0000000..2e8c68e --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainStatus.java @@ -0,0 +1,55 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.crd; + +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import io.fabric8.kubernetes.api.model.KubernetesResource; +import lombok.Data; + +/** + * @description 分布式训练状态 + * @date 2020-09-23 + */ +@JsonDeserialize( + using = JsonDeserializer.None.class +) +@Data +public class DistributeTrainStatus implements KubernetesResource { + + /** + * 副本数 + */ + private Integer replicas; + + /** + * 处在ready状态的副本数 + */ + private Integer readyReplicas; + + /** + * 成功数 + */ + private Integer success; + + /** + * 失败数 + */ + private Integer failed; + +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DoneableDistributeTrain.java b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DoneableDistributeTrain.java new file mode 100644 index 0000000..677bac6 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/crd/DoneableDistributeTrain.java @@ -0,0 +1,31 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.crd; + +import io.fabric8.kubernetes.api.builder.Function; +import io.fabric8.kubernetes.client.CustomResourceDoneable; + +/** + * @description CRD资源的修改Builder + * @date 2020-09-24 + */ +public class DoneableDistributeTrain extends CustomResourceDoneable { + public DoneableDistributeTrain(DistributeTrain resource, Function function) { + super(resource, function); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/enums/AccessModeEnum.java b/distribute-train-operator/src/main/java/org/onebrain/operator/enums/AccessModeEnum.java new file mode 100644 index 0000000..6ba7693 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/enums/AccessModeEnum.java @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.enums; + +/** + * @description pvc的访问模式 + * @date 2020-09-24 + */ +public enum AccessModeEnum { + + /** + * RWO是最基本的方式,可读可写,但只支持被单个Pod挂载 + */ + RWO("ReadWriteOnce"), + + /** + * 可以以只读的方式被多个Pod挂载 + */ + ROX("ReadOnlyMany"), + + /****/ + /** + * 这种存储可以以读写的方式被多个Pod共享。 + * 不是每一种存储都支持这三种方式,像共享方式,目前支持的还比较少,比较常用的是NFS。 + * 在PVC绑定PV时通常根据两个条件来绑定,一个是存储的大小,另一个就是访问模式。 + */ + RWX("ReadWriteMany"); + + /** + * 模式 + */ + private final String mode; + + AccessModeEnum(String mode) { + this.mode = mode; + } + + public String getMode() { + return mode; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/exception/OperatorException.java b/distribute-train-operator/src/main/java/org/onebrain/operator/exception/OperatorException.java new file mode 100644 index 0000000..1e70905 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/exception/OperatorException.java @@ -0,0 +1,49 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.exception; + +import lombok.Getter; +import lombok.extern.slf4j.Slf4j; + +/** + * @description Operator自定义异常 + * @date 2020-09-24 + */ +@Slf4j +@Getter +public class OperatorException extends RuntimeException{ + + /** + * 信息 + */ + private String msg; + + /** + * 原因 + */ + private Throwable cause; + + public OperatorException(String msg, Throwable cause) { + this.msg = msg; + this.cause = cause; + } + + public OperatorException(String msg) { + this.msg = msg; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/properties/KubeProperties.java b/distribute-train-operator/src/main/java/org/onebrain/operator/properties/KubeProperties.java new file mode 100644 index 0000000..d5de775 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/properties/KubeProperties.java @@ -0,0 +1,34 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.properties; + +import lombok.Data; +import org.springframework.boot.context.properties.ConfigurationProperties; +import org.springframework.stereotype.Component; + +/** + * @description 属性配置 + * @date 2020-09-24 + */ +@Data +@ConfigurationProperties("k8s") +@Component +public class KubeProperties { + + private String kubeconfig; +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/redis/AbstractKeyPrefix.java b/distribute-train-operator/src/main/java/org/onebrain/operator/redis/AbstractKeyPrefix.java new file mode 100644 index 0000000..914bd4f --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/redis/AbstractKeyPrefix.java @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.redis; + +/** + * @description redis Key 前缀 + * @date 2020-09-23 + */ +public abstract class AbstractKeyPrefix { + + /** + * key模板 + */ + private static final String KEY_TEMPLATE = "Operator:%s"; + + /** + * 过期时间 + */ + private int expireSeconds; + + /** + * 前缀 + */ + private String prefix; + + public AbstractKeyPrefix(String prefix) {//0代表永不过期 + this(prefix,0); + } + + public AbstractKeyPrefix(String prefix, int expireSeconds) { + this.expireSeconds = expireSeconds; + this.prefix = prefix; + } + + /** + * 获取过期时间 + * @return + */ + public int getExpireSeconds() {//默认0代表永不过期 + return expireSeconds; + } + + /** + * 获取前缀 + * @return + */ + public String getPrefix() { + return String.format(KEY_TEMPLATE, prefix); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/redis/RedisService.java b/distribute-train-operator/src/main/java/org/onebrain/operator/redis/RedisService.java new file mode 100644 index 0000000..a4d767c --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/redis/RedisService.java @@ -0,0 +1,290 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.redis; + +import org.onebrain.operator.utils.FastjsonUtils; +import org.onebrain.operator.utils.RedisUtils; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import java.util.Set; + +/** + * @description redis服务 + * @date 2020-09-03 + */ +@Service +public class RedisService { + + @Autowired + private RedisUtils redisUtils; + + /** + * 真正key模板 + */ + private static final String REAL_KEY_TEMPLATE = "%s:%s"; + + /** + * 获取真正的key + * @param prefix 前缀 + * @param key key值 + * @return 放入redis里的key值 + */ + private String getRealKey(AbstractKeyPrefix prefix, String key){ + return String.format(REAL_KEY_TEMPLATE, prefix.getPrefix(), key); + } + + /** + * 实现命令:TTL key,以秒为单位,返回给定 key的剩余生存时间(TTL, time to live)。 + * @param prefix 前缀 + * @param key key值 + * @return 返回过期时间秒数 + */ + public long ttl(AbstractKeyPrefix prefix, String key) { + return redisUtils.ttl(getRealKey(prefix, key)); + } + + /** + * 实现命令:expire 设置过期时间,单位秒 + * @param prefix 前缀 + * @param key key值 + * @param timeout 期望过期时间 + */ + public void expire(AbstractKeyPrefix prefix, String key, long timeout) { + redisUtils.expire(getRealKey(prefix, key), timeout); + } + + /** + * 实现命令:INCR key,增加key一次 + * @param prefix 前缀 + * @param key key值 + * @param delta 增量 + * @return 计数值 + */ + public long incr(AbstractKeyPrefix prefix, String key, long delta) { + return redisUtils.incr(getRealKey(prefix, key), delta); + } + + /** + * 实现命令: key,减少key一次 + * @param prefix 前缀 + * @param key key值 + * @param delta 增量 + * @return 计数值 + */ + public long decr(AbstractKeyPrefix prefix, String key, long delta) { + String realKey = getRealKey(prefix, key); + if(delta < 0){ + //throw new RuntimeException("递减因子必须大于0"); + del(realKey); + return 0; + } + return redisUtils.decr(realKey, delta); + } + + /** + * 实现命令:KEYS pattern,查找所有符合给定模式 pattern的 key + * @param prefix key前缀 + * @return key集合 + */ + public Set keys(AbstractKeyPrefix prefix) { + String pattern = prefix.getPrefix(); + return redisUtils.keys(pattern + ":*"); + } + + /** + * 实现命令:KEYS pattern,查找所有符合给定模式 pattern的 key + * @param prefix key前缀 + * @param key key值 + * @return key集合 + */ + public Set keys(AbstractKeyPrefix prefix, String key) { + String pattern = prefix.getPrefix(); + return redisUtils.keys(pattern + ":" + key + ":*"); + } + + /** + * 实现命令:DEL key,删除一个key + * @param prefix key前缀 + * @param key key值 + */ + public void del(AbstractKeyPrefix prefix, String key) { + redisUtils.del(getRealKey(prefix, key)); + } + + /** + * 删除一个key + * @param realKey 真正的key + */ + public void del(String realKey) { + redisUtils.del(realKey); + } + + /** + * 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + */ + public void set(AbstractKeyPrefix prefix, String key, String value) { + if(prefix.getExpireSeconds() <= 0){ + redisUtils.set(getRealKey(prefix, key), value); + }else{ + redisUtils.set(getRealKey(prefix, key), value, prefix.getExpireSeconds()); + } + } + + /** + * 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + * @param 指定类型 + */ + public void set(AbstractKeyPrefix prefix, String key, T value) { + if(prefix.getExpireSeconds() <= 0){ + redisUtils.set(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value)); + }else{ + redisUtils.set(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), prefix.getExpireSeconds()); + } + } + + + /** + * 实现命令:SET key value EX seconds,设置key-value和超时时间(秒) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + */ + public void set(AbstractKeyPrefix prefix, String key, String value, long timeout) { + redisUtils.set(getRealKey(prefix, key), value, timeout); + } + + /** + * 实现命令:SET key value EX seconds,设置key-value和超时时间(秒) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + * @param 指定类型 + */ + public void set(AbstractKeyPrefix prefix, String key, T value, long timeout) { + redisUtils.set(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), timeout); + } + + /** + * 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + * @return 是否设值成功 + */ + public Boolean setnx(AbstractKeyPrefix prefix, String key, String value){ + if(prefix.getExpireSeconds() <= 0){ + return redisUtils.setnx(getRealKey(prefix, key), value); + }else{ + return redisUtils.setnx(getRealKey(prefix, key), value, prefix.getExpireSeconds()); + } + } + + /** + * 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + * @param 指定类型 + * @return 是否设值成功 + */ + public Boolean setnx(AbstractKeyPrefix prefix, String key, T value){ + if(prefix.getExpireSeconds() <= 0){ + return redisUtils.setnx(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value)); + }else{ + return redisUtils.setnx(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), prefix.getExpireSeconds()); + } + } + + /** + * 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + * @return 是否设值成功 + */ + public Boolean setnx(AbstractKeyPrefix prefix, String key, String value, long timeout) { + return redisUtils.setnx(getRealKey(prefix, key), value, timeout); + } + + /** + * 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒) + * @param prefix key前缀 + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + * @param 指定类型 + * @return 是否设值成功 + */ + public Boolean setnx(AbstractKeyPrefix prefix, String key, T value, long timeout) { + return redisUtils.setnx(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), timeout); + } + + /** + * 实现命令:GET key,返回 key所关联的字符串值。 + * @param prefix key前缀 + * @param key key值 + * @return 值 + */ + public String get(AbstractKeyPrefix prefix, String key) { + return redisUtils.get(getRealKey(prefix, key)); + } + + /** + * 实现命令:GET key,返回 key所关联的字符串值。 + * @param prefix key前缀 + * @param key key值 + * @param 指定类型 + * @return 值 + */ + public T get(AbstractKeyPrefix prefix, String key, Class clazz) { + return redisUtils.get(getRealKey(prefix, key), clazz); + } + + /** + * 根据key获取值 + * @param lastKey 真正的key + * @param clazz 类型 + * @param 泛型 + * @return + */ + public T get(String lastKey, Class clazz) { + return redisUtils.get(lastKey, clazz); + } + + /** + * 实现命令:GET key,返回 key所关联的字符串值。 + * @param prefix key前缀 + * @param key key值 + * @return 是否存在 + */ + public Boolean exists(AbstractKeyPrefix prefix, String key) { + return redisUtils.exists(getRealKey(prefix, key)); + } + + +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/redis/key/OperatorKey.java b/distribute-train-operator/src/main/java/org/onebrain/operator/redis/key/OperatorKey.java new file mode 100644 index 0000000..8cdba91 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/redis/key/OperatorKey.java @@ -0,0 +1,45 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.redis.key; + +import org.onebrain.operator.redis.AbstractKeyPrefix; + +/** + * @description 由operator产生的cr的唯一标识 + * @date 2020-09-23 + */ +public class OperatorKey extends AbstractKeyPrefix { + + public OperatorKey(String prefix) { + super(prefix); + } + + public OperatorKey(String prefix, int expireSeconds) { + super(prefix, expireSeconds); + } + + /** + * 分布式训练 Key + */ + public static final OperatorKey CR = new OperatorKey("DistributeTrain"); + + /** + * 分布式训练Job Key + */ + public static final OperatorKey CR_JOB = new OperatorKey("DistributeTrain:Job"); +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/utils/DistributeTrainClientHolder.java b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/DistributeTrainClientHolder.java new file mode 100644 index 0000000..18e85c2 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/DistributeTrainClientHolder.java @@ -0,0 +1,41 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.utils; + +import io.fabric8.kubernetes.client.dsl.MixedOperation; +import io.fabric8.kubernetes.client.dsl.Resource; +import org.onebrain.operator.crd.DistributeTrain; +import org.onebrain.operator.crd.DistributeTrainList; +import org.onebrain.operator.crd.DoneableDistributeTrain; + +/** + * @description 分布式训练客户端持有器 + * @date 2020-09-23 + */ +public class DistributeTrainClientHolder { + + private static MixedOperation> distributeTrainClient; + + public static MixedOperation> getClient(){ + return distributeTrainClient; + } + + public static void setDistributeTrainClient(MixedOperation> client){ + distributeTrainClient = client; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/utils/FastjsonUtils.java b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/FastjsonUtils.java new file mode 100644 index 0000000..e4af46a --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/FastjsonUtils.java @@ -0,0 +1,188 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.utils; + + +import com.alibaba.fastjson.JSON; +import com.alibaba.fastjson.JSONObject; +import com.alibaba.fastjson.serializer.SerializerFeature; + +import java.util.List; +import java.util.Map; + +/** + * @description json工具类 + * @date 2020-09-24 + */ +public class FastjsonUtils { + + private static final SerializerFeature[] FEATURES = { + // 输出空置字段 + SerializerFeature.WriteMapNullValue, + //日期类型用日期字符串 yyyy-MM-dd HH:mm:ss + SerializerFeature.WriteDateUseDateFormat, + // list字段如果为null,输出为[],而不是null + SerializerFeature.WriteNullListAsEmpty, + // 数值字段如果为null,输出为0,而不是null + SerializerFeature.WriteNullNumberAsZero, + // Boolean字段如果为null,输出为false,而不是null + SerializerFeature.WriteNullBooleanAsFalse, + // 字符类型字段如果为null,输出为"",而不是null + SerializerFeature.WriteNullStringAsEmpty + }; + + /** + * 将对象转为json + * @param object + * @return json的String + */ + public static String convertObjectToJSON(Object object) { + return JSON.toJSONString(object, FEATURES); + } + + /** + * 将对象转为json(无循环引用) + * @param object + * @return json的String + */ + public static String toJSONNoFeatures(Object object) { + return JSON.toJSONString(object, SerializerFeature.DisableCircularReferenceDetect); + } + + /** + * 将json转为对象 + * @param text + * @return 对象 + */ + public static Object toBean(String text) { + return JSON.parse(text); + } + + /** + * 将json转为对象 + * @param text 文本字符串 + * @param clazz 类型 + * @param 泛型 + * @return 泛型对象 + */ + public static T toBean(String text, Class clazz) { + return JSON.parseObject(text, clazz); + } + + /** + * 转换为数组 + * @param text 文本字符串 + * @return 泛型对象 + */ + public static Object[] toArray(String text) { + return toArray(text, null); + } + + /** + * 转换为数组 + * @param text 文本字符串 + * @param clazz 类型 + * @return + */ + public static Object[] toArray(String text, Class clazz) { + return JSON.parseArray(text, clazz).toArray(); + } + + /** + * 转换为List + * @param text 文本字符串 + * @param clazz 类型 + * @return + */ + public static List toList(String text, Class clazz) { + return JSON.parseArray(text, clazz); + } + + /** + * 将string转化为序列化的json字符串 + * @param text 文本字符串 + * @return json对象 + */ + public static Object textToJson(String text) { + Object objectJson = JSON.parse(text); + return objectJson; + } + + /** + * json字符串转化为map + * @param text json字符串 + * @return Map集合 + */ + public static Map stringToCollect(String text) { + Map m = (Map) JSONObject.parseObject(text); + return m; + } + + /** + * 转换JSON字符串为对象 + * @param jsonData json字符串 + * @param clazz 转换目标对象的类型 + * @return json对象 + */ + public static Object convertJsonToObject(String jsonData, Class clazz) { + return JSONObject.parseObject(jsonData, clazz); + } + + /** + * 将map转化为string + * @param m Map集合 + * @return 字符串 + */ + public static String collectToString(Map m) { + String s = JSONObject.toJSONString(m); + return s; + } + + /** + * json字符串转化为map + * + * @param text 字符串 + * @return Map 对象 + */ + public static Map stringToMap(String text) { + Map m = JSONObject.parseObject(text); + return m; + } + + /** + * 将map转化为string + * + * @param m Map集合 + * @return 字符串 + */ + public static String mapToString(Map m) { + String s = JSONObject.toJSONString(m); + return s; + } + + /** + * 把对象转换为指定对象 + * @param source 原对象 + * @param target 目标class + * @param 泛型 + * @return 泛型对象 + */ + public static T toObjectFromSource(Object source,Class target) { + return toBean(convertObjectToJSON(source), target); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/utils/IOUtils.java b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/IOUtils.java new file mode 100644 index 0000000..a2a6f5d --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/IOUtils.java @@ -0,0 +1,56 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.utils; + +import lombok.extern.slf4j.Slf4j; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; + + +/** + * @description IO工具类 + * @date 2020-09-24 + */ +@Slf4j +public class IOUtils { + + /** + * 将input流转换为文件 + * + * @param is 输入流 + * @param targetFile 目标文件 + */ + public static void copy(InputStream is, File targetFile) { + try (FileOutputStream fos = new FileOutputStream(targetFile)) { + byte[] b = new byte[1024]; + int readCount = is.read(b); + while (readCount != -1) { + // 写入数据 + fos.write(b, 0, readCount); + readCount = is.read(b); + } + is.close(); + fos.flush(); + } catch (IOException e) { + log.error("copy file error:【{}】", e); + } + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/utils/RedisUtils.java b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/RedisUtils.java new file mode 100644 index 0000000..d556e77 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/RedisUtils.java @@ -0,0 +1,289 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.utils; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.redis.core.StringRedisTemplate; +import org.springframework.stereotype.Component; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +/** + * @description 封装redis简单的key-value操作 + * @date 2020-09-23 + */ +@Component +public class RedisUtils { + + @Autowired + private StringRedisTemplate redisTemplate; + + /** + * 实现命令:TTL key,以秒为单位,返回给定 key的剩余生存时间(TTL, time to live)。 + * @param key key值 + * @return 返回过期时间秒数 + */ + public long ttl(String key) { + return redisTemplate.getExpire(key); + } + + /** + * 实现命令:expire 设置过期时间,单位秒 + * @param key key值 + * @param timeout 期望过期时间 + */ + public void expire(String key, long timeout) { + redisTemplate.expire(key, timeout, TimeUnit.SECONDS); + } + + /** + * 实现命令:INCR key,增加key一次 + * @param key key值 + * @param delta 增量 + * @return 计数值 + */ + public long incr(String key, long delta) { + return redisTemplate.opsForValue().increment(key, delta); + } + + /** + * 实现命令: key,减少key一次 + * @param key key值 + * @param delta 增量 + * @return 计数值 + */ + public long decr(String key, long delta) { + if(delta < 0){ + //throw new RuntimeException("递减因子必须大于0"); + del(key); + return 0; + } + return redisTemplate.opsForValue().increment(key, -delta); + } + + /** + * 实现命令:KEYS pattern,查找所有符合给定模式 pattern的 key + * @return key集合 + */ + public Set keys(String pattern) { + return redisTemplate.keys(pattern); + } + + /** + * 实现命令:DEL key,删除一个key + * @param key key值 + */ + public void del(String key) { + redisTemplate.delete(key); + } + + /** + * 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key) + * @param key key值 + * @param value 值 + */ + public void set(String key, String value) { + redisTemplate.opsForValue().set(key, value); + } + + /** + * 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key) + * @param key key值 + * @param value 值 + * @param 指定类型 + */ + public void set(String key, T value) { + redisTemplate.opsForValue().set(key, FastjsonUtils.convertObjectToJSON(value)); + } + + /** + * 实现命令:SET key value EX seconds,设置key-value和超时时间(秒) + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + */ + public void set(String key, String value, long timeout) { + redisTemplate.opsForValue().set(key, value, timeout, TimeUnit.SECONDS); + } + + /** + * 实现命令:SET key value EX seconds,设置key-value和超时时间(秒) + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + * @param 指定类型 + */ + public void set(String key, T value, long timeout) { + redisTemplate.opsForValue().set(key, FastjsonUtils.convertObjectToJSON(value), timeout, TimeUnit.SECONDS); + } + + /** + * 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key) + * @param key key值 + * @param value 值 + * @return 是否设值成功 + */ + public Boolean setnx(String key, String value){ + return redisTemplate.opsForValue().setIfAbsent(key, value); + } + + /** + * 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key) + * @param key key值 + * @param value 值 + * @param 指定类型 + * @return 是否设值成功 + */ + public Boolean setnx(String key, T value){ + return redisTemplate.opsForValue().setIfAbsent(key, FastjsonUtils.convertObjectToJSON(value)); + } + + /** + * 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒) + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + * @return 是否设值成功 + */ + public Boolean setnx(String key, String value, long timeout) { + return redisTemplate.opsForValue().setIfAbsent(key, value, timeout, TimeUnit.SECONDS); + } + + /** + * 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒) + * @param key key值 + * @param value 值 + * @param timeout 过期时间 + * @param 指定类型 + * @return 是否设值成功 + */ + public Boolean setnx(String key, T value, long timeout) { + return redisTemplate.opsForValue().setIfAbsent(key, FastjsonUtils.convertObjectToJSON(value), timeout, TimeUnit.SECONDS); + } + + /** + * 实现命令:GET key,返回 key所关联的字符串值。 + * @param key key值 + * @return 值 + */ + public String get(String key) { + return (String) redisTemplate.opsForValue().get(key); + } + + /** + * + * 根据key获取值 + * @param key 真正的key + * @param clazz 类型 + * @param 泛型 + * @return + */ + public T get(String key, Class clazz) { + String value = (String) redisTemplate.opsForValue().get(key); + return (T) FastjsonUtils.convertJsonToObject(value, clazz); + } + + /** + * 实现命令:GET key,返回 key所关联的字符串值。 + * @param key key值 + * @return 是否存在 + */ + public Boolean exists(String key) { + return redisTemplate.hasKey(key); + } + + /****----------------------------------Hash----------------------------------------****/ + + /** + * 实现命令:HSET key field value,将哈希表 key中的域 field的值设为 value + * + * @param key key + * @param field 域 + * @param value 值 + */ + public void hset(String key, String field, Object value) { + redisTemplate.opsForHash().put(key, field, value); + } + + /** + * 实现命令:HGET key field,返回哈希表 key中给定域 field的值 + * + * @param key key + * @param field 域 + * @return + */ + public String hget(String key, String field) { + return (String) redisTemplate.opsForHash().get(key, field); + } + + /** + * 实现命令:HDEL key field [field ...],删除哈希表 key 中的一个或多个指定域,不存在的域将被忽略。 + * + * @param key key + * @param fields 域 + */ + public void hdel(String key, Object... fields) { + redisTemplate.opsForHash().delete(key, fields); + } + + /** + * 实现命令:HGETALL key,返回哈希表 key中,所有的域和值。 + * + * @param key + * @return 域和值 + */ + public Map hgetall(String key) { + return redisTemplate.opsForHash().entries(key); + } + + /****----------------------------------List----------------------------------------****/ + + /** + * 实现命令:LPUSH key value,将一个值 value插入到列表 key的表头 + * + * @param key + * @param value + * @return 执行 LPUSH命令后,列表的长度。 + */ + public long lpush(String key, String value) { + return redisTemplate.opsForList().leftPush(key, value); + } + + /** + * 实现命令:LPOP key,移除并返回列表 key的头元素。 + * + * @param key + * @return 列表key的头元素。 + */ + public String lpop(String key) { + return (String) redisTemplate.opsForList().leftPop(key); + } + + /** + * 实现命令:RPUSH key value,将一个值 value插入到列表 key的表尾(最右边)。 + * + * @param key + * @param value + * @return 执行 LPUSH命令后,列表的长度。 + */ + public long rpush(String key, String value) { + return redisTemplate.opsForList().rightPush(key, value); + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/utils/SpringContextHolder.java b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/SpringContextHolder.java new file mode 100644 index 0000000..90e89c6 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/utils/SpringContextHolder.java @@ -0,0 +1,99 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.utils; + +import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.DisposableBean; +import org.springframework.context.ApplicationContext; +import org.springframework.context.ApplicationContextAware; +import org.springframework.stereotype.Component; + +/** + * @description 上下文工具类 + * @date 2020-09-24 + */ +@Component +@Slf4j +public class SpringContextHolder implements ApplicationContextAware, DisposableBean { + + public static ApplicationContext applicationContext = null; + + /** + * 从静态变量applicationContext中取得Bean, 自动转型为所赋值对象的类型. + * @param name bean名称 + * @param 类型 + * @return bean对象 + */ + @SuppressWarnings("unchecked") + public static T getBean(String name) { + assertContextInjected(); + return (T) applicationContext.getBean(name); + } + + /** + * 从静态变量applicationContext中取得Bean, 自动转型为所赋值对象的类型. + * @param requiredType bean类型 class + * @param 泛型 + * @return bean对象 + */ + public static T getBean(Class requiredType) { + assertContextInjected(); + return applicationContext.getBean(requiredType); + } + + /** + * 检查ApplicationContext不为空. + */ + private static void assertContextInjected() { + if (applicationContext == null) { + throw new IllegalStateException("applicaitonContext属性未注入, 请在applicationContext" + + ".xml中定义SpringContextHolder或在SpringBoot启动类中注册SpringContextHolder."); + } + } + + /** + * 清除SpringContextHolder中的ApplicationContext为Null. + */ + private static void clearHolder() { + log.debug("清除SpringContextHolder中的ApplicationContext:" + + applicationContext); + applicationContext = null; + } + + /** + * 销毁回调函数 + */ + @Override + public void destroy() { + SpringContextHolder.clearHolder(); + } + + /** + * spring上下文设置 + * @param applicationContext + * @throws BeansException + */ + @Override + public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { + if (SpringContextHolder.applicationContext != null) { + log.warn("SpringContextHolder中的ApplicationContext被覆盖, 原有ApplicationContext为:" + SpringContextHolder.applicationContext); + } + SpringContextHolder.applicationContext = applicationContext; + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobHandler.java b/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobHandler.java new file mode 100644 index 0000000..a95ce71 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobHandler.java @@ -0,0 +1,111 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.watcher; + +import cn.hutool.core.collection.CollectionUtil; +import cn.hutool.core.util.StrUtil; +import io.fabric8.kubernetes.api.model.OwnerReference; +import io.fabric8.kubernetes.api.model.apps.StatefulSet; +import io.fabric8.kubernetes.api.model.batch.Job; +import io.fabric8.kubernetes.client.KubernetesClient; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.constants.KubeConstants; +import org.onebrain.operator.redis.RedisService; +import org.onebrain.operator.redis.key.OperatorKey; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.List; + +import static org.onebrain.operator.constants.CrdConstants.CRD_KIND; + +/** + * @description Job处理器 + * @date 2020-09-24 + */ +@Data +@Slf4j +@Component +public class JobHandler { + + public static final String FINISHED = "finished"; + public static final String PENDING = "pending"; + @Autowired + private RedisService redis; + + @Autowired + private KubernetesClient client; + + /** + * 处理Job + * + * @param job + */ + public void handleJob(Job job) { + log.info("handleJob=>job : 【{}】", job); + + //筛选出DistributeTrain下的job + List ownerReferences = job.getMetadata().getOwnerReferences(); + if (CollectionUtil.isEmpty(ownerReferences) || !CRD_KIND.equals(ownerReferences.get(0).getKind())) { + return; + } + + String key = job.getMetadata().getUid(); + if (StrUtil.equals(redis.get(OperatorKey.CR_JOB, key), FINISHED)) { + return; + } + + try { + redis.set(OperatorKey.CR_JOB, key, PENDING); + + final Integer parallelism = job.getSpec().getParallelism(); + final Integer backoffLimit = job.getSpec().getBackoffLimit(); + //成功 或者 失败达到最大次数 + if (job.getStatus() != null + && ((job.getStatus().getFailed() != null && job.getStatus().getFailed() + 1 >= backoffLimit) + || (job.getStatus().getSucceeded() != null && parallelism.equals(job.getStatus().getSucceeded())))) { + //得到DistributeTrain的Statefulset + String dtName = ownerReferences.get(0).getName(); + String namespace = job.getMetadata().getNamespace(); + + List statefulsetList = client.apps().statefulSets() + .inNamespace(namespace) + .withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, dtName) + .list().getItems(); + + if (CollectionUtil.isEmpty(statefulsetList)) { + log.info("jobWatcher: statefulset of 【{}】 not exists", dtName); + return; + } + + //缩容Statefulset的replica到0 + StatefulSet statefulSet = statefulsetList.get(0); + statefulSet.getSpec().setReplicas(0); + client.resource(statefulSet).createOrReplace(); + log.info("jobWatcher: reduce replicas of 【{}】 to zero", dtName); + + redis.set(OperatorKey.CR_JOB, key, "finished"); + } + + } catch (Exception e) { + redis.set(OperatorKey.CR_JOB, key, "error"); + log.error("handle job error:【{}】", e); + } + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobWatcher.java b/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobWatcher.java new file mode 100644 index 0000000..ddc708f --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobWatcher.java @@ -0,0 +1,71 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.watcher; + +import io.fabric8.kubernetes.api.model.batch.Job; +import io.fabric8.kubernetes.client.KubernetesClientException; +import io.fabric8.kubernetes.client.Watcher; +import lombok.Data; +import lombok.extern.slf4j.Slf4j; + +/** + * @description Job监视器 + * @date 2020-09-24 + */ +@Data +@Slf4j +public class JobWatcher implements Watcher { + + private String namespace; + + private String jobName; + + private KubeWatcherManager manager; + + private JobHandler jobHandler; + + public JobWatcher(JobHandler jobHandler, KubeWatcherManager manager) { + this.manager = manager; + this.jobHandler = jobHandler; + } + + /** + * 接收事件进行处理 + * @param action 事件类型 + * @param job job信息 + */ + @Override + public void eventReceived(Action action, Job job) { + log.info("Job Event received: {} at {}", job.getMetadata().getUid(), job.getMetadata().getCreationTimestamp()); + jobHandler.handleJob(job); + } + + /** + * 关闭事件 + * @param e 客户端异常 + */ + @Override + public void onClose(KubernetesClientException e) { + log.debug("job watcher close"); + if (e != null) { + log.error(e.getMessage()); + log.info("restart new job watcher thread"); + manager.putNewWatcher(); + } + } +} diff --git a/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/KubeWatcherManager.java b/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/KubeWatcherManager.java new file mode 100644 index 0000000..3394913 --- /dev/null +++ b/distribute-train-operator/src/main/java/org/onebrain/operator/watcher/KubeWatcherManager.java @@ -0,0 +1,120 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + +package org.onebrain.operator.watcher; + +import io.fabric8.kubernetes.client.KubernetesClient; +import lombok.extern.slf4j.Slf4j; +import org.onebrain.operator.context.KubeContext; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Component; + +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * @description 监视器的管理器 + * @date 2020-09-24 + */ +@Slf4j +@Component +public class KubeWatcherManager { + + /** + * 监视队列 + */ + private static final LinkedBlockingQueue watchQueue = new LinkedBlockingQueue<>(1000); + + /** + * 单例线程池 + */ + private ThreadPoolExecutor pool = new ThreadPoolExecutor(1, 1, 1, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1), new ThreadFactory() { + private final AtomicInteger mThreadNum = new AtomicInteger(1); + @Override + public Thread newThread(Runnable r) { + return new Thread(r, "job-watcher-" + mThreadNum.getAndIncrement()); + } + }); + + @Autowired + private KubeContext kubeContext; + + @Autowired + private JobHandler jobHandler; + + /** + * 第一次启动时 + */ + public void startWatching(){ + JobWatchHolder jobWatchHolder = new JobWatchHolder(); + pool.execute(jobWatchHolder); + putNewWatcher(); + } + + /** + * 监听指定job + * @param jobWatcher + */ + public void watch(JobWatcher jobWatcher){ + KubernetesClient client = kubeContext.getClient(); + //监听指定job + client.batch().jobs() + .inAnyNamespace().watch(jobWatcher); + } + + /** + * 加入新watcher + */ + public void putNewWatcher(){ + try { + JobWatcher jobWatcher = new JobWatcher(jobHandler, this); + watchQueue.put(jobWatcher); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + /** + * Job监视器持有者 + */ + class JobWatchHolder implements Runnable { + + @Override + public void run() { + while(true){ + try { + //无监视器时阻塞 + JobWatcher jobWatcher = watchQueue.take(); + + //启动监视器 + try{ + watch(jobWatcher); + }catch (Exception e){ + //出错不影响其他listener + log.error("JobWatchHolder watch error:【{}】",e); + } + + } catch (InterruptedException e) { + log.error("JobWatchHolder run error:【{}】",e); + } + } + } + } +} diff --git a/distribute-train-operator/src/main/resources/key/id_rsa b/distribute-train-operator/src/main/resources/key/id_rsa new file mode 100644 index 0000000..a7bbbc5 --- /dev/null +++ b/distribute-train-operator/src/main/resources/key/id_rsa @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEA06ZOLQq4pzBZL+bybsxdl4PzYg3jB4kRVc771nm5Y8JenDAT +hlOTz6+nGH4EDT63J7oNj4JYLufsONKYhJkya8p0btWeKHqz5LgEfLGwz/FTMRH5 +WTCZCZUa/3i9gQeKK/CKEned1h4l2w4agrYrnXHpnuNSw6HSlTpX8FgaQGfmTkL3 +XtzSCeY9F2fXGOm9fMfVmv5I5uP6B4TmKwtWPvx3a/1MDgHbmtoaCqYP/JmzWHyi +mc9l2ilX3kTPxh57oRtW9N3FATc8/OCYkNt4vDUTRVB4drODaR5TgUbFtkBVGcFR +f7MrQo4Krd2g8rtEv7PaWN/wlNle5ANXJ/oL3wIDAQABAoIBADiqC8APYMSSMy6Z +/EohuOT51M1pvmCkF9oLYm1XhYTp4v6Z+IA8HBS8iFYMVvVc1xhxvXOwh/925E2K +RH8rrM4jE+0gkAlyYHtZsQnZYOcrSwSWNVXlpvNj0iiXoNTMufdtnOm40K8kvynY +qsxYDXFHsC5z2hK6XnDJgAW+8LhRHCizWwxc0dSN9r33VGry0rgndUZsj2ZBf7u5 +rdslZKvRzMymXct7CIQQ3s5IUO3qbaj7TIzMIo14bmHgD3zlBQ66ESCX1o5A+hPq +1gfUNqUPBtJhsNJg4YYJ/bGgGhBxAxam8jWz3DFZEuYHr6fCDIhLJzL5ulxoQS2z +vJYBwsECgYEA8JGfw004BxqcBVxqBveestsCVGIWDtb+Zx4OI+uBAmYMXd2WCzxv +XxgQJ/IrpNx6FAXZ/bFdE0HRZWR6H07wtNgABuBgd0tAfcH8sw2CJkTO/0N2Xr6/ +O4kh3yHNMy/wAxnktISf1hE/ElEdPI6slhwGDQObRdXxaqBEq+Tjc28CgYEA4TnM +rCaJ8aMaUE0nvVzrev3VTLp4f1qOcPUOnrHDdyrPs1SjYzmAOC72X/FylJZmtkvh +coMQUKVQgiBn1dTtnALANq705b1S+0U07m6+dGJ7LWchOY2tFPiIsx3SZvNJeEKJ +38PsaFi2eDcDP8cKriNoAoby8TbqjqiyHgDX9pECfxww9IfuhKJQe/gk3Ef0vKQ5 +BgzdcbhLeYScAQw0jOm7C7f0P6ERc/uw1jPYLUUkkSnHhcQ1BLM9A0zeeXExzwNi +TJ6BrMxOBUC3euWAr7/MUHWZckWoFMDlURLU4zccZwP2BNcis5hibQG4f7SZA6CT +qCHeSlPkvmXAYkvChuUCgYEA0DNlL9KkfBqBja/1R4jpKhYSIs7R6zCkMmlm7W54 +ueV6gVWBgI08KTPIj2KcwBzUsDovG3NrFpHrfY9FTZd7W1fzpdlQDDxaxGryhmMb +bm1HXu5R+WktkhA6FhJAWOkXhrNDzvXHyaIQc8qvFzsBdX7HfGaRmEhixiPOHAw9 +l/ECgYEAwNywUARR9HtmgoyrwifrzIkMo6jcmLNEIzi2kJ4OQQxW5eKj5JgSV0ND +QUoAIWDAhHQd3ygSfbeShcvtcw+zoF92iOVFn0SLiSe1TgA5ggzC/VJUnInO7zx7 +8Sj8Zk5tHrVmTlelEA2Nbq5H7/U1Q33c1AWbw8yxqD/JRxudHKA= +-----END RSA PRIVATE KEY----- \ No newline at end of file diff --git a/distribute-train-operator/src/main/resources/key/id_rsa.pub b/distribute-train-operator/src/main/resources/key/id_rsa.pub new file mode 100644 index 0000000..a356ff4 --- /dev/null +++ b/distribute-train-operator/src/main/resources/key/id_rsa.pub @@ -0,0 +1 @@ +ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDTpk4tCrinMFkv5vJuzF2Xg/NiDeMHiRFVzvvWebljwl6cMBOGU5PPr6cYfgQNPrcnug2Pglgu5+w40piEmTJrynRu1Z4oerPkuAR8sbDP8VMxEflZMJkJlRr/eL2BB4or8IoSd53WHiXbDhqCtiudceme41LDodKVOlfwWBpAZ+ZOQvde3NIJ5j0XZ9cY6b18x9Wa/kjm4/oHhOYrC1Y+/Hdr/UwOAdua2hoKpg/8mbNYfKKZz2XaKVfeRM/GHnuhG1b03cUBNzz84JiQ23i8NRNFUHh2s4NpHlOBRsW2QFUZwVF/sytCjgqt3aDyu0S/s9pY3/CU2V7kA1cn+gvf root@{{ip}} \ No newline at end of file diff --git a/distribute-train-operator/src/main/resources/kubeconfig b/distribute-train-operator/src/main/resources/kubeconfig new file mode 100644 index 0000000..b7510e3 --- /dev/null +++ b/distribute-train-operator/src/main/resources/kubeconfig @@ -0,0 +1,19 @@ +apiVersion: v1 +clusters: +- cluster: + certificate-authority-data: {} + server: {} + name: kubernetes +contexts: +- context: + cluster: kubernetes + user: kubernetes-admin + name: kubernetes-admin@kubernetes +current-context: kubernetes-admin@kubernetes +kind: Config +preferences: {} +users: +- name: kubernetes-admin + user: + client-certificate-data: {} + client-key-data: {} \ No newline at end of file diff --git a/distribute-train-operator/src/main/resources/shell/pretreatment b/distribute-train-operator/src/main/resources/shell/pretreatment new file mode 100644 index 0000000..b6b1566 --- /dev/null +++ b/distribute-train-operator/src/main/resources/shell/pretreatment @@ -0,0 +1,46 @@ +#!/bin/bash +if [ ! -f "/etc/init.d/ssh" ]; then + if [ ! -f "/etc/redhat-release" ]; then + echo 'apt install -y openssh-server' >> pretreatment.log + apt update >> pretreatment.log + apt install -y openssh-server >> pretreatment.log + fi + if [ ! -f "/etc/lsb-release" ]; then + echo 'yum install -y sshd' >> pretreatment.log + yum update >> pretreatment.log + yum install -y sshd >> pretreatment.log + fi +fi +echo '/etc/init.d/ssh start' >> pretreatment.log +/etc/init.d/ssh start >> pretreatment.log +if [ -f "/etc/redhat-release" ]; then + if command -v nslookup >/dev/null 2>&1; then + echo 'exists nslookup' >> pretreatment.log + else + echo 'yum install dnsutils jq' >> pretreatment.log + yum install -y dnsutils >> pretreatment.log + yum install -y jq >> pretreatment.log + fi + if command -v nslookup >/dev/null 2>&1; then + echo 'exists nslookup' >> pretreatment.log + else + echo 'yum install dnsutils jq' >> pretreatment.log + yum install -y dnsutils >> pretreatment.log + yum install -y jq >> pretreatment.log + fi +fi + +if [ -f "/etc/lsb-release" ]; then + if command -v jq >/dev/null 2>&1; then + echo 'exists jq' >> pretreatment.log + else + echo 'apt install jq' >> pretreatment.log + apt install -y jq >> pretreatment.log + fi + if command -v nslookup >/dev/null 2>&1; then + echo 'exists nslookup' >> pretreatment.log + else + echo 'apt install dnsutils' >> pretreatment.log + apt install -y dnsutils >> pretreatment.log + fi +fi diff --git a/distribute-train-operator/src/test/java/org/onebrain/operator/DistributeTrainOperatorApplicationTests.java b/distribute-train-operator/src/test/java/org/onebrain/operator/DistributeTrainOperatorApplicationTests.java new file mode 100644 index 0000000..a64de1f --- /dev/null +++ b/distribute-train-operator/src/test/java/org/onebrain/operator/DistributeTrainOperatorApplicationTests.java @@ -0,0 +1,43 @@ +/** + * Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================= + */ + + +package org.onebrain.operator; + +import org.onebrain.operator.api.pod.PodApi; +import org.onebrain.operator.constants.KubeConstants; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.io.File; +import java.net.URISyntaxException; +import java.net.URL; + +@SpringBootTest +public class DistributeTrainOperatorApplicationTests { + + @Autowired + private PodApi podApi; + +// @Test + public void contextLoads() throws URISyntaxException { + final URL url = getClass().getClassLoader().getResource("key/id_rsa"); + File file = new File(url.toURI()); + podApi.copyToPod("default", "distribute-train-test-job-sv2dj", KubeConstants.MASTER_CONTAINER_NAME, file, "/root/.ssh/id_rsa"); + } + +}