Browse Source

update distribute-train-operator

tags/v0.3.0
之江实验室 3 years ago
parent
commit
2e80c376f6
53 changed files with 5048 additions and 0 deletions
  1. +26
    -0
      distribute-train-operator/README.md
  2. +65
    -0
      distribute-train-operator/docs/crds/distribute-train-cr.yaml
  3. +61
    -0
      distribute-train-operator/docs/crds/distribute-train-crd.yaml
  4. +47
    -0
      distribute-train-operator/docs/deploy/distribute-train-operator_deploy.yaml
  5. +150
    -0
      distribute-train-operator/pom.xml
  6. +35
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/DistributeTrainOperatorApplication.java
  7. +199
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/DistributeTrainOperatorManager.java
  8. +58
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/OperatorRunner.java
  9. +44
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/PodInfo.java
  10. +41
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/AbstractResourceCreateInfo.java
  11. +227
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ChildResourceCreateInfo.java
  12. +35
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/JobDeployer.java
  13. +33
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ServiceDeployer.java
  14. +33
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/StatefulSetDeployer.java
  15. +246
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseJobDeployer.java
  16. +73
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseServiceDeployer.java
  17. +246
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseStatefulSetDeployer.java
  18. +614
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/AddActionHandler.java
  19. +88
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DeleteActionHandler.java
  20. +33
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DistributeTrainActionHandler.java
  21. +85
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/DefaultPodExecListener.java
  22. +177
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/PodApi.java
  23. +83
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/StdPodExecListener.java
  24. +66
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/config/KubeConfig.java
  25. +34
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/constants/CrdConstants.java
  26. +40
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/constants/KubeConstants.java
  27. +43
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/constants/NumberConstant.java
  28. +117
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/context/KubeContext.java
  29. +131
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/controller/DistributeTrainController.java
  30. +47
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrain.java
  31. +27
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainList.java
  32. +108
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainSpec.java
  33. +55
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainStatus.java
  34. +31
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/crd/DoneableDistributeTrain.java
  35. +56
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/enums/AccessModeEnum.java
  36. +49
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/exception/OperatorException.java
  37. +34
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/properties/KubeProperties.java
  38. +65
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/redis/AbstractKeyPrefix.java
  39. +290
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/redis/RedisService.java
  40. +45
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/redis/key/OperatorKey.java
  41. +41
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/utils/DistributeTrainClientHolder.java
  42. +188
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/utils/FastjsonUtils.java
  43. +56
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/utils/IOUtils.java
  44. +289
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/utils/RedisUtils.java
  45. +99
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/utils/SpringContextHolder.java
  46. +111
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobHandler.java
  47. +71
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobWatcher.java
  48. +120
    -0
      distribute-train-operator/src/main/java/org/onebrain/operator/watcher/KubeWatcherManager.java
  49. +27
    -0
      distribute-train-operator/src/main/resources/key/id_rsa
  50. +1
    -0
      distribute-train-operator/src/main/resources/key/id_rsa.pub
  51. +19
    -0
      distribute-train-operator/src/main/resources/kubeconfig
  52. +46
    -0
      distribute-train-operator/src/main/resources/shell/pretreatment
  53. +43
    -0
      distribute-train-operator/src/test/java/org/onebrain/operator/DistributeTrainOperatorApplicationTests.java

+ 26
- 0
distribute-train-operator/README.md View File

@@ -0,0 +1,26 @@
# 之江天枢-分布式训练 operator
该模块是分布式训练CRD的控制器,管理分布式训练容器生命周期,为分布式训练容器注入其他容器ip。

## 源码部署

### 准备环境
安装如下软件环境。
- OpenJDK:1.8+
- Redis: 3.0+
- Maven: 3.0+

### 下载源码
``` bash
git clone https://codeup.teambition.com/zhejianglab/distribute-train-operator.git
# 进入项目根目录
cd distribute-train-operator
```

### 构建
``` bash
# 构建,生成的 jar 包位于 ./target/distribute-train-operator-1.0.jar
mvn clean compile package
```

### 部署
部署过程参看文档:[部署 分布式训练operator](http://tianshu.org.cn/?/course/1.html)

+ 65
- 0
distribute-train-operator/docs/crds/distribute-train-cr.yaml View File

@@ -0,0 +1,65 @@
apiVersion: onebrain.oneflow.org/v1alpha1
kind: DistributeTrain
metadata:
name: dt-resnet50
namespace: resnet50
labels:
key: value
spec:
size: 3
image: {{IMAGE}}
imagePullPolicy: IfNotPresent
masterCmd: export NODE_IPS=`cat /home/hostfile.json |jq -r '.[]|.ip'|paste -d "," -s` && cd /workspace/Classification/cnns && rm -rf core.* && rm -rf ./output/snapshots/* && python3 of_cnn_train_val.py --train_data_dir=$DATA_ROOT/train --train_data_part_num=$TRAIN_DATA_PART_NUM --val_data_dir=$DATA_ROOT/validation --val_data_part_num=$VAL_DATA_PART_NUM --num_nodes=$NODE_NUM --node_ips="$NODE_IPS" --gpu_num_per_node=$GPU_NUM_PER_NODE --model_update="momentum" --learning_rate=0.256 --loss_print_every_n_iter=1 --batch_size_per_device=64 --val_batch_size_per_device=64 --num_epoch=1 --model="resnet50" --model_save_dir=/model
masterResources:
requests:
nvidia.com/gpu: 2
memory: "16Gi"
cpu: "2"
limits:
nvidia.com/gpu: 2
memory: "16Gi"
cpu: "2"
slaveCmd: export NODE_IPS=`cat /home/hostfile.json |jq -r '.[]|.ip'|paste -d "," -s` && cd /workspace/Classification/cnns && rm -rf core.* && rm -rf ./output/snapshots/* && python3 of_cnn_train_val.py --train_data_dir=$DATA_ROOT/train --train_data_part_num=$TRAIN_DATA_PART_NUM --val_data_dir=$DATA_ROOT/validation --val_data_part_num=$VAL_DATA_PART_NUM --num_nodes=$NODE_NUM --node_ips="$NODE_IPS" --gpu_num_per_node=$GPU_NUM_PER_NODE --model_update="momentum" --learning_rate=0.256 --loss_print_every_n_iter=1 --batch_size_per_device=64 --val_batch_size_per_device=64 --num_epoch=1 --model="resnet50" --model_save_dir=/model
slaveResources:
requests:
nvidia.com/gpu: 2
memory: "16Gi"
cpu: "2"
limits:
nvidia.com/gpu: 2
memory: "16Gi"
cpu: "2"
nodeSelector:
kubernetes.io/hostname: node02
env:
- name: ENABLE_USER_OP
value: 'True'
- name: DATA_ROOT
value: '/dataset'
- name: NODE_NUM
value: 3
- name: GPU_NUM_PER_NODE
value: 2
- name: ONEFLOW_DEBUG_MODE
value: ""
- name: TRAIN_DATA_PART_NUM
value: 6
- name: VAL_DATA_PART_NUM
value: 6
- name: NCCL_DEBUG
value: INFO
datasetStorage:
name: pvc-dataset
nfs:
path: {{DATASET}}
server: {{NFS}}
workspaceStorage:
name: pvc-workspace
nfs:
path: /nfs/resnet50/workspace
server: {{WORKSPACE}}
modelStorage:
name: pvc-model
nfs:
path: /nfs/resnet50/model
server: {{MODEL}}

+ 61
- 0
distribute-train-operator/docs/crds/distribute-train-crd.yaml View File

@@ -0,0 +1,61 @@
---
apiVersion: apiextensions.k8s.io/v1beta1
kind: CustomResourceDefinition
metadata:
name: distributetrains.onebrain.oneflow.org
spec:
group: onebrain.oneflow.org
names:
kind: DistributeTrain
singular: distributetrain
plural: distributetrains
shortNames:
- dt
scope: Namespaced
subresources:
status: {}
version: v1alpha1
validation:
openAPIV3Schema:
properties:
apiVersion:
type: string
kind:
type: string
metadata:
type: object
spec:
properties:
image:
type: string
imagePullPolicy:
type: string
size:
format: int32
type: integer
masterCmd:
type: string
slaveCmd:
type: string
masterResources:
type: object
slaveResources:
type: object
nodeSelector:
type: object
initContainer:
type: object
datasetStorage:
type: object
workspaceStorage:
type: object
modelStorage:
type: object
required:
- image
- imagePullPolicy
- size
- masterCmd
- slaveCmd
- workspaceStorage
type: object

+ 47
- 0
distribute-train-operator/docs/deploy/distribute-train-operator_deploy.yaml View File

@@ -0,0 +1,47 @@
kind: Deployment
apiVersion: apps/v1
metadata:
name: distribute-train-operator
namespace: test-ns
labels:
name: distribute-train-operator
spec:
replicas: 1
selector:
matchLabels:
name: distribute-train-operator
template:
metadata:
labels:
name: distribute-train-operator
spec:
containers:
- name: distribute-train-operator
image: {{IMAGE}}
ports:
- containerPort: 8080
protocol: TCP
volumeMounts:d
- mountPath: /root/config
name: config-volume
env:
- name: JAR_BALL
value: "distribute-train-operator-1.0.jar --k8s.kubeconfig=/root/config --spring.redis.host=192.168.1.104"
imagePullPolicy: IfNotPresent
volumes:
- name: config-volume
hostPath:
path: /root/.kube/config
restartPolicy: Always
terminationGracePeriodSeconds: 30
securityContext:
runAsUser: 0
schedulerName: default-scheduler
strategy:
type: RollingUpdate
rollingUpdate:
maxUnavailable: 1
maxSurge: 1
revisionHistoryLimit: 7
progressDeadlineSeconds: 600


+ 150
- 0
distribute-train-operator/pom.xml View File

@@ -0,0 +1,150 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 https://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>2.2.5.RELEASE</version>
</parent>

<groupId>org.onebrain</groupId>
<artifactId>distribute-train-operator</artifactId>
<version>1.0</version>
<name>distribute-train-operator</name>
<description>distribute-train operatior</description>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<project.reporting.outputEncoding>UTF-8</project.reporting.outputEncoding>
<java.version>1.8</java.version>
<fabric.io.version>4.9.0</fabric.io.version>
</properties>

<dependencies>
<!-- web -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>

<!-- k8s -->
<dependency>
<groupId>io.fabric8</groupId>
<artifactId>kubernetes-client</artifactId>
<version>${fabric.io.version}</version>
</dependency>
<dependency>
<groupId>io.fabric8</groupId>
<artifactId>kubernetes-assertions</artifactId>
<version>4.0.0</version>
<scope>test</scope>
</dependency>

<!-- configuration processor -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-configuration-processor</artifactId>
</dependency>

<!-- redis -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>redis.clients</groupId>
<artifactId>jedis</artifactId>
</dependency>

<!-- common jars -->
<dependency>
<groupId>commons-io</groupId>
<artifactId>commons-io</artifactId>
<version>2.6</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-compress</artifactId>
<version>1.19</version>
</dependency>
<dependency>
<groupId>commons-codec</groupId>
<artifactId>commons-codec</artifactId>
</dependency>

<!-- tools -->
<dependency>
<groupId>cn.hutool</groupId>
<artifactId>hutool-all</artifactId>
<version>5.1.1</version>
</dependency>

<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
<version>27.0.1-jre</version>
</dependency>

<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>1.2.54</version>
</dependency>

<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
</dependencies>

<build>
<plugins>
<plugin>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-maven-plugin</artifactId>
</plugin>
<!-- 打包时跳过测试 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<configuration>
<skip>true</skip>
</configuration>
</plugin>
</plugins>
</build>

<repositories>
<repository>
<id>public</id>
<name>aliyun nexus</name>
<url>http://maven.aliyun.com/nexus/content/groups/public/</url>
<releases>
<enabled>true</enabled>
</releases>
</repository>
</repositories>

<pluginRepositories>
<pluginRepository>
<id>public</id>
<name>aliyun nexus</name>
<url>http://maven.aliyun.com/nexus/content/groups/public/</url>
<releases>
<enabled>true</enabled>
</releases>
<snapshots>
<enabled>false</enabled>
</snapshots>
</pluginRepository>
</pluginRepositories>

</project>

+ 35
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/DistributeTrainOperatorApplication.java View File

@@ -0,0 +1,35 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator;

import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
/**
* @description Operator启动类
* @date 2020-09-03
*/
@SpringBootApplication
@EnableAsync
public class DistributeTrainOperatorApplication {

public static void main(String[] args) {
SpringApplication.run(DistributeTrainOperatorApplication.class, args);
}

}

+ 199
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/DistributeTrainOperatorManager.java View File

@@ -0,0 +1,199 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action;

import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.google.common.collect.Maps;
import io.fabric8.kubernetes.api.model.apiextensions.*;
import io.fabric8.kubernetes.client.KubernetesClient;
import io.fabric8.kubernetes.client.dsl.MixedOperation;
import io.fabric8.kubernetes.client.dsl.Resource;
import io.fabric8.kubernetes.client.dsl.base.CustomResourceDefinitionContext;
import io.fabric8.kubernetes.client.informers.SharedIndexInformer;
import io.fabric8.kubernetes.client.informers.SharedInformerFactory;
import io.fabric8.kubernetes.client.internal.SerializationUtils;
import io.fabric8.kubernetes.internal.KubernetesDeserializer;
import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.controller.DistributeTrainController;
import org.onebrain.operator.crd.DistributeTrain;
import org.onebrain.operator.crd.DistributeTrainList;
import org.onebrain.operator.crd.DoneableDistributeTrain;
import org.onebrain.operator.utils.DistributeTrainClientHolder;
import org.onebrain.operator.utils.SpringContextHolder;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.support.BeanDefinitionBuilder;
import org.springframework.beans.factory.support.DefaultListableBeanFactory;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.stereotype.Component;

import java.util.Map;

import static org.onebrain.operator.constants.CrdConstants.*;

/**
* @description operator 主控制器
* @date 2020-09-23
*/
@Component
@Slf4j
public class DistributeTrainOperatorManager {

public static final String NAMESPACE_DEFAULT = "default";
public static final String TYPE_STRING = "string";
public static final String TYPE_INTEGER = "integer";
public static final String TYPE_OBJECT = "object";
public static final String TYPE_ARRAY = "array";
public static final String FORMAT_INT_32 = "int32";
@Autowired
private KubernetesClient client;

private CustomResourceDefinition crd;

private String namespace;

/**
* 检查crd是否存在,若不存在则创建
* @throws JsonProcessingException
*/
public void createCrdIfNotExists() throws JsonProcessingException {
String namespace = client.getNamespace();
if (namespace == null) {
log.info("No namespace found via config, assuming default.");
namespace = NAMESPACE_DEFAULT;
}
this.namespace = namespace;
log.info("Using namespace : {}", namespace);
//检查crd是否已存在
CustomResourceDefinition crd = client.customResourceDefinitions().withName(CRD_NAME).get();
if(crd == null){
Map<String, JSONSchemaProps> crdPropsMap = buildCrdProperties();
log.info("crd props map is : 【{}】",crdPropsMap);
//如不存在,则创建
CustomResourceDefinition distributeTrainCustomResourceDefinition = new CustomResourceDefinitionBuilder()
.withApiVersion(CRD_API_VERSION)
.withNewMetadata()
.withName(CRD_NAME)
.endMetadata()
.withNewSpec()
.withGroup(CRD_GROUP)
.withVersion(CRD_VERSION)
.withScope(CRD_SCOPE)
.withNewNames()
.withKind(CRD_KIND)
.withSingular(CRD_SINGULAR_NAME)
.withPlural(CRD_PLURAL_NAME)
.withShortNames(CRD_SHORT_NAME)
.endNames()
.withNewValidation()
.withNewOpenAPIV3Schema()
.addToProperties(crdPropsMap)
.endOpenAPIV3Schema()
.endValidation()
.endSpec()
.build();
distributeTrainCustomResourceDefinition = client.customResourceDefinitions().create(distributeTrainCustomResourceDefinition);
log.info("create crd successfully : \n{}", SerializationUtils.dumpAsYaml(distributeTrainCustomResourceDefinition));
crd = distributeTrainCustomResourceDefinition;
}
//注册到k8s反序列化解析器
KubernetesDeserializer.registerCustomKind(CRD_GROUP + StrUtil.SLASH + CRD_VERSION, CRD_KIND, DistributeTrain.class);
this.crd = crd;
}

/**
* 初始化informer
*/
public void initInformer(){
CustomResourceDefinitionContext distributeTrainCustomResourceDefinitionContext = new CustomResourceDefinitionContext.Builder()
.withVersion(CRD_VERSION)
.withScope(CRD_SCOPE)
.withGroup(CRD_GROUP)
.withPlural(CRD_PLURAL_NAME)
.build();

SharedInformerFactory informerFactory = client.informers();

MixedOperation<DistributeTrain, DistributeTrainList, DoneableDistributeTrain, Resource<DistributeTrain, DoneableDistributeTrain>> distributeTrainClient = client.customResources(this.crd, DistributeTrain.class, DistributeTrainList.class, DoneableDistributeTrain.class);
SharedIndexInformer<DistributeTrain> distributeTrainSharedIndexInformer = informerFactory.sharedIndexInformerForCustomResource(distributeTrainCustomResourceDefinitionContext, DistributeTrain.class, DistributeTrainList.class, 10 * 60 * 1000);
//使用静态变量维持
DistributeTrainClientHolder.setDistributeTrainClient(distributeTrainClient);
//手动注册controller到ioc容器
BeanDefinitionBuilder beanDefinitionBuilder = BeanDefinitionBuilder.genericBeanDefinition(DistributeTrainController.class);
DefaultListableBeanFactory beanFactory = (DefaultListableBeanFactory)((ConfigurableApplicationContext) SpringContextHolder.applicationContext).getBeanFactory();
beanDefinitionBuilder.addConstructorArgValue(distributeTrainClient);
beanDefinitionBuilder.addConstructorArgValue(distributeTrainSharedIndexInformer);
beanDefinitionBuilder.addConstructorArgValue(namespace);
beanFactory.registerBeanDefinition("org.onebrain.operator.controller.DistributeTrainController", beanDefinitionBuilder.getRawBeanDefinition());

//取得托管的controller
DistributeTrainController controller = SpringContextHolder.getBean(DistributeTrainController.class);
//注册informer监听
controller.create();
informerFactory.startAllRegisteredInformers();
//等待就绪
controller.run();
}

/**
* 生成crd属性
* @return crd属性集合
*/
private Map<String, JSONSchemaProps> buildCrdProperties(){
Map<String, JSONSchemaProps> properties = Maps.newHashMap();
JSONSchemaProps stringType = new JSONSchemaPropsBuilder()
.withType(TYPE_STRING)
.build();
JSONSchemaProps intType = new JSONSchemaPropsBuilder()
.withType(TYPE_INTEGER)
.withFormat(FORMAT_INT_32)
.build();
JSONSchemaProps objectType = new JSONSchemaPropsBuilder()
.withType(TYPE_OBJECT)
.build();
JSONSchemaProps arrayType = new JSONSchemaPropsBuilder()
.withType(TYPE_ARRAY)
.withNewItems()
.endItems()
.build();

//添加属性校验规则
JSONSchemaProps specObjectType = new JSONSchemaPropsBuilder()
.addToProperties("image", stringType)
.addToProperties("imagePullPolicy", stringType)
.addToProperties("size", intType)
.addToProperties("env", arrayType)
.addToProperties("masterCmd", stringType)
.addToProperties("slaveCmd", stringType)
.addToProperties("masterResources", objectType)
.addToProperties("slaveResources", objectType)
.addToProperties("nodeSelector", objectType)
.addToProperties("initContainer", objectType)
.addToProperties("datasetStorage", objectType)
.addToProperties("workspaceStorage", objectType)
.addToProperties("modelStorage", objectType)
.withType("object")
.addToRequired("image", "imagePullPolicy", "size", "masterCmd", "slaveCmd", "workspaceStorage")
.build();
properties.put("apiVersion", stringType);
properties.put("kind", stringType);
properties.put("metadata", objectType);
properties.put("spec", specObjectType);
return properties;
}
}

+ 58
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/OperatorRunner.java View File

@@ -0,0 +1,58 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action;

import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.watcher.KubeWatcherManager;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;

/**
* @description Operator运行入口
* @date 2020-09-23
*/
@Component
@Slf4j
public class OperatorRunner implements ApplicationRunner {

@Autowired
private DistributeTrainOperatorManager operatorManager;

@Autowired
private KubeWatcherManager watcherManager;

/**
* spring 容器完全启动后 注册operator运行逻辑
* @param args
* @throws Exception
*/
@Override
public void run(ApplicationArguments args) throws Exception {
//检查crd是否已存在,如果不存在则创建
operatorManager.createCrdIfNotExists();

//job监控者启动
watcherManager.startWatching();
log.info("job watcher is running");

//初始化informer
operatorManager.initInformer();
}
}

+ 44
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/PodInfo.java View File

@@ -0,0 +1,44 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
* @description pod信息类
* @date 2020-09-23
*/
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class PodInfo {

/**
* ip地址
*/
private String ip;

/**
* 角色
*/
private String role;
}

+ 41
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/AbstractResourceCreateInfo.java View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer;

import cn.hutool.core.util.RandomUtil;
import lombok.Data;
import lombok.experimental.Accessors;

/**
* @description 创建资源的信息的抽象类
* @date 2020-04-30
*/
@Data
@Accessors(chain = true)
public abstract class AbstractResourceCreateInfo {


/**
* 生成随机字符串
* @param digits 位数
* @return
*/
protected static String getRandomStr(Integer digits){
return RandomUtil.randomString(digits);
}
}

+ 227
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ChildResourceCreateInfo.java View File

@@ -0,0 +1,227 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import io.fabric8.kubernetes.api.model.*;
import lombok.Data;
import lombok.experimental.Accessors;
import org.onebrain.operator.constants.KubeConstants;
import org.onebrain.operator.constants.NumberConstant;
import org.onebrain.operator.crd.DistributeTrain;

import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;

/**
* @description 暂存创建子资源所需的信息
* @date 2020-06-16
*/
@Data
@Accessors(chain = true)
public class ChildResourceCreateInfo extends AbstractResourceCreateInfo {

public static final String SLAVE_TEMPLATE = "{}-slave-{}";
public static final String MASTER_TEMPLATE = "{}-master-{}";
public static final String SVC_TEMPLATE = "{}-svc";
/**
* 父级名称(分布式训练名称)
*/
private String parentName;

/**
* job名称
*/
private String jobName;

/**
* statefullSet名称
*/
private String statefulSetName;

/**
* 服务名称
*/
private String svcName;

/**
* 命名空间
*/
private String namespace;

/**
* 镜像
*/
private String image;

/**
* 镜像拉取策略
*/
private String imagePullPolicy;

/**
* 标签
*/
private Map<String, String> labels;

/**
* master副本数
*/
private Integer masterReplicas;

/**
* slave副本数
*/
private Integer slaveReplicas;

/**
* master命令
*/
private String masterCmd;

/**
* slave命令
*/
private String slaveCmd;

/**
* master 资源节点限制
*/
private ResourceRequirements masterResources;

/**
* slave 资源节点限制
*/
private ResourceRequirements slaveResources;

/**
* 节点调度选择器
*/
private Map<String, String> nodeSelector;

/**
* 初始化容器
*/
private Container initContainer;

/**
* 工作目录挂载
*/
private Volume workspaceVolume;

/**
* 数据集目录挂载
*/
private Volume datasetVolume;

/**
* 模型目录挂载
*/
private Volume modelVolume;

/**
* 环境变量
*/
private List<EnvVar> env;

/**
* 拥有者信息
*/
private OwnerReference ownerReference;

/**
* 将分布式训练转换为K8S的资源信息
* @param distributeTrain 分布式训练
* @return ChildResourceCreateInfo
*/
public static ChildResourceCreateInfo fromCr(DistributeTrain distributeTrain){
ChildResourceCreateInfo info = new ChildResourceCreateInfo();
//ownerReferece信息
info.generateOwnerReference(distributeTrain);
//各种资源的名称
info.setNamespace(distributeTrain.getMetadata().getNamespace());
info.setParentName(distributeTrain.getMetadata().getName());
info.generateResoureName();
//标签
info.setLabels(distributeTrain.getMetadata().getLabels());
//镜像
info.setImage(distributeTrain.getSpec().getImage())
.setImagePullPolicy(distributeTrain.getSpec().getImagePullPolicy());
//副本数
Integer size = distributeTrain.getSpec().getSize();
info.setMasterReplicas(NumberConstant.NUMBER_1);
info.setSlaveReplicas(size - NumberConstant.NUMBER_1);
//命令行
info.setMasterCmd(distributeTrain.getSpec().getMasterCmd())
.setSlaveCmd(distributeTrain.getSpec().getSlaveCmd());
//挂载
Optional.ofNullable(distributeTrain.getSpec().getWorkspaceStorage())
.ifPresent(v -> info.setWorkspaceVolume(v));
Optional.ofNullable(distributeTrain.getSpec().getDatasetStorage())
.ifPresent(v -> info.setDatasetVolume(v));
Optional.ofNullable(distributeTrain.getSpec().getModelStorage())
.ifPresent(v -> info.setModelVolume(v));

//主从两组资源限制
Optional.ofNullable(distributeTrain.getSpec().getMasterResources())
.ifPresent(v -> info.setMasterResources(v));
Optional.ofNullable(distributeTrain.getSpec().getSlaveResources())
.ifPresent(v -> info.setSlaveResources(v));

//环境变量
List<EnvVar> env = distributeTrain.getSpec().getEnv();
if(CollectionUtil.isNotEmpty(env)){
env = env.stream().filter(e -> !KubeConstants.ENV_NODE_NUM.equals(e.getName())).collect(Collectors.toList());
info.setEnv(env);
}

//node调度
info.setNodeSelector(distributeTrain.getSpec().getNodeSelector());

//init-container
info.setInitContainer(distributeTrain.getSpec().getInitContainer());

return info;
}

/**
* 生成资源名称
*/
private void generateResoureName(){
String suffix = getRandomStr(NumberConstant.NUMBER_5);
this.statefulSetName = StrUtil.format(SLAVE_TEMPLATE, this.parentName, suffix);
this.jobName = StrUtil.format(MASTER_TEMPLATE, this.parentName, suffix);
this.svcName = StrUtil.format(SVC_TEMPLATE, this.parentName);
}

/**
* 生成所有者信息
* @param distributeTrain 分布式训练
*/
private void generateOwnerReference(DistributeTrain distributeTrain){
this.ownerReference = new OwnerReferenceBuilder()
.withApiVersion(distributeTrain.getApiVersion())
.withKind(distributeTrain.getKind())
.withName(distributeTrain.getMetadata().getName())
.withNewUid(distributeTrain.getMetadata().getUid())
.build();
}
}

+ 35
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/JobDeployer.java View File

@@ -0,0 +1,35 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer;

import io.fabric8.kubernetes.api.model.batch.JobBuilder;

/**
* @description Job部署接口 规范部署方法
* T 必须是AbstractResourceCreateInfo 的子类型
* @date 2020-09-23
*/
public interface JobDeployer<T extends AbstractResourceCreateInfo> {

/**
* 构建 Job信息
* @param info 资源信息
* @return Job构建者
*/
JobBuilder deploy(T info);
}

+ 33
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/ServiceDeployer.java View File

@@ -0,0 +1,33 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer;

import io.fabric8.kubernetes.api.model.ServiceBuilder;

/**
* @description service部署器接口
* @date 2020-09-23
*/
public interface ServiceDeployer<T extends AbstractResourceCreateInfo> {
/**
* 构建service信息
* @param info 资源信息
* @return 服务构建者
*/
ServiceBuilder deploy(T info);
}

+ 33
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/StatefulSetDeployer.java View File

@@ -0,0 +1,33 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer;

import io.fabric8.kubernetes.api.model.apps.StatefulSetBuilder;

/**
* @description statefulset部署器接口
* @date 2020-09-23
*/
public interface StatefulSetDeployer<T extends AbstractResourceCreateInfo> {
/**
* 构建service信息
* @param info 资源信息
* @return StatefulSet构建者
*/
StatefulSetBuilder deploy(T info);
}

+ 246
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseJobDeployer.java View File

@@ -0,0 +1,246 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer.impl;

import cn.hutool.core.collection.CollectionUtil;
import com.google.common.collect.Lists;
import io.fabric8.kubernetes.api.model.CapabilitiesBuilder;
import io.fabric8.kubernetes.api.model.Container;
import io.fabric8.kubernetes.api.model.ContainerPortBuilder;
import io.fabric8.kubernetes.api.model.EnvVar;
import io.fabric8.kubernetes.api.model.EnvVarBuilder;
import io.fabric8.kubernetes.api.model.SecurityContextBuilder;
import io.fabric8.kubernetes.api.model.Volume;
import io.fabric8.kubernetes.api.model.VolumeBuilder;
import io.fabric8.kubernetes.api.model.VolumeMount;
import io.fabric8.kubernetes.api.model.VolumeMountBuilder;
import io.fabric8.kubernetes.api.model.batch.JobBuilder;
import org.onebrain.operator.action.deployer.ChildResourceCreateInfo;
import org.onebrain.operator.action.deployer.JobDeployer;
import org.onebrain.operator.constants.KubeConstants;

import java.util.*;

import static org.onebrain.operator.constants.NumberConstant.LONG_NUMBER_0;
import static org.onebrain.operator.constants.NumberConstant.NUMBER_1;
import static org.onebrain.operator.constants.NumberConstant.NUMBER_22;

/**
* @description Job部署器
* @date 2020-09-23
*/
public class BaseJobDeployer implements JobDeployer<ChildResourceCreateInfo> {

public static final String PVC_WORKSPACE = "pvc-workspace";
public static final String SSH = "ssh";
public static final String WORKSPACE = "/workspace";
public static final String PVC_DATASET = "pvc-dataset";
public static final String DATASET = "/dataset";
public static final String PVC_MODEL = "pvc-model";
public static final String MODEL = "/model";
public static final String MEMORY = "Memory";
public static final String DEV_SHM = "/dev/shm";
public static final String BIN_BASH = "/bin/bash";
public static final String IPC_LOCK = "IPC_LOCK";
public static final String RESTART_POLICY_NEVER = "Never";

/**
* 部署Job
* @param info 资源信息
* @return
*/
@Override
public JobBuilder deploy(ChildResourceCreateInfo info) {

//容器
Container container = buildContainer(info);
//存储卷
List<Volume> volumes = buildVolumes(info);
//挂载
List<VolumeMount> volumeMounts = buildVolumeMounts(volumes);

container.setVolumeMounts(volumeMounts);

//启动命令
container.setCommand(Collections.singletonList(BIN_BASH));
//训练等待命令
//一个是等待 pretreatment 文件 通过 podApi 拷贝 到pod上
//另一个是等待 服务(svc)创建成功
List<String> cmdLines = Arrays.asList("while [ ! -f /home/pretreatment ]; do echo pretreatment not exist >> pretreatment.log; sleep 1;done && chmod a+x /home/pretreatment && bash /home/pretreatment ", "until nslookup " + info.getSvcName() + "; do sleep 5; done", info.getMasterCmd());
container.setArgs(Arrays.asList("-c", CollectionUtil.join(cmdLines, " && ")));

//权限
container.setSecurityContext(new SecurityContextBuilder()
.withAllowPrivilegeEscalation(true)
.withCapabilities(new CapabilitiesBuilder()
.withAdd(Collections.singletonList(IPC_LOCK))
.build())
.build());

//用户自定义的标签
Map<String,String> customizeLabels = CollectionUtil.isNotEmpty(info.getLabels())? info.getLabels(): new HashMap<>();

JobBuilder builder = new JobBuilder();
builder.withNewMetadata()
.withName(info.getJobName())
.withNamespace(info.getNamespace())
.addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName())
.addToLabels(customizeLabels)
.addToOwnerReferences(info.getOwnerReference())
.endMetadata()
.withNewSpec()
//并行1个
.withParallelism(NUMBER_1)
//共计运行1次
.withCompletions(NUMBER_1)
//失败重试次数
.withBackoffLimit(KubeConstants.BACKOFFLIMIT)
.withNewTemplate()
.withNewMetadata()
.withName(info.getJobName())
.addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName())
.addToLabels(KubeConstants.JOB_LABEL, info.getJobName())
.addToLabels(customizeLabels)
.endMetadata()
.withNewSpec()
//关闭指令发出时 立即执行
.withTerminationGracePeriodSeconds(LONG_NUMBER_0)
.addToContainers(container)
.addToVolumes(volumes.toArray(new Volume[volumes.size()]))
.withRestartPolicy(RESTART_POLICY_NEVER)
.endSpec()
.endTemplate()
.endSpec();

//init-container
JobBuilder finalBuilder = builder;
Optional.ofNullable(info.getInitContainer())
.ifPresent(initContainer -> {
finalBuilder.editSpec()
.editTemplate()
.editSpec()
.addToInitContainers(initContainer)
.endSpec()
.endTemplate()
.endSpec();
});

//固定节点调度
if(CollectionUtil.isNotEmpty(info.getNodeSelector())){
builder = builder.editSpec()
.editTemplate().editSpec()
.addToNodeSelector(info.getNodeSelector())
.endSpec().endTemplate()
.endSpec();
}

return builder;
}




/**
* 构建容器
* @param info 资源信息
* @return 容器信息
*/
private Container buildContainer(ChildResourceCreateInfo info){
//容器
Container container = new Container();
//镜像
container.setName(KubeConstants.MASTER_CONTAINER_NAME);
container.setImage(info.getImage());
container.setImagePullPolicy(info.getImagePullPolicy());
//端口映射
container.setPorts(Arrays.asList(new ContainerPortBuilder()
.withContainerPort(NUMBER_22)
.withName(SSH).build()));
//环境变量
List<EnvVar> envVars = Lists.newArrayList(new EnvVarBuilder()
.withName(KubeConstants.ENV_NODE_NUM)
.withValue(String.valueOf(info.getSlaveReplicas() + info.getMasterReplicas()))
.build());
Optional.ofNullable(info.getEnv()).ifPresent(v -> envVars.addAll(v));
container.setEnv(envVars);

//资源限制
Optional.ofNullable(info.getMasterResources()).ifPresent(v->container.setResources(v));
return container;
}

/**
* 构建存储卷集合
* @param info 资源信息
* @return 存储卷集合
*/
private List<Volume> buildVolumes(ChildResourceCreateInfo info){
//存储卷
List<Volume> volumes = new LinkedList<>();
Optional.ofNullable(info.getWorkspaceVolume()).ifPresent(v-> volumes.add(v));
Optional.ofNullable(info.getDatasetVolume()).ifPresent(v-> volumes.add(v));
Optional.ofNullable(info.getModelVolume()).ifPresent(v-> volumes.add(v));
//shm默认就有
volumes.add(new VolumeBuilder()
.withName(KubeConstants.VOLUME_SHM)
.withNewEmptyDir()
.withMedium(MEMORY)
.endEmptyDir()
.build());

return volumes;
}

/**
* 构建挂载存储卷集合
* @param volumes 存储卷集合
* @return 构建挂载存储卷集合
*/
private List<VolumeMount> buildVolumeMounts(List<Volume> volumes) {
List<VolumeMount> volumeMounts = new LinkedList<>();
for (Volume volume : volumes) {
if(PVC_WORKSPACE.equals(volume.getName())){
volumeMounts.add(new VolumeMountBuilder()
.withName(volume.getName())
.withMountPath(WORKSPACE)
.build());
continue;
}
if(PVC_DATASET.equals(volume.getName())){
volumeMounts.add(new VolumeMountBuilder()
.withName(volume.getName())
.withMountPath(DATASET)
.build());
continue;
}
if(PVC_MODEL.equals(volume.getName())){
volumeMounts.add(new VolumeMountBuilder()
.withName(volume.getName())
.withMountPath(MODEL)
.build());
continue;
}
}

volumeMounts.add(new VolumeMountBuilder()
.withName(KubeConstants.VOLUME_SHM)
.withMountPath(DEV_SHM)
.build());
return volumeMounts;
}
}

+ 73
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseServiceDeployer.java View File

@@ -0,0 +1,73 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer.impl;

import cn.hutool.core.collection.CollectionUtil;
import io.fabric8.kubernetes.api.model.IntOrString;
import io.fabric8.kubernetes.api.model.ServiceBuilder;
import org.onebrain.operator.action.deployer.ChildResourceCreateInfo;
import org.onebrain.operator.action.deployer.ServiceDeployer;
import org.onebrain.operator.constants.KubeConstants;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.onebrain.operator.constants.NumberConstant.NUMBER_22;
import static org.onebrain.operator.constants.NumberConstant.NUMBER_30000;

/**
* @description Service部署器
* @date 2020-09-23
*/
public class BaseServiceDeployer implements ServiceDeployer<ChildResourceCreateInfo> {

public static final String WEB_SSH = "web-ssh";
public static final String NONE = "None";

/**
* 构建service信息
* @param info 资源信息
* @return
*/
@Override
public ServiceBuilder deploy(ChildResourceCreateInfo info) {

//用户自定义的标签
Map<String,String> customizeLabels = CollectionUtil.isNotEmpty(info.getLabels())? info.getLabels(): new HashMap<>();

return new ServiceBuilder()
.withNewMetadata()
.withName(info.getSvcName())
.addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName())
.addToLabels(customizeLabels)
.withNamespace(info.getNamespace())
.addToOwnerReferences(info.getOwnerReference())
.endMetadata()
.withNewSpec()
.addNewPort()
.withPort(NUMBER_30000)
.withTargetPort(new IntOrString(NUMBER_22))
.withName(WEB_SSH)
.endPort()
.withClusterIP(NONE)
//选择带有分布式训练的节点
.withSelector(Collections.singletonMap(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName()))
.endSpec();
}
}

+ 246
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/deployer/impl/BaseStatefulSetDeployer.java View File

@@ -0,0 +1,246 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.deployer.impl;

import cn.hutool.core.collection.CollectionUtil;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import io.fabric8.kubernetes.api.model.CapabilitiesBuilder;
import io.fabric8.kubernetes.api.model.Container;
import io.fabric8.kubernetes.api.model.ContainerPortBuilder;
import io.fabric8.kubernetes.api.model.EnvVar;
import io.fabric8.kubernetes.api.model.EnvVarBuilder;
import io.fabric8.kubernetes.api.model.LabelSelector;
import io.fabric8.kubernetes.api.model.SecurityContextBuilder;
import io.fabric8.kubernetes.api.model.Volume;
import io.fabric8.kubernetes.api.model.VolumeBuilder;
import io.fabric8.kubernetes.api.model.VolumeMount;
import io.fabric8.kubernetes.api.model.VolumeMountBuilder;
import io.fabric8.kubernetes.api.model.apps.StatefulSetBuilder;
import org.onebrain.operator.action.deployer.ChildResourceCreateInfo;
import org.onebrain.operator.action.deployer.StatefulSetDeployer;
import org.onebrain.operator.constants.KubeConstants;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static org.onebrain.operator.constants.NumberConstant.LONG_NUMBER_0;
import static org.onebrain.operator.constants.NumberConstant.LONG_NUMBER_60;
import static org.onebrain.operator.constants.NumberConstant.NUMBER_22;

/**
* @description StatefullSet部署器
* @date 2020-09-23
*/
public class BaseStatefulSetDeployer implements StatefulSetDeployer<ChildResourceCreateInfo> {

public static final String SSH = "ssh";
public static final String PVC_WORKSPACE = "pvc-workspace";
public static final String WORKSPACE = "/workspace";
public static final String PVC_DATASET = "pvc-dataset";
public static final String DATASET = "/dataset";
public static final String PVC_MODEL = "pvc-model";
public static final String MODEL = "/model";
public static final String MEMORY = "Memory";
public static final String DEV_SHM = "/dev/shm";
public static final String BIN_BASH = "/bin/bash";
public static final String IPC_LOCK = "IPC_LOCK";

/**
* 生成 StatefullSet 信息
* @param info 资源信息
* @return
*/
@Override
public StatefulSetBuilder deploy(ChildResourceCreateInfo info) {
//标签筛选
LabelSelector labelSelector = new LabelSelector();
labelSelector.setMatchLabels(ImmutableMap.of(KubeConstants.STATEFULSET_LABEL, info.getStatefulSetName()));
//存储卷
List<Volume> volumes = buildVolumes(info);
//容器
Container container = buildContainer(info);
//挂载
List<VolumeMount> volumeMounts = buildVolumeMounts(volumes);

container.setVolumeMounts(volumeMounts);

//启动命令
List<String> cmdLines = Arrays.asList("while [ ! -f /home/pretreatment ]; do echo pretreatment not exist >> pretreatment.log; sleep 1;done && chmod a+x /home/pretreatment && bash /home/pretreatment ", "until nslookup " + info.getSvcName() + "; do sleep 5; done", info.getSlaveCmd());
container.setCommand(Collections.singletonList(BIN_BASH));
container.setArgs(Arrays.asList("-c", CollectionUtil.join(cmdLines, " && ")));

//权限
container.setSecurityContext(new SecurityContextBuilder()
.withAllowPrivilegeEscalation(true)
// .withPrivileged(true)
.withCapabilities(new CapabilitiesBuilder()
.withAdd(Collections.singletonList(IPC_LOCK))
.build())
.build());

//用户自定义的标签
Map<String,String> customizeLabels = CollectionUtil.isNotEmpty(info.getLabels())? info.getLabels(): new HashMap<>();


StatefulSetBuilder builder = new StatefulSetBuilder();
builder.withNewMetadata()
.withName(info.getStatefulSetName())
.withNamespace(info.getNamespace())
.addToOwnerReferences(info.getOwnerReference())
.addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName())
.endMetadata()
.withNewSpec()
.withSelector(labelSelector)
.withServiceName(info.getStatefulSetName())
.withReplicas(info.getSlaveReplicas())
.withNewTemplate()
.withNewMetadata()
.withName(info.getStatefulSetName())
.addToLabels(KubeConstants.DISTRIBUTE_TRAIN_LABEL, info.getParentName())
.addToLabels(KubeConstants.STATEFULSET_LABEL, info.getStatefulSetName())
.addToLabels(customizeLabels)
.endMetadata()
.withNewSpec()
.withTerminationGracePeriodSeconds(LONG_NUMBER_0)
.withTerminationGracePeriodSeconds(LONG_NUMBER_60)
.addToContainers(container)
.addToVolumes(volumes.toArray(new Volume[0]))
.endSpec()
.endTemplate()
.endSpec();

//init-container
StatefulSetBuilder finalBuilder = builder;
Optional.ofNullable(info.getInitContainer())
.ifPresent(initContainer -> {
finalBuilder.editSpec()
.editTemplate()
.editSpec()
.addToInitContainers(initContainer)
.endSpec()
.endTemplate()
.endSpec();
});

//固定节点调度
if(CollectionUtil.isNotEmpty(info.getNodeSelector())){
builder = builder.editSpec()
.editTemplate().editSpec()
.addToNodeSelector(info.getNodeSelector())
.endSpec().endTemplate()
.endSpec();
}

return builder;
}

/**
* 构建容器
* @param info 资源信息
* @return 容器信息
*/
private Container buildContainer(ChildResourceCreateInfo info) {
Container container = new Container();
//镜像
container.setName(KubeConstants.SLAVE_CONTAINER_NAME);
container.setImage(info.getImage());
container.setImagePullPolicy(info.getImagePullPolicy());
//端口映射
container.setPorts(Arrays.asList(new ContainerPortBuilder()
.withContainerPort(NUMBER_22)
.withName(SSH).build()));
//环境变量
List<EnvVar> envVars = Lists.newArrayList(new EnvVarBuilder()
.withName(KubeConstants.ENV_NODE_NUM)
.withValue(String.valueOf(info.getSlaveReplicas() + info.getMasterReplicas()))
.build());
Optional.ofNullable(info.getEnv()).ifPresent(v -> envVars.addAll(v));
container.setEnv(envVars);

//资源限制
Optional.ofNullable(info.getSlaveResources()).ifPresent(v -> container.setResources(v));

return container;
}

/**
* 构建存储卷集合
* @param info 资源信息
* @return 存储卷集合
*/
private List<Volume> buildVolumes(ChildResourceCreateInfo info) {
List<Volume> volumes = buildVolumes(info);
Optional.ofNullable(info.getWorkspaceVolume()).ifPresent(v-> volumes.add(v));
Optional.ofNullable(info.getDatasetVolume()).ifPresent(v-> volumes.add(v));
Optional.ofNullable(info.getModelVolume()).ifPresent(v-> volumes.add(v));

//shm默认就有
volumes.add(new VolumeBuilder()
.withName(KubeConstants.VOLUME_SHM)
.withNewEmptyDir()
.withMedium(MEMORY)
.endEmptyDir()
.build());

return volumes;
}

/**
* 构建挂载存储卷集合
* @param volumes 存储卷集合
* @return 构建挂载存储卷集合
*/
private List<VolumeMount> buildVolumeMounts(List<Volume> volumes) {
List<VolumeMount> volumeMounts=new LinkedList<>();
for (Volume volume : volumes) {
if(PVC_WORKSPACE.equals(volume.getName())){
volumeMounts.add(new VolumeMountBuilder()
.withName(volume.getName())
.withMountPath(WORKSPACE)
.build());
continue;
}
if(PVC_DATASET.equals(volume.getName())){
volumeMounts.add(new VolumeMountBuilder()
.withName(volume.getName())
.withMountPath(DATASET)
.build());
continue;
}
if(PVC_MODEL.equals(volume.getName())){
volumeMounts.add(new VolumeMountBuilder()
.withName(volume.getName())
.withMountPath(MODEL)
.build());
continue;
}
}

volumeMounts.add(new VolumeMountBuilder()
.withName(KubeConstants.VOLUME_SHM)
.withMountPath(DEV_SHM)
.build());
return volumeMounts;
}
}

+ 614
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/AddActionHandler.java View File

@@ -0,0 +1,614 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.handler;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.io.FileUtil;
import cn.hutool.core.util.ObjectUtil;
import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONArray;
import com.alibaba.fastjson.JSONObject;
import com.google.common.collect.Lists;
import com.google.common.io.Files;
import io.fabric8.kubernetes.api.model.ObjectMeta;
import io.fabric8.kubernetes.api.model.Pod;
import io.fabric8.kubernetes.api.model.Service;
import io.fabric8.kubernetes.api.model.ServiceBuilder;
import io.fabric8.kubernetes.api.model.apps.StatefulSet;
import io.fabric8.kubernetes.api.model.apps.StatefulSetBuilder;
import io.fabric8.kubernetes.api.model.batch.Job;
import io.fabric8.kubernetes.api.model.batch.JobBuilder;
import io.fabric8.kubernetes.client.KubernetesClient;
import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.action.PodInfo;
import org.onebrain.operator.action.deployer.ChildResourceCreateInfo;
import org.onebrain.operator.action.deployer.JobDeployer;
import org.onebrain.operator.action.deployer.ServiceDeployer;
import org.onebrain.operator.action.deployer.StatefulSetDeployer;
import org.onebrain.operator.action.deployer.impl.BaseJobDeployer;
import org.onebrain.operator.action.deployer.impl.BaseServiceDeployer;
import org.onebrain.operator.action.deployer.impl.BaseStatefulSetDeployer;
import org.onebrain.operator.api.pod.PodApi;
import org.onebrain.operator.constants.KubeConstants;
import org.onebrain.operator.crd.DistributeTrain;
import org.onebrain.operator.crd.DistributeTrainSpec;
import org.onebrain.operator.crd.DistributeTrainStatus;
import org.onebrain.operator.exception.OperatorException;
import org.onebrain.operator.redis.RedisService;
import org.onebrain.operator.redis.key.OperatorKey;
import org.onebrain.operator.utils.DistributeTrainClientHolder;
import org.onebrain.operator.utils.IOUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.core.io.ClassPathResource;
import org.springframework.stereotype.Component;

import java.io.File;
import java.io.InputStream;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static org.onebrain.operator.constants.KubeConstants.CHARSET;
import static org.onebrain.operator.constants.KubeConstants.JOB_LABEL;
import static org.onebrain.operator.constants.KubeConstants.MASTER_CONTAINER_NAME;
import static org.onebrain.operator.constants.KubeConstants.SLAVE_CONTAINER_NAME;
import static org.onebrain.operator.constants.KubeConstants.STATEFULSET_LABEL;
import static org.onebrain.operator.constants.NumberConstant.NUMBER_2;

/**
* @description 分布式训练添加事件的处理器
* @date 2020-09-23
*/
@Component("addActionHandler")
@Slf4j
public class AddActionHandler implements DistributeTrainActionHandler {

public static final String JOB_WATCHER = "job-watcher-";
public static final String PRETREATMENT = "pretreatment";
public static final String JOB_NAME = "job-name";
public static final String RUNNING = "Running";
public static final String MASTER = "master";
public static final String SLAVE = "slave";
public static final String PRETREATMENT_TARGET_DIR = "/home/pretreatment";
public static final String IP = "ip";
public static final String ROLE = "role";
public static final String HOSTFILE_TARGET_DIR = "/home/hostfile.json";
@Autowired
private KubernetesClient client;

@Autowired
private PodApi podApi;

/**
* String 训练uid List pod信息
*/
private Map<String, List<PodInfo>> dtMap = new ConcurrentHashMap();

@Autowired
private RedisService redis;

/**
* 线程池
*/
private ThreadPoolExecutor pool = new ThreadPoolExecutor(5, 10, 10, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1), new ThreadFactory() {
private final AtomicInteger mThreadNum = new AtomicInteger(1);

@Override
public Thread newThread(Runnable r) {
return new Thread(r, JOB_WATCHER + mThreadNum.getAndIncrement());
}
}, new ThreadPoolExecutor.DiscardOldestPolicy());

/**
* 处理事件的任务
*/
class HandlerActionTask implements Runnable {

private DistributeTrain distributeTrain;

public HandlerActionTask(DistributeTrain distributeTrain) {
this.distributeTrain = distributeTrain;
}

@Override
public void run() {
doAction(distributeTrain);
}
}

/**
* 执行任务动作
* @param distributeTrain
*/
public void doAction(DistributeTrain distributeTrain) {
log.info("doAction=>distributeTrain : 【{}】", distributeTrain);
ChildResourceCreateInfo info = null;
try {
//redis重复检查
//根据k8s 创建DistributionTrain 的uid去重
if (null != redis.get(OperatorKey.CR, distributeTrain.getMetadata().getUid())) {
log.info("distribute train 【{}】 in namespace 【{}】 already exists", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace());
return;
} else {
//录入redis做消费记录
redis.set(OperatorKey.CR, distributeTrain.getMetadata().getUid(), System.currentTimeMillis());
}

//参数检查,提取并生成所需参数
validateParams(distributeTrain);
info = ChildResourceCreateInfo.fromCr(distributeTrain);
//按照size,创建副本数为size-1的statefulSet
createStatefulSet(info);
//等待statefulset全部ready
waitUntilStatefulSetReady(info);
//创建job,job此时在死循环
createJob(info);
//等待job ready
waitUntilJobReady(info);
//复制 /home/pretreatment 到 pod
copyPretreatmentShell(info);
//收集statefulSet和job的ip
validateAndCollectPods(info);
//本地生成公私钥、认证文件,并拷贝到所有节点的~/.ssh目录下
sshAuthWithoutPass(info);
//本地生成hostfile,并拷贝到所有节点的指定目录下
generateAndUploadHostFile(info);
//解锁job的死循环
releaseInterLock(info);
//改状态
//updateStatus(info, distributeTrain);
//为job注册监听器
registerJobListener(info);

log.info("all parts of【{}】 are ready", info.getParentName());
} catch (Exception e) {
log.error("doAction error:【{}】", e);
//移除缓存
redis.del(OperatorKey.CR, distributeTrain.getMetadata().getUid());
//回收创建的资源
if (info != null) {
recycleCr(info);
}
}
}

/**
* 处理分布式训练
* @param distributeTrain 分布式训练信息
*/
@Override
public void handlerAction(DistributeTrain distributeTrain) {
log.info("handlerAction=>distributeTrain : 【{}】", distributeTrain);
HandlerActionTask handlerActionTask = new HandlerActionTask(distributeTrain);
pool.getActiveCount();
pool.execute(handlerActionTask);
}

/**
* 校验参数合法性
* @param distributeTrain 分布式训练
*/
private void validateParams(DistributeTrain distributeTrain) {
log.info("validateParams=>distributeTrain : 【{}】", distributeTrain);
Integer size = distributeTrain.getSpec().getSize();
if (size < NUMBER_2) {
throw new OperatorException("size must be greater than 1");
}
String masterCmd = distributeTrain.getSpec().getMasterCmd();
String slaveCmd = distributeTrain.getSpec().getSlaveCmd();
if (StrUtil.isEmpty(slaveCmd) || StrUtil.isEmpty(masterCmd)) {
throw new OperatorException("cmd lines must not be empty");
}
}

/**
* 拷贝文件pretreatment到pod
* @param info 资源信息
*/
private void copyPretreatmentShell(ChildResourceCreateInfo info) {
log.info("start to copy pretreatment for 【{}】 ", info.getParentName());
try {
String path = System.getProperty(KubeConstants.USER_DIR_SYSTEM_PROPERTY) + File.separator + PRETREATMENT;
if (!FileUtil.exist(path)) {
FileUtil.writeFromStream(new ClassPathResource("/shell/pretreatment").getInputStream(), path);
}
File pretreatment = new File(path);
//上传到pod指定目录
List<Pod> pods = getPods(info);
for (int i = 0; i < pods.size(); i++) {
Pod pod = pods.get(i);
//默认第一个为master
String containerName = i < 1 ? MASTER_CONTAINER_NAME : SLAVE_CONTAINER_NAME;
podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, pretreatment, PRETREATMENT_TARGET_DIR);
}
} catch (Exception e) {
log.error("copy pretreatment shell error: 【{}】",e);
throw new OperatorException("exception is thrown when copy pretreatment for 【" + info.getParentName() + "】 : \n" + e.getMessage());
}
}

/**
* 创建statefulSet
* @param info 资源信息
*/
private void createStatefulSet(ChildResourceCreateInfo info) {
log.info("createStatefulSet=>childResourceCreateInfo : 【{}】", info);
StatefulSet statefulSet = client.apps().statefulSets()
.inNamespace(info.getNamespace())
.withName(info.getStatefulSetName()).get();
//已存在
if (statefulSet != null) {
log.info("statefulSet 【{}】 already exists", statefulSet.getMetadata().getName());
return;
}
//不存在,新建
StatefulSetDeployer deployer = new BaseStatefulSetDeployer();
StatefulSetBuilder builder = deployer.deploy(info);
statefulSet = builder.build();
client.apps().statefulSets().create(statefulSet);
log.info("create statefulSet【{}】 successfully", statefulSet.getMetadata().getName());
}

/**
* 等待statefulSet全部ready
* @param info 资源信息
*/
private void waitUntilStatefulSetReady(ChildResourceCreateInfo info) {
log.info("wait for statefulSet 【{}】 in namespace 【{}】 ready", info.getStatefulSetName(), info.getNamespace());
try {
client.apps().statefulSets()
.inNamespace(info.getNamespace())
.withName(info.getStatefulSetName())
//阻塞 直到全部pod Ready 最长阻塞时间2小时
.waitUntilCondition(c ->
c.getStatus().getReplicas() != null
&& ObjectUtil.equal(c.getStatus().getReplicas(), c.getStatus().getReadyReplicas()),
NUMBER_2, TimeUnit.HOURS);
log.info("statefulSet 【{}】 in namespace 【{}】 is ready", info.getStatefulSetName(), info.getNamespace());
} catch (Exception e) {
log.error("wait until statefulSet ready error:【{}】", e);
throw new OperatorException("exception is thrown when waiting for statefulSet 【" + info.getStatefulSetName() + "】 ready : \n" + e.getMessage());
}
}

/**
* 创建job
* @param info Job信息
*/
private void createJob(ChildResourceCreateInfo info) {
log.info("createJob=>childResourceCreateInfo : 【{}】", info);
Job job = client.batch().jobs()
.inNamespace(info.getNamespace())
.withName(info.getJobName()).get();
//已存在
if (job != null) {
log.info("job 【{}】 already exists", job.getMetadata().getName());
return;
}
//不存在,新建
JobDeployer deployer = new BaseJobDeployer();
JobBuilder builder = deployer.deploy(info);
job = builder.build();
log.info("job is : 【{}】", job);
client.batch().jobs().create(job);
log.info("create job【{}】 successfully", job.getMetadata().getName());
}

/**
* 等待job全部ready
* @param info 资源信息
*/
private void waitUntilJobReady(ChildResourceCreateInfo info) {
log.info("wait for job 【{}】 in namespace 【{}】 ready", info.getStatefulSetName(), info.getNamespace());
try {
List<Pod> podList = client.pods().inNamespace(info.getNamespace())
.withLabel(JOB_NAME, info.getJobName())
.list().getItems();
while (CollectionUtil.isEmpty(podList)) {
TimeUnit.SECONDS.sleep(2);
podList = client.pods().inNamespace(info.getNamespace())
.withLabel(JOB_NAME, info.getJobName())
.list().getItems();
}
Pod pod = podList.get(0);
client.pods().inNamespace(info.getNamespace())
.withName(pod.getMetadata().getName())
//等待直到Ready状态 最长2小时
.waitUntilReady(2, TimeUnit.HOURS);
log.info("job 【{}】 in namespace 【{}】 is ready", info.getJobName(), info.getNamespace());
} catch (Exception e) {
log.info(e.getMessage(), e);
throw new OperatorException("exception is thrown when waiting for job 【" + info.getJobName() + "】 ready : \n" + e.getMessage());
}
}

/**
* 收集资源的podInfo
* @param info 资源信息
*/
private void validateAndCollectPods(ChildResourceCreateInfo info) {
//检查是否都在正常运行
log.info("validate pods status for 【{}】", info.getParentName());
boolean isAllSlaveRunning = true;
boolean isMasterRunning = true;
Pod masterPod = null;
List<Pod> slavePods = null;

do {
//取得主的pod
masterPod = getMasterPod(info);

//取得从的所有pod
slavePods = getSlavePods(info);

if (masterPod == null) {
log.info("can not find pod belongs to job 【{}】", info.getJobName());
return;
}
if (CollectionUtil.isEmpty(slavePods)) {
log.info("can not find pod belongs to statefulSet 【{}】", info.getStatefulSetName());
return;
}

isMasterRunning = RUNNING.equals(masterPod.getStatus().getPhase());
isAllSlaveRunning = true;
for (Pod slavePod : slavePods) {
boolean isSlaveRunning = RUNNING.equals(slavePod.getStatus().getPhase());
if (!isSlaveRunning) {
isAllSlaveRunning = false;
break;
}
}
} while (!(isMasterRunning && isAllSlaveRunning));

log.info("status checked 【{}】 all right", info.getParentName());
collectChildPodInfo(info, masterPod, slavePods);
}

/**
* 收集pod基本信息
* @param info 资源信息
* @param masterPod
* @param slavePods
*/
private void collectChildPodInfo(ChildResourceCreateInfo info, Pod masterPod, List<Pod> slavePods) {
log.info("collectChildPodInfo=>childResourceCreateInfo : 【{}】, masterPod : 【{}】, slavePods : 【{}】", info, masterPod, slavePods);
String key = info.getOwnerReference().getUid();
if (dtMap.containsKey(key)) {
dtMap.remove(key);
}
List<PodInfo> podInfos = Lists.newArrayList();
PodInfo masterPodInfo = PodInfo.builder()
.ip(masterPod.getStatus().getPodIP())
.role(MASTER)
.build();
podInfos.add(masterPodInfo);
for (Pod slavePod : slavePods) {
PodInfo slavePodInfo = PodInfo.builder()
.ip(slavePod.getStatus().getPodIP())
.role(SLAVE)
.build();
podInfos.add(slavePodInfo);
}
dtMap.put(key, podInfos);
}

/**
* ssh免密互通相关配置
* @param info 资源信息
*/
private void sshAuthWithoutPass(ChildResourceCreateInfo info) {
log.info("start to configure ssh no password environment for 【{}】 ", info.getParentName());
File tempDir = Files.createTempDir();
try (
InputStream isRsa = getClass().getClassLoader().getResourceAsStream("key/id_rsa");
InputStream isRsaPub = getClass().getClassLoader().getResourceAsStream("key/id_rsa.pub")
) {
//id_rsa
File tempIdRsa = FileUtil.createTempFile(tempDir);
IOUtils.copy(isRsa, tempIdRsa);
//id_rsa.pub
File tempIdRsaPub = FileUtil.createTempFile(tempDir);
IOUtils.copy(isRsaPub, tempIdRsaPub);
List<String> pubLines = FileUtil.readLines(tempIdRsaPub, CHARSET);
String pubKeyContent = pubLines.get(0);
//按机器修改id_rsa.pub, 并组装一个大而全的authorized_keys
List<File> idRsaPubFiles = Lists.newArrayList();
File tempAuthorizedKeys = FileUtil.createTempFile(tempDir);
List<String> pubKeys = Lists.newArrayList();
for (PodInfo podInfo : dtMap.get(info.getOwnerReference().getUid())) {
String podPubKeyContent = pubKeyContent.replace("{{ip}}", podInfo.getIp());
File tempIdRsaPubOnPod = FileUtil.createTempFile(tempDir);
FileUtil.writeLines(Collections.singletonList(podPubKeyContent), tempIdRsaPubOnPod, CHARSET);
idRsaPubFiles.add(tempIdRsaPubOnPod);
pubKeys.add(podPubKeyContent);
}
FileUtil.writeLines(pubKeys, tempAuthorizedKeys, CHARSET);

//获得所有pod, 上传三个文件
List<Pod> pods = getPods(info);
for (int i = 0; i < pods.size(); i++) {
Pod pod = pods.get(i);
String containerName = i < 1 ? MASTER_CONTAINER_NAME : SLAVE_CONTAINER_NAME;
//上传id_rsa
podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempIdRsa, "/root/.ssh/id_rsa");
//上传id_rsa.pub
File tempIdRsaPubOnPod = idRsaPubFiles.get(i);
podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempIdRsaPubOnPod, "/root/.ssh/id_rsa.pub");
//上传authorized_keys
podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempAuthorizedKeys, "/root/.ssh/authorized_keys");
//修改权限
String chmodCmd = StrUtil.format("chmod 644 /root/.ssh/authorized_keys && chmod 600 /root/.ssh/id_rsa && chmod 644 /root/.ssh/id_rsa.pub");
podApi.exec(info.getNamespace(), pod.getMetadata().getName(), containerName, chmodCmd);
}
log.info("configure ssh no password environment for 【{}】 successfully ", info.getParentName());
} catch (Exception e) {
log.error("sshAuthWithoutPass error:【{}】", e);
throw new OperatorException("exception is thrown when configure ssh no password environment for 【" + info.getParentName() + "】 : \n" + e.getMessage());
} finally {
//清理临时文件
FileUtil.del(tempDir);
}
}

/**
* 生成并上传hostfile
* @param info 资源信息
*/
private void generateAndUploadHostFile(ChildResourceCreateInfo info) {
log.info("start to configure hostfile for 【{}】 ", info.getParentName());
File tempDir = Files.createTempDir();
try {
//生成hostfile
JSONArray jsonArray = new JSONArray();
List<PodInfo> podInfos = dtMap.get(info.getOwnerReference().getUid());
for (PodInfo podInfo : podInfos) {
JSONObject podJson = new JSONObject();
podJson.put(IP, podInfo.getIp());
podJson.put(ROLE, podInfo.getRole());
jsonArray.add(podJson);
}
File tempHostFile = FileUtil.createTempFile(tempDir);
FileUtil.writeLines(Collections.singletonList(jsonArray.toJSONString()), tempHostFile, CHARSET);
//上传到pod指定目录
List<Pod> pods = getPods(info);
for (int i = 0; i < pods.size(); i++) {
Pod pod = pods.get(i);
String containerName = i < 1 ? MASTER_CONTAINER_NAME : SLAVE_CONTAINER_NAME;
podApi.copyToPod(info.getNamespace(), pod.getMetadata().getName(), containerName, tempHostFile, HOSTFILE_TARGET_DIR);
}

} catch (Exception e) {
log.error("generateAndUploadHostFile error:【{}】", e);
throw new OperatorException("exception is thrown when generate and upload hostfile for 【" + info.getParentName() + "】 : \n" + e.getMessage());
} finally {
//清理临时文件
FileUtil.del(tempDir);
}
}

/**
* 创建service 解除闭锁
* @param info
*/
private void releaseInterLock(ChildResourceCreateInfo info) {
log.info("release lock for 【{}】", info.getParentName());
ServiceDeployer deployer = new BaseServiceDeployer();
ServiceBuilder builder = deployer.deploy(info);
Service svc = builder.build();
client.services().create(svc);
log.info("lock for 【{}】 released", info.getParentName());
}

/**
* 回收cr
* @param info
*/
private void recycleCr(ChildResourceCreateInfo info) {
log.info("recycleCr=>childResourceCreateInfo : 【{}】", info);
Optional.ofNullable(DistributeTrainClientHolder.getClient())
.ifPresent(distributeTrainClient -> {
ObjectMeta metadata = new ObjectMeta();
metadata.setName(info.getParentName());
metadata.setNamespace(info.getNamespace());
DistributeTrain dt = new DistributeTrain(metadata, DistributeTrainSpec.builder()
.build());
distributeTrainClient.delete(dt);
log.info("recycle distribute train 【{}】", info.getParentName());
});
}

/**更新状态*/
private void updateStatus(ChildResourceCreateInfo info, DistributeTrain distributeTrain) {
log.info("updateStatus=>childResourceCreateInfo : 【{}】, distributeTrain : 【{}】", info, distributeTrain);
if (distributeTrain.getStatus() == null) {
distributeTrain.setStatus(new DistributeTrainStatus());
}
Integer size = distributeTrain.getSpec().getSize();
distributeTrain.getStatus().setReplicas(size);
distributeTrain.getStatus().setReadyReplicas(size);
}

/**
* 为job注册监听器
* @param info
*/
private void registerJobListener(ChildResourceCreateInfo info) {
log.info("register listener for distribute train 【{}】", info.getParentName());
// client.batch().jobs()
// .inNamespace(info.getNamespace())
// .withName(info.getJobName()).watch(null);
}

/**
* 获取所有分布式训练相关的pod
* @param info
* @return List<Pod> 分布式相关Pod集合
*/
private List<Pod> getPods(ChildResourceCreateInfo info) {
log.info("getPods=>childResourceCreateInfo : 【{}】", info);
List<Pod> pods = Lists.newArrayList();
pods.add(getMasterPod(info));
pods.addAll(getSlavePods(info));
if (CollectionUtil.hasNull(pods) || pods.size() != info.getSlaveReplicas() + 1) {
throw new OperatorException("can not get pods in correct numbers");
}
return pods;
}

/**
* 获取master信息
* @param info 资源信息
* @return Pod Master节点对应的Pod
*/
private Pod getMasterPod(ChildResourceCreateInfo info) {
log.info("getMasterPod=>childResourceCreateInfo : 【{}】", info);
List<Pod> masterPods = client.pods().inNamespace(info.getNamespace())
.withLabel(JOB_LABEL, info.getJobName())
.list().getItems();
if (CollectionUtil.isEmpty(masterPods)) {
return null;
}
return masterPods.get(0);
}

/**
* 取得从的所有pod
* @param info 资源信息
* @return List<Pod> Slave节点对应的Pod集合
*/
private List<Pod> getSlavePods(ChildResourceCreateInfo info) {
log.info("getSlavePods=>childResourceCreateInfo : 【{}】", info);
//取得从的所有pod
List<Pod> slavePods = client.pods().inNamespace(info.getNamespace())
.withLabel(STATEFULSET_LABEL, info.getStatefulSetName())
.list().getItems();
if (CollectionUtil.isEmpty(slavePods)) {
return null;
}
return slavePods;
}

}

+ 88
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DeleteActionHandler.java View File

@@ -0,0 +1,88 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.handler;

import cn.hutool.core.collection.CollectionUtil;
import io.fabric8.kubernetes.api.model.Service;
import io.fabric8.kubernetes.api.model.ServiceList;
import io.fabric8.kubernetes.api.model.apps.StatefulSet;
import io.fabric8.kubernetes.api.model.apps.StatefulSetList;
import io.fabric8.kubernetes.api.model.batch.Job;
import io.fabric8.kubernetes.api.model.batch.JobList;
import io.fabric8.kubernetes.client.KubernetesClient;
import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.constants.KubeConstants;
import org.onebrain.operator.crd.DistributeTrain;
import org.onebrain.operator.redis.RedisService;
import org.onebrain.operator.redis.key.OperatorKey;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

/**
* @description 删除事件的处理器
* @date 2020-09-23
*/
@Component("deleteActionHandler")
@Slf4j
public class DeleteActionHandler implements DistributeTrainActionHandler {

@Autowired
private KubernetesClient client;

@Autowired
private RedisService redis;

/**
* 处理删除事件
* @param distributeTrain 分布式训练信息
*/
@Override
public void handlerAction(DistributeTrain distributeTrain) {
log.info("handlerAction=>distributeTrain : 【{}】", distributeTrain);
String namespace = distributeTrain.getMetadata().getNamespace();
String parentName = distributeTrain.getMetadata().getName();
// namespace+parentName(分布式训练名称) 确定相应的资源
//删除job
JobList jobList = client.batch().jobs().inNamespace(namespace).withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, parentName).list();
if(CollectionUtil.isNotEmpty(jobList.getItems())){
for (Job item : jobList.getItems()) {
client.batch().jobs().delete(item);
}
log.info("delete job in distributeTrain 【{}】", parentName);
}
//删除statefullSete
StatefulSetList statefulSetList = client.apps().statefulSets().inNamespace(namespace).withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, parentName).list();
if(CollectionUtil.isNotEmpty(statefulSetList.getItems())){
for (StatefulSet item : statefulSetList.getItems()) {
client.apps().statefulSets().delete(item);
}
log.info("delete statefulSet in distributeTrain 【{}】", parentName);
}
//删除service
ServiceList svcList = client.services().inNamespace(namespace).withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, parentName).list();
if(CollectionUtil.isNotEmpty(svcList.getItems())){
for (Service item : svcList.getItems()) {
client.services().delete(item);
}
log.info("delete svc in distributeTrain 【{}】", parentName);
}
//删除redis里记录的分布式训练信息
redis.del(OperatorKey.CR, distributeTrain.getMetadata().getUid());
log.info("delete distributeTrain 【{}】 successfully", parentName);
}
}

+ 33
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/action/handler/DistributeTrainActionHandler.java View File

@@ -0,0 +1,33 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.action.handler;

import org.onebrain.operator.crd.DistributeTrain;

/**
* @description 分布式训练的事件处理器
* @date 2020-09-23
*/
public interface DistributeTrainActionHandler {

/**
* 处理相应的事件
* @param distributeTrain 分布式训练信息
*/
void handlerAction(DistributeTrain distributeTrain);
}

+ 85
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/DefaultPodExecListener.java View File

@@ -0,0 +1,85 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.api.pod;

import io.fabric8.kubernetes.client.dsl.ExecListener;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;

import java.util.concurrent.CountDownLatch;

/**
* @description 默认命令执行监听器
* @date 2020-09-23
*/
@Slf4j
@Getter
public class DefaultPodExecListener implements ExecListener {

/**
* pod名称
*/
private String podName;

/**
* 命名空间
*/
private String namespace;

/**
* 容器名称
*/
private String containerName;

/**
* 执行门栓 线程通信用
*/
private CountDownLatch execLatch;

public DefaultPodExecListener(String podName, String namespace, String containerName, CountDownLatch execLatch) {
this.podName = podName;
this.namespace = namespace;
this.containerName = containerName;
this.execLatch = execLatch;
}

@Override
public void onOpen(Response response) {
log.debug("shell environment in pod '{}', namespace '{}' is opened", podName, namespace);
log.debug("onOpen: {}", response);
}

@Override
public void onFailure(Throwable t, Response response) {
log.error("shell environment in pod '{}', namespace '{}' barfed", podName, namespace);
log.error("onFailure: {} {}", t.getMessage(), response);
if (execLatch != null) {
execLatch.countDown();
}
}

@Override
public void onClose(int code, String reason) {
log.debug("shell environment in pod '{}', namespace '{}' closed", podName, namespace);
log.debug("onClose: {} {}", code, reason);
if (execLatch != null) {
execLatch.countDown();
}
}
}

+ 177
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/PodApi.java View File

@@ -0,0 +1,177 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.api.pod;

import cn.hutool.core.util.StrUtil;
import io.fabric8.kubernetes.client.KubernetesClient;
import io.fabric8.kubernetes.client.dsl.ExecWatch;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.io.FileUtils;
import org.onebrain.operator.context.KubeContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.io.File;
import java.io.IOException;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.atomic.AtomicBoolean;

/**
*
* @description PodApi 操作pod 里的容器用于上传文件等操作吧
* @date 2020-09-23
*/
@Component
@Slf4j
public class PodApi {

private static final Integer DEFAULT_LOG_LINES = 50;

@Autowired
private KubeContext kubeContext;

@Autowired
private KubernetesClient client;
/**
* 从Pod下载单个文件
* @return File 临时文件,用后需要及时清理
* **/
public File copyFileFromPod(String namespace, String podName, String containerName, String filePath){
try {
File tmpFile = File.createTempFile("copy-from-pod-", "");
client.pods().inNamespace(namespace).withName(podName)
.inContainer(containerName)
.file(filePath)
.copy(tmpFile.toPath());

if(tmpFile.length() == 0){
return null;
}

return tmpFile;
} catch (IOException e) {

log.error(" File copy error : 【{}】",e);
}
return null;
}

/**
* 从Pod下载目录
* @return File 临时文件,用后需要及时清理
* **/
public File copyFolderFromPod(String namespace, String podName, String containerName, String folderPath){
final PipedInputStream stdoutInput = new PipedInputStream();
final PipedOutputStream stdoutOutput = new PipedOutputStream();
final PipedInputStream stderrInput = new PipedInputStream();
final PipedOutputStream stderrOutput = new PipedOutputStream();
final AtomicBoolean failed = new AtomicBoolean(false);
try {
stdoutInput.connect(stdoutOutput);
stderrInput.connect(stderrOutput);

//去除路径上的/前缀
if(folderPath.startsWith(StrUtil.SLASH)){
folderPath = StrUtil.removePrefix(folderPath, StrUtil.SLASH);
}

//监听器异步执行
DefaultPodExecListener defaultPodExecListener = new DefaultPodExecListener(podName, namespace, containerName, null);

StdPodExecListener stdPodExecListener = new StdPodExecListener(defaultPodExecListener, stdoutOutput, stderrOutput, failed);

ExecWatch watch = client.pods().inNamespace(namespace)
.withName(podName).inContainer(containerName)
.writingOutput(stdoutOutput).writingError(stderrOutput)
.usingListener(stdPodExecListener)
.exec("tar", "cf", "-", "-C", folderPath, ".");
// execLatch.await();

} catch (IOException e) {
log.error("copyFolderFromPod:【{}】",e);
}

File tmpFile = null;

try {
tmpFile = File.createTempFile("copy-from-pod-", ".tar");

int length;
byte[] buffer = new byte[1024];
while (!Thread.currentThread().isInterrupted()
&& (length = stdoutInput.read(buffer)) != -1) {

byte[] content = new byte[length];
System.arraycopy(buffer, 0, content, 0, length);

FileUtils.writeByteArrayToFile(tmpFile, content, true);
}

while (!Thread.currentThread().isInterrupted()
&& (length = stderrInput.read(buffer)) != -1) {
log.error(new String(buffer, 0, length));
}
} catch (IOException e) {
if (!Thread.currentThread().isInterrupted()) {
log.error("Error while pumping stream. 【{}】", e);
} else {
log.error("Interrupted while pumping stream. 【{}】", e);
}
}

return tmpFile;
}

/**
* 拷贝文件到pod
* @param namespace 命名空间
* @param podName pod名称
* @param containerName 容器名称
* @param file 文件
* @param targetDir 目标路径
*/
public void copyToPod(String namespace, String podName, String containerName, File file, String targetDir){
client.pods().inNamespace(namespace).withName(podName)
.inContainer(containerName)
.file(targetDir)
.upload(file.toPath());
}

/**
* 同步执行
* @param namespace 命名空间
* @param podName pod名称
* @param containerName 容器名称
* @param cmd 命令
*/
public void exec(String namespace, String podName, String containerName, String cmd){
try {
final CountDownLatch execLatch = new CountDownLatch(1);
ExecWatch execWatch = client.pods().inNamespace(namespace).withName(podName).inContainer(containerName)
.redirectingOutput()
.withTTY() //不展示输出
.usingListener(new DefaultPodExecListener(namespace, podName, containerName, execLatch))
.exec("sh", "-c", cmd);
execLatch.await();
} catch (InterruptedException e) {
log.error(" PodApi execute cmd error : 【{}】",e);
}
}
}

+ 83
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/api/pod/StdPodExecListener.java View File

@@ -0,0 +1,83 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.api.pod;

import io.fabric8.kubernetes.client.dsl.ExecListener;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Response;

import java.io.IOException;
import java.io.PipedOutputStream;
import java.util.concurrent.atomic.AtomicBoolean;

/**
* @description 标准pod执行监听器
* @date 2020-09-23
*/
@Slf4j
public class StdPodExecListener implements ExecListener {

private ExecListener defaultExecListener;

private PipedOutputStream stdoutOutput;

private PipedOutputStream stderrOutput;

private AtomicBoolean failed;

public StdPodExecListener(ExecListener defaultExecListener, PipedOutputStream stdoutOutput, PipedOutputStream stderrOutput, AtomicBoolean failed) {
this.defaultExecListener = defaultExecListener;
this.stdoutOutput = stdoutOutput;
this.stderrOutput = stderrOutput;
this.failed = failed;
}

@Override
public void onOpen(Response response) {
log.info("onOpen=>response : 【{}】",response);
defaultExecListener.onOpen(response);
}

@Override
public void onFailure(Throwable t, Response response) {
log.info("onFailure=> t :【{}】,response : 【{}】",t,response);
try {
failed.set(true);
stdoutOutput.close();
stderrOutput.close();
} catch (IOException e) {
log.error("Failed to close stdout and stderr pipes. 【{}】", e);
} finally {
defaultExecListener.onFailure(t, response);
}
}

@Override
public void onClose(int code, String reason) {
log.info("onClose=>code : 【{}】,reason : 【{}】",code,reason);
try {
stdoutOutput.close();
stderrOutput.close();
} catch (IOException e) {
log.error("Failed to close stdout and stderr pipes. 【{}】", e);
} finally {
defaultExecListener.onClose(code, reason);
}
}

}

+ 66
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/config/KubeConfig.java View File

@@ -0,0 +1,66 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.config;

import cn.hutool.core.util.StrUtil;
import io.fabric8.kubernetes.client.KubernetesClient;
import org.onebrain.operator.context.KubeContext;
import org.onebrain.operator.properties.KubeProperties;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;

/**
* @description k8s配置类
* @date 2020-09-23
*/
@Configuration
@EnableConfigurationProperties(KubeProperties.class)
public class KubeConfig {

@Autowired
private KubeProperties kubeProperties;

/**
* 注册k8s配置
* @return
*/
@Bean
public KubeContext kubeContext() {
if (kubeProperties == null) {
return null;
}

final String configSource = kubeProperties.getKubeconfig();
if(StrUtil.isEmpty(configSource)){
return null;
}
return new KubeContext(kubeProperties);
}

/**
* 注册k8s客户端
* @param kubeContext k8s配置
* @return
*/
@Bean
public KubernetesClient kubernetesClient(KubeContext kubeContext){
return kubeContext.getClient();
}
}

+ 34
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/constants/CrdConstants.java View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.constants;

/**
* @description crd 常量信息
* @date 2020-09-23
*/
public class CrdConstants {
public static final String CRD_GROUP = "onebrain.oneflow.org";
public static final String CRD_SINGULAR_NAME = "distributetrain";
public static final String CRD_PLURAL_NAME = "distributetrains";
public static final String CRD_NAME = CRD_PLURAL_NAME + "." + CRD_GROUP;
public static final String CRD_KIND = "DistributeTrain";
public static final String CRD_SCOPE = "Namespaced";
public static final String CRD_SHORT_NAME = "dt";
public static final String CRD_VERSION = "v1alpha1";
public static final String CRD_API_VERSION = "apiextensions.k8s.io/v1beta1";
}

+ 40
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/constants/KubeConstants.java View File

@@ -0,0 +1,40 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.constants;

/**
* @description k8s常量
* @date 2020-09-23
*/
public class KubeConstants {

public static final String DISTRIBUTE_TRAIN_LABEL = "dt-name";
public static final String STATEFULSET_LABEL = "dt-ss-name";
public static final String JOB_LABEL = "dt-job-name";
public static final String MASTER_CONTAINER_NAME = "distribute-train-master";
public static final String SLAVE_CONTAINER_NAME = "distribute-train-slave";
public final static String USER_DIR_SYSTEM_PROPERTY = "user.dir";
//不许重试
public static final Integer BACKOFFLIMIT = 0;

public static final String CHARSET = "utf-8";

public static final String ENV_NODE_NUM = "NODE_NUM";

public static final String VOLUME_SHM = "dshm";
}

+ 43
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/constants/NumberConstant.java View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.constants;

/**
* @Description 数字常量
* @Date 2020-6-9
*/
public class NumberConstant {

public final static int NUMBER_0 = 0;
public final static long LONG_NUMBER_0 = 0L;
public final static int NUMBER_1 = 1;
public final static int NUMBER_2 = 2;
public final static int NUMBER_3 = 3;
public final static int NUMBER_5 = 5;
public final static int NUMBER_10 = 10;
public final static int NUMBER_22 = 22;
public final static int NUMBER_30 = 30;
public final static int NUMBER_50 = 50;
public final static int NUMBER_60 = 60;
public final static long LONG_NUMBER_60 = 60L;
public final static int HOUR_SECOND = 60 * 60;
public final static int DAY_SECOND = 60 * 60 * 24;
public final static int WEEK_SECOND = 60 * 60 * 24 * 7;
public final static int MAX_PAGE_SIZE = 2000;
public final static int NUMBER_30000 = 30000;
}

+ 117
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/context/KubeContext.java View File

@@ -0,0 +1,117 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.context;

import cn.hutool.core.util.StrUtil;
import com.fasterxml.jackson.core.JsonProcessingException;
import io.fabric8.kubernetes.api.model.HasMetadata;
import io.fabric8.kubernetes.client.Config;
import io.fabric8.kubernetes.client.DefaultKubernetesClient;
import io.fabric8.kubernetes.client.KubernetesClient;
import io.fabric8.kubernetes.client.VersionInfo;
import io.fabric8.kubernetes.client.internal.SerializationUtils;
import io.fabric8.kubernetes.client.utils.Utils;
import lombok.Getter;
import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.properties.KubeProperties;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;


/**
* @description k8s上下文
* @date 2020-09-23
*/
@Slf4j
@Getter
public class KubeContext implements ApplicationContextAware {

private static final String AUTO = "auto";

private ApplicationContext applicationContext;

private KubernetesClient client;

private Config config;


public KubeContext(KubeProperties kubeProperties) {
String configSource = kubeProperties.getKubeconfig();
try {
if(AUTO.equals(configSource)){
//在集群内部可自动侦测
log.info("kubernetes client is in cluster mode");
client = new DefaultKubernetesClient();
config = client.getConfiguration();
}else{
if(configSource.startsWith(StrUtil.SLASH)){
log.info("read kubeconfig from file system:{}", configSource);
System.setProperty(Config.KUBERNETES_KUBECONFIG_FILE, configSource);
}else{
log.info("read kubeconfig from classpath:{}", configSource);
final String testKubeconfigFile = Utils.filePath(getClass().getResource(StrUtil.SLASH + configSource));
//修改环境变量,重新指定kubeconfig读取位置
System.setProperty(Config.KUBERNETES_KUBECONFIG_FILE, testKubeconfigFile);
}
client = new DefaultKubernetesClient();
config = client.getConfiguration();
}

//打印集群信息
log.info("ApiVersion : {}", client.getApiVersion());
log.info("MasterUrl : {}", client.getMasterUrl());
if(log.isDebugEnabled()){
VersionInfo versionInfo = client.getVersion();
log.debug("Version details of this Kubernetes cluster :-");
log.debug("Major : {}", versionInfo.getMajor());
log.debug("Minor : {}", versionInfo.getMinor());
log.debug("GitVersion : {}", versionInfo.getGitVersion());
log.debug("GitCommit : {}", versionInfo.getGitCommit());
log.debug("BuildDate : {}", versionInfo.getBuildDate());
log.debug("GitTreeState : {}", versionInfo.getGitTreeState());
log.debug("Platform : {}", versionInfo.getPlatform());
log.debug("GoVersion : {}", versionInfo.getGoVersion());
}
}catch (Exception e){
client = null;
log.error("初始化 K8sUtils 失败!", e);
e.printStackTrace();
}
}

/**
* 导出成yaml字符串
* @param resource k8s元数据
* @return
*/
public String convertToYaml(HasMetadata resource) {
try {
return SerializationUtils.dumpAsYaml(resource);
} catch (JsonProcessingException e) {
e.printStackTrace();
throw new RuntimeException("can not transform resource to yaml");
}
}

@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}

}

+ 131
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/controller/DistributeTrainController.java View File

@@ -0,0 +1,131 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.controller;

import io.fabric8.kubernetes.client.KubernetesClient;
import io.fabric8.kubernetes.client.dsl.MixedOperation;
import io.fabric8.kubernetes.client.dsl.Resource;
import io.fabric8.kubernetes.client.informers.ResourceEventHandler;
import io.fabric8.kubernetes.client.informers.SharedIndexInformer;
import io.fabric8.kubernetes.client.informers.cache.Lister;
import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.action.handler.DistributeTrainActionHandler;
import org.onebrain.operator.crd.DistributeTrain;
import org.onebrain.operator.crd.DistributeTrainList;
import org.onebrain.operator.crd.DoneableDistributeTrain;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.scheduling.annotation.Async;

import java.util.concurrent.TimeUnit;

/**
* @description 分布式训练控制器
* @date 2020-06-16
*/
@Slf4j
public class DistributeTrainController {

@Autowired
private KubernetesClient client;

/**
* 分布式训练informer
*/
private SharedIndexInformer<DistributeTrain> distributeTrainSharedIndexInformer;

/**
* 分布式训练k8s访问客户端
*/
private MixedOperation<DistributeTrain, DistributeTrainList, DoneableDistributeTrain, Resource<DistributeTrain, DoneableDistributeTrain>> distributeTrainClient;

/**
* 分布式训练lister
*/
private Lister<DistributeTrain> distributeTrainLister;

@Autowired
@Qualifier("addActionHandler")
private DistributeTrainActionHandler addActionHandler;

@Autowired
@Qualifier("deleteActionHandler")
private DistributeTrainActionHandler deleteActionHandler;

public DistributeTrainController(MixedOperation<DistributeTrain, DistributeTrainList, DoneableDistributeTrain, Resource<DistributeTrain, DoneableDistributeTrain>> distributeTrainClient, SharedIndexInformer<DistributeTrain> distributeTrainSharedIndexInformer, String namespace) {
this.distributeTrainSharedIndexInformer = distributeTrainSharedIndexInformer;
this.distributeTrainClient = distributeTrainClient;
this.distributeTrainLister = new Lister<>(distributeTrainSharedIndexInformer.getIndexer());
}

/**
* 添加事件监听器
*/
public void create() {
distributeTrainSharedIndexInformer.addEventHandler(new ResourceEventHandler<DistributeTrain>() {
/**
* 处理添加事件
* @param distributeTrain 分布式训练信息
*/
@Override
public void onAdd(DistributeTrain distributeTrain) {
log.info("add distributeTrain named 【{}】 in namespace 【{}】", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace());
addActionHandler.handlerAction(distributeTrain);
}

/**
* 处理更内心事件
* @param distributeTrain 旧的 分布式训练信息
* @param newDistributeTrain 新的 分布式训练信息
*/
@Override
public void onUpdate(DistributeTrain distributeTrain, DistributeTrain newDistributeTrain) {
log.info("update distributeTrain named 【{}】 in namespace 【{}】", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace());
}

/**
* 处理删除事件
* @param distributeTrain 分布式训练信息
* @param b 是否为未知事件
*/
@Override
public void onDelete(DistributeTrain distributeTrain, boolean b) {
log.info("delete distributeTrain named 【{}】 in namespace 【{}】", distributeTrain.getMetadata().getName(), distributeTrain.getMetadata().getNamespace());
deleteActionHandler.handlerAction(distributeTrain);
}
});
}

/**
* 运行
*/
@Async
public void run() {
log.info("Starting DistributeTrain controller");
try {
//分布式训练信息同步
while (!distributeTrainSharedIndexInformer.hasSynced()){
TimeUnit.SECONDS.sleep(1);
}
} catch (InterruptedException e) {
e.printStackTrace();
log.error("run error:【{}】",e);
}
log.info("DistributeTrain controller is Running");
}
}

+ 47
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrain.java View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.crd;

import io.fabric8.kubernetes.api.model.ObjectMeta;
import io.fabric8.kubernetes.client.CustomResource;
import lombok.Data;
import lombok.NoArgsConstructor;

/**
* @description 分布式训练
* @date 2020-09-24
*/
@Data
@NoArgsConstructor
public class DistributeTrain extends CustomResource {

/**
* 分布式训练详细规格
*/
private DistributeTrainSpec spec;

/**
* 分布式训练状态
*/
private DistributeTrainStatus status;

public DistributeTrain(ObjectMeta objectMeta, DistributeTrainSpec spec) {
this.setMetadata(objectMeta);
this.spec = spec;
}
}

+ 27
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainList.java View File

@@ -0,0 +1,27 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.crd;

import io.fabric8.kubernetes.client.CustomResourceList;

/**
* @description CRD资源列表(分布式训练)
* @date 2020-09-24
*/
public class DistributeTrainList extends CustomResourceList<DistributeTrain> {
}

+ 108
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainSpec.java View File

@@ -0,0 +1,108 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.crd;

import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.fabric8.kubernetes.api.model.*;
import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.util.List;
import java.util.Map;

/**
* @description 分布式训练详细规格
* @date 2020-09-23
*/
@JsonDeserialize(
using = JsonDeserializer.None.class
)
@Data
@NoArgsConstructor
@AllArgsConstructor
@Builder
public class DistributeTrainSpec implements KubernetesResource {

/**
* 镜像
*/
private String image;

/**
* 镜像拉取策略
*/
private String imagePullPolicy;
/**
* 机器数
*/
private Integer size;

/**
* 环境参数
*/
private List<EnvVar> env;

/**
* master 命令
*/
private String masterCmd;

/**
* slave命令
*/
private String slaveCmd;

/**
* master 资源节点限制
*/
private ResourceRequirements masterResources;

/**
* slave 资源节点限制
*/
private ResourceRequirements slaveResources;

/**
* 节点调度选择器
*/
private Map<String,String> nodeSelector;

/**
* 初始化容器
*/
private Container initContainer;

/**
* 工作目录挂载
*/
private Volume workspaceStorage;

/**
* 数据集目录挂载
*/
private Volume datasetStorage;

/**
* 模型目录挂载
*/
private Volume modelStorage;

}

+ 55
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/crd/DistributeTrainStatus.java View File

@@ -0,0 +1,55 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.crd;

import com.fasterxml.jackson.databind.JsonDeserializer;
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
import io.fabric8.kubernetes.api.model.KubernetesResource;
import lombok.Data;

/**
* @description 分布式训练状态
* @date 2020-09-23
*/
@JsonDeserialize(
using = JsonDeserializer.None.class
)
@Data
public class DistributeTrainStatus implements KubernetesResource {

/**
* 副本数
*/
private Integer replicas;

/**
* 处在ready状态的副本数
*/
private Integer readyReplicas;

/**
* 成功数
*/
private Integer success;

/**
* 失败数
*/
private Integer failed;

}

+ 31
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/crd/DoneableDistributeTrain.java View File

@@ -0,0 +1,31 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.crd;

import io.fabric8.kubernetes.api.builder.Function;
import io.fabric8.kubernetes.client.CustomResourceDoneable;

/**
* @description CRD资源的修改Builder
* @date 2020-09-24
*/
public class DoneableDistributeTrain extends CustomResourceDoneable<DistributeTrain> {
public DoneableDistributeTrain(DistributeTrain resource, Function<DistributeTrain, DistributeTrain> function) {
super(resource, function);
}
}

+ 56
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/enums/AccessModeEnum.java View File

@@ -0,0 +1,56 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.enums;

/**
* @description pvc的访问模式
* @date 2020-09-24
*/
public enum AccessModeEnum {

/**
* RWO是最基本的方式,可读可写,但只支持被单个Pod挂载
*/
RWO("ReadWriteOnce"),

/**
* 可以以只读的方式被多个Pod挂载
*/
ROX("ReadOnlyMany"),

/****/
/**
* 这种存储可以以读写的方式被多个Pod共享。
* 不是每一种存储都支持这三种方式,像共享方式,目前支持的还比较少,比较常用的是NFS。
* 在PVC绑定PV时通常根据两个条件来绑定,一个是存储的大小,另一个就是访问模式。
*/
RWX("ReadWriteMany");

/**
* 模式
*/
private final String mode;

AccessModeEnum(String mode) {
this.mode = mode;
}

public String getMode() {
return mode;
}
}

+ 49
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/exception/OperatorException.java View File

@@ -0,0 +1,49 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.exception;

import lombok.Getter;
import lombok.extern.slf4j.Slf4j;

/**
* @description Operator自定义异常
* @date 2020-09-24
*/
@Slf4j
@Getter
public class OperatorException extends RuntimeException{

/**
* 信息
*/
private String msg;

/**
* 原因
*/
private Throwable cause;

public OperatorException(String msg, Throwable cause) {
this.msg = msg;
this.cause = cause;
}

public OperatorException(String msg) {
this.msg = msg;
}
}

+ 34
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/properties/KubeProperties.java View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.properties;

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

/**
* @description 属性配置
* @date 2020-09-24
*/
@Data
@ConfigurationProperties("k8s")
@Component
public class KubeProperties {

private String kubeconfig;
}

+ 65
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/redis/AbstractKeyPrefix.java View File

@@ -0,0 +1,65 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.redis;

/**
* @description redis Key 前缀
* @date 2020-09-23
*/
public abstract class AbstractKeyPrefix {

/**
* key模板
*/
private static final String KEY_TEMPLATE = "Operator:%s";

/**
* 过期时间
*/
private int expireSeconds;

/**
* 前缀
*/
private String prefix;

public AbstractKeyPrefix(String prefix) {//0代表永不过期
this(prefix,0);
}

public AbstractKeyPrefix(String prefix, int expireSeconds) {
this.expireSeconds = expireSeconds;
this.prefix = prefix;
}

/**
* 获取过期时间
* @return
*/
public int getExpireSeconds() {//默认0代表永不过期
return expireSeconds;
}

/**
* 获取前缀
* @return
*/
public String getPrefix() {
return String.format(KEY_TEMPLATE, prefix);
}
}

+ 290
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/redis/RedisService.java View File

@@ -0,0 +1,290 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.redis;

import org.onebrain.operator.utils.FastjsonUtils;
import org.onebrain.operator.utils.RedisUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;

import java.util.Set;

/**
* @description redis服务
* @date 2020-09-03
*/
@Service
public class RedisService {

@Autowired
private RedisUtils redisUtils;

/**
* 真正key模板
*/
private static final String REAL_KEY_TEMPLATE = "%s:%s";

/**
* 获取真正的key
* @param prefix 前缀
* @param key key值
* @return 放入redis里的key值
*/
private String getRealKey(AbstractKeyPrefix prefix, String key){
return String.format(REAL_KEY_TEMPLATE, prefix.getPrefix(), key);
}

/**
* 实现命令:TTL key,以秒为单位,返回给定 key的剩余生存时间(TTL, time to live)。
* @param prefix 前缀
* @param key key值
* @return 返回过期时间秒数
*/
public long ttl(AbstractKeyPrefix prefix, String key) {
return redisUtils.ttl(getRealKey(prefix, key));
}

/**
* 实现命令:expire 设置过期时间,单位秒
* @param prefix 前缀
* @param key key值
* @param timeout 期望过期时间
*/
public void expire(AbstractKeyPrefix prefix, String key, long timeout) {
redisUtils.expire(getRealKey(prefix, key), timeout);
}

/**
* 实现命令:INCR key,增加key一次
* @param prefix 前缀
* @param key key值
* @param delta 增量
* @return 计数值
*/
public long incr(AbstractKeyPrefix prefix, String key, long delta) {
return redisUtils.incr(getRealKey(prefix, key), delta);
}

/**
* 实现命令: key,减少key一次
* @param prefix 前缀
* @param key key值
* @param delta 增量
* @return 计数值
*/
public long decr(AbstractKeyPrefix prefix, String key, long delta) {
String realKey = getRealKey(prefix, key);
if(delta < 0){
//throw new RuntimeException("递减因子必须大于0");
del(realKey);
return 0;
}
return redisUtils.decr(realKey, delta);
}

/**
* 实现命令:KEYS pattern,查找所有符合给定模式 pattern的 key
* @param prefix key前缀
* @return key集合
*/
public Set<String> keys(AbstractKeyPrefix prefix) {
String pattern = prefix.getPrefix();
return redisUtils.keys(pattern + ":*");
}

/**
* 实现命令:KEYS pattern,查找所有符合给定模式 pattern的 key
* @param prefix key前缀
* @param key key值
* @return key集合
*/
public Set<String> keys(AbstractKeyPrefix prefix, String key) {
String pattern = prefix.getPrefix();
return redisUtils.keys(pattern + ":" + key + ":*");
}

/**
* 实现命令:DEL key,删除一个key
* @param prefix key前缀
* @param key key值
*/
public void del(AbstractKeyPrefix prefix, String key) {
redisUtils.del(getRealKey(prefix, key));
}

/**
* 删除一个key
* @param realKey 真正的key
*/
public void del(String realKey) {
redisUtils.del(realKey);
}

/**
* 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key)
* @param prefix key前缀
* @param key key值
* @param value 值
*/
public void set(AbstractKeyPrefix prefix, String key, String value) {
if(prefix.getExpireSeconds() <= 0){
redisUtils.set(getRealKey(prefix, key), value);
}else{
redisUtils.set(getRealKey(prefix, key), value, prefix.getExpireSeconds());
}
}

/**
* 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key)
* @param prefix key前缀
* @param key key值
* @param value 值
* @param <T> 指定类型
*/
public <T> void set(AbstractKeyPrefix prefix, String key, T value) {
if(prefix.getExpireSeconds() <= 0){
redisUtils.set(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value));
}else{
redisUtils.set(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), prefix.getExpireSeconds());
}
}


/**
* 实现命令:SET key value EX seconds,设置key-value和超时时间(秒)
* @param prefix key前缀
* @param key key值
* @param value 值
* @param timeout 过期时间
*/
public void set(AbstractKeyPrefix prefix, String key, String value, long timeout) {
redisUtils.set(getRealKey(prefix, key), value, timeout);
}

/**
* 实现命令:SET key value EX seconds,设置key-value和超时时间(秒)
* @param prefix key前缀
* @param key key值
* @param value 值
* @param timeout 过期时间
* @param <T> 指定类型
*/
public <T> void set(AbstractKeyPrefix prefix, String key, T value, long timeout) {
redisUtils.set(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), timeout);
}

/**
* 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key)
* @param prefix key前缀
* @param key key值
* @param value 值
* @return 是否设值成功
*/
public Boolean setnx(AbstractKeyPrefix prefix, String key, String value){
if(prefix.getExpireSeconds() <= 0){
return redisUtils.setnx(getRealKey(prefix, key), value);
}else{
return redisUtils.setnx(getRealKey(prefix, key), value, prefix.getExpireSeconds());
}
}

/**
* 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key)
* @param prefix key前缀
* @param key key值
* @param value 值
* @param <T> 指定类型
* @return 是否设值成功
*/
public <T> Boolean setnx(AbstractKeyPrefix prefix, String key, T value){
if(prefix.getExpireSeconds() <= 0){
return redisUtils.setnx(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value));
}else{
return redisUtils.setnx(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), prefix.getExpireSeconds());
}
}

/**
* 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒)
* @param prefix key前缀
* @param key key值
* @param value 值
* @param timeout 过期时间
* @return 是否设值成功
*/
public Boolean setnx(AbstractKeyPrefix prefix, String key, String value, long timeout) {
return redisUtils.setnx(getRealKey(prefix, key), value, timeout);
}

/**
* 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒)
* @param prefix key前缀
* @param key key值
* @param value 值
* @param timeout 过期时间
* @param <T> 指定类型
* @return 是否设值成功
*/
public <T> Boolean setnx(AbstractKeyPrefix prefix, String key, T value, long timeout) {
return redisUtils.setnx(getRealKey(prefix, key), FastjsonUtils.convertObjectToJSON(value), timeout);
}

/**
* 实现命令:GET key,返回 key所关联的字符串值。
* @param prefix key前缀
* @param key key值
* @return 值
*/
public String get(AbstractKeyPrefix prefix, String key) {
return redisUtils.get(getRealKey(prefix, key));
}

/**
* 实现命令:GET key,返回 key所关联的字符串值。
* @param prefix key前缀
* @param key key值
* @param <T> 指定类型
* @return 值
*/
public <T> T get(AbstractKeyPrefix prefix, String key, Class<T> clazz) {
return redisUtils.get(getRealKey(prefix, key), clazz);
}

/**
* 根据key获取值
* @param lastKey 真正的key
* @param clazz 类型
* @param <T> 泛型
* @return
*/
public <T> T get(String lastKey, Class<T> clazz) {
return redisUtils.get(lastKey, clazz);
}

/**
* 实现命令:GET key,返回 key所关联的字符串值。
* @param prefix key前缀
* @param key key值
* @return 是否存在
*/
public Boolean exists(AbstractKeyPrefix prefix, String key) {
return redisUtils.exists(getRealKey(prefix, key));
}


}

+ 45
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/redis/key/OperatorKey.java View File

@@ -0,0 +1,45 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.redis.key;

import org.onebrain.operator.redis.AbstractKeyPrefix;

/**
* @description 由operator产生的cr的唯一标识
* @date 2020-09-23
*/
public class OperatorKey extends AbstractKeyPrefix {

public OperatorKey(String prefix) {
super(prefix);
}

public OperatorKey(String prefix, int expireSeconds) {
super(prefix, expireSeconds);
}

/**
* 分布式训练 Key
*/
public static final OperatorKey CR = new OperatorKey("DistributeTrain");

/**
* 分布式训练Job Key
*/
public static final OperatorKey CR_JOB = new OperatorKey("DistributeTrain:Job");
}

+ 41
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/utils/DistributeTrainClientHolder.java View File

@@ -0,0 +1,41 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.utils;

import io.fabric8.kubernetes.client.dsl.MixedOperation;
import io.fabric8.kubernetes.client.dsl.Resource;
import org.onebrain.operator.crd.DistributeTrain;
import org.onebrain.operator.crd.DistributeTrainList;
import org.onebrain.operator.crd.DoneableDistributeTrain;

/**
* @description 分布式训练客户端持有器
* @date 2020-09-23
*/
public class DistributeTrainClientHolder {

private static MixedOperation<DistributeTrain, DistributeTrainList, DoneableDistributeTrain, Resource<DistributeTrain, DoneableDistributeTrain>> distributeTrainClient;

public static MixedOperation<DistributeTrain, DistributeTrainList, DoneableDistributeTrain, Resource<DistributeTrain, DoneableDistributeTrain>> getClient(){
return distributeTrainClient;
}

public static void setDistributeTrainClient(MixedOperation<DistributeTrain, DistributeTrainList, DoneableDistributeTrain, Resource<DistributeTrain, DoneableDistributeTrain>> client){
distributeTrainClient = client;
}
}

+ 188
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/utils/FastjsonUtils.java View File

@@ -0,0 +1,188 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.utils;


import com.alibaba.fastjson.JSON;
import com.alibaba.fastjson.JSONObject;
import com.alibaba.fastjson.serializer.SerializerFeature;

import java.util.List;
import java.util.Map;

/**
* @description json工具类
* @date 2020-09-24
*/
public class FastjsonUtils {

private static final SerializerFeature[] FEATURES = {
// 输出空置字段
SerializerFeature.WriteMapNullValue,
//日期类型用日期字符串 yyyy-MM-dd HH:mm:ss
SerializerFeature.WriteDateUseDateFormat,
// list字段如果为null,输出为[],而不是null
SerializerFeature.WriteNullListAsEmpty,
// 数值字段如果为null,输出为0,而不是null
SerializerFeature.WriteNullNumberAsZero,
// Boolean字段如果为null,输出为false,而不是null
SerializerFeature.WriteNullBooleanAsFalse,
// 字符类型字段如果为null,输出为"",而不是null
SerializerFeature.WriteNullStringAsEmpty
};

/**
* 将对象转为json
* @param object
* @return json的String
*/
public static String convertObjectToJSON(Object object) {
return JSON.toJSONString(object, FEATURES);
}

/**
* 将对象转为json(无循环引用)
* @param object
* @return json的String
*/
public static String toJSONNoFeatures(Object object) {
return JSON.toJSONString(object, SerializerFeature.DisableCircularReferenceDetect);
}

/**
* 将json转为对象
* @param text
* @return 对象
*/
public static Object toBean(String text) {
return JSON.parse(text);
}

/**
* 将json转为对象
* @param text 文本字符串
* @param clazz 类型
* @param <T> 泛型
* @return 泛型对象
*/
public static <T> T toBean(String text, Class<T> clazz) {
return JSON.parseObject(text, clazz);
}

/**
* 转换为数组
* @param text 文本字符串
* @return 泛型对象
*/
public static <T> Object[] toArray(String text) {
return toArray(text, null);
}

/**
* 转换为数组
* @param text 文本字符串
* @param clazz 类型
* @return
*/
public static <T> Object[] toArray(String text, Class<T> clazz) {
return JSON.parseArray(text, clazz).toArray();
}

/**
* 转换为List
* @param text 文本字符串
* @param clazz 类型
* @return
*/
public static <T> List<T> toList(String text, Class<T> clazz) {
return JSON.parseArray(text, clazz);
}

/**
* 将string转化为序列化的json字符串
* @param text 文本字符串
* @return json对象
*/
public static Object textToJson(String text) {
Object objectJson = JSON.parse(text);
return objectJson;
}

/**
* json字符串转化为map
* @param text json字符串
* @return Map集合
*/
public static <K, V> Map<K, V> stringToCollect(String text) {
Map<K, V> m = (Map<K, V>) JSONObject.parseObject(text);
return m;
}

/**
* 转换JSON字符串为对象
* @param jsonData json字符串
* @param clazz 转换目标对象的类型
* @return json对象
*/
public static Object convertJsonToObject(String jsonData, Class<?> clazz) {
return JSONObject.parseObject(jsonData, clazz);
}

/**
* 将map转化为string
* @param m Map集合
* @return 字符串
*/
public static <K, V> String collectToString(Map<K, V> m) {
String s = JSONObject.toJSONString(m);
return s;
}

/**
* json字符串转化为map
*
* @param text 字符串
* @return Map 对象
*/
public static Map stringToMap(String text) {
Map m = JSONObject.parseObject(text);
return m;
}

/**
* 将map转化为string
*
* @param m Map集合
* @return 字符串
*/
public static String mapToString(Map m) {
String s = JSONObject.toJSONString(m);
return s;
}

/**
* 把对象转换为指定对象
* @param source 原对象
* @param target 目标class
* @param <T> 泛型
* @return 泛型对象
*/
public static <T> T toObjectFromSource(Object source,Class<T> target) {
return toBean(convertObjectToJSON(source), target);
}
}

+ 56
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/utils/IOUtils.java View File

@@ -0,0 +1,56 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.utils;

import lombok.extern.slf4j.Slf4j;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;


/**
* @description IO工具类
* @date 2020-09-24
*/
@Slf4j
public class IOUtils {

/**
* 将input流转换为文件
*
* @param is 输入流
* @param targetFile 目标文件
*/
public static void copy(InputStream is, File targetFile) {
try (FileOutputStream fos = new FileOutputStream(targetFile)) {
byte[] b = new byte[1024];
int readCount = is.read(b);
while (readCount != -1) {
// 写入数据
fos.write(b, 0, readCount);
readCount = is.read(b);
}
is.close();
fos.flush();
} catch (IOException e) {
log.error("copy file error:【{}】", e);
}
}
}

+ 289
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/utils/RedisUtils.java View File

@@ -0,0 +1,289 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.utils;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;

/**
* @description 封装redis简单的key-value操作
* @date 2020-09-23
*/
@Component
public class RedisUtils {

@Autowired
private StringRedisTemplate redisTemplate;

/**
* 实现命令:TTL key,以秒为单位,返回给定 key的剩余生存时间(TTL, time to live)。
* @param key key值
* @return 返回过期时间秒数
*/
public long ttl(String key) {
return redisTemplate.getExpire(key);
}

/**
* 实现命令:expire 设置过期时间,单位秒
* @param key key值
* @param timeout 期望过期时间
*/
public void expire(String key, long timeout) {
redisTemplate.expire(key, timeout, TimeUnit.SECONDS);
}

/**
* 实现命令:INCR key,增加key一次
* @param key key值
* @param delta 增量
* @return 计数值
*/
public long incr(String key, long delta) {
return redisTemplate.opsForValue().increment(key, delta);
}

/**
* 实现命令: key,减少key一次
* @param key key值
* @param delta 增量
* @return 计数值
*/
public long decr(String key, long delta) {
if(delta < 0){
//throw new RuntimeException("递减因子必须大于0");
del(key);
return 0;
}
return redisTemplate.opsForValue().increment(key, -delta);
}

/**
* 实现命令:KEYS pattern,查找所有符合给定模式 pattern的 key
* @return key集合
*/
public Set<String> keys(String pattern) {
return redisTemplate.keys(pattern);
}

/**
* 实现命令:DEL key,删除一个key
* @param key key值
*/
public void del(String key) {
redisTemplate.delete(key);
}

/**
* 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key)
* @param key key值
* @param value 值
*/
public void set(String key, String value) {
redisTemplate.opsForValue().set(key, value);
}

/**
* 实现命令:SET key value,设置一个key-value(将字符串值 value关联到 key)
* @param key key值
* @param value 值
* @param <T> 指定类型
*/
public <T> void set(String key, T value) {
redisTemplate.opsForValue().set(key, FastjsonUtils.convertObjectToJSON(value));
}

/**
* 实现命令:SET key value EX seconds,设置key-value和超时时间(秒)
* @param key key值
* @param value 值
* @param timeout 过期时间
*/
public void set(String key, String value, long timeout) {
redisTemplate.opsForValue().set(key, value, timeout, TimeUnit.SECONDS);
}

/**
* 实现命令:SET key value EX seconds,设置key-value和超时时间(秒)
* @param key key值
* @param value 值
* @param timeout 过期时间
* @param <T> 指定类型
*/
public <T> void set(String key, T value, long timeout) {
redisTemplate.opsForValue().set(key, FastjsonUtils.convertObjectToJSON(value), timeout, TimeUnit.SECONDS);
}

/**
* 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key)
* @param key key值
* @param value 值
* @return 是否设值成功
*/
public Boolean setnx(String key, String value){
return redisTemplate.opsForValue().setIfAbsent(key, value);
}

/**
* 实现命令:SETNX key value,设置一个key-value(将字符串值 value关联到 key)
* @param key key值
* @param value 值
* @param <T> 指定类型
* @return 是否设值成功
*/
public <T> Boolean setnx(String key, T value){
return redisTemplate.opsForValue().setIfAbsent(key, FastjsonUtils.convertObjectToJSON(value));
}

/**
* 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒)
* @param key key值
* @param value 值
* @param timeout 过期时间
* @return 是否设值成功
*/
public Boolean setnx(String key, String value, long timeout) {
return redisTemplate.opsForValue().setIfAbsent(key, value, timeout, TimeUnit.SECONDS);
}

/**
* 实现命令:SETNX key value EX seconds,设置key-value和超时时间(秒)
* @param key key值
* @param value 值
* @param timeout 过期时间
* @param <T> 指定类型
* @return 是否设值成功
*/
public <T> Boolean setnx(String key, T value, long timeout) {
return redisTemplate.opsForValue().setIfAbsent(key, FastjsonUtils.convertObjectToJSON(value), timeout, TimeUnit.SECONDS);
}

/**
* 实现命令:GET key,返回 key所关联的字符串值。
* @param key key值
* @return 值
*/
public String get(String key) {
return (String) redisTemplate.opsForValue().get(key);
}

/**
*
* 根据key获取值
* @param key 真正的key
* @param clazz 类型
* @param <T> 泛型
* @return
*/
public <T> T get(String key, Class<T> clazz) {
String value = (String) redisTemplate.opsForValue().get(key);
return (T) FastjsonUtils.convertJsonToObject(value, clazz);
}

/**
* 实现命令:GET key,返回 key所关联的字符串值。
* @param key key值
* @return 是否存在
*/
public Boolean exists(String key) {
return redisTemplate.hasKey(key);
}

/****----------------------------------Hash----------------------------------------****/

/**
* 实现命令:HSET key field value,将哈希表 key中的域 field的值设为 value
*
* @param key key
* @param field 域
* @param value 值
*/
public void hset(String key, String field, Object value) {
redisTemplate.opsForHash().put(key, field, value);
}

/**
* 实现命令:HGET key field,返回哈希表 key中给定域 field的值
*
* @param key key
* @param field 域
* @return
*/
public String hget(String key, String field) {
return (String) redisTemplate.opsForHash().get(key, field);
}

/**
* 实现命令:HDEL key field [field ...],删除哈希表 key 中的一个或多个指定域,不存在的域将被忽略。
*
* @param key key
* @param fields 域
*/
public void hdel(String key, Object... fields) {
redisTemplate.opsForHash().delete(key, fields);
}

/**
* 实现命令:HGETALL key,返回哈希表 key中,所有的域和值。
*
* @param key
* @return 域和值
*/
public Map<Object, Object> hgetall(String key) {
return redisTemplate.opsForHash().entries(key);
}

/****----------------------------------List----------------------------------------****/

/**
* 实现命令:LPUSH key value,将一个值 value插入到列表 key的表头
*
* @param key
* @param value
* @return 执行 LPUSH命令后,列表的长度。
*/
public long lpush(String key, String value) {
return redisTemplate.opsForList().leftPush(key, value);
}

/**
* 实现命令:LPOP key,移除并返回列表 key的头元素。
*
* @param key
* @return 列表key的头元素。
*/
public String lpop(String key) {
return (String) redisTemplate.opsForList().leftPop(key);
}

/**
* 实现命令:RPUSH key value,将一个值 value插入到列表 key的表尾(最右边)。
*
* @param key
* @param value
* @return 执行 LPUSH命令后,列表的长度。
*/
public long rpush(String key, String value) {
return redisTemplate.opsForList().rightPush(key, value);
}
}

+ 99
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/utils/SpringContextHolder.java View File

@@ -0,0 +1,99 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.utils;

import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

/**
* @description 上下文工具类
* @date 2020-09-24
*/
@Component
@Slf4j
public class SpringContextHolder implements ApplicationContextAware, DisposableBean {

public static ApplicationContext applicationContext = null;

/**
* 从静态变量applicationContext中取得Bean, 自动转型为所赋值对象的类型.
* @param name bean名称
* @param <T> 类型
* @return bean对象
*/
@SuppressWarnings("unchecked")
public static <T> T getBean(String name) {
assertContextInjected();
return (T) applicationContext.getBean(name);
}

/**
* 从静态变量applicationContext中取得Bean, 自动转型为所赋值对象的类型.
* @param requiredType bean类型 class
* @param <T> 泛型
* @return bean对象
*/
public static <T> T getBean(Class<T> requiredType) {
assertContextInjected();
return applicationContext.getBean(requiredType);
}

/**
* 检查ApplicationContext不为空.
*/
private static void assertContextInjected() {
if (applicationContext == null) {
throw new IllegalStateException("applicaitonContext属性未注入, 请在applicationContext" +
".xml中定义SpringContextHolder或在SpringBoot启动类中注册SpringContextHolder.");
}
}

/**
* 清除SpringContextHolder中的ApplicationContext为Null.
*/
private static void clearHolder() {
log.debug("清除SpringContextHolder中的ApplicationContext:"
+ applicationContext);
applicationContext = null;
}

/**
* 销毁回调函数
*/
@Override
public void destroy() {
SpringContextHolder.clearHolder();
}

/**
* spring上下文设置
* @param applicationContext
* @throws BeansException
*/
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
if (SpringContextHolder.applicationContext != null) {
log.warn("SpringContextHolder中的ApplicationContext被覆盖, 原有ApplicationContext为:" + SpringContextHolder.applicationContext);
}
SpringContextHolder.applicationContext = applicationContext;
}
}

+ 111
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobHandler.java View File

@@ -0,0 +1,111 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.watcher;

import cn.hutool.core.collection.CollectionUtil;
import cn.hutool.core.util.StrUtil;
import io.fabric8.kubernetes.api.model.OwnerReference;
import io.fabric8.kubernetes.api.model.apps.StatefulSet;
import io.fabric8.kubernetes.api.model.batch.Job;
import io.fabric8.kubernetes.client.KubernetesClient;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.constants.KubeConstants;
import org.onebrain.operator.redis.RedisService;
import org.onebrain.operator.redis.key.OperatorKey;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.List;

import static org.onebrain.operator.constants.CrdConstants.CRD_KIND;

/**
* @description Job处理器
* @date 2020-09-24
*/
@Data
@Slf4j
@Component
public class JobHandler {

public static final String FINISHED = "finished";
public static final String PENDING = "pending";
@Autowired
private RedisService redis;

@Autowired
private KubernetesClient client;

/**
* 处理Job
*
* @param job
*/
public void handleJob(Job job) {
log.info("handleJob=>job : 【{}】", job);

//筛选出DistributeTrain下的job
List<OwnerReference> ownerReferences = job.getMetadata().getOwnerReferences();
if (CollectionUtil.isEmpty(ownerReferences) || !CRD_KIND.equals(ownerReferences.get(0).getKind())) {
return;
}

String key = job.getMetadata().getUid();
if (StrUtil.equals(redis.get(OperatorKey.CR_JOB, key), FINISHED)) {
return;
}

try {
redis.set(OperatorKey.CR_JOB, key, PENDING);

final Integer parallelism = job.getSpec().getParallelism();
final Integer backoffLimit = job.getSpec().getBackoffLimit();
//成功 或者 失败达到最大次数
if (job.getStatus() != null
&& ((job.getStatus().getFailed() != null && job.getStatus().getFailed() + 1 >= backoffLimit)
|| (job.getStatus().getSucceeded() != null && parallelism.equals(job.getStatus().getSucceeded())))) {
//得到DistributeTrain的Statefulset
String dtName = ownerReferences.get(0).getName();
String namespace = job.getMetadata().getNamespace();

List<StatefulSet> statefulsetList = client.apps().statefulSets()
.inNamespace(namespace)
.withLabel(KubeConstants.DISTRIBUTE_TRAIN_LABEL, dtName)
.list().getItems();

if (CollectionUtil.isEmpty(statefulsetList)) {
log.info("jobWatcher: statefulset of 【{}】 not exists", dtName);
return;
}

//缩容Statefulset的replica到0
StatefulSet statefulSet = statefulsetList.get(0);
statefulSet.getSpec().setReplicas(0);
client.resource(statefulSet).createOrReplace();
log.info("jobWatcher: reduce replicas of 【{}】 to zero", dtName);

redis.set(OperatorKey.CR_JOB, key, "finished");
}

} catch (Exception e) {
redis.set(OperatorKey.CR_JOB, key, "error");
log.error("handle job error:【{}】", e);
}
}
}

+ 71
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/watcher/JobWatcher.java View File

@@ -0,0 +1,71 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.watcher;

import io.fabric8.kubernetes.api.model.batch.Job;
import io.fabric8.kubernetes.client.KubernetesClientException;
import io.fabric8.kubernetes.client.Watcher;
import lombok.Data;
import lombok.extern.slf4j.Slf4j;

/**
* @description Job监视器
* @date 2020-09-24
*/
@Data
@Slf4j
public class JobWatcher implements Watcher<Job> {

private String namespace;

private String jobName;

private KubeWatcherManager manager;

private JobHandler jobHandler;

public JobWatcher(JobHandler jobHandler, KubeWatcherManager manager) {
this.manager = manager;
this.jobHandler = jobHandler;
}

/**
* 接收事件进行处理
* @param action 事件类型
* @param job job信息
*/
@Override
public void eventReceived(Action action, Job job) {
log.info("Job Event received: {} at {}", job.getMetadata().getUid(), job.getMetadata().getCreationTimestamp());
jobHandler.handleJob(job);
}

/**
* 关闭事件
* @param e 客户端异常
*/
@Override
public void onClose(KubernetesClientException e) {
log.debug("job watcher close");
if (e != null) {
log.error(e.getMessage());
log.info("restart new job watcher thread");
manager.putNewWatcher();
}
}
}

+ 120
- 0
distribute-train-operator/src/main/java/org/onebrain/operator/watcher/KubeWatcherManager.java View File

@@ -0,0 +1,120 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.onebrain.operator.watcher;

import io.fabric8.kubernetes.client.KubernetesClient;
import lombok.extern.slf4j.Slf4j;
import org.onebrain.operator.context.KubeContext;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;

import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

/**
* @description 监视器的管理器
* @date 2020-09-24
*/
@Slf4j
@Component
public class KubeWatcherManager {

/**
* 监视队列
*/
private static final LinkedBlockingQueue<JobWatcher> watchQueue = new LinkedBlockingQueue<>(1000);

/**
* 单例线程池
*/
private ThreadPoolExecutor pool = new ThreadPoolExecutor(1, 1, 1, TimeUnit.SECONDS, new LinkedBlockingQueue<>(1), new ThreadFactory() {
private final AtomicInteger mThreadNum = new AtomicInteger(1);
@Override
public Thread newThread(Runnable r) {
return new Thread(r, "job-watcher-" + mThreadNum.getAndIncrement());
}
});

@Autowired
private KubeContext kubeContext;

@Autowired
private JobHandler jobHandler;

/**
* 第一次启动时
*/
public void startWatching(){
JobWatchHolder jobWatchHolder = new JobWatchHolder();
pool.execute(jobWatchHolder);
putNewWatcher();
}

/**
* 监听指定job
* @param jobWatcher
*/
public void watch(JobWatcher jobWatcher){
KubernetesClient client = kubeContext.getClient();
//监听指定job
client.batch().jobs()
.inAnyNamespace().watch(jobWatcher);
}

/**
* 加入新watcher
*/
public void putNewWatcher(){
try {
JobWatcher jobWatcher = new JobWatcher(jobHandler, this);
watchQueue.put(jobWatcher);
} catch (InterruptedException e) {
e.printStackTrace();
}
}

/**
* Job监视器持有者
*/
class JobWatchHolder implements Runnable {

@Override
public void run() {
while(true){
try {
//无监视器时阻塞
JobWatcher jobWatcher = watchQueue.take();

//启动监视器
try{
watch(jobWatcher);
}catch (Exception e){
//出错不影响其他listener
log.error("JobWatchHolder watch error:【{}】",e);
}

} catch (InterruptedException e) {
log.error("JobWatchHolder run error:【{}】",e);
}
}
}
}
}

+ 27
- 0
distribute-train-operator/src/main/resources/key/id_rsa View File

@@ -0,0 +1,27 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEA06ZOLQq4pzBZL+bybsxdl4PzYg3jB4kRVc771nm5Y8JenDAT
hlOTz6+nGH4EDT63J7oNj4JYLufsONKYhJkya8p0btWeKHqz5LgEfLGwz/FTMRH5
WTCZCZUa/3i9gQeKK/CKEned1h4l2w4agrYrnXHpnuNSw6HSlTpX8FgaQGfmTkL3
XtzSCeY9F2fXGOm9fMfVmv5I5uP6B4TmKwtWPvx3a/1MDgHbmtoaCqYP/JmzWHyi
mc9l2ilX3kTPxh57oRtW9N3FATc8/OCYkNt4vDUTRVB4drODaR5TgUbFtkBVGcFR
f7MrQo4Krd2g8rtEv7PaWN/wlNle5ANXJ/oL3wIDAQABAoIBADiqC8APYMSSMy6Z
/EohuOT51M1pvmCkF9oLYm1XhYTp4v6Z+IA8HBS8iFYMVvVc1xhxvXOwh/925E2K
RH8rrM4jE+0gkAlyYHtZsQnZYOcrSwSWNVXlpvNj0iiXoNTMufdtnOm40K8kvynY
qsxYDXFHsC5z2hK6XnDJgAW+8LhRHCizWwxc0dSN9r33VGry0rgndUZsj2ZBf7u5
rdslZKvRzMymXct7CIQQ3s5IUO3qbaj7TIzMIo14bmHgD3zlBQ66ESCX1o5A+hPq
1gfUNqUPBtJhsNJg4YYJ/bGgGhBxAxam8jWz3DFZEuYHr6fCDIhLJzL5ulxoQS2z
vJYBwsECgYEA8JGfw004BxqcBVxqBveestsCVGIWDtb+Zx4OI+uBAmYMXd2WCzxv
XxgQJ/IrpNx6FAXZ/bFdE0HRZWR6H07wtNgABuBgd0tAfcH8sw2CJkTO/0N2Xr6/
O4kh3yHNMy/wAxnktISf1hE/ElEdPI6slhwGDQObRdXxaqBEq+Tjc28CgYEA4TnM
rCaJ8aMaUE0nvVzrev3VTLp4f1qOcPUOnrHDdyrPs1SjYzmAOC72X/FylJZmtkvh
coMQUKVQgiBn1dTtnALANq705b1S+0U07m6+dGJ7LWchOY2tFPiIsx3SZvNJeEKJ
38PsaFi2eDcDP8cKriNoAoby8TbqjqiyHgDX9pECfxww9IfuhKJQe/gk3Ef0vKQ5
BgzdcbhLeYScAQw0jOm7C7f0P6ERc/uw1jPYLUUkkSnHhcQ1BLM9A0zeeXExzwNi
TJ6BrMxOBUC3euWAr7/MUHWZckWoFMDlURLU4zccZwP2BNcis5hibQG4f7SZA6CT
qCHeSlPkvmXAYkvChuUCgYEA0DNlL9KkfBqBja/1R4jpKhYSIs7R6zCkMmlm7W54
ueV6gVWBgI08KTPIj2KcwBzUsDovG3NrFpHrfY9FTZd7W1fzpdlQDDxaxGryhmMb
bm1HXu5R+WktkhA6FhJAWOkXhrNDzvXHyaIQc8qvFzsBdX7HfGaRmEhixiPOHAw9
l/ECgYEAwNywUARR9HtmgoyrwifrzIkMo6jcmLNEIzi2kJ4OQQxW5eKj5JgSV0ND
QUoAIWDAhHQd3ygSfbeShcvtcw+zoF92iOVFn0SLiSe1TgA5ggzC/VJUnInO7zx7
8Sj8Zk5tHrVmTlelEA2Nbq5H7/U1Q33c1AWbw8yxqD/JRxudHKA=
-----END RSA PRIVATE KEY-----

+ 1
- 0
distribute-train-operator/src/main/resources/key/id_rsa.pub View File

@@ -0,0 +1 @@
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDTpk4tCrinMFkv5vJuzF2Xg/NiDeMHiRFVzvvWebljwl6cMBOGU5PPr6cYfgQNPrcnug2Pglgu5+w40piEmTJrynRu1Z4oerPkuAR8sbDP8VMxEflZMJkJlRr/eL2BB4or8IoSd53WHiXbDhqCtiudceme41LDodKVOlfwWBpAZ+ZOQvde3NIJ5j0XZ9cY6b18x9Wa/kjm4/oHhOYrC1Y+/Hdr/UwOAdua2hoKpg/8mbNYfKKZz2XaKVfeRM/GHnuhG1b03cUBNzz84JiQ23i8NRNFUHh2s4NpHlOBRsW2QFUZwVF/sytCjgqt3aDyu0S/s9pY3/CU2V7kA1cn+gvf root@{{ip}}

+ 19
- 0
distribute-train-operator/src/main/resources/kubeconfig View File

@@ -0,0 +1,19 @@
apiVersion: v1
clusters:
- cluster:
certificate-authority-data: {}
server: {}
name: kubernetes
contexts:
- context:
cluster: kubernetes
user: kubernetes-admin
name: kubernetes-admin@kubernetes
current-context: kubernetes-admin@kubernetes
kind: Config
preferences: {}
users:
- name: kubernetes-admin
user:
client-certificate-data: {}
client-key-data: {}

+ 46
- 0
distribute-train-operator/src/main/resources/shell/pretreatment View File

@@ -0,0 +1,46 @@
#!/bin/bash
if [ ! -f "/etc/init.d/ssh" ]; then
if [ ! -f "/etc/redhat-release" ]; then
echo 'apt install -y openssh-server' >> pretreatment.log
apt update >> pretreatment.log
apt install -y openssh-server >> pretreatment.log
fi
if [ ! -f "/etc/lsb-release" ]; then
echo 'yum install -y sshd' >> pretreatment.log
yum update >> pretreatment.log
yum install -y sshd >> pretreatment.log
fi
fi
echo '/etc/init.d/ssh start' >> pretreatment.log
/etc/init.d/ssh start >> pretreatment.log
if [ -f "/etc/redhat-release" ]; then
if command -v nslookup >/dev/null 2>&1; then
echo 'exists nslookup' >> pretreatment.log
else
echo 'yum install dnsutils jq' >> pretreatment.log
yum install -y dnsutils >> pretreatment.log
yum install -y jq >> pretreatment.log
fi
if command -v nslookup >/dev/null 2>&1; then
echo 'exists nslookup' >> pretreatment.log
else
echo 'yum install dnsutils jq' >> pretreatment.log
yum install -y dnsutils >> pretreatment.log
yum install -y jq >> pretreatment.log
fi
fi

if [ -f "/etc/lsb-release" ]; then
if command -v jq >/dev/null 2>&1; then
echo 'exists jq' >> pretreatment.log
else
echo 'apt install jq' >> pretreatment.log
apt install -y jq >> pretreatment.log
fi
if command -v nslookup >/dev/null 2>&1; then
echo 'exists nslookup' >> pretreatment.log
else
echo 'apt install dnsutils' >> pretreatment.log
apt install -y dnsutils >> pretreatment.log
fi
fi

+ 43
- 0
distribute-train-operator/src/test/java/org/onebrain/operator/DistributeTrainOperatorApplicationTests.java View File

@@ -0,0 +1,43 @@
/**
* Copyright 2020 Zhejiang Lab & The OneFlow Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/


package org.onebrain.operator;

import org.onebrain.operator.api.pod.PodApi;
import org.onebrain.operator.constants.KubeConstants;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.context.SpringBootTest;

import java.io.File;
import java.net.URISyntaxException;
import java.net.URL;

@SpringBootTest
public class DistributeTrainOperatorApplicationTests {

@Autowired
private PodApi podApi;

// @Test
public void contextLoads() throws URISyntaxException {
final URL url = getClass().getClassLoader().getResource("key/id_rsa");
File file = new File(url.toURI());
podApi.copyToPod("default", "distribute-train-test-job-sv2dj", KubeConstants.MASTER_CONTAINER_NAME, file, "/root/.ssh/id_rsa");
}

}

Loading…
Cancel
Save