Browse Source

update dubhe-server

tags/v0.3.0
之江实验室 4 years ago
parent
commit
38e4deca84
100 changed files with 4388 additions and 1296 deletions
  1. +0
    -1
      dubhe-server/.gitignore
  2. +37
    -6
      dubhe-server/README.md
  3. +4
    -5
      dubhe-server/common/pom.xml
  4. +9
    -12
      dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermission.java
  5. +47
    -0
      dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermissionMethod.java
  6. +6
    -5
      dubhe-server/common/src/main/java/org/dubhe/annotation/EnumValue.java
  7. +66
    -0
      dubhe-server/common/src/main/java/org/dubhe/annotation/FlagValidator.java
  8. +5
    -5
      dubhe-server/common/src/main/java/org/dubhe/aspect/LogAspect.java
  9. +116
    -0
      dubhe-server/common/src/main/java/org/dubhe/aspect/PermissionAspect.java
  10. +1
    -1
      dubhe-server/common/src/main/java/org/dubhe/base/BaseImageDTO.java
  11. +77
    -0
      dubhe-server/common/src/main/java/org/dubhe/base/BaseService.java
  12. +2
    -2
      dubhe-server/common/src/main/java/org/dubhe/base/BaseVO.java
  13. +62
    -0
      dubhe-server/common/src/main/java/org/dubhe/base/DataContext.java
  14. +3
    -1
      dubhe-server/common/src/main/java/org/dubhe/base/MagicNumConstant.java
  15. +2
    -2
      dubhe-server/common/src/main/java/org/dubhe/base/PageQueryBase.java
  16. +19
    -20
      dubhe-server/common/src/main/java/org/dubhe/base/ScheduleTaskHandler.java
  17. +0
    -65
      dubhe-server/common/src/main/java/org/dubhe/config/KaptchaConfig.java
  18. +14
    -17
      dubhe-server/common/src/main/java/org/dubhe/config/MetaHandlerConfig.java
  19. +7
    -274
      dubhe-server/common/src/main/java/org/dubhe/config/MybatisPlusConfig.java
  20. +62
    -0
      dubhe-server/common/src/main/java/org/dubhe/config/RecycleConfig.java
  21. +18
    -23
      dubhe-server/common/src/main/java/org/dubhe/config/TrainJobConfig.java
  22. +17
    -4
      dubhe-server/common/src/main/java/org/dubhe/config/TrainPoolConfig.java
  23. +4
    -2
      dubhe-server/common/src/main/java/org/dubhe/constant/NumberConstant.java
  24. +7
    -6
      dubhe-server/common/src/main/java/org/dubhe/constant/PermissionConstant.java
  25. +22
    -17
      dubhe-server/common/src/main/java/org/dubhe/constant/StringConstant.java
  26. +3
    -1
      dubhe-server/common/src/main/java/org/dubhe/constant/SymbolConstant.java
  27. +1
    -4
      dubhe-server/common/src/main/java/org/dubhe/constant/UserAuxiliaryInfoConstant.java
  28. +52
    -0
      dubhe-server/common/src/main/java/org/dubhe/domain/dto/CommonPermissionDataDTO.java
  29. +21
    -26
      dubhe-server/common/src/main/java/org/dubhe/domain/entity/LogInfo.java
  30. +0
    -3
      dubhe-server/common/src/main/java/org/dubhe/domain/entity/Menu.java
  31. +0
    -5
      dubhe-server/common/src/main/java/org/dubhe/domain/entity/Role.java
  32. +73
    -0
      dubhe-server/common/src/main/java/org/dubhe/dto/GlobalRequestRecordDTO.java
  33. +13
    -1
      dubhe-server/common/src/main/java/org/dubhe/dto/callback/BaseK8sPodCallbackCreateDTO.java
  34. +6
    -3
      dubhe-server/common/src/main/java/org/dubhe/enums/BizEnum.java
  35. +8
    -4
      dubhe-server/common/src/main/java/org/dubhe/enums/BizNfsEnum.java
  36. +1
    -1
      dubhe-server/common/src/main/java/org/dubhe/enums/DatasetTypeEnum.java
  37. +46
    -30
      dubhe-server/common/src/main/java/org/dubhe/enums/LogEnum.java
  38. +72
    -0
      dubhe-server/common/src/main/java/org/dubhe/enums/OperationTypeEnum.java
  39. +54
    -0
      dubhe-server/common/src/main/java/org/dubhe/enums/RecycleResourceEnum.java
  40. +61
    -13
      dubhe-server/common/src/main/java/org/dubhe/enums/TrainJobStatusEnum.java
  41. +1
    -0
      dubhe-server/common/src/main/java/org/dubhe/exception/BaseErrorCode.java
  42. +42
    -0
      dubhe-server/common/src/main/java/org/dubhe/exception/DataSequenceException.java
  43. +1
    -2
      dubhe-server/common/src/main/java/org/dubhe/exception/NotebookBizException.java
  44. +144
    -144
      dubhe-server/common/src/main/java/org/dubhe/exception/handler/GlobalExceptionHandler.java
  45. +74
    -0
      dubhe-server/common/src/main/java/org/dubhe/filter/BaseLogFilter.java
  46. +49
    -0
      dubhe-server/common/src/main/java/org/dubhe/filter/ConsoleLogFilter.java
  47. +0
    -60
      dubhe-server/common/src/main/java/org/dubhe/filter/FileLogFilter.java
  48. +34
    -0
      dubhe-server/common/src/main/java/org/dubhe/filter/GlobalRequestLogFilter.java
  49. +113
    -0
      dubhe-server/common/src/main/java/org/dubhe/interceptor/MySqlInterceptor.java
  50. +457
    -0
      dubhe-server/common/src/main/java/org/dubhe/interceptor/PaginationInterceptor.java
  51. +19
    -1
      dubhe-server/common/src/main/java/org/dubhe/utils/DateUtil.java
  52. +54
    -0
      dubhe-server/common/src/main/java/org/dubhe/utils/FileUtil.java
  53. +48
    -10
      dubhe-server/common/src/main/java/org/dubhe/utils/HttpClientUtils.java
  54. +2
    -2
      dubhe-server/common/src/main/java/org/dubhe/utils/HttpUtils.java
  55. +46
    -0
      dubhe-server/common/src/main/java/org/dubhe/utils/IOUtil.java
  56. +1
    -1
      dubhe-server/common/src/main/java/org/dubhe/utils/JwtUtils.java
  57. +3
    -3
      dubhe-server/common/src/main/java/org/dubhe/utils/K8sNameTool.java
  58. +307
    -0
      dubhe-server/common/src/main/java/org/dubhe/utils/LocalFileUtil.java
  59. +279
    -214
      dubhe-server/common/src/main/java/org/dubhe/utils/LogUtil.java
  60. +2
    -2
      dubhe-server/common/src/main/java/org/dubhe/utils/MathUtils.java
  61. +20
    -12
      dubhe-server/common/src/main/java/org/dubhe/utils/MinioUtil.java
  62. +11
    -142
      dubhe-server/common/src/main/java/org/dubhe/utils/NfsUtil.java
  63. +108
    -46
      dubhe-server/common/src/main/java/org/dubhe/utils/RedisUtils.java
  64. +2
    -1
      dubhe-server/common/src/main/java/org/dubhe/utils/ReflectionUtils.java
  65. +2
    -2
      dubhe-server/common/src/main/java/org/dubhe/utils/RegexUtil.java
  66. +50
    -1
      dubhe-server/common/src/main/java/org/dubhe/utils/SqlUtil.java
  67. +37
    -0
      dubhe-server/common/src/main/java/org/dubhe/utils/StringUtils.java
  68. +18
    -27
      dubhe-server/common/src/main/java/org/dubhe/utils/TimeTransferUtil.java
  69. +2
    -3
      dubhe-server/common/src/main/java/org/dubhe/utils/UniqueKeyGenerator.java
  70. +1
    -2
      dubhe-server/common/src/main/java/org/dubhe/utils/WrapperHelp.java
  71. +2
    -2
      dubhe-server/common/src/test/java/org/dubhe/HttpUtilsTest.java
  72. +71
    -0
      dubhe-server/deploy.sh
  73. +1
    -0
      dubhe-server/dubhe-admin/pom.xml
  74. +54
    -30
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/HarborImagePushAsync.java
  75. +144
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/StopTrainJobAsync.java
  76. +115
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainAlgorithmUploadAsync.java
  77. +417
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainJobAsync.java
  78. +23
    -14
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TransactionAsyncManager.java
  79. +148
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/GlobalFilter.java
  80. +101
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/RequestBodyWrapper.java
  81. +126
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/ResponseBodyWrapper.java
  82. +1
    -1
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TimestampConverter.java
  83. +47
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/ModelQueryMapper.java
  84. +6
    -6
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/NoteBookMapper.java
  85. +2
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtImageMapper.java
  86. +2
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmMapper.java
  87. +4
    -4
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmUsageMapper.java
  88. +2
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainJobMapper.java
  89. +2
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainMapper.java
  90. +2
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainParamMapper.java
  91. +36
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/BaseTrainJobDTO.java
  92. +39
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/ModelQueryDTO.java
  93. +1
    -2
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookListQueryDTO.java
  94. +1
    -1
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookQueryDTO.java
  95. +2
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDTO.java
  96. +39
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDeleteDTO.java
  97. +3
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageQueryDTO.java
  98. +47
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUpdateDTO.java
  99. +2
    -2
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUploadDTO.java
  100. +3
    -0
      dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmCreateDTO.java

+ 0
- 1
dubhe-server/.gitignore View File

@@ -48,6 +48,5 @@ output/
.classpath
logs/
/dubhe-k8s/src/main/resources/kubeconfig
/dubhe-k8s/src/main/resources
*.log
/dubhe-admin/kubeconfig

+ 37
- 6
dubhe-server/README.md View File

@@ -1,13 +1,22 @@
# 一站式开发平台-服务端
# 之江天枢-服务端

## 本地开发
**之江天枢一站式人工智能开源平台**(简称:**之江天枢**),包括海量数据处理、交互式模型构建(包含Notebook和模型可视化)、AI模型高效训练。多维度产品形态满足从开发者到大型企业的不同需求,将提升人工智能技术的研发效率、扩大算法模型的应用范围,进一步构建人工智能生态“朋友圈”。

## 源码部署

### 准备环境
安装如下软件环境。
- OpenJDK:1.8+
- Redis: 3.0+
- Maven: 3.0+
- MYSQL: 5.5.0+
- MYSQL: 5.7.0+

### 下载源码
``` bash
git clone https://codeup.teambition.com/zhejianglab/dubhe-server.git
# 进入项目根目录
cd dubhe-server
```

### 创建DB
在MySQL中依次执行如下sql文件
@@ -20,14 +29,35 @@ sql/v1/02-Dubhe-DML.sql
### 配置
根据实际情况修改如下配置文件。
```
dubhe-admin/src/main/resources/config/application-dev.yml
dubhe-admin/src/main/resources/config/application-prod.yml
```

### 启动:
### 构建
``` bash
# 构建,生成的 jar 包位于 ./dubhe-admin/target/dubhe-admin-1.0.jar
mvn clean compile package
```
mvn spring-boot:run

### 启动
``` bash
# 指定启动环境为 prod
## admin模块
java -jar ./dubhe-admin/target/dubhe-admin-1.0-exec.jar --spring.profiles.active=prod

## task模块
java -jar ./dubhe-task/target/dubhe-task-1.0.jar --spring.profiles.active=prod
```

## 本地开发

### 必要条件:
导入maven项目,下载所需的依赖包
mysql下创建数据库dubhe,初始化数据脚本
安装redis

### 启动:
mvn spring-boot:run

## 代码结构:
```
├── common 公共模块
@@ -49,6 +79,7 @@ mvn spring-boot:run
├── dubhe-data 数据处理模块
├── dubhe-model 模型管理模块
├── dubhe-system 系统管理
├── dubhe-task 定时任务模块
```

## docker服务器


+ 4
- 5
dubhe-server/common/pom.xml View File

@@ -29,11 +29,6 @@
<artifactId>guava</artifactId>
<version>21.0</version>
</dependency>
<dependency>
<groupId>com.github.penggle</groupId>
<artifactId>kaptcha</artifactId>
<version>${kaptcha.version}</version>
</dependency>
<!-- shiro -->
<dependency>
<groupId>org.apache.shiro</groupId>
@@ -95,6 +90,10 @@
<artifactId>commons-compress</artifactId>
<version>1.20</version>
</dependency>
<dependency>
<groupId>com.github.whvcse</groupId>
<artifactId>easy-captcha</artifactId>
</dependency>
</dependencies>
<build>
<plugins>


+ 9
- 12
dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermission.java View File

@@ -14,31 +14,28 @@
* limitations under the License.
* =============================================================
*/

package org.dubhe.annotation;

import java.lang.annotation.*;

/**
* 数据权限过滤Mapper拦截
*
* @date 2020-06-22
* @description 数据权限注解
* @date 2020-09-24
*/
import java.lang.annotation.*;

@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface DataPermission {

/**
* 不需要数据权限的方法名
* 只在类的注解上使用,代表方法的数据权限类型
* @return
*/
String[] ignores() default {};
String permission() default "";

/**
* 只在方法的注解上使用,代表方法的数据权限类型,如果不加注解,只会识别带"select"方法名的方法
*
* 不需要数据权限的方法名
* @return
*/
String[] permission() default {};

String[] ignoresMethod() default {};
}

+ 47
- 0
dubhe-server/common/src/main/java/org/dubhe/annotation/DataPermissionMethod.java View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.annotation;

import org.dubhe.enums.DatasetTypeEnum;

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

/**
* @description 数据权限方法注解
* @date 2020-09-24
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
public @interface DataPermissionMethod {

/**
* 是否需要拦截标识 true: 不拦截 false: 拦截
*
* @return 拦截标识
*/
boolean interceptFlag() default false;

/**
* 数据类型
*
* @return 数据集类型
*/
DatasetTypeEnum dataType() default DatasetTypeEnum.PRIVATE;
}

+ 6
- 5
dubhe-server/common/src/main/java/org/dubhe/annotation/EnumValue.java View File

@@ -17,19 +17,20 @@

package org.dubhe.annotation;

import javax.validation.Constraint;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import javax.validation.Payload;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import javax.validation.Constraint;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import javax.validation.Payload;

/**
* @date: 2020-05-21
* @description 接口枚举类检测标注类
* @date 2020-05-21
*/
@Target({ ElementType.METHOD, ElementType.FIELD, ElementType.ANNOTATION_TYPE })
@Retention(RetentionPolicy.RUNTIME)


+ 66
- 0
dubhe-server/common/src/main/java/org/dubhe/annotation/FlagValidator.java View File

@@ -0,0 +1,66 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.annotation;

import javax.validation.Constraint;
import javax.validation.ConstraintValidator;
import javax.validation.ConstraintValidatorContext;
import javax.validation.Payload;
import java.lang.annotation.*;
import java.util.Arrays;

/**
* @description 自定义状态校验注解(传入值是否在指定状态范围内)
* @date 2020-09-18
*/
@Target({ElementType.FIELD, ElementType.PARAMETER})
@Retention(RetentionPolicy.RUNTIME)
@Constraint(validatedBy = FlagValidator.Validator.class)
@Documented
public @interface FlagValidator {

String[] value() default {};

String message() default "flag value is invalid";

Class<?>[] groups() default {};

Class<? extends Payload>[] payload() default {};

/**
* @description 校验传入值是否在默认值范围校验逻辑
* @date 2020-09-18
*/
class Validator implements ConstraintValidator<FlagValidator, Integer> {

private String[] values;

@Override
public void initialize(FlagValidator flagValidator) {
this.values = flagValidator.value();
}

@Override
public boolean isValid(Integer value, ConstraintValidatorContext constraintValidatorContext) {
if (value == null) {
//当状态为空时,使用默认值
return false;
}
return Arrays.stream(values).anyMatch(value::equals);
}
}
}

+ 5
- 5
dubhe-server/common/src/main/java/org/dubhe/aspect/LogAspect.java View File

@@ -15,8 +15,7 @@
*/
package org.dubhe.aspect;

import java.util.UUID;

import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
@@ -28,10 +27,11 @@ import org.slf4j.MDC;
import org.springframework.stereotype.Component;
import org.springframework.util.StringUtils;

import lombok.extern.slf4j.Slf4j;
import java.util.UUID;

/**
* @date 2020/04/10
* @description 日志切面
* @date 2020-04-10
*/
@Component
@Aspect
@@ -54,7 +54,7 @@ public class LogAspect {
public void taskAspect() {
}

@Pointcut(" serviceAspect() || taskAspect() ")
@Pointcut(" serviceAspect() ")
public void aroundAspect() {
}



+ 116
- 0
dubhe-server/common/src/main/java/org/dubhe/aspect/PermissionAspect.java View File

@@ -0,0 +1,116 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.aspect;

import org.aspectj.lang.ProceedingJoinPoint;
import org.aspectj.lang.annotation.Around;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Pointcut;
import org.aspectj.lang.reflect.MethodSignature;
import org.dubhe.annotation.DataPermissionMethod;
import org.dubhe.base.BaseService;
import org.dubhe.base.DataContext;
import org.dubhe.domain.dto.CommonPermissionDataDTO;
import org.dubhe.enums.DatasetTypeEnum;
import org.dubhe.enums.LogEnum;
import org.dubhe.utils.JwtUtils;
import org.dubhe.utils.LogUtil;
import org.springframework.stereotype.Component;

import java.lang.reflect.Method;
import java.util.HashSet;
import java.util.Objects;
import java.util.Set;

/**
* @description 数据权限切面
* @date 2020-09-24
*/
@Aspect
@Component
public class PermissionAspect {

/**
* 公共数据的有用户ID
*/
public static final Long PUBLIC_DATA_USER_ID = 0L;

/**
* 基于注解的切面方法
*/
@Pointcut("@annotation(org.dubhe.annotation.DataPermissionMethod)")
private void cutMethod() {

}

/**
*环绕通知
* @param joinPoint 切入参数对象
* @return 返回方法结果集
* @throws Throwable
*/
@Around("cutMethod()")
public Object around(ProceedingJoinPoint joinPoint) throws Throwable {
// 获取方法传入参数
Object[] params = joinPoint.getArgs();
DataPermissionMethod dataPermissionMethod = getDeclaredAnnotation(joinPoint);

if (!Objects.isNull(JwtUtils.getCurrentUserDto()) && !Objects.isNull(dataPermissionMethod)) {
Set<Long> ids = new HashSet<>();
ids.add(JwtUtils.getCurrentUserDto().getId());
CommonPermissionDataDTO commonPermissionDataDTO = CommonPermissionDataDTO.builder().type(dataPermissionMethod.interceptFlag()).resourceUserIds(ids).build();
if (DatasetTypeEnum.PUBLIC.equals(dataPermissionMethod.dataType())) {
ids.add(PUBLIC_DATA_USER_ID);
commonPermissionDataDTO.setResourceUserIds(ids);
}
DataContext.set(commonPermissionDataDTO);
}
// 执行源方法
try {
return joinPoint.proceed(params);
} finally {
// 模拟进行验证
BaseService.removeContext();
}
}

/**
* 获取方法中声明的注解
*
* @param joinPoint 切入参数对象
* @return DataPermissionMethod 方法注解类型
*/
public DataPermissionMethod getDeclaredAnnotation(ProceedingJoinPoint joinPoint){
// 获取方法名
String methodName = joinPoint.getSignature().getName();
// 反射获取目标类
Class<?> targetClass = joinPoint.getTarget().getClass();
// 拿到方法对应的参数类型
Class<?>[] parameterTypes = ((MethodSignature) joinPoint.getSignature()).getParameterTypes();
// 根据类、方法、参数类型(重载)获取到方法的具体信息
Method objMethod = null;
try {
objMethod = targetClass.getMethod(methodName, parameterTypes);
} catch (NoSuchMethodException e) {
LogUtil.error(LogEnum.BIZ_DATASET,"获取注解方法参数异常 error:{}",e);
}
// 拿到方法定义的注解信息
DataPermissionMethod annotation = objMethod.getDeclaredAnnotation(DataPermissionMethod.class);
// 返回
return annotation;
}
}

+ 1
- 1
dubhe-server/common/src/main/java/org/dubhe/base/BaseImageDTO.java View File

@@ -25,7 +25,7 @@ import java.io.Serializable;

/**
* @description 镜像基础类DTO
* @date: 2020-07-14
* @date 2020-07-14
*/
@Data
@Accessors(chain = true)


+ 77
- 0
dubhe-server/common/src/main/java/org/dubhe/base/BaseService.java View File

@@ -0,0 +1,77 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.base;

import org.dubhe.constant.PermissionConstant;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.domain.entity.Role;
import org.dubhe.exception.BaseErrorCode;
import org.dubhe.exception.BusinessException;
import org.dubhe.utils.JwtUtils;
import org.springframework.util.CollectionUtils;

import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;

/**
* @description 服务层基础数据公共方法类
* @date 2020-03-27
*/
public class BaseService {

private BaseService (){}

/**
* 校验是否具有管理员权限
*/
public static void checkAdminPermission() {
if(!isAdmin()){
throw new BusinessException(BaseErrorCode.DATASET_ADMIN_PERMISSION_ERROR);
}
}

/**
* 校验是否是管理管理员
*
* @return 校验标识
*/
public static Boolean isAdmin() {
UserDTO currentUserDto = JwtUtils.getCurrentUserDto();
if (currentUserDto != null && !CollectionUtils.isEmpty(currentUserDto.getRoles())) {
List<Role> roles = currentUserDto.getRoles();
List<Role> roleList = roles.stream().
filter(a -> a.getId().compareTo(PermissionConstant.ADMIN_USER_ID) == 0)
.collect(Collectors.toList());
if (!CollectionUtils.isEmpty(roleList)) {
return true;
}
}
return false;
}


/**
* 清除本地线程数据权限数据
*/
public static void removeContext(){
if( !Objects.isNull(DataContext.get())){
DataContext.remove();
}
}

}

+ 2
- 2
dubhe-server/common/src/main/java/org/dubhe/base/BaseVO.java View File

@@ -24,8 +24,8 @@ import java.io.Serializable;
import java.sql.Timestamp;

/**
* @description: VO基础类
* @date: 2020-05-22
* @description VO基础类
* @date 2020-05-22
*/
@Data
public class BaseVO implements Serializable {


+ 62
- 0
dubhe-server/common/src/main/java/org/dubhe/base/DataContext.java View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.base;


import org.dubhe.domain.dto.CommonPermissionDataDTO;

/**
* @description 共享上下文数据集信息
* @date 2020-04-10
*/
public class DataContext {

/**
* 私有化构造参数
*/
private DataContext() {
}

private static final ThreadLocal<CommonPermissionDataDTO> CONTEXT = new ThreadLocal<>();

/**
* 存放数据集信息
*
* @param datasetVO
*/
public static void set(CommonPermissionDataDTO datasetVO) {
CONTEXT.set(datasetVO);
}

/**
* 获取用户信息
*
* @return
*/
public static CommonPermissionDataDTO get() {
return CONTEXT.get();
}

/**
* 清除当前线程内引用,防止内存泄漏
*/
public static void remove() {
CONTEXT.remove();
}

}

+ 3
- 1
dubhe-server/common/src/main/java/org/dubhe/base/MagicNumConstant.java View File

@@ -18,7 +18,8 @@
package org.dubhe.base;

/**
* @date: 2020-05-14
* @description 常用常量类
* @date 2020-05-14
*/
public final class MagicNumConstant {

@@ -86,6 +87,7 @@ public final class MagicNumConstant {

public static final long TWELVE_LONG = 12L;
public static final long SIXTY_LONG = 60L;
public static final long THOUSAND_LONG = 1000L;
public static final long TEN_THOUSAND_LONG = 10000L;
public static final long ONE_ZERO_ONE_ZERO_ONE_ZERO_LONG = 101010L;
public static final long NINE_ZERO_NINE_ZERO_NINE_ZERO_LONG = 909090L;


+ 2
- 2
dubhe-server/common/src/main/java/org/dubhe/base/PageQueryBase.java View File

@@ -26,8 +26,8 @@ import org.dubhe.constant.NumberConstant;
import javax.validation.constraints.Min;

/**
* @description: 分页基类
* @date: 2020-05-8
* @description 分页基类
* @date 2020-05-08
*/
@Data
@Accessors(chain = true)


dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborProjectNameSyncTask.java → dubhe-server/common/src/main/java/org/dubhe/base/ScheduleTaskHandler.java View File

@@ -14,32 +14,31 @@
* limitations under the License.
* =============================================================
*/

package org.dubhe.task;
package org.dubhe.base;

import org.dubhe.enums.LogEnum;
import org.dubhe.service.PtImageService;
import org.dubhe.utils.LogUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Component;

/**
* @description 从harbor同步projectName
* @date 2020-6-23
**/
@Component
public class HarborProjectNameSyncTask {
* @description 定时任务处理器, 主要做日志标识
* @date 2020-08-13
*/
public class ScheduleTaskHandler {


public static void process(Handler handler) {
LogUtil.startScheduleTrace();
try {
handler.run();
} catch (Exception e) {
LogUtil.error(LogEnum.BIZ_SYS, "There is something wrong in schedule task handler :{}", e);
} finally {
LogUtil.cleanTrace();
}
}

@Autowired
private PtImageService ptImageService;

/**
* 每天晚上11点开始同步
**/
@Scheduled(cron = "0 0 23 * * ?")
public void syncProjectName() {
LogUtil.info(LogEnum.BIZ_TRAIN, "开始到harbor同步projectName到harbor_project表。。。。。");
ptImageService.harborImageNameSync();
public interface Handler {
void run();
}
}

+ 0
- 65
dubhe-server/common/src/main/java/org/dubhe/config/KaptchaConfig.java View File

@@ -1,65 +0,0 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.config;

import com.google.code.kaptcha.impl.DefaultKaptcha;
import com.google.code.kaptcha.util.Config;
import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;

import java.util.Properties;

/**
* @description 验证码配置
* @date 2020-02-23
*/
@Component
public class KaptchaConfig {
private final static String CODE_LENGTH = "4";
private final static String SESSION_KEY = "verification_session_key";

@Bean
public DefaultKaptcha defaultKaptcha() {
DefaultKaptcha defaultKaptcha = new DefaultKaptcha();
Properties properties = new Properties();

// 设置边框
properties.setProperty("kaptcha.border", "yes");
// 设置边框颜色
properties.setProperty("kaptcha.border.color", "105,179,90");
// 设置字体颜色
properties.setProperty("kaptcha.textproducer.font.color", "blue");
// 设置图片宽度
properties.setProperty("kaptcha.image.width", "108");
// 设置图片高度
properties.setProperty("kaptcha.image.height", "28");
// 设置字体尺寸
properties.setProperty("kaptcha.textproducer.font.size", "26");
// 设置session key
properties.setProperty("kaptcha.session.key", SESSION_KEY);
// 设置验证码长度
properties.setProperty("kaptcha.textproducer.char.length", CODE_LENGTH);
// 设置字体
properties.setProperty("kaptcha.textproducer.font.names", "宋体,楷体,黑体");
//去噪点
properties.setProperty("kaptcha.noise.impl", "com.google.code.kaptcha.impl.NoNoise");
Config config = new Config(properties);
defaultKaptcha.setConfig(config);
return defaultKaptcha;
}
}

+ 14
- 17
dubhe-server/common/src/main/java/org/dubhe/config/MetaHandlerConfig.java View File

@@ -1,12 +1,12 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -30,21 +30,17 @@ import java.util.Objects;

/**
* @description 处理新增和更新的基础数据填充,配合BaseEntity和MyBatisPlusConfig使用
* @date 2020-6-10
* @date 2020-06-10
*/
@Component
public class MetaHandlerConfig implements MetaObjectHandler {


private final String LOCK_USER_ID = "LOCK_USER_ID";


/**
* 新增数据执行
*
* @param metaObject
* @param metaObject 基础数据
*/

@Override
public void insertFill(MetaObject metaObject) {
if (Objects.isNull(getFieldValByName(StringConstant.CREATE_TIME, metaObject))) {
@@ -53,13 +49,14 @@ public class MetaHandlerConfig implements MetaObjectHandler {
if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_TIME, metaObject))) {
this.setFieldValByName(StringConstant.UPDATE_TIME, DateUtil.getCurrentTimestamp(), metaObject);
}
synchronized (LOCK_USER_ID){
if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_USER_ID, metaObject))) {
this.setFieldValByName(StringConstant.UPDATE_USER_ID, getUserId(), metaObject);
}
if (Objects.isNull(getFieldValByName(StringConstant.CREATE_USER_ID, metaObject))) {
this.setFieldValByName(StringConstant.CREATE_USER_ID, getUserId(), metaObject);
}
if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_USER_ID, metaObject))) {
this.setFieldValByName(StringConstant.UPDATE_USER_ID, getUserId(), metaObject);
}
if (Objects.isNull(getFieldValByName(StringConstant.CREATE_USER_ID, metaObject))) {
this.setFieldValByName(StringConstant.CREATE_USER_ID, getUserId(), metaObject);
}
if (Objects.isNull(getFieldValByName(StringConstant.ORIGIN_USER_ID, metaObject))) {
this.setFieldValByName(StringConstant.ORIGIN_USER_ID, getUserId(), metaObject);
}
if (Objects.isNull(getFieldValByName(StringConstant.DELETED, metaObject))) {
this.setFieldValByName(StringConstant.DELETED, SwitchEnum.getBooleanValue(SwitchEnum.OFF.getValue()), metaObject);
@@ -69,7 +66,7 @@ public class MetaHandlerConfig implements MetaObjectHandler {
/**
* 更新数据执行
*
* @param metaObject
* @param metaObject 基础数据
*/
@Override
public void updateFill(MetaObject metaObject) {


+ 7
- 274
dubhe-server/common/src/main/java/org/dubhe/config/MybatisPlusConfig.java View File

@@ -17,294 +17,27 @@

package org.dubhe.config;

import com.baomidou.mybatisplus.core.override.MybatisMapperProxy;
import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.core.parser.SqlParserHelper;
import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler;
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantSqlParser;
import com.google.common.collect.Sets;
import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.LongValue;
import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
import net.sf.jsqlparser.expression.operators.relational.InExpression;
import net.sf.jsqlparser.schema.Column;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.shiro.UnavailableSecurityManagerException;
import org.dubhe.annotation.DataPermission;
import org.dubhe.base.MagicNumConstant;
import org.dubhe.constant.PermissionConstant;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.domain.entity.Role;
import org.dubhe.enums.LogEnum;
import org.dubhe.utils.JwtUtils;
import org.dubhe.utils.LogUtil;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ApplicationListener;
import org.dubhe.interceptor.PaginationInterceptor;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.transaction.annotation.EnableTransactionManagement;
import org.springframework.util.CollectionUtils;

import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.*;
import java.util.stream.Collectors;

/**
* @description MybatisPlus配置类
* @date 2020-06-24
* @description mybatis plus拦截器
* @date 2020-06-10
*/
@EnableTransactionManagement
@Configuration
public class MybatisPlusConfig implements ApplicationListener<ContextRefreshedEvent>, ApplicationContextAware {

/**
* 以此字段作为租户实现数据隔离
*/
private static final String TENANT_ID_COLUMN = "create_user_id";
/**
* 以0作为公共数据的标识
*/
private static final long PUBLIC_TENANT_ID = MagicNumConstant.ZERO;
private static final Set<Long> PUBLIC_TENANT_ID_SET = new HashSet<Long>() {{
add(PUBLIC_TENANT_ID);
}};
private static final String PACKAGE_SEPARATOR = ".";
private static final Set<String> SELECT_PERMISSION = new HashSet<String>() {{
add(PermissionConstant.SELECT);
}};
private static final Set<String> UPDATE_DELETE_PERMISSION = new HashSet<String>() {{
add(PermissionConstant.UPDATE);
add(PermissionConstant.DELETE);
}};
public class MybatisPlusConfig {

private static final String SELECT_STR = "select";
/**
* 优先级高于dataFilters,如果ignore,则不进行sql注入
*/
private Map<String, Set<String>> dataFilters = new HashMap<>();

private ApplicationContext applicationContext;
public Set<Long> tenantId;

/**
* mybatis plus 分页插件
* 其中增加了通过多租户实现了数据权限功能
* 注入 MybatisPlus 分页拦截器
*
* @return
* @return 自定义MybatisPlus分页拦截器
*/
@Bean
public PaginationInterceptor paginationInterceptor() {
PaginationInterceptor paginationInterceptor = new PaginationInterceptor();
List<ISqlParser> sqlParserList = new ArrayList<>();
TenantSqlParser tenantSqlParser = new TenantSqlParser();
tenantSqlParser.setTenantHandler(new TenantHandler() {
@Override
public Expression getTenantId(boolean where) {
Set<Long> tenants = tenantId;

final boolean multipleTenantIds = tenants.size() > MagicNumConstant.ONE;
if (multipleTenantIds) {
return multipleTenantIdCondition(tenants);
} else {
return singleTenantIdCondition(tenants);
}
}

private Expression singleTenantIdCondition(Set<Long> tenants) {
return new LongValue((Long) tenants.toArray()[0]);
}

private Expression multipleTenantIdCondition(Set<Long> tenants) {
final InExpression inExpression = new InExpression();
inExpression.setLeftExpression(new Column(getTenantIdColumn()));
final ExpressionList itemsList = new ExpressionList();
final List<Expression> inValues = new ArrayList<>(tenants.size());
tenants.forEach(i ->
inValues.add(new LongValue(i))
);
itemsList.setExpressions(inValues);
inExpression.setRightItemsList(itemsList);
return inExpression;
}

@Override
public String getTenantIdColumn() {
return TENANT_ID_COLUMN;
}

@Override
public boolean doTableFilter(String tableName) {
return false;
}
});
sqlParserList.add(tenantSqlParser);
paginationInterceptor.setSqlParserList(sqlParserList);
paginationInterceptor.setSqlParserFilter(metaObject -> {
MappedStatement ms = SqlParserHelper.getMappedStatement(metaObject);
String method = ms.getId();
if (!dataFilters.containsKey(method) || isAdmin()) {
return true;
}
Set<String> permission = dataFilters.get(method);
tenantId = getTenantId(permission);
return false;
});
return paginationInterceptor;
}

/**
* 判断用户是否是管理员
* 如果未登录,无法请求任何接口,所以不会到该层,因此匿名认为是定时任务,给予admin权限。
*
* @return 判断用户是否是管理员
*/
private boolean isAdmin() {
UserDTO user;
try {
user = JwtUtils.getCurrentUserDto();
} catch (UnavailableSecurityManagerException e) {
return true;
}
if (Objects.isNull(user)) {
return true;
}
List<Role> roles;
if ((roles = user.getRoles()) == null) {
return false;
}
Set<String> permissions = roles.stream().map(Role::getPermission).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(permissions)) {
return false;
}
return user.getId() == PermissionConstant.ANONYMOUS_USER || user.getId() == PermissionConstant.ADMIN_USER_ID;
}

/**
* 如果是管理员,在前一步isAdmin已过滤;
* 如果是匿名用户,在shiro层被过滤;
* 因此只会是无角色、权限用户或普通用户
*
* @return Set<Long> 租户ID集合
*/
private Set<Long> getTenantId(Set<String> permission) {
UserDTO user = JwtUtils.getCurrentUserDto();
List<Role> roles;
if (Objects.isNull(user) || (roles = user.getRoles()) == null) {
if (permission.contains(PermissionConstant.SELECT)) {
return PUBLIC_TENANT_ID_SET;
}
return Collections.EMPTY_SET;
}
Set<String> permissions = roles.stream().map(Role::getPermission).collect(Collectors.toSet());
if (CollectionUtils.isEmpty(permissions)) {
if (permission.contains(PermissionConstant.SELECT)) {
return PUBLIC_TENANT_ID_SET;
}
return Collections.EMPTY_SET;
}
if (permission.contains(PermissionConstant.SELECT)) {
return new HashSet<Long>() {{
add(PUBLIC_TENANT_ID);
add(user.getId());
}};
}
return new HashSet<Long>() {{
add(user.getId());
}};
}

/**
* 设置上下文
* #需要通过上下文 获取SpringBean
*
* @param applicationContext spring上下文
* @throws BeansException 找不到bean异常
*/
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
}

@Override
public void onApplicationEvent(ContextRefreshedEvent contextRefreshedEvent) {
Class<? extends Annotation> annotationClass = DataPermission.class;
Map<String, Object> beanWithAnnotation = applicationContext.getBeansWithAnnotation(annotationClass);
Set<Map.Entry<String, Object>> entitySet = beanWithAnnotation.entrySet();
for (Map.Entry<String, Object> entry : entitySet) {
Proxy proxy = (Proxy) entry.getValue();
Class clazz = getMapperClass(proxy);
populateDataFilters(clazz);
}
}

/**
* 根据mapper对应代理对象获取Class
*
* @param proxy mapper对应代理对象
* @return
*/
private Class getMapperClass(Proxy proxy) {
try {
Field field = proxy.getClass().getSuperclass().getDeclaredField("h");
field.setAccessible(true);
MybatisMapperProxy mapperProxy = (MybatisMapperProxy) field.get(proxy);
field = mapperProxy.getClass().getDeclaredField("mapperInterface");
field.setAccessible(true);
return (Class) field.get(mapperProxy);
} catch (NoSuchFieldException | IllegalAccessException e) {
LogUtil.error(LogEnum.BIZ_DATASET, "reflect error", e);
}
return null;
}

/**
* 填充数据权限过滤,处理那些需要排除的方法
*
* @param clazz 需要处理的类(mapper)
*/
private void populateDataFilters(Class clazz) {
if (clazz == null) {
return;
}
Method[] methods = clazz.getMethods();
DataPermission dataPermission = AnnotationUtils.findAnnotation((Class<?>) clazz, DataPermission.class);
Set<String> ignores = Sets.newHashSet(dataPermission.ignores());
for (Method method : methods) {
if (ignores.contains(method.getName())) {
continue;
}
Set<String> permission = getDataPermission(method);
dataFilters.put(clazz.getName() + PACKAGE_SEPARATOR + method.getName(), permission);
}
}

/**
* 获取方法上权限注解
* 权限注解包含
* 1.用户拥有指定权限才可以执行该方法:比如 PermissionConstant.SELECT 表示用户必须拥有select权限,才可以使用该方法
* 2.方法权限校验排除:比如 ignores = {"insert"} 表示insert方法不做权限处理
*
* @param method 方法对象
* @return
*/
private Set<String> getDataPermission(Method method) {
DataPermission dataPermission = AnnotationUtils.findAnnotation(method, DataPermission.class);
// 无注解时以方法名判断
if (dataPermission == null) {
if (method.getName().contains(SELECT_STR)) {
return SELECT_PERMISSION;
}
return UPDATE_DELETE_PERMISSION;
}
return Sets.newHashSet(dataPermission.permission());
}

}


+ 62
- 0
dubhe-server/common/src/main/java/org/dubhe/config/RecycleConfig.java View File

@@ -0,0 +1,62 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.config;

import lombok.Data;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

/**
* @description 垃圾回收机制配置常量
* @date 2020-09-21
*/
@Data
@Component
@ConfigurationProperties(prefix = "recycle.timeout")
public class RecycleConfig {

/**
* 回收无效文件的默认有效时长
*/
private Integer date;

/**
* 用户上传文件至临时路径下后文件最大有效时长,以小时为单位
*/
private Integer fileValid;

/**
* 用户删除某一算法后,其算法文件最大有效时长,以天为单位
*/
private Integer algorithmValid;

/**
* 用户删除某一模型后,其模型文件最大有效时长,以天为单位
*/
private Integer modelValid;

/**
* 用户删除训练任务后,其训练管理文件最大有效时长,以天为单位
*/
private Integer trainValid;

/**
* 删除服务器无效文件(大文件)
* 示例:rsync --delete-before -d /空目录 /需要回收的源目录
*/
public static final String DEL_COMMAND = "ssh %s@%s \"mkdir -p %s; rsync --delete-before -d %s %s; rmdir %s %s\"";
}

dubhe-server/common/src/main/java/org/dubhe/constant/TrainJobConstant.java → dubhe-server/common/src/main/java/org/dubhe/config/TrainJobConfig.java View File

@@ -15,77 +15,71 @@
* =============================================================
*/

package org.dubhe.constant;
package org.dubhe.config;

import lombok.Data;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.stereotype.Component;

/**
* @description 训练常量
* @create: 2020-05-12
* @date 2020-05-12
*/
@Component
@Data
public class TrainJobConstant {
@ConfigurationProperties(prefix = "train-job")
public class TrainJobConfig {

@Value("${train-job.namespace}")
private String namespace;

@Value("${train-job.version-label}")
private String versionLabel;

@Value("${train-job.separator}")
private String separator;

@Value("${train-job.pod-name}")
private String podName;

@Value("${train-job.python-format}")
private String pythonFormat;

@Value("${train-job.manage}")
private String manage;

@Value("${train-job.out-path}")
private String outPath;

@Value("${train-job.log-path}")
private String logPath;

@Value("${train-job.visualized-log-path}")
private String visualizedLogPath;

@Value("${train-job.docker-dataset-path}")
private String dockerDatasetPath;

@Value("${train-job.docker-train-path}")
private String dockerTrainPath;

@Value("${train-job.docker-out-path}")
private String dockerOutPath;

@Value("${train-job.docker-log-path}")
private String dockerLogPath;

@Value("${train-job.docker-dataset}")
private String dockerDataset;

@Value("${train-job.docker-visualized-log-path}")
private String dockerModelPath;

private String dockerValDatasetPath;

private String loadValDatasetKey;

private String dockerVisualizedLogPath;

@Value("${train-job.load-path}")
private String loadPath;

@Value("${train-job.load-key}")
private String loadKey;

@Value("${train-job.eight}")
private String eight;

@Value("${train-job.plus-eight}")
private String plusEight;

private String nodeIps;

private String nodeNum;

private String gpuNumPerNode;

public static final String TRAIN_ID = "trainId";

public static final String TRAIN_VERSION = "trainVersion";
@@ -97,4 +91,5 @@ public class TrainJobConstant {
public static final String CREATE_TIME = "createTime";

public static final String ALGORITHM_NAME = "algorithmName";

}

dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TrainPoolConfig.java → dubhe-server/common/src/main/java/org/dubhe/config/TrainPoolConfig.java View File

@@ -16,9 +16,13 @@
*/
package org.dubhe.config;

import org.dubhe.enums.LogEnum;
import org.dubhe.utils.LogUtil;
import org.springframework.aop.interceptor.AsyncUncaughtExceptionHandler;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.annotation.AsyncConfigurer;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;

import java.util.concurrent.Executor;
@@ -29,7 +33,7 @@ import java.util.concurrent.ThreadPoolExecutor;
* @date 2020-07-17
*/
@Configuration
public class TrainPoolConfig {
public class TrainPoolConfig implements AsyncConfigurer {

@Value("${basepool.corePoolSize:40}")
private Integer corePoolSize;
@@ -44,8 +48,9 @@ public class TrainPoolConfig {
* 训练任务异步处理线程池
* @return Executor 线程实例
*/
@Bean
public Executor trainJobAsyncExecutor() {
@Bean("trainExecutor")
@Override
public Executor getAsyncExecutor() {
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor();
//核心线程数
taskExecutor.setCorePoolSize(corePoolSize);
@@ -57,10 +62,18 @@ public class TrainPoolConfig {
//配置队列大小
taskExecutor.setQueueCapacity(blockQueueSize);
//配置线程池前缀
taskExecutor.setThreadNamePrefix("async-train-job-");
taskExecutor.setThreadNamePrefix("async-train-");
//拒绝策略
taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy());
taskExecutor.initialize();
return taskExecutor;
}

@Override
public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() {
LogUtil.error(LogEnum.BIZ_TRAIN, "开始捕获训练管理异步任务异常信息-----》》》");
return (ex, method, params) -> {
LogUtil.error(LogEnum.BIZ_TRAIN, "训练管理方法名{}的异步任务执行失败,参数信息:{},异常信息:{}", method.getName(), params, ex);
};
}
}

+ 4
- 2
dubhe-server/common/src/main/java/org/dubhe/constant/NumberConstant.java View File

@@ -18,8 +18,8 @@
package org.dubhe.constant;

/**
* @Description 数字常量
* @Date 2020-6-9
* @description 数字常量
* @date 2020-06-09
*/
public class NumberConstant {

@@ -32,6 +32,8 @@ public class NumberConstant {
public final static int NUMBER_30 = 30;
public final static int NUMBER_50 = 50;
public final static int NUMBER_60 = 60;
public final static int NUMBER_1024 = 1024;
public final static int NUMBER_1000 = 1000;
public final static int HOUR_SECOND = 60 * 60;
public final static int DAY_SECOND = 60 * 60 * 24;
public final static int WEEK_SECOND = 60 * 60 * 24 * 7;


+ 7
- 6
dubhe-server/common/src/main/java/org/dubhe/constant/PermissionConstant.java View File

@@ -21,8 +21,8 @@ import lombok.Data;
import org.springframework.stereotype.Component;

/**
* @description: 权限常量
* @since: 2020-05-25 14:39
* @description 权限常量
* @date 2020-05-25
*/
@Component
@Data
@@ -32,9 +32,10 @@ public class PermissionConstant {
* 超级用户
*/
public static final long ADMIN_USER_ID = 1L;
public static final long ANONYMOUS_USER = -1L;
public static final String SELECT = "select";
public static final String UPDATE = "update";
public static final String DELETE = "delete";

/**
* 数据集模块类型
*/
public static final Integer RESOURCE_DATA_MODEL = 1;

}

+ 22
- 17
dubhe-server/common/src/main/java/org/dubhe/constant/StringConstant.java View File

@@ -23,24 +23,29 @@ package org.dubhe.constant;
*/
public final class StringConstant {

public static final String MSIE = "MSIE";
public static final String MOZILLA = "Mozilla";
public static final String MSIE = "MSIE";
public static final String MOZILLA = "Mozilla";
public static final String REQUEST_METHOD_GET = "GET";

/**
* 公共字段
*/
public static final String CREATE_TIME = "createTime";
public static final String UPDATE_TIME = "updateTime";
public static final String UPDATE_USER_ID = "updateUserId";
public static final String CREATE_USER_ID = "createUserId";
public static final String DELETED = "deleted";
public static final String UTF8 = "utf-8";
/**
* 公共字段
*/
public static final String CREATE_TIME = "createTime";
public static final String UPDATE_TIME = "updateTime";
public static final String UPDATE_USER_ID = "updateUserId";
public static final String CREATE_USER_ID = "createUserId";
public static final String ORIGIN_USER_ID = "originUserId";
public static final String DELETED = "deleted";
public static final String UTF8 = "utf-8";
public static final String JSON_REQUEST = "application/json";
public static final String K8S_CALLBACK_URI = "/api/k8s/callback/pod";
public static final String MULTIPART = "multipart/form-data";

/**
* 测试环境
*/
public static final String PROFILE_ACTIVE_TEST = "test";
/**
* 测试环境
*/
public static final String PROFILE_ACTIVE_TEST = "test";

private StringConstant() {
}
private StringConstant() {
}
}

+ 3
- 1
dubhe-server/common/src/main/java/org/dubhe/constant/SymbolConstant.java View File

@@ -19,7 +19,7 @@ package org.dubhe.constant;

/**
* @description 符号常量
* @Date 2020-5-29
* @date 2020-5-29
*/
public class SymbolConstant {
public static final String SLASH = "/";
@@ -40,6 +40,8 @@ public class SymbolConstant {
public static final String DOUBLE_MARK= "\"\"";
public static final String MARK= "\"";

public static final String FLAG_EQUAL = "=";

private SymbolConstant() {
}



+ 1
- 4
dubhe-server/common/src/main/java/org/dubhe/constant/UserAuxiliaryInfoConstant.java View File

@@ -17,15 +17,12 @@

package org.dubhe.constant;

import org.springframework.stereotype.Component;

import lombok.Data;

/**
* @description 算法用途
* @date: 2020-06-23
* @date 2020-06-23
*/
@Component
@Data
public class UserAuxiliaryInfoConstant {



+ 52
- 0
dubhe-server/common/src/main/java/org/dubhe/domain/dto/CommonPermissionDataDTO.java View File

@@ -0,0 +1,52 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.domain.dto;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;

import java.io.Serializable;
import java.util.Set;

/**
* @description 公共权限信息DTO
* @date 2020-09-24
*/
@AllArgsConstructor
@NoArgsConstructor
@Builder
@Data
public class CommonPermissionDataDTO implements Serializable {

/**
* 资源拥有者ID
*/
private Long id;

/**
* 公共类型
*/
private Boolean type;
/**
* 资源所属用户ids
*/
private Set<Long> resourceUserIds;


}

+ 21
- 26
dubhe-server/common/src/main/java/org/dubhe/domain/entity/LogInfo.java View File

@@ -16,15 +16,14 @@
*/

package org.dubhe.domain.entity;
import java.io.Serializable;

import org.dubhe.base.MagicNumConstant;

import com.alibaba.fastjson.annotation.JSONField;

import cn.hutool.core.date.DateUtil;
import com.alibaba.fastjson.annotation.JSONField;
import lombok.Data;
import lombok.experimental.Accessors;
import org.dubhe.base.MagicNumConstant;

import java.io.Serializable;

/**
* @description 日志对象封装类
@@ -34,31 +33,27 @@ import lombok.experimental.Accessors;
@Accessors(chain = true)
public class LogInfo implements Serializable {

@JSONField(ordinal = MagicNumConstant.ONE)
private String traceId;
private static final long serialVersionUID = 5250395474667395607L;

@JSONField(ordinal = MagicNumConstant.ONE)
private String traceId;

@JSONField(ordinal = MagicNumConstant.TWO)
private String type;
@JSONField(ordinal = MagicNumConstant.TWO)
private String type;

@JSONField(ordinal = MagicNumConstant.THREE)
private String level;
@JSONField(ordinal = MagicNumConstant.THREE)
private String level;

@JSONField(ordinal = MagicNumConstant.FOUR)
private String cName;
@JSONField(ordinal = MagicNumConstant.FOUR)
private String location;

@JSONField(ordinal = MagicNumConstant.FIVE)
private String mName;
@JSONField(ordinal = MagicNumConstant.SIX)
private String line;
@JSONField(ordinal = MagicNumConstant.SEVEN)
private String time = DateUtil.now();
@JSONField(ordinal = MagicNumConstant.FIVE)
private String time = DateUtil.now();

@JSONField(ordinal = MagicNumConstant.EIGHT)
private Object info;
@JSONField(ordinal = MagicNumConstant.SIX)
private Object info;

public void setInfo(Object info) {
this.info = info;
}
public void setInfo(Object info) {
this.info = info;
}
}

+ 0
- 3
dubhe-server/common/src/main/java/org/dubhe/domain/entity/Menu.java View File

@@ -22,12 +22,9 @@ import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.dubhe.base.BaseEntity;
import org.hibernate.validator.constraints.Length;

import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import java.io.Serializable;
import java.sql.Timestamp;
import java.util.Objects;

/**


+ 0
- 5
dubhe-server/common/src/main/java/org/dubhe/domain/entity/Role.java View File

@@ -22,12 +22,7 @@ import lombok.Builder;
import lombok.Data;
import lombok.NoArgsConstructor;
import org.dubhe.base.BaseEntity;
import org.hibernate.validator.constraints.Length;

import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotEmpty;
import java.io.Serializable;
import java.sql.Timestamp;
import java.util.Objects;
import java.util.Set;



+ 73
- 0
dubhe-server/common/src/main/java/org/dubhe/dto/GlobalRequestRecordDTO.java View File

@@ -0,0 +1,73 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.dto;

import lombok.Data;

/**
* @description 全局请求日志信息
* @date 2020-08-13
*/
@Data
public class GlobalRequestRecordDTO {
/**
* 客户主机地址
*/
private String clientHost;
/**
* 请求地址
*/
private String uri;
/**
* 授权信息
*/
private String authorization;
/**
* 用户名
*/
private String username;
/**
* form参数
*/
private String params;
/**
* 返回值类型
*/
private String contentType;
/**
* 返回状态
*/
private Integer status;
/**
* 时间耗费
*/
private Long timeCost;
/**
* 请求方式
*/
private String method;
/**
* 请求体body参数
*/
private String requestBody;
/**
* 返回值json数据
*/
private String responseBody;

}

+ 13
- 1
dubhe-server/common/src/main/java/org/dubhe/dto/callback/BaseK8sPodCallbackCreateDTO.java View File

@@ -44,6 +44,14 @@ public class BaseK8sPodCallbackCreateDTO {
@NotEmpty(message = "podName 不能为空!")
private String podName;

@ApiModelProperty(required = true,value = "k8s pod parent type")
@NotEmpty(message = "podParentType 不能为空!")
private String podParentType;

@ApiModelProperty(required = true,value = "k8s pod parent name")
@NotEmpty(message = "podParentName 不能为空!")
private String podParentName;

@ApiModelProperty(value = "k8s pod phase",notes = "对应PodPhaseEnum")
@NotEmpty(message = "phase 不能为空!")
private String phase;
@@ -55,10 +63,12 @@ public class BaseK8sPodCallbackCreateDTO {

}

public BaseK8sPodCallbackCreateDTO(String namespace,String resourceName,String podName,String phase,String messages){
public BaseK8sPodCallbackCreateDTO(String namespace,String resourceName,String podName,String podParentType,String podParentName,String phase,String messages){
this.namespace = namespace;
this.resourceName = resourceName;
this.podName = podName;
this.podParentType = podParentType;
this.podParentName = podParentName;
this.phase = phase;
this.messages = messages;
}
@@ -69,6 +79,8 @@ public class BaseK8sPodCallbackCreateDTO {
"namespace='" + namespace + '\'' +
", resourceName='" + resourceName + '\'' +
", podName='" + podName + '\'' +
", podParentType='" + podParentType + '\'' +
", podParentName='" + podParentName + '\'' +
", phase='" + phase + '\'' +
", messages=" + messages +
'}';


+ 6
- 3
dubhe-server/common/src/main/java/org/dubhe/enums/BizEnum.java View File

@@ -23,9 +23,8 @@ import java.util.HashMap;
import java.util.Map;

/**
* @desc: 业务模块
*
* @date 2020.05.25
* @description 业务模块
* @date 2020-05-25
*/
@Getter
public enum BizEnum {
@@ -38,6 +37,10 @@ public enum BizEnum {
* 算法管理
*/
ALGORITHM("算法管理","algorithm",1),
/**
* 模型管理
*/
MODEL("模型管理","model",2),
;

/**


+ 8
- 4
dubhe-server/common/src/main/java/org/dubhe/enums/BizNfsEnum.java View File

@@ -21,8 +21,8 @@ import java.util.HashMap;
import java.util.Map;

/**
* @desc: 业务NFS路径枚举
* @date 2020.05.13
* @description 业务NFS路径枚举
* @date 2020-05-13
*/
public enum BizNfsEnum {
/**
@@ -33,6 +33,10 @@ public enum BizNfsEnum {
* 算法管理 NFS 路径命名
*/
ALGORITHM(BizEnum.ALGORITHM, "algorithm-manage"),
/**
* 模型管理 NFS 路径命名
*/
MODEL(BizEnum.MODEL, "model"),
;

BizNfsEnum(BizEnum bizEnum, String bizNfsPath) {
@@ -81,11 +85,11 @@ public enum BizNfsEnum {
return bizNfsPath;
}

public BizEnum getBizEnum(){
public BizEnum getBizEnum() {
return bizEnum;
}

public String getBizCode() {
return bizEnum == null ? null :bizEnum.getBizCode();
return bizEnum == null ? null : bizEnum.getBizCode();
}
}

dubhe-server/dubhe-data/src/main/java/org/dubhe/data/constant/DatasetTypeEnum.java → dubhe-server/common/src/main/java/org/dubhe/enums/DatasetTypeEnum.java View File

@@ -15,7 +15,7 @@
* =============================================================
*/

package org.dubhe.data.constant;
package org.dubhe.enums;

import lombok.Getter;


+ 46
- 30
dubhe-server/common/src/main/java/org/dubhe/enums/LogEnum.java View File

@@ -26,36 +26,52 @@ import lombok.Getter;
@Getter
public enum LogEnum {

// 系统报错日志
SYS_ERR,
// 用户请求日志
REST_REQ,
// 训练模块
BIZ_TRAIN,
// 系统模块
BIZ_SYS,
// 模型模块
BIZ_MODEL,
// 数据集模块
BIZ_DATASET,
// k8s模块
BIZ_K8S,
//note book
NOTE_BOOK,
//NFS UTILS
NFS_UTIL;
// 系统报错日志
SYS_ERR,
// 用户请求日志
REST_REQ,
//全局请求日志
GLOBAL_REQ,
// 训练模块
BIZ_TRAIN,
// 系统模块
BIZ_SYS,
// 模型模块
BIZ_MODEL,
// 数据集模块
BIZ_DATASET,
// k8s模块
BIZ_K8S,
//note book
NOTE_BOOK,
//NFS UTILS
NFS_UTIL,
//localFileUtil
LOCAL_FILE_UTIL,
//FILE UTILS
FILE_UTIL,
//FILE UTILS
UPLOAD_TEMP,
//STATE MACHINE
STATE_MACHINE,
//全局垃圾回收
GARBAGE_RECYCLE,
//DATA_SEQUENCE
DATA_SEQUENCE,
//IO UTIL
IO_UTIL;

/**
* 判断日志类型不能为空
*
* @param logType 日志类型
* @return boolean 返回类型
*/
public static boolean isLogType(LogEnum logType) {
/**
* 判断日志类型不能为空
*
* @param logType 日志类型
* @return boolean 返回类型
*/
public static boolean isLogType(LogEnum logType) {

if (logType != null) {
return true;
}
return false;
}
if (logType != null) {
return true;
}
return false;
}
}

+ 72
- 0
dubhe-server/common/src/main/java/org/dubhe/enums/OperationTypeEnum.java View File

@@ -0,0 +1,72 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.enums;

import lombok.Getter;
import lombok.ToString;

/**
* @Description 操作类型枚举
* @Date 2020-08-24
*/
@ToString
@Getter
public enum OperationTypeEnum {
/**
* SELECT 查询类型
*/
SELECT("select", "查询"),

/**
* UPDATE 修改类型
*/
UPDATE("update", "修改"),

/**
* DELETE 删除类型
*/
DELETE("delete", "删除"),

/**
* LIMIT 禁止操作类型
*/
LIMIT("limit", "禁止操作"),

/**
* INSERT 新增类型
*/
INSERT("insert", "新增类型"),

;

/**
* 操作类型值
*/
private String type;

/**
* 操作类型备注
*/
private String desc;

OperationTypeEnum(String type, String desc) {
this.type = type;
this.desc = desc;
}

}

+ 54
- 0
dubhe-server/common/src/main/java/org/dubhe/enums/RecycleResourceEnum.java View File

@@ -0,0 +1,54 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.enums;

import lombok.Getter;

import java.util.HashSet;
import java.util.Set;

/**
* @description 资源回收枚举类
* @date 2020-10-10
*/
@Getter
public enum RecycleResourceEnum {

/**
* 数据集文件回收
*/
DATASET_RECYCLE_FILE("datasetRecycleFile", "数据集文件回收"),
/**
* 数据集版本文件回收
*/
DATASET_RECYCLE_VERSION_FILE("datasetRecycleVersionFile", "数据集版本文件回收"),

;

private String className;

private String message;

RecycleResourceEnum(String className, String message) {
this.className = className;
this.message = message;
}


}

+ 61
- 13
dubhe-server/common/src/main/java/org/dubhe/enums/TrainJobStatusEnum.java View File

@@ -19,8 +19,8 @@ package org.dubhe.enums;

import lombok.Getter;

import java.util.Arrays;
import java.util.List;
import java.util.HashSet;
import java.util.Set;

/**
* @description 训练任务枚举类
@@ -72,7 +72,13 @@ public enum TrainJobStatusEnum {
this.message = message;
}

public static TrainJobStatusEnum get(String msg) {
/**
* 根据信息获取枚举类对象
*
* @param msg 信息
* @return 枚举类对象
*/
public static TrainJobStatusEnum getByMessage(String msg) {
for (TrainJobStatusEnum statusEnum : values()) {
if (statusEnum.message.equalsIgnoreCase(msg)) {
return statusEnum;
@@ -81,21 +87,63 @@ public enum TrainJobStatusEnum {
return UNKNOWN;
}

/**
* 回调状态转换 若是DELETED则转换为STOP,避免状态不统一
* @param phase k8s pod phase
* @return
*/
public static TrainJobStatusEnum transferStatus(String phase) {
TrainJobStatusEnum enums = getByMessage(phase);
if (enums != DELETED) {
return enums;
}
return STOP;
}

/**
* 根据状态获取枚举类对象
*
* @param status 状态
* @return 枚举类对象
*/
public static TrainJobStatusEnum getByStatus(Integer status) {
for (TrainJobStatusEnum statusEnum : values()) {
if (statusEnum.status.equals(status)) {
return statusEnum;
}
}
return UNKNOWN;
}


/**
* 结束状态枚举集合
*/
public static final Set<TrainJobStatusEnum> END_TRAIN_JOB_STATUS;

static {
END_TRAIN_JOB_STATUS = new HashSet<>();
END_TRAIN_JOB_STATUS.add(SUCCEEDED);
END_TRAIN_JOB_STATUS.add(FAILED);
END_TRAIN_JOB_STATUS.add(STOP);
END_TRAIN_JOB_STATUS.add(CREATE_FAILED);
END_TRAIN_JOB_STATUS.add(DELETED);
}

public static boolean isEnd(String msg) {
List<String> endList = Arrays.asList("SUCCEEDED", "FAILED", "STOP", "CREATE_FAILED");
return endList.stream().anyMatch(s -> s.equalsIgnoreCase(msg));
return END_TRAIN_JOB_STATUS.contains(getByMessage(msg));
}

public static boolean isEnd(Integer num) {
List<Integer> endList = Arrays.asList(2, 3, 4, 7);
return endList.stream().anyMatch(s -> s.equals(num));
public static boolean isEnd(Integer status) {
return END_TRAIN_JOB_STATUS.contains(getByStatus(status));
}

public static boolean checkStopStatus(Integer num) {
return SUCCEEDED.getStatus().equals(num) ||
FAILED.getStatus().equals(num) ||
STOP.getStatus().equals(num) ||
CREATE_FAILED.getStatus().equals(num) ||
DELETED.getStatus().equals(num);
return isEnd(num);
}

public static boolean checkRunStatus(Integer num) {
return PENDING.getStatus().equals(num) ||
RUNNING.getStatus().equals(num);
}
}

+ 1
- 0
dubhe-server/common/src/main/java/org/dubhe/exception/BaseErrorCode.java View File

@@ -54,6 +54,7 @@ public enum BaseErrorCode implements ErrorCode {
SYSTEM_USER_CANNOT_DELETE(20014, "系统默认用户不可删除!"),
SYSTEM_ROLE_CANNOT_DELETE(20015, "系统默认角色不可删除!"),

DATASET_ADMIN_PERMISSION_ERROR(1310,"无此权限,请联系管理员"),

;



+ 42
- 0
dubhe-server/common/src/main/java/org/dubhe/exception/DataSequenceException.java View File

@@ -0,0 +1,42 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.exception;

import lombok.Getter;

/**
* @description 获取序列异常
* @date 2020-09-23
*/
@Getter
public class DataSequenceException extends BusinessException {

private static final long serialVersionUID = 1L;

public DataSequenceException(String msg) {
super(msg);
}

public DataSequenceException(String msg, Throwable cause) {
super(msg,cause);
}

public DataSequenceException(Throwable cause) {
super(cause);
}
}

+ 1
- 2
dubhe-server/common/src/main/java/org/dubhe/exception/NotebookBizException.java View File

@@ -20,8 +20,7 @@ package org.dubhe.exception;
import lombok.Getter;

/**
* @description: Notebook 业务处理异常
*
* @description Notebook 业务处理异常
* @date 2020.04.27
*/
@Getter


+ 144
- 144
dubhe-server/common/src/main/java/org/dubhe/exception/handler/GlobalExceptionHandler.java View File

@@ -17,8 +17,8 @@

package org.dubhe.exception.handler;

import java.util.Objects;
import lombok.extern.slf4j.Slf4j;
import org.apache.ibatis.exceptions.IbatisException;
import org.apache.shiro.ShiroException;
import org.apache.shiro.authc.AuthenticationException;
import org.apache.shiro.authc.IncorrectCredentialsException;
@@ -27,11 +27,7 @@ import org.apache.shiro.authc.UnknownAccountException;
import org.dubhe.base.DataResponseBody;
import org.dubhe.base.ResponseCode;
import org.dubhe.enums.LogEnum;
import org.dubhe.exception.BusinessException;
import org.dubhe.exception.CaptchaException;
import org.dubhe.exception.LoginException;
import org.dubhe.exception.NotebookBizException;
import org.dubhe.exception.UnauthorizedException;
import org.dubhe.exception.*;
import org.dubhe.utils.LogUtil;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
@@ -41,7 +37,7 @@ import org.springframework.web.bind.MethodArgumentNotValidException;
import org.springframework.web.bind.annotation.ExceptionHandler;
import org.springframework.web.bind.annotation.RestControllerAdvice;

import lombok.extern.slf4j.Slf4j;
import java.util.Objects;

/**
* @description 处理异常
@@ -51,140 +47,144 @@ import lombok.extern.slf4j.Slf4j;
@RestControllerAdvice
public class GlobalExceptionHandler {

/**
* 处理所有不可知的异常
*/
@ExceptionHandler(Throwable.class)
public ResponseEntity<DataResponseBody> handleException(Throwable e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
return buildResponseEntity(HttpStatus.INTERNAL_SERVER_ERROR,
new DataResponseBody(ResponseCode.ERROR, e.getMessage()));
}

/**
* UnauthorizedException
*/
@ExceptionHandler(UnauthorizedException.class)
public ResponseEntity<DataResponseBody> badCredentialsException(UnauthorizedException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
String message = "坏的凭证".equals(e.getMessage()) ? "用户名或密码不正确" : e.getMessage();
return buildResponseEntity(HttpStatus.UNAUTHORIZED, new DataResponseBody(ResponseCode.ERROR, message));
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = BusinessException.class)
public ResponseEntity<DataResponseBody> badRequestException(BusinessException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
return buildResponseEntity(HttpStatus.OK, e.getResponseBody());
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = AuthenticationException.class)
public ResponseEntity<DataResponseBody> badRequestException(AuthenticationException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问"));
}

/**
* shiro 异常捕捉
*/
@ExceptionHandler(value = ShiroException.class)
public ResponseEntity<DataResponseBody> accountException(ShiroException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
ResponseEntity<DataResponseBody> responseEntity;
if (e instanceof IncorrectCredentialsException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "密码不正确"));
} else if (e instanceof UnknownAccountException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "此账户不存在"));
}

else if (e instanceof LockedAccountException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "未知的账号"));
}

else if (e instanceof UnknownAccountException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "账户已被禁用"));
}

else {
responseEntity = buildResponseEntity(HttpStatus.OK,
new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问"));
}
return responseEntity;
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = LoginException.class)
public ResponseEntity<DataResponseBody> loginException(LoginException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
return buildResponseEntity(HttpStatus.UNAUTHORIZED, e.getResponseBody());
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = CaptchaException.class)
public ResponseEntity<DataResponseBody> captchaException(CaptchaException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
return buildResponseEntity(HttpStatus.OK, e.getResponseBody());
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = NotebookBizException.class)
public ResponseEntity<DataResponseBody> captchaException(NotebookBizException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
return buildResponseEntity(HttpStatus.OK, e.getResponseBody());
}

/**
* 处理所有接口数据验证异常
*/
@ExceptionHandler(MethodArgumentNotValidException.class)
public ResponseEntity<DataResponseBody> handleMethodArgumentNotValidException(MethodArgumentNotValidException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
String[] str = Objects.requireNonNull(e.getBindingResult().getAllErrors().get(0).getCodes())[1].split("\\.");
String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage();
String msg = "不能为空";
if (msg.equals(message)) {
message = str[1] + ":" + message;
}
return buildResponseEntity(HttpStatus.BAD_REQUEST, new DataResponseBody(ResponseCode.ERROR, message));
}

@ExceptionHandler(BindException.class)
public ResponseEntity<DataResponseBody> bindException(BindException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, e);
ObjectError error = e.getAllErrors().get(0);
return buildResponseEntity(HttpStatus.BAD_REQUEST,
new DataResponseBody(ResponseCode.ERROR, error.getDefaultMessage()));
}

/**
* 统一返回
*
* @param httpStatus
* @param responseBody
* @return
*/
private ResponseEntity<DataResponseBody> buildResponseEntity(HttpStatus httpStatus, DataResponseBody responseBody) {
return new ResponseEntity<>(responseBody, httpStatus);
}
/**
* 处理所有不可知的异常
*/
@ExceptionHandler(Throwable.class)
public ResponseEntity<DataResponseBody> handleException(Throwable e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
return buildResponseEntity(HttpStatus.INTERNAL_SERVER_ERROR,
new DataResponseBody(ResponseCode.ERROR, e.getMessage()));
}

/**
* UnauthorizedException
*/
@ExceptionHandler(UnauthorizedException.class)
public ResponseEntity<DataResponseBody> badCredentialsException(UnauthorizedException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
String message = "坏的凭证".equals(e.getMessage()) ? "用户名或密码不正确" : e.getMessage();
return buildResponseEntity(HttpStatus.UNAUTHORIZED, new DataResponseBody(ResponseCode.ERROR, message));
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = BusinessException.class)
public ResponseEntity<DataResponseBody> badRequestException(BusinessException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
return buildResponseEntity(HttpStatus.OK, e.getResponseBody());
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = IbatisException.class)
public ResponseEntity<DataResponseBody> persistenceException(IbatisException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, e.getMessage()));
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = AuthenticationException.class)
public ResponseEntity<DataResponseBody> badRequestException(AuthenticationException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问"));
}

/**
* shiro 异常捕捉
*/
@ExceptionHandler(value = ShiroException.class)
public ResponseEntity<DataResponseBody> accountException(ShiroException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
ResponseEntity<DataResponseBody> responseEntity;
if (e instanceof IncorrectCredentialsException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "密码不正确"));
} else if (e instanceof UnknownAccountException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "此账户不存在"));
} else if (e instanceof LockedAccountException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "未知的账号"));
} else if (e instanceof UnknownAccountException) {
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "账户已被禁用"));
} else {
responseEntity = buildResponseEntity(HttpStatus.OK,
new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问"));
}
return responseEntity;
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = LoginException.class)
public ResponseEntity<DataResponseBody> loginException(LoginException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
return buildResponseEntity(HttpStatus.UNAUTHORIZED, e.getResponseBody());
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = CaptchaException.class)
public ResponseEntity<DataResponseBody> captchaException(CaptchaException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
return buildResponseEntity(HttpStatus.OK, e.getResponseBody());
}

/**
* 处理自定义异常
*/
@ExceptionHandler(value = NotebookBizException.class)
public ResponseEntity<DataResponseBody> captchaException(NotebookBizException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
return buildResponseEntity(HttpStatus.OK, e.getResponseBody());
}

/**
* 处理所有接口数据验证异常
*/
@ExceptionHandler(MethodArgumentNotValidException.class)
public ResponseEntity<DataResponseBody> handleMethodArgumentNotValidException(MethodArgumentNotValidException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
String[] str = Objects.requireNonNull(e.getBindingResult().getAllErrors().get(0).getCodes())[1].split("\\.");
String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage();
String msg = "不能为空";
if (msg.equals(message)) {
message = str[1] + ":" + message;
}
return buildResponseEntity(HttpStatus.BAD_REQUEST, new DataResponseBody(ResponseCode.ERROR, message));
}

@ExceptionHandler(BindException.class)
public ResponseEntity<DataResponseBody> bindException(BindException e) {
// 打印堆栈信息
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e);
ObjectError error = e.getAllErrors().get(0);
return buildResponseEntity(HttpStatus.BAD_REQUEST,
new DataResponseBody(ResponseCode.ERROR, error.getDefaultMessage()));
}

/**
* 统一返回
*
* @param httpStatus
* @param responseBody
* @return
*/
private ResponseEntity<DataResponseBody> buildResponseEntity(HttpStatus httpStatus, DataResponseBody responseBody) {
return new ResponseEntity<>(responseBody, httpStatus);
}
}

+ 74
- 0
dubhe-server/common/src/main/java/org/dubhe/filter/BaseLogFilter.java View File

@@ -0,0 +1,74 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.filter;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.filter.AbstractMatcherFilter;
import ch.qos.logback.core.spi.FilterReply;
import cn.hutool.core.util.StrUtil;
import org.slf4j.Marker;

/**
* @description 自定义日志过滤器
* @date 2020-07-21
*/
public class BaseLogFilter extends AbstractMatcherFilter<ILoggingEvent> {

Level level;

/**
* 重写decide方法
*
* @param iLoggingEvent event to decide upon.
* @return FilterReply
*/
@Override
public FilterReply decide(ILoggingEvent iLoggingEvent) {
if (!isStarted()) {
return FilterReply.NEUTRAL;
}
final String msg = iLoggingEvent.getMessage();
//自定义级别
if (checkLevel(iLoggingEvent) && msg != null && msg.startsWith(StrUtil.DELIM_START) && msg.endsWith(StrUtil.DELIM_END)) {
final Marker marker = iLoggingEvent.getMarker();
if (marker != null && this.getName() != null && this.getName().contains(marker.getName())) {
return onMatch;
}
}

return onMismatch;
}

protected boolean checkLevel(ILoggingEvent iLoggingEvent) {
return this.level != null
&& iLoggingEvent.getLevel() != null
&& iLoggingEvent.getLevel().toInt() == this.level.toInt();
}

public void setLevel(Level level) {
this.level = level;
}

@Override
public void start() {
if (this.level != null) {
super.start();
}
}
}

+ 49
- 0
dubhe-server/common/src/main/java/org/dubhe/filter/ConsoleLogFilter.java View File

@@ -0,0 +1,49 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.filter;

import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.spi.FilterReply;
import org.dubhe.utils.LogUtil;
import org.slf4j.MarkerFactory;

/**
* @description 自定义日志过滤器
* @date 2020-07-21
*/
public class ConsoleLogFilter extends BaseLogFilter {

@Override
public FilterReply decide(ILoggingEvent iLoggingEvent) {
if (!isStarted()) {
return FilterReply.NEUTRAL;
}
return checkLevel(iLoggingEvent) ? onMatch : onMismatch;
}

protected boolean checkLevel(ILoggingEvent iLoggingEvent) {


return this.level != null
&& iLoggingEvent.getLevel() != null
&& iLoggingEvent.getLevel().toInt() >= this.level.toInt()
&& !MarkerFactory.getMarker(LogUtil.K8S_CALLBACK_LEVEL).equals(iLoggingEvent.getMarker())
&& !MarkerFactory.getMarker(LogUtil.SCHEDULE_LEVEL).equals(iLoggingEvent.getMarker())
&& !"log4jdbc.log4j2".equals(iLoggingEvent.getLoggerName());
}
}

+ 0
- 60
dubhe-server/common/src/main/java/org/dubhe/filter/FileLogFilter.java View File

@@ -1,60 +0,0 @@
/**
* Copyright 2019-2020 Zheng Jie
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.dubhe.filter;

import ch.qos.logback.classic.Level;
import ch.qos.logback.classic.spi.ILoggingEvent;
import ch.qos.logback.core.filter.AbstractMatcherFilter;
import ch.qos.logback.core.spi.FilterReply;

/**
* @description 自定义日志过滤器
* @date 2020-07-21
*/
public class FileLogFilter extends AbstractMatcherFilter<ILoggingEvent> {

Level level;

/**
* 重写decide方法
*
* @param iLoggingEvent event to decide upon.
* @return FilterReply
*/
@Override
public FilterReply decide(ILoggingEvent iLoggingEvent) {
if (!isStarted()) {
return FilterReply.NEUTRAL;
}
if (iLoggingEvent.getLevel().equals(level) && iLoggingEvent.getMessage() != null
&& iLoggingEvent.getMessage().startsWith("{") && iLoggingEvent.getMessage().endsWith("}")) {
return onMatch;
}
return onMismatch;

}

public void setLevel(Level level) {
this.level = level;
}

@Override
public void start() {
if (this.level != null) {
super.start();
}
}
}

+ 34
- 0
dubhe-server/common/src/main/java/org/dubhe/filter/GlobalRequestLogFilter.java View File

@@ -0,0 +1,34 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.filter;

import ch.qos.logback.classic.spi.ILoggingEvent;

/**
* @description 全局请求 日志过滤器
* @date 2020-08-13
*/
public class GlobalRequestLogFilter extends BaseLogFilter {


@Override
public boolean checkLevel(ILoggingEvent iLoggingEvent) {
return this.level != null;
}

}

+ 113
- 0
dubhe-server/common/src/main/java/org/dubhe/interceptor/MySqlInterceptor.java View File

@@ -0,0 +1,113 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.interceptor;

import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.DefaultReflectorFactory;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.dubhe.annotation.DataPermission;
import org.dubhe.base.DataContext;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.enums.OperationTypeEnum;
import org.dubhe.utils.JwtUtils;
import org.dubhe.utils.SqlUtil;
import org.springframework.stereotype.Component;

import java.lang.reflect.Field;
import java.sql.Connection;
import java.util.Arrays;
import java.util.Objects;
import java.util.Properties;


/**
* @description mybatis拦截器
* @date 2020-06-10
*/
@Component
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})})
public class MySqlInterceptor implements Interceptor {
@Override
public Object intercept(Invocation invocation) throws Throwable {

StatementHandler statementHandler = (StatementHandler) invocation.getTarget();
MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY,
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory());
/*
* 先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler,
* 然后就到BaseStatementHandler的成员变量mappedStatement
*/
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
//id为执行的mapper方法的全路径名,如com.uv.dao.UserDao.selectPageVo
String id = mappedStatement.getId();
//sql语句类型 select、delete、insert、update
String sqlCommandType = mappedStatement.getSqlCommandType().toString();
BoundSql boundSql = statementHandler.getBoundSql();

//获取到原始sql语句
String sql = boundSql.getSql();
String mSql = sql;

//注解逻辑判断 添加注解了才拦截
Class<?> classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf(".")));
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, mappedStatement.getId().length());
UserDTO currentUserDto = JwtUtils.getCurrentUserDto();

//获取类注解 获取需要忽略拦截的方法名称
DataPermission dataAnnotation = classType.getAnnotation(DataPermission.class);
if (!Objects.isNull(dataAnnotation)) {

String[] ignores = dataAnnotation.ignoresMethod();
//校验拦截忽略方法名 忽略新增方法 忽略回调/定时方法
if ((!Objects.isNull(ignores) && Arrays.asList(ignores).contains(mName))
|| OperationTypeEnum.INSERT.getType().equals(sqlCommandType.toLowerCase())
|| Objects.isNull(currentUserDto)
|| (!Objects.isNull(DataContext.get()) && DataContext.get().getType())
) {
return invocation.proceed();
} else {
//拦截所有sql操作类型
mSql = SqlUtil.buildTargetSql(sql, SqlUtil.getResourceIds());
}
}

//通过反射修改sql语句
Field field = boundSql.getClass().getDeclaredField("sql");
field.setAccessible(true);
field.set(boundSql, mSql);
return invocation.proceed();
}

@Override
public Object plugin(Object target) {
if (target instanceof StatementHandler) {
return Plugin.wrap(target, this);
} else {
return target;
}
}

@Override
public void setProperties(Properties properties) {

}

}

+ 457
- 0
dubhe-server/common/src/main/java/org/dubhe/interceptor/PaginationInterceptor.java View File

@@ -0,0 +1,457 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.interceptor;

import com.baomidou.mybatisplus.annotation.DbType;
import com.baomidou.mybatisplus.core.MybatisDefaultParameterHandler;
import com.baomidou.mybatisplus.core.metadata.IPage;
import com.baomidou.mybatisplus.core.metadata.OrderItem;
import com.baomidou.mybatisplus.core.parser.ISqlParser;
import com.baomidou.mybatisplus.core.parser.SqlInfo;
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.StringUtils;
import com.baomidou.mybatisplus.extension.handlers.AbstractSqlParserHandler;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory;
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel;
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect;
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils;
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils;
import net.sf.jsqlparser.JSQLParserException;
import net.sf.jsqlparser.parser.CCJSqlParserUtil;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.statement.select.*;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.logging.Log;
import org.apache.ibatis.logging.LogFactory;
import org.apache.ibatis.mapping.*;
import org.apache.ibatis.plugin.*;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.apache.ibatis.session.Configuration;
import org.dubhe.annotation.DataPermission;
import org.dubhe.base.DataContext;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.enums.OperationTypeEnum;
import org.dubhe.utils.JwtUtils;
import org.dubhe.utils.SqlUtil;

import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.util.*;
import java.util.stream.Collectors;

/**
* @description MybatisPlus 分页拦截器
* @date 2020-10-09
*/
@Intercepts({@Signature(
type = StatementHandler.class,
method = "prepare",
args = {Connection.class, Integer.class}
)})
public class PaginationInterceptor extends AbstractSqlParserHandler implements Interceptor {
protected static final Log logger = LogFactory.getLog(PaginationInterceptor.class);

/**
* COUNT SQL 解析
*/
protected ISqlParser countSqlParser;
/**
* 溢出总页数,设置第一页
*/
protected boolean overflow = false;
/**
* 单页限制 500 条,小于 0 如 -1 不受限制
*/
protected long limit = 500L;
/**
* 数据类型
*/
private DbType dbType;
/**
* 方言
*/
private IDialect dialect;
/**
* 方言类型
*/
@Deprecated
protected String dialectType;
/**
* 方言实现类
*/
@Deprecated
protected String dialectClazz;

public PaginationInterceptor() {
}

/**
* 构建分页sql
*
* @param originalSql 原生sql
* @param page 分页参数
* @return 构建后 sql
*/
public static String concatOrderBy(String originalSql, IPage<?> page) {
if (CollectionUtils.isNotEmpty(page.orders())) {
try {
List<OrderItem> orderList = page.orders();
Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql);
List orderByElements;
List orderByElementsReturn;
if (selectStatement.getSelectBody() instanceof PlainSelect) {
PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
orderByElements = plainSelect.getOrderByElements();
orderByElementsReturn = addOrderByElements(orderList, orderByElements);
plainSelect.setOrderByElements(orderByElementsReturn);
return plainSelect.toString();
}

if (selectStatement.getSelectBody() instanceof SetOperationList) {
SetOperationList setOperationList = (SetOperationList) selectStatement.getSelectBody();
orderByElements = setOperationList.getOrderByElements();
orderByElementsReturn = addOrderByElements(orderList, orderByElements);
setOperationList.setOrderByElements(orderByElementsReturn);
return setOperationList.toString();
}

if (selectStatement.getSelectBody() instanceof WithItem) {
return originalSql;
}

return originalSql;
} catch (JSQLParserException var7) {
logger.error("failed to concat orderBy from IPage, exception=", var7);
}
}

return originalSql;
}

/**
* 添加分页排序规则
*
* @param orderList 分页规则
* @param orderByElements 分页排序元素
* @return 分页规则
*/
private static List<OrderByElement> addOrderByElements(List<OrderItem> orderList, List<OrderByElement> orderByElements) {
orderByElements = CollectionUtils.isEmpty(orderByElements) ? new ArrayList(orderList.size()) : orderByElements;
List<OrderByElement> orderByElementList = (List) orderList.stream().filter((item) -> {
return StringUtils.isNotBlank(item.getColumn());
}).map((item) -> {
OrderByElement element = new OrderByElement();
element.setExpression(new Column(item.getColumn()));
element.setAsc(item.isAsc());
element.setAscDescPresent(true);
return element;
}).collect(Collectors.toList());
((List) orderByElements).addAll(orderByElementList);
return (List) orderByElements;
}

/**
* 执行sql查询逻辑
*
* @param invocation mybatis 调用类
* @return
* @throws Throwable
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget());
MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
this.sqlParser(metaObject);
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement");
if (SqlCommandType.SELECT == mappedStatement.getSqlCommandType() && StatementType.CALLABLE != mappedStatement.getStatementType()) {
BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql");
Object paramObj = boundSql.getParameterObject();
IPage<?> page = null;
if (paramObj instanceof IPage) {
page = (IPage) paramObj;
} else if (paramObj instanceof Map) {
Iterator var8 = ((Map) paramObj).values().iterator();

while (var8.hasNext()) {
Object arg = var8.next();
if (arg instanceof IPage) {
page = (IPage) arg;
break;
}
}
}

if (null != page && page.getSize() >= 0L) {
if (this.limit > 0L && this.limit <= page.getSize()) {
this.handlerLimit(page);
}

String originalSql = boundSql.getSql();

//注解逻辑判断 添加注解了才拦截
Class<?> classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf(".")));
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, mappedStatement.getId().length());
UserDTO currentUserDto = JwtUtils.getCurrentUserDto();

String sqlCommandType = mappedStatement.getSqlCommandType().toString();
//获取类注解 获取需要忽略拦截的方法名称
DataPermission dataAnnotation = classType.getAnnotation(DataPermission.class);
if (!Objects.isNull(dataAnnotation)) {

String[] ignores = dataAnnotation.ignoresMethod();
//校验拦截忽略方法名 忽略新增方法 忽略回调/定时方法
if (!((!Objects.isNull(ignores) && Arrays.asList(ignores).contains(mName))
|| OperationTypeEnum.INSERT.getType().equals(sqlCommandType.toLowerCase())
|| Objects.isNull(currentUserDto)
|| (!Objects.isNull(DataContext.get()) && DataContext.get().getType()))


) {
originalSql = SqlUtil.buildTargetSql(originalSql, SqlUtil.getResourceIds());
}
}

Connection connection = (Connection) invocation.getArgs()[0];
if (page.isSearchCount() && !page.isHitCount()) {
SqlInfo sqlInfo = SqlParserUtils.getOptimizeCountSql(page.optimizeCountSql(), this.countSqlParser, originalSql);
this.queryTotal(sqlInfo.getSql(), mappedStatement, boundSql, page, connection);
if (page.getTotal() <= 0L) {
return null;
}
}

DbType dbType = Optional.ofNullable(this.dbType).orElse(JdbcUtils.getDbType(connection.getMetaData().getURL()));
IDialect dialect = Optional.ofNullable(this.dialect).orElse(DialectFactory.getDialect(dbType));
String buildSql = concatOrderBy(originalSql, page);
DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize());
Configuration configuration = mappedStatement.getConfiguration();
List<ParameterMapping> mappings = new ArrayList(boundSql.getParameterMappings());
Map<String, Object> additionalParameters = (Map) metaObject.getValue("delegate.boundSql.additionalParameters");
model.consumers(mappings, configuration, additionalParameters);
metaObject.setValue("delegate.boundSql.sql", model.getDialectSql());
metaObject.setValue("delegate.boundSql.parameterMappings", mappings);
return invocation.proceed();
} else {
return invocation.proceed();
}
} else {
return invocation.proceed();
}
}

/**
* 处理分页数量
*
* @param page 分页参数
*/
protected void handlerLimit(IPage<?> page) {
page.setSize(this.limit);
}

/**
* 查询总数量
*
* @param sql sql语句
* @param mappedStatement 映射语句包装类
* @param boundSql sql包装类
* @param page 分页参数
* @param connection JDBC连接包装类
*/
protected void queryTotal(String sql, MappedStatement mappedStatement, BoundSql boundSql, IPage<?> page, Connection connection) {
try {
PreparedStatement statement = connection.prepareStatement(sql);
Throwable var7 = null;

try {
DefaultParameterHandler parameterHandler = new MybatisDefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql);
parameterHandler.setParameters(statement);
long total = 0L;
ResultSet resultSet = statement.executeQuery();
Throwable var12 = null;

try {
if (resultSet.next()) {
total = resultSet.getLong(1);
}
} catch (Throwable var37) {
var12 = var37;
throw var37;
} finally {
if (resultSet != null) {
if (var12 != null) {
try {
resultSet.close();
} catch (Throwable var36) {
var12.addSuppressed(var36);
}
} else {
resultSet.close();
}
}

}

page.setTotal(total);
if (this.overflow && page.getCurrent() > page.getPages()) {
this.handlerOverflow(page);
}
} catch (Throwable var39) {
var7 = var39;
throw var39;
} finally {
if (statement != null) {
if (var7 != null) {
try {
statement.close();
} catch (Throwable var35) {
var7.addSuppressed(var35);
}
} else {
statement.close();
}
}

}

} catch (Exception var41) {
throw ExceptionUtils.mpe("Error: Method queryTotal execution error of sql : \n %s \n", var41, new Object[]{sql});
}
}

/**
* 设置默认当前页
*
* @param page 分页参数
*/
protected void handlerOverflow(IPage<?> page) {
page.setCurrent(1L);
}

/**
* MybatisPlus拦截器实现自定义插件
*
* @param target 拦截目标对象
* @return
*/
@Override
public Object plugin(Object target) {
return target instanceof StatementHandler ? Plugin.wrap(target, this) : target;
}

/**
* MybatisPlus拦截器实现自定义属性设置
*
* @param prop 属性参数
*/
@Override
public void setProperties(Properties prop) {
String dialectType = prop.getProperty("dialectType");
String dialectClazz = prop.getProperty("dialectClazz");
if (StringUtils.isNotBlank(dialectType)) {
this.setDialectType(dialectType);
}

if (StringUtils.isNotBlank(dialectClazz)) {
this.setDialectClazz(dialectClazz);
}

}

/**
* 设置数据源类型
*
* @param dialectType 数据源类型
*/
@Deprecated
public void setDialectType(String dialectType) {
this.setDbType(DbType.getDbType(dialectType));
}


/**
* 设置方言实现类配置
*
* @param dialectClazz 方言实现类
*/
@Deprecated
public void setDialectClazz(String dialectClazz) {
this.setDialect(DialectFactory.getDialect(dialectClazz));
}

/**
* 设置获取总数的sql解析器
*
* @param countSqlParser 总数的sql解析器
* @return 自定义MybatisPlus拦截器
*/
public PaginationInterceptor setCountSqlParser(final ISqlParser countSqlParser) {
this.countSqlParser = countSqlParser;
return this;
}

/**
* 溢出总页数,设置第一页
*
* @param overflow 溢出总页数
* @return 自定义MybatisPlus拦截器
*/
public PaginationInterceptor setOverflow(final boolean overflow) {
this.overflow = overflow;
return this;
}

/**
* 设置分页规则
*
* @param limit 分页数量
* @return 自定义MybatisPlus拦截器
*/
public PaginationInterceptor setLimit(final long limit) {
this.limit = limit;
return this;
}

/**
* 设置数据类型
*
* @param dbType 数据类型
* @return 自定义MybatisPlus拦截器
*/
public PaginationInterceptor setDbType(final DbType dbType) {
this.dbType = dbType;
return this;
}

/**
* 设置方言
*
* @param dialect 方言
* @return 自定义MybatisPlus拦截器
*/
public PaginationInterceptor setDialect(final IDialect dialect) {
this.dialect = dialect;
return this;
}


}


+ 19
- 1
dubhe-server/common/src/main/java/org/dubhe/utils/DateUtil.java View File

@@ -18,16 +18,25 @@
package org.dubhe.utils;

import java.sql.Timestamp;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneId;
import java.util.Date;

/**
* @description 日期工具类
* @date 2020-6-10
* @date 2020-06-10
*/
public class DateUtil {

private DateUtil(){

}


/**
* 获取当前时间戳
*
@@ -77,4 +86,13 @@ public class DateUtil {
return (milli-l1);
}

/**
* @return 当前字符串时间yyyy-MM-dd HH:mm:ss SSS
*/
public static String getCurrentTimeStr(){
Date date = new Date();
DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss SSS");
return dateFormat.format(date);
}

}

+ 54
- 0
dubhe-server/common/src/main/java/org/dubhe/utils/FileUtil.java View File

@@ -19,10 +19,12 @@ package org.dubhe.utils;

import cn.hutool.core.codec.Base64;
import cn.hutool.core.io.IoUtil;
import cn.hutool.core.util.CharsetUtil;
import cn.hutool.core.util.IdUtil;
import cn.hutool.poi.excel.BigExcelWriter;
import cn.hutool.poi.excel.ExcelUtil;
import org.apache.poi.util.IOUtils;
import org.dubhe.enums.LogEnum;
import org.dubhe.exception.BusinessException;
import org.springframework.web.multipart.MultipartFile;

@@ -352,4 +354,56 @@ public class FileUtil extends cn.hutool.core.io.FileUtil {
return getMd5(getByte(file));
}


/**
* 生成文件
* @param filePath 文件绝对路径
* @param content 文件内容
* @param append 文件是否是追加
* @return
*/
public static boolean generateFile(String filePath,String content,boolean append){
File file = new File(filePath);
FileOutputStream outputStream = null;
try {
if (!file.exists()){
file.createNewFile();
}
outputStream = new FileOutputStream(file,append);
outputStream.write(content.getBytes(CharsetUtil.defaultCharset()));
outputStream.flush();
}catch (IOException e) {
LogUtil.error(LogEnum.FILE_UTIL,e);
return false;
}finally {
if (outputStream != null){
try {
outputStream.close();
} catch (IOException e) {
LogUtil.error(LogEnum.FILE_UTIL,e);
}
}
}
return true;
}


/**
* 压缩文件目录
*
* @param zipDir 待压缩文件夹路径
* @param zipFile 压缩完成zip文件绝对路径
* @return
*/
public static boolean zipPath(String zipDir,String zipFile) {
if (zipDir == null) {
return false;
}
File zip = new File(zipFile);
cn.hutool.core.util.ZipUtil.zip(zip, CharsetUtil.defaultCharset(), true,
(f) -> !f.isDirectory(),
new File(zipDir).listFiles());
return true;
}

}

+ 48
- 10
dubhe-server/common/src/main/java/org/dubhe/utils/HttpClientUtils.java View File

@@ -20,6 +20,7 @@ import org.apache.commons.io.IOUtils;
import org.dubhe.enums.LogEnum;



import javax.net.ssl.HttpsURLConnection;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
@@ -31,13 +32,13 @@ import java.io.InputStreamReader;
import java.net.URL;
import java.security.cert.CertificateException;
import java.security.cert.X509Certificate;
import org.apache.commons.codec.binary.Base64;
import static org.dubhe.constant.StringConstant.UTF8;
import static org.dubhe.constant.SymbolConstant.BLANK;

/**
* @description: httpClient工具类,不校验SSL证书
* @date: 2020-5-21
* @description httpClient工具类,不校验SSL证书
* @date 2020-05-21
*/
public class HttpClientUtils {

@@ -45,7 +46,7 @@ public class HttpClientUtils {
InputStream inputStream = null;
BufferedReader bufferedReader = null;
InputStreamReader inputStreamReader = null;
StringBuilder stringBuider = new StringBuilder();
StringBuilder stringBuilder = new StringBuilder();
String result = BLANK;
HttpsURLConnection con = null;
try {
@@ -59,10 +60,10 @@ public class HttpClientUtils {

String str = null;
while ((str = bufferedReader.readLine()) != null) {
stringBuider.append(str);
stringBuilder.append(str);
}

result = stringBuider.toString();
result = stringBuilder.toString();
LogUtil.info(LogEnum.BIZ_SYS,"Request path:{}, SUCCESS, result:{}", path, result);

} catch (Exception e) {
@@ -74,17 +75,54 @@ public class HttpClientUtils {

return result;
}
public static String sendHttpsDelete(String path,String username,String password) {
InputStream inputStream = null;
BufferedReader bufferedReader = null;
InputStreamReader inputStreamReader = null;
StringBuilder stringBuilder = new StringBuilder();
String result = BLANK;
HttpsURLConnection con = null;
try {
con = getConnection(path);
String input =username+ ":" +password;
String encoding=Base64.encodeBase64String(input.getBytes());
con.setRequestProperty(JwtUtils.AUTH_HEADER, "Basic " + encoding);
con.setRequestMethod("DELETE");
con.connect();
/**将返回的输入流转换成字符串**/
inputStream = con.getInputStream();
inputStreamReader = new InputStreamReader(inputStream, UTF8);
bufferedReader = new BufferedReader(inputStreamReader);

private static void closeResource(BufferedReader bufferedReader,InputStreamReader inputStreamReader,InputStream inputStream,HttpsURLConnection con) {
String str = null;
while ((str = bufferedReader.readLine()) != null) {
stringBuilder.append(str);
}

IOUtils.closeQuietly(bufferedReader);
result = stringBuilder.toString();
LogUtil.info(LogEnum.BIZ_SYS,"Request path:{}, SUCCESS, result:{}", path, result);

if (inputStreamReader != null) {
IOUtils.closeQuietly(inputStreamReader);
} catch (Exception e) {
LogUtil.error(LogEnum.BIZ_SYS,"Request path:{}, ERROR, exception:{}", path, e);
return result;
} finally {
closeResource(bufferedReader,inputStreamReader,inputStream,con);
}

return result;
}

private static void closeResource(BufferedReader bufferedReader,InputStreamReader inputStreamReader,InputStream inputStream,HttpsURLConnection con) {
if (inputStream != null) {
IOUtils.closeQuietly(inputStream);
}

if (inputStreamReader != null) {
IOUtils.closeQuietly(inputStreamReader);
}
if (bufferedReader != null) {
IOUtils.closeQuietly(bufferedReader);
}
if (con != null) {
con.disconnect();
}


+ 2
- 2
dubhe-server/common/src/main/java/org/dubhe/utils/HttpUtils.java View File

@@ -20,8 +20,8 @@ package org.dubhe.utils;
import lombok.extern.slf4j.Slf4j;

/**
* @description: HttpUtil
* @date 2020.04.30
* @description HttpUtil
* @date 2020-04-30
*/
@Slf4j
public class HttpUtils {


+ 46
- 0
dubhe-server/common/src/main/java/org/dubhe/utils/IOUtil.java View File

@@ -0,0 +1,46 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.utils;

import org.dubhe.enums.LogEnum;

import java.io.Closeable;
import java.io.IOException;

/**
* @description IO流操作工具类
* @date 2020-10-14
*/
public class IOUtil {

/**
* 循环的依次关闭流
*
* @param closeableList 要被关闭的流集合
*/
public static void close(Closeable... closeableList) {
for (Closeable closeable : closeableList) {
try {
if (closeable != null) {
closeable.close();
}
} catch (IOException e) {
LogUtil.error(LogEnum.IO_UTIL, "关闭流异常,异常信息:{}", e);
}
}
}
}

+ 1
- 1
dubhe-server/common/src/main/java/org/dubhe/utils/JwtUtils.java View File

@@ -139,7 +139,7 @@ public class JwtUtils {
public static boolean isTokenExpired(String token) {
Date now = Calendar.getInstance().getTime();
DecodedJWT jwt = JWT.decode(token);
return jwt.getExpiresAt().before(now);
return jwt.getExpiresAt() == null || jwt.getExpiresAt().before(now);
}

/**


+ 3
- 3
dubhe-server/common/src/main/java/org/dubhe/utils/K8sNameTool.java View File

@@ -67,12 +67,12 @@ public class K8sNameTool {
}

/**
* 生成 Notebook的NameSpace
* 生成 Notebook的Namespace
*
* @param userId
* @return namespace
*/
public String generateNameSpace(long userId) {
public String generateNamespace(long userId) {
return this.k8sNameConfig.getNamespace() + SEPARATOR + userId;
}

@@ -96,7 +96,7 @@ public class K8sNameTool {
* @param namespace
* @return Long
*/
public Long getUserIdFromNameSpace(String namespace) {
public Long getUserIdFromNamespace(String namespace) {
if (StringUtils.isEmpty(namespace) || !namespace.contains(this.k8sNameConfig.getNamespace() + SEPARATOR)) {
return null;
}


+ 307
- 0
dubhe-server/common/src/main/java/org/dubhe/utils/LocalFileUtil.java View File

@@ -0,0 +1,307 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.utils;

import cn.hutool.core.util.StrUtil;
import lombok.Getter;
import org.apache.commons.compress.archivers.zip.ZipArchiveEntry;
import org.apache.commons.compress.archivers.zip.ZipFile;
import org.apache.commons.io.IOUtils;
import org.dubhe.base.MagicNumConstant;
import org.dubhe.config.NfsConfig;
import org.dubhe.enums.LogEnum;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Component;
import org.springframework.util.FileCopyUtils;

import java.io.*;
import java.util.Enumeration;
import java.util.zip.ZipEntry;

/**
* @description 本地文件操作工具类
* @date 2020-08-19
*/
@Component
@Getter
public class LocalFileUtil {

@Autowired
private NfsConfig nfsConfig;

private static final String FILE_SEPARATOR = File.separator;

private static final String ZIP = ".zip";

private static final String CHARACTER_GBK = "GBK";

private static final String OS_NAME = "os.name";

private static final String WINDOWS = "Windows";

@Value("${k8s.nfs-root-path}")
private String nfsRootPath;

@Value("${k8s.nfs-root-windows-path}")
private String nfsRootWindowsPath;

/**
* windows 与 linux 的路径兼容
*
* @param path linux下的路径
* @return path 兼容windows后的路径
*/
private String compatiblePath(String path) {
if (path == null) {
return null;
}
if (System.getProperties().getProperty(OS_NAME).contains(WINDOWS)) {
path = path.replace(nfsRootPath, StrUtil.SLASH);
path = path.replace(StrUtil.SLASH, FILE_SEPARATOR);
path = nfsRootWindowsPath + path;
}
return path;
}


/**
* 本地解压zip包并删除压缩文件
*
* @param sourcePath zip源文件 例如:/abc/z.zip
* @param targetPath 解压后的目标文件夹 例如:/abc/
* @return boolean
*/
public boolean unzipLocalPath(String sourcePath, String targetPath) {
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) {
return false;
}
if (!sourcePath.toLowerCase().endsWith(ZIP)) {
return false;
}
//绝对路径
String sourceAbsolutePath = nfsConfig.getRootDir() + sourcePath;
String targetPathAbsolutePath = nfsConfig.getRootDir() + targetPath;
ZipFile zipFile = null;
InputStream in = null;
OutputStream out = null;
File sourceFile = new File(compatiblePath(sourceAbsolutePath));
File targetFileDir = new File(compatiblePath(targetPathAbsolutePath));
if (!targetFileDir.exists()) {
boolean targetMkdir = targetFileDir.mkdirs();
if (!targetMkdir) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}failed to create target folder before decompression", sourceAbsolutePath);
}
}
try {
zipFile = new ZipFile(sourceFile);
//判断压缩文件编码方式,并重新获取文件对象
try {
zipFile.close();
zipFile = new ZipFile(sourceFile, CHARACTER_GBK);
} catch (Exception e) {
zipFile.close();
zipFile = new ZipFile(sourceFile);
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}the encoding mode of decompressed compressed file is changed to UTF-8:{}", sourceAbsolutePath, e);
}
ZipEntry entry;
Enumeration enumeration = zipFile.getEntries();
while (enumeration.hasMoreElements()) {
entry = (ZipEntry) enumeration.nextElement();
String entryName = entry.getName();
File fileDir;
if (entry.isDirectory()) {
fileDir = new File(targetPathAbsolutePath + entry.getName());
if (!fileDir.exists()) {
boolean fileMkdir = fileDir.mkdirs();
if (!fileMkdir) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to create folder {} while decompressing {}", fileDir, sourceAbsolutePath);
}
}
} else {
//若文件夹未创建则创建文件夹
if (entryName.contains(FILE_SEPARATOR)) {
String zipDirName = entryName.substring(MagicNumConstant.ZERO, entryName.lastIndexOf(FILE_SEPARATOR));
fileDir = new File(targetPathAbsolutePath + zipDirName);
if (!fileDir.exists()) {
boolean fileMkdir = fileDir.mkdirs();
if (!fileMkdir) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to create folder {} while decompressing {}", fileDir, sourceAbsolutePath);
}
}
}
in = zipFile.getInputStream((ZipArchiveEntry) entry);
out = new FileOutputStream(new File(targetPathAbsolutePath, entryName));
IOUtils.copyLarge(in, out);
in.close();
out.close();
}
}
boolean deleteZipFile = sourceFile.delete();
if (!deleteZipFile) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}compressed file deletion failed after decompression", sourceAbsolutePath);
}
return true;
} catch (IOException e) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}decompression failed: {}", sourceAbsolutePath, e);
return false;
} finally {
//关闭未关闭的io流
closeIoFlow(sourceAbsolutePath, zipFile, in, out);
}

}

/**
* 关闭未关闭的io流
*
* @param sourceAbsolutePath 源路径
* @param zipFile 压缩文件对象
* @param in 输入流
* @param out 输出流
*/
private void closeIoFlow(String sourceAbsolutePath, ZipFile zipFile, InputStream in, OutputStream out) {
if (in != null) {
try {
in.close();
} catch (IOException e) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}input stream shutdown failed: {}", sourceAbsolutePath, e);
}
}
if (out != null) {
try {
out.close();
} catch (IOException e) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}output stream shutdown failed: {}", sourceAbsolutePath, e);
}
}
if (zipFile != null) {
try {
zipFile.close();
} catch (IOException e) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}input stream shutdown failed: {}", sourceAbsolutePath, e);
}
}
}

/**
* NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况
*
* 通过本地文件复制方式
*
* @param sourcePath 需要复制的文件目录 例如:/abc/def
* @param targetPath 需要放置的目标目录 例如:/abc/dd
* @return boolean
*/
public boolean copyPath(String sourcePath, String targetPath) {
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) {
return false;
}
sourcePath = formatPath(sourcePath);
targetPath = formatPath(targetPath);
try {
return copyLocalPath(nfsConfig.getRootDir() + sourcePath, nfsConfig.getRootDir() + targetPath);
} catch (Exception e) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, " failed to Copy file original path: {} ,target path: {} ,copyPath: {}", sourcePath, targetPath, e);
return false;
}
}

/**
* 复制文件到指定目录下 单个文件
*
* @param sourcePath 需要复制的文件 例如:/abc/def/cc.txt
* @param targetPath 需要放置的目标目录 例如:/abc/dd
* @return boolean
*/
private boolean copyLocalFile(String sourcePath, String targetPath) {
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) {
return false;
}
sourcePath = formatPath(sourcePath);
targetPath = formatPath(targetPath);
try (InputStream input = new FileInputStream(sourcePath);
FileOutputStream output = new FileOutputStream(targetPath)) {
FileCopyUtils.copy(input, output);
return true;
} catch (IOException e) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, " failed to copy file original path: {} ,target path: {} ,copyLocalFile:{} ", sourcePath, targetPath, e);
return false;
}
}


/**
* 复制文件 到指定目录下 多个文件 包含目录与文件并存情况
*
* @param sourcePath 需要复制的文件目录 例如:/abc/def
* @param targetPath 需要放置的目标目录 例如:/abc/dd
* @return boolean
*/
private boolean copyLocalPath(String sourcePath, String targetPath) {
if (!StringUtils.isEmpty(sourcePath) && !StringUtils.isEmpty(targetPath)) {
sourcePath = formatPath(sourcePath);
if (sourcePath.endsWith(FILE_SEPARATOR)) {
sourcePath = sourcePath.substring(MagicNumConstant.ZERO, sourcePath.lastIndexOf(FILE_SEPARATOR));
}
targetPath = formatPath(targetPath);
File sourceFile = new File(sourcePath);
if (sourceFile.exists()) {
File[] files = sourceFile.listFiles();
if (files != null && files.length != 0) {
for (File file : files) {
try {
if (file.isDirectory()) {
File fileDir = new File(targetPath + FILE_SEPARATOR + file.getName());
if (!fileDir.exists()) {
fileDir.mkdirs();
}
copyLocalPath(sourcePath + FILE_SEPARATOR + file.getName(), targetPath + FILE_SEPARATOR + file.getName());
}
if (file.isFile()) {
File fileTargetPath = new File(targetPath);
if (!fileTargetPath.exists()) {
fileTargetPath.mkdirs();
}
copyLocalFile(file.getAbsolutePath(), targetPath + FILE_SEPARATOR + file.getName());
}
} catch (Exception e) {
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to copy folder original path: {} , target path : {} ,copyLocalPath: {}", sourcePath, targetPath, e);
return false;
}
}
}
return true;
}
}
return false;
}

/**
* 替换路径中多余的 "/"
*
* @param path
* @return String
*/
public String formatPath(String path) {
if (!StringUtils.isEmpty(path)) {
return path.replaceAll("///*", FILE_SEPARATOR);
}
return path;
}

}

+ 279
- 214
dubhe-server/common/src/main/java/org/dubhe/utils/LogUtil.java View File

@@ -23,10 +23,10 @@ import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.exception.ExceptionUtils;
import org.dubhe.aspect.LogAspect;
import org.dubhe.base.MagicNumConstant;
import org.dubhe.constant.SymbolConstant;
import org.dubhe.domain.entity.LogInfo;
import org.dubhe.enums.LogEnum;
import org.slf4j.MDC;
import org.slf4j.MarkerFactory;
import org.slf4j.helpers.MessageFormatter;

import java.util.Arrays;
@@ -38,218 +38,283 @@ import java.util.UUID;
*/
@Slf4j
public class LogUtil {
/**
* info级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
*/
public static void info(LogEnum logType, Object... object) {

logHandle(logType, Level.INFO, object);
}

/**
* debug级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
*/
public static void debug(LogEnum logType, Object... object) {
logHandle(logType, Level.DEBUG, object);
}

/**
* error级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
*/
public static void error(LogEnum logType, Object... object) {
errorObjectHandle(object);
logHandle(logType, Level.ERROR, object);
}

/**
* warn级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
*/
public static void warn(LogEnum logType, Object... object) {
logHandle(logType, Level.WARN, object);
}

/**
* trace级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
*/
public static void trace(LogEnum logType, Object... object) {
logHandle(logType, Level.TRACE, object);
}

/**
* 日志处理
*
* @param logType 日志类型
* @param level 日志级别
* @param object 打印的日志参数
*/
private static void logHandle(LogEnum logType, Level level, Object[] object) {

LogInfo logInfo = generateLogInfo(logType, level, object);
String logInfoJsonStr = logJsonStringLengthLimit(logInfo);
switch (Level.toLevel(logInfo.getLevel()).levelInt) {
case Level.TRACE_INT:
log.trace(logInfoJsonStr);
break;
case Level.DEBUG_INT:
log.debug(logInfoJsonStr);
break;
case Level.INFO_INT:
log.info(logInfoJsonStr);
break;
case Level.WARN_INT:
log.warn(logInfoJsonStr);
break;
case Level.ERROR_INT:
log.error(logInfoJsonStr);
break;
default:
}

}

/**
* 日志信息组装的内部方法
*
* @param logType 日志类型
* @param level 日志级别
* @param object 打印的日志参数
* @return LogInfo日志对象信息
*/
private static LogInfo generateLogInfo(LogEnum logType, Level level, Object[] object) {
LogInfo logInfo = new LogInfo();
// 日志类型检测
if (!LogEnum.isLogType(logType)) {
level = Level.ERROR;
object = new Object[MagicNumConstant.ONE];
object[MagicNumConstant.ZERO] = String.valueOf("logType【").concat(String.valueOf(logType))
.concat("】is error !");
logType = LogEnum.SYS_ERR;
}

// 获取trace_id
if (StringUtils.isEmpty(MDC.get(LogAspect.TRACE_ID))) {
MDC.put(LogAspect.TRACE_ID, UUID.randomUUID().toString());
}
// 设置logInfo的level,type,traceId属性
logInfo.setLevel(level.levelStr).setType(logType.toString()).setTraceId(MDC.get(LogAspect.TRACE_ID));
// 设置logInfo的堆栈信息
setLogStackInfo(logInfo);
// 设置logInfo的info信息
setLogInfo(logInfo, object);
// 截取loginfo的长度并转换成json字符串
return logInfo;

}

/**
* 设置loginfo的堆栈信息
*
* @param logInfo 日志对象
*/
private static void setLogStackInfo(LogInfo logInfo) {
StackTraceElement[] elements = Thread.currentThread().getStackTrace();
if (elements.length >= MagicNumConstant.SIX) {
logInfo.setCName(elements[MagicNumConstant.FIVE].getClassName())
.setMName(elements[MagicNumConstant.FIVE].getMethodName())
.setLine(String.valueOf(elements[MagicNumConstant.FIVE].getLineNumber()));
}
}

/**
* 限制log日志的长度并转换成json
*
* @param logInfo 日志对象
* @return String 日志对象Json字符串
*/
private static String logJsonStringLengthLimit(LogInfo logInfo) {
try {
String jsonString = JSON.toJSONString(logInfo.getInfo());
if (jsonString.length() > MagicNumConstant.TEN_THOUSAND) {
jsonString = jsonString.substring(MagicNumConstant.ZERO, MagicNumConstant.TEN_THOUSAND);
}

logInfo.setInfo(jsonString);
jsonString = JSON.toJSONString(logInfo);
jsonString = jsonString.replace(SymbolConstant.BACKSLASH_MARK, SymbolConstant.MARK)
.replace(SymbolConstant.DOUBLE_MARK, SymbolConstant.MARK)
.replace(SymbolConstant.BRACKETS, SymbolConstant.BLANK);

return jsonString;

} catch (Exception e) {
logInfo.setLevel(Level.ERROR.levelStr).setType(LogEnum.SYS_ERR.toString())
.setInfo("cannot serialize exception: " + ExceptionUtils.getStackTrace(e));
return JSON.toJSONString(logInfo);
}
}

/**
* 设置日志对象的info信息
*
* @param logInfo 日志对象
* @param object 打印的日志参数
*/
private static void setLogInfo(LogInfo logInfo, Object[] object) {
for (Object obj : object) {
if (obj instanceof Exception) {
log.error((ExceptionUtils.getStackTrace((Throwable) obj)));
}
}

if (object.length > MagicNumConstant.ONE) {
logInfo.setInfo(MessageFormatter.arrayFormat(object[MagicNumConstant.ZERO].toString(),
Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)).getMessage());

} else if (object.length == MagicNumConstant.ONE && object[MagicNumConstant.ZERO] instanceof Exception) {
logInfo.setInfo((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO])));
} else if (object.length == MagicNumConstant.ONE) {
logInfo.setInfo(
object[MagicNumConstant.ZERO] == null ? SymbolConstant.BLANK : object[MagicNumConstant.ZERO]);
} else {
logInfo.setInfo(SymbolConstant.BLANK);
}

}

/**
* 处理Exception的情况
*
* @param object 打印的日志参数
*/
private static void errorObjectHandle(Object[] object) {
if (object.length >= MagicNumConstant.TWO) {
object[MagicNumConstant.ZERO] = String.valueOf(object[MagicNumConstant.ZERO])
.concat(SymbolConstant.BRACKETS);
}

if (object.length == MagicNumConstant.TWO && object[MagicNumConstant.ONE] instanceof Exception) {
log.error((ExceptionUtils.getStackTrace((Throwable) object[MagicNumConstant.ONE])));
object[MagicNumConstant.ONE] = ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ONE]);

} else if (object.length >= MagicNumConstant.THREE) {
for (int i = 0; i < object.length; i++) {
if (object[i] instanceof Exception) {
log.error((ExceptionUtils.getStackTrace((Throwable) object[i])));
object[i] = ExceptionUtils.getStackTrace((Exception) object[i]);
}

}
}
}

private static final String TRACE_TYPE = "TRACE_TYPE";

public static final String SCHEDULE_LEVEL = "SCHEDULE";

public static final String K8S_CALLBACK_LEVEL = "K8S_CALLBACK";

private static final String GLOBAL_REQUEST_LEVEL = "GLOBAL_REQUEST";

private static final String TRACE_LEVEL = "TRACE";

private static final String DEBUG_LEVEL = "DEBUG";

private static final String INFO_LEVEL = "INFO";

private static final String WARN_LEVEL = "WARN";

private static final String ERROR_LEVEL = "ERROR";


public static void startScheduleTrace() {
MDC.put(TRACE_TYPE, SCHEDULE_LEVEL);
}

public static void startK8sCallbackTrace() {
MDC.put(TRACE_TYPE, K8S_CALLBACK_LEVEL);
}

public static void cleanTrace() {
MDC.clear();
}

/**
* info级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
* @return void
*/

public static void info(LogEnum logType, Object... object) {

logHandle(logType, Level.INFO, object);
}

/**
* debug级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
* @return void
*/
public static void debug(LogEnum logType, Object... object) {
logHandle(logType, Level.DEBUG, object);
}

/**
* error级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
* @return void
*/
public static void error(LogEnum logType, Object... object) {
errorObjectHandle(object);
logHandle(logType, Level.ERROR, object);
}

/**
* warn级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
* @return void
*/
public static void warn(LogEnum logType, Object... object) {
logHandle(logType, Level.WARN, object);
}

/**
* trace级别的日志
*
* @param logType 日志类型
* @param object 打印的日志参数
* @return void
*/
public static void trace(LogEnum logType, Object... object) {
logHandle(logType, Level.TRACE, object);
}

/**
* 日志处理
*
* @param logType 日志类型
* @param level 日志级别
* @param object 打印的日志参数
* @return void
*/
private static void logHandle(LogEnum logType, Level level, Object[] object) {

LogInfo logInfo = generateLogInfo(logType, level, object);

switch (logInfo.getLevel()) {
case TRACE_LEVEL:
log.trace(MarkerFactory.getMarker(TRACE_LEVEL), logJsonStringLengthLimit(logInfo));
break;
case DEBUG_LEVEL:
log.debug(MarkerFactory.getMarker(DEBUG_LEVEL), logJsonStringLengthLimit(logInfo));
break;
case GLOBAL_REQUEST_LEVEL:
logInfo.setLevel(null);
logInfo.setType(null);
logInfo.setLocation(null);
log.info(MarkerFactory.getMarker(GLOBAL_REQUEST_LEVEL), logJsonStringLengthLimit(logInfo));
break;
case SCHEDULE_LEVEL:
log.info(MarkerFactory.getMarker(SCHEDULE_LEVEL), logJsonStringLengthLimit(logInfo));
break;
case K8S_CALLBACK_LEVEL:
log.info(MarkerFactory.getMarker(K8S_CALLBACK_LEVEL), logJsonStringLengthLimit(logInfo));
break;
case INFO_LEVEL:
log.info(MarkerFactory.getMarker(INFO_LEVEL), logJsonStringLengthLimit(logInfo));
break;
case WARN_LEVEL:
log.warn(MarkerFactory.getMarker(WARN_LEVEL), logJsonStringLengthLimit(logInfo));
break;
case ERROR_LEVEL:
log.error(MarkerFactory.getMarker(ERROR_LEVEL), logJsonStringLengthLimit(logInfo));
break;
default:
}

}


/**
* 日志信息组装的内部方法
*
* @param logType 日志类型
* @param level 日志级别
* @param object 打印的日志参数
* @return LogInfo
*/
private static LogInfo generateLogInfo(LogEnum logType, Level level, Object[] object) {


LogInfo logInfo = new LogInfo();
// 日志类型检测
if (!LogEnum.isLogType(logType)) {
level = Level.ERROR;
object = new Object[MagicNumConstant.ONE];
object[MagicNumConstant.ZERO] = "日志类型【".concat(String.valueOf(logType)).concat("】不正确!");
logType = LogEnum.SYS_ERR;
}

// 获取trace_id
if (StringUtils.isEmpty(MDC.get(LogAspect.TRACE_ID))) {
MDC.put(LogAspect.TRACE_ID, UUID.randomUUID().toString());
}
// 设置logInfo的level,type,traceId属性
logInfo.setLevel(level.levelStr)
.setType(logType.toString())
.setTraceId(MDC.get(LogAspect.TRACE_ID));


//自定义日志级别
//LogEnum、 MDC中的 TRACE_TYPE 做日志分流标识
if (Level.INFO.toInt() == level.toInt()) {
if (LogEnum.GLOBAL_REQ.equals(logType)) {
//info全局请求
logInfo.setLevel(GLOBAL_REQUEST_LEVEL);
} else if (LogEnum.BIZ_K8S.equals(logType)) {
logInfo.setLevel(K8S_CALLBACK_LEVEL);
} else {
//schedule定时等 链路记录
String traceType = MDC.get(TRACE_TYPE);
if (StringUtils.isNotBlank(traceType)) {
logInfo.setLevel(traceType);
}
}
}

// 设置logInfo的堆栈信息
setLogStackInfo(logInfo);
// 设置logInfo的info信息
setLogInfo(logInfo, object);
// 截取logInfo的长度并转换成json字符串
return logInfo;
}

/**
* 设置loginfo的堆栈信息
*
* @param logInfo 日志对象
* @return void
*/
private static void setLogStackInfo(LogInfo logInfo) {
StackTraceElement[] elements = Thread.currentThread().getStackTrace();
if (elements.length >= MagicNumConstant.SIX) {
StackTraceElement element = elements[MagicNumConstant.FIVE];
logInfo.setLocation(String.format("%s#%s:%s", element.getClassName(), element.getMethodName(), element.getLineNumber()));
}
}

/**
* 限制log日志的长度并转换成json
*
* @param logInfo 日志对象
* @return String
*/
private static String logJsonStringLengthLimit(LogInfo logInfo) {
try {

String jsonString = JSON.toJSONString(logInfo);
if (StringUtils.isBlank(jsonString)) {
return "";
}
if (jsonString.length() > MagicNumConstant.TEN_THOUSAND) {
String trunk = logInfo.getInfo().toString().substring(MagicNumConstant.ZERO, MagicNumConstant.NINE_THOUSAND);
logInfo.setInfo(trunk);
jsonString = JSON.toJSONString(logInfo);
}
return jsonString;

} catch (Exception e) {
logInfo.setLevel(Level.ERROR.levelStr).setType(LogEnum.SYS_ERR.toString())
.setInfo("cannot serialize exception: " + ExceptionUtils.getStackTrace(e));
return JSON.toJSONString(logInfo);
}
}

/**
* 设置日志对象的info信息
*
* @param logInfo 日志对象
* @param object 打印的日志参数
* @return void
*/
private static void setLogInfo(LogInfo logInfo, Object[] object) {

if (object.length > MagicNumConstant.ONE) {
logInfo.setInfo(MessageFormatter.arrayFormat(object[MagicNumConstant.ZERO].toString(),
Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)).getMessage());

} else if (object.length == MagicNumConstant.ONE && object[MagicNumConstant.ZERO] instanceof Exception) {
logInfo.setInfo((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO])));
log.error((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO])));
} else if (object.length == MagicNumConstant.ONE) {
logInfo.setInfo(object[MagicNumConstant.ZERO] == null ? "" : object[MagicNumConstant.ZERO]);
} else {
logInfo.setInfo("");
}

}

/**
* 处理Exception的情况
*
* @param object 打印的日志参数
* @return void
*/
private static void errorObjectHandle(Object[] object) {

if (object.length == MagicNumConstant.TWO && object[MagicNumConstant.ONE] instanceof Exception) {
log.error(String.valueOf(object[MagicNumConstant.ZERO]), (Exception) object[MagicNumConstant.ONE]);
object[MagicNumConstant.ONE] = ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ONE]);

} else if (object.length >= MagicNumConstant.THREE) {
log.error(String.valueOf(object[MagicNumConstant.ZERO]),
Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length));
for (int i = 0; i < object.length; i++) {
if (object[i] instanceof Exception) {
object[i] = ExceptionUtils.getStackTrace((Exception) object[i]);
}

}
}
}
}

+ 2
- 2
dubhe-server/common/src/main/java/org/dubhe/utils/MathUtils.java View File

@@ -20,8 +20,8 @@ package org.dubhe.utils;
import org.dubhe.base.MagicNumConstant;

/**
* @description: 计算工具类
* @create: 2020/6/4 14:53
* @description 计算工具类
* @date 2020-06-04
*/
public class MathUtils {



+ 20
- 12
dubhe-server/common/src/main/java/org/dubhe/utils/MinioUtil.java View File

@@ -68,9 +68,9 @@ public class MinioUtil {
try {
client = new MinioClient(url, accessKey, secretKey);
} catch (InvalidEndpointException e) {
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint invalid. e:", e);
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint invalid. e, {}", e);
} catch (InvalidPortException e) {
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint port invalid. e:", e);
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint port invalid. e, {}", e);
}
}

@@ -94,7 +94,7 @@ public class MinioUtil {
/**
* 读取文件
*
* @param bucket
* @param bucket
* @param fullFilePath 文件存储的全路径,包括文件名,非'/'开头. e.g. dataset/12/annotation/test.txt
* @return String
*/
@@ -107,7 +107,7 @@ public class MinioUtil {
/**
* 文件删除
*
* @param bucket
* @param bucket
* @param fullFilePath 文件存储的全路径,包括文件名,非'/'开头. e.g. dataset/12/annotation/test.txt
*/
public void del(String bucket, String fullFilePath) throws Exception {
@@ -125,8 +125,8 @@ public class MinioUtil {
/**
* 批量删除文件
*
* @param bucket
* @param objectNames
* @param bucket
* @param objectNames 对象名称
*/
public void delFiles(String bucket,List<String> objectNames) throws Exception{
Iterable<Result<DeleteError>> results = client.removeObjects(bucket, objectNames);
@@ -138,6 +138,10 @@ public class MinioUtil {
/**
* 获取对象名称
*
* @param bucketName 桶名称
* @param prefix 前缀
* @return
* @throws Exception
*/
public List<String> getObjects(String bucketName, String prefix)throws Exception{
List<String> fileNames = new ArrayList<>();
@@ -152,6 +156,10 @@ public class MinioUtil {
/**
* 获取文件流
*
* @param bucket 桶
* @param objectName 对象名称
* @return
* @throws Exception
*/
public InputStream getObjectInputStream(String bucket,String objectName)throws Exception{
return client.getObject(bucket, objectName);
@@ -160,9 +168,9 @@ public class MinioUtil {
/**
* 文件夹复制
*
* @param bucket
* @param bucket
* @param sourceFiles 源文件
* @param targetDir 目标文件夹
* @param targetDir 目标文件夹
*/
public void copyDir(String bucket, List<String> sourceFiles, String targetDir) {
sourceFiles.forEach(sourceFile -> {
@@ -211,10 +219,10 @@ public class MinioUtil {
/**
* 生成文件下载请求参数方法
*
* @param bucketName
* @param prefix
* @param objects
* @return MinioDownloadDto
* @param bucketName 桶名称
* @param prefix 前缀
* @param objects 对象名称
* @return MinioDownloadDto 下载请求参数
*/
public MinioDownloadDto getDownloadParam(String bucketName, String prefix, List<String> objects, String zipName) {
String paramTemplate = "{\"id\":%d,\"jsonrpc\":\"%s\",\"params\":{\"username\":\"%s\",\"password\":\"%s\"},\"method\":\"%s\"}";


+ 11
- 142
dubhe-server/common/src/main/java/org/dubhe/utils/NfsUtil.java View File

@@ -17,6 +17,7 @@

package org.dubhe.utils;

import cn.hutool.core.util.StrUtil;
import com.emc.ecs.nfsclient.nfs.io.Nfs3File;
import com.emc.ecs.nfsclient.nfs.io.NfsFileInputStream;
import com.emc.ecs.nfsclient.nfs.io.NfsFileOutputStream;
@@ -31,7 +32,6 @@ import org.dubhe.exception.NfsBizException;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;
import org.springframework.util.FileCopyUtils;

import java.io.*;
import java.util.ArrayList;
@@ -130,14 +130,14 @@ public class NfsUtil {
}

/**
* 校验文件或文件夹是否存在
* 校验文件或文件夹是否存在
*
* @param path 文件路径
* @return boolean
*/
public boolean fileOrDirIsEmpty(String path) {
if (!StringUtils.isEmpty(path)) {
path = formatPath(path);
path = formatPath(path.startsWith(nfsConfig.getRootDir()) ? path.replaceFirst(nfsConfig.getRootDir(), StrUtil.SLASH) : path);
Nfs3File nfs3File = getNfs3File(path);
try {
if (nfs3File.exists()) {
@@ -398,30 +398,6 @@ public class NfsUtil {
}


/**
* NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况
*
* 通过本地文件复制方式
*
* @param sourcePath 需要复制的文件目录 例如:/abc/def
* @param targetPath 需要放置的目标目录 例如:/abc/dd
* @return boolean
*/
public boolean copyPath(String sourcePath, String targetPath) {
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) {
return false;
}
sourcePath = formatPath(sourcePath);
targetPath = formatPath(targetPath);
try {
return copyLocalPath(nfsConfig.getRootDir() + sourcePath, nfsConfig.getRootDir() + targetPath);
} catch (Exception e) {
LogUtil.error(LogEnum.NFS_UTIL, " copyPath 复制失败: ", e);
return false;
}
}


/**
* NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况
*
@@ -468,126 +444,15 @@ public class NfsUtil {
}
}


/**
* 复制文件到指定目录下 单个文件
*
* @param sourcePath 需要复制的文件 例如:/abc/def/cc.txt
* @param targetPath 需要放置的目标目录 例如:/abc/dd
* @return boolean
*/
public boolean copyLocalFile(String sourcePath, String targetPath) {
LogUtil.info(LogEnum.NFS_UTIL, "复制文件原路径: {} ,目标路径: {}", sourcePath, targetPath);
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) {
return false;
}
sourcePath = formatPath(sourcePath);
targetPath = formatPath(targetPath);
LogUtil.info(LogEnum.NFS_UTIL, "过滤后文件原路径: {} ,目标路径:{}", sourcePath, targetPath);
try (InputStream input = new FileInputStream(sourcePath);
FileOutputStream output = new FileOutputStream(targetPath)) {
FileCopyUtils.copy(input, output);
LogUtil.info(LogEnum.NFS_UTIL, "复制文件成功");
return true;
} catch (IOException e) {
LogUtil.error(LogEnum.NFS_UTIL, " copyLocalFile 复制失败: ", e);
return false;
}
}


/**
* 复制文件 到指定目录下 多个文件 包含目录与文件并存情况
*
* @param sourcePath 需要复制的文件目录 例如:/abc/def
* @param targetPath 需要放置的目标目录 例如:/abc/dd
* @return boolean
*/
public boolean copyLocalPath(String sourcePath, String targetPath) {
if (!StringUtils.isEmpty(sourcePath) && !StringUtils.isEmpty(targetPath)) {
sourcePath = formatPath(sourcePath);
if (sourcePath.endsWith(FILE_SEPARATOR)) {
sourcePath = sourcePath.substring(MagicNumConstant.ZERO, sourcePath.lastIndexOf(FILE_SEPARATOR));
}
targetPath = formatPath(targetPath);
LogUtil.info(LogEnum.NFS_UTIL, "复制文件夹 原路径: {} , 目标路径 : {} ", sourcePath, targetPath);
File[] files = new File(sourcePath).listFiles();
LogUtil.info(LogEnum.NFS_UTIL, "需要复制的文件数量为: {}", files.length);
if (files.length != 0) {
for (File file : files) {
try {
if (file.isDirectory()) {
LogUtil.info(LogEnum.NFS_UTIL, "需要复制夹: {}", file.getAbsolutePath());
LogUtil.info(LogEnum.NFS_UTIL, "目标文件夹: {}", targetPath + FILE_SEPARATOR + file.getName());
File fileDir = new File(targetPath + FILE_SEPARATOR + file.getName());
if(!fileDir.exists()){
fileDir.mkdirs();
}
copyLocalPath(sourcePath + FILE_SEPARATOR + file.getName(), targetPath + FILE_SEPARATOR + file.getName());
}
if (file.isFile()) {
File fileTargetPath = new File(targetPath);
if(!fileTargetPath.exists()){
fileTargetPath.mkdirs();
}
LogUtil.info(LogEnum.NFS_UTIL, "需要复制文件: {}", file.getAbsolutePath());
LogUtil.info(LogEnum.NFS_UTIL, "需要复制文件名称: {}", file.getName());
copyLocalFile(file.getAbsolutePath() , targetPath + FILE_SEPARATOR + file.getName());
}
}catch (Exception e){
LogUtil.error(LogEnum.NFS_UTIL, "复制文件夹失败: {}", e);
return false;
}
}
}
return true;
}
return false;
}

/**
* 解压前清理同路径下其他文件(目前只支持路径下无文件夹,文件均为zip文件)
* 上传路径垃圾文件清理
*
* @param zipFilePath zip源文件 例如:/abc/z.zip
* @param path 文件夹 例如:/abc/
* @return boolean
*/
public boolean cleanPath(String zipFilePath, String path) {
if (!StringUtils.isEmpty(zipFilePath) && !StringUtils.isEmpty(path) && zipFilePath.toLowerCase().endsWith(ZIP)) {
zipFilePath = formatPath(zipFilePath);
path = formatPath(path);
Nfs3File nfs3Files = getNfs3File(path);
try {
String zipName = zipFilePath.substring(zipFilePath.lastIndexOf(FILE_SEPARATOR) + MagicNumConstant.ONE);
if (!StringUtils.isEmpty(zipName)) {
List<Nfs3File> nfs3FilesList = nfs3Files.listFiles();
if (!CollectionUtils.isEmpty(nfs3FilesList)) {
for (Nfs3File nfs3File : nfs3FilesList) {
if (!zipName.equals(nfs3File.getName())) {
nfs3File.delete();
}
}
return true;
}
}
} catch (Exception e) {
LogUtil.error(LogEnum.NFS_UTIL, "路径{}清理失败,错误原因为:{} ", path, e);
return false;
} finally {
nfsPool.revertNfs(nfs3Files.getNfs());
}
}
return false;
}

/**
* zip解压并删除压缩文件
*
* 当压缩包文件较多时,可能会因为RPC超时而解压失败
* 该方法已废弃,请使用org.dubhe.utils.LocalFileUtil类的unzipLocalPath方法来替代
* @param sourcePath zip源文件 例如:/abc/z.zip
* @param targetPath 解压后的目标文件夹 例如:/abc/
* @return boolean
*/
@Deprecated
public boolean unzip(String sourcePath, String targetPath) {
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) {
return false;
@@ -894,11 +759,15 @@ public class NfsUtil {
* @param path
* @return String
*/
private String formatPath(String path) {
public String formatPath(String path) {
if (!StringUtils.isEmpty(path)) {
return path.replaceAll("///*", FILE_SEPARATOR);
}
return path;
}

public String getAbsolutePath(String relativePath) {
return nfsConfig.getRootDir() + nfsConfig.getBucket() + relativePath;
}

}

+ 108
- 46
dubhe-server/common/src/main/java/org/dubhe/utils/RedisUtils.java View File

@@ -26,6 +26,9 @@ import org.springframework.data.redis.core.Cursor;
import org.springframework.data.redis.core.RedisConnectionUtils;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.ScanOptions;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer;
import org.springframework.stereotype.Component;
import org.springframework.util.CollectionUtils;

@@ -33,7 +36,7 @@ import java.util.*;
import java.util.concurrent.TimeUnit;

/**
* @description redis工具类
* @description redis工具类
* @date 2020-03-13
*/
@Component
@@ -62,7 +65,7 @@ public class RedisUtils {
redisTemplate.expire(key, time, TimeUnit.SECONDS);
}
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils expire key {} time {} error:{}",key,time,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils expire key {} time {} error:{}", key, time, e.getMessage(), e);
return false;
}
return true;
@@ -82,7 +85,7 @@ public class RedisUtils {
* 查找匹配key
*
* @param pattern key
* @return /
* @return List<String> 匹配的key集合
*/
public List<String> scan(String pattern) {
ScanOptions options = ScanOptions.scanOptions().match(pattern).build();
@@ -94,9 +97,9 @@ public class RedisUtils {
result.add(new String(cursor.next()));
}
try {
RedisConnectionUtils.releaseConnection(rc, factory,true);
RedisConnectionUtils.releaseConnection(rc, factory, true);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils scan pattern {} error:{}",pattern,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils scan pattern {} error:{}", pattern, e.getMessage(), e);
}
return result;
}
@@ -107,7 +110,7 @@ public class RedisUtils {
* @param patternKey key
* @param page 页码
* @param size 每页数目
* @return /
* @return 匹配到的key集合
*/
public List<String> findKeysForPage(String patternKey, int page, int size) {
ScanOptions options = ScanOptions.scanOptions().match(patternKey).build();
@@ -132,9 +135,9 @@ public class RedisUtils {
cursor.next();
}
try {
RedisConnectionUtils.releaseConnection(rc, factory,true);
RedisConnectionUtils.releaseConnection(rc, factory, true);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils findKeysForPage patternKey {} page {} size {} error:{}",patternKey,page,size,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils findKeysForPage patternKey {} page {} size {} error:{}", patternKey, page, size, e.getMessage(), e);
}
return result;
}
@@ -149,7 +152,7 @@ public class RedisUtils {
try {
return redisTemplate.hasKey(key);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hasKey key {} error:{}",key,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hasKey key {} error:{}", key, e.getMessage(), e);
return false;
}
}
@@ -175,7 +178,7 @@ public class RedisUtils {
* 普通缓存获取
*
* @param key 键
* @return 值
* @return key对应的value
*/
public Object get(String key) {

@@ -185,8 +188,8 @@ public class RedisUtils {
/**
* 批量获取
*
* @param keys
* @return
* @param keys key集合
* @return key集合对应的value集合
*/
public List<Object> multiGet(List<String> keys) {
Object obj = redisTemplate.opsForValue().multiGet(Collections.singleton(keys));
@@ -205,7 +208,7 @@ public class RedisUtils {
redisTemplate.opsForValue().set(key, value);
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} error:{}",key,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} error:{}", key, value, e.getMessage(), e);
return false;
}
}
@@ -227,7 +230,7 @@ public class RedisUtils {
}
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} time {} error:{}",key,value,time,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} time {} error:{}", key, value, time, e.getMessage(), e);
return false;
}
}
@@ -250,11 +253,56 @@ public class RedisUtils {
}
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} time {} timeUnit {} error:{}",key,value,time,timeUnit,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} time {} timeUnit {} error:{}", key, value, time, timeUnit, e.getMessage(), e);
return false;
}
}

//===============================Lock=================================

/**
* 加锁
* @param key 键
* @param requestId 请求id用以释放锁
* @param expireTime 超时时间(秒)
* @return
*/
public boolean getDistributedLock(String key, String requestId, long expireTime) {
String script = "if redis.call('setNx',KEYS[1],ARGV[1]) == 1 then if redis.call('get',KEYS[1]) == ARGV[1] then return redis.call('expire',KEYS[1],ARGV[2]) else return 0 end else return 0 end";
Object result = executeRedisScript(script, key, requestId, expireTime);
return result != null && result.equals(MagicNumConstant.ONE_LONG);
}

/**
* 释放锁
* @param key 键
* @param requestId 请求id用以释放锁
* @return
*/
public boolean releaseDistributedLock(String key, String requestId) {
String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end";
Object result = executeRedisScript(script, key, requestId);
return result != null && result.equals(MagicNumConstant.ONE_LONG);
}

/**
*
* @param script 脚本字符串
* @param key 键
* @param args 脚本其他参数
* @return
*/
public Object executeRedisScript(String script, String key, Object... args) {
try {
RedisScript<Long> redisScript = new DefaultRedisScript<>(script, Long.class);
redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(Object.class));
return redisTemplate.execute(redisScript, Collections.singletonList(key), args);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR, "executeRedisScript script {} key {} expireTime {} args {} error:{}", script, key, args, e);
return MagicNumConstant.ZERO_LONG;
}
}

// ================================Map=================================

/**
@@ -291,7 +339,7 @@ public class RedisUtils {
redisTemplate.opsForHash().putAll(key, map);
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hmset key {} map {} error:{}",key,map,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hmset key {} map {} error:{}", key, map, e.getMessage(), e);
return false;
}
}
@@ -312,7 +360,7 @@ public class RedisUtils {
}
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hmset key {} map {} time {} error:{}",key,map,time,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hmset key {} map {} time {} error:{}", key, map, time, e.getMessage(), e);
return false;
}
}
@@ -330,7 +378,7 @@ public class RedisUtils {
redisTemplate.opsForHash().put(key, item, value);
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hset key {} item {} value {} error:{}",key,item,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hset key {} item {} value {} error:{}", key, item, value, e.getMessage(), e);
return false;
}
}
@@ -352,7 +400,7 @@ public class RedisUtils {
}
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hset key {} item {} value {} time {} error:{}",key,item,value,time,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hset key {} item {} value {} time {} error:{}", key, item, value, time, e.getMessage(), e);
return false;
}
}
@@ -414,7 +462,7 @@ public class RedisUtils {
try {
return redisTemplate.opsForSet().members(key);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sGet key {} error:{}",key,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sGet key {} error:{}", key, e.getMessage(), e);
return null;
}
}
@@ -430,7 +478,7 @@ public class RedisUtils {
try {
return redisTemplate.opsForSet().isMember(key, value);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sHasKey key {} value {} error:{}",key,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sHasKey key {} value {} error:{}", key, value, e.getMessage(), e);
return false;
}
}
@@ -446,7 +494,7 @@ public class RedisUtils {
try {
return redisTemplate.opsForSet().add(key, values);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sSet key {} values {} error:{}",key,values,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sSet key {} values {} error:{}", key, values, e.getMessage(), e);
return 0;
}
}
@@ -467,7 +515,7 @@ public class RedisUtils {
}
return count;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sSetAndTime key {} time {} values {} error:{}",key,time,values,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sSetAndTime key {} time {} values {} error:{}", key, time, values, e.getMessage(), e);
return 0;
}
}
@@ -482,7 +530,7 @@ public class RedisUtils {
try {
return redisTemplate.opsForSet().size(key);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sGetSetSize key {} error:{}",key,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sGetSetSize key {} error:{}", key, e.getMessage(), e);
return 0;
}
}
@@ -499,7 +547,7 @@ public class RedisUtils {
Long count = redisTemplate.opsForSet().remove(key, values);
return count;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils setRemove key {} values {} error:{}",key,values,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils setRemove key {} values {} error:{}", key, values, e.getMessage(), e);
return 0;
}
}
@@ -507,22 +555,22 @@ public class RedisUtils {
// ===============================sorted set=================================

/**
*将zSet数据放入缓存
* 将zSet数据放入缓存
*
* @param key
* @param time
* @param values
* @return Boolean
*/
public Boolean zSet(String key, long time, Object value){
public Boolean zSet(String key, long time, Object value) {
try {
Boolean success = redisTemplate.opsForZSet().add(key, value,System.currentTimeMillis());
Boolean success = redisTemplate.opsForZSet().add(key, value, System.currentTimeMillis());
if (success) {
expire(key, time);
}
return success;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils zSet key {} time {} value {} error:{}",key,time,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils zSet key {} time {} value {} error:{}", key, time, value, e.getMessage(), e);
return false;
}
}
@@ -533,17 +581,16 @@ public class RedisUtils {
* @param key
* @return Set<Object>
*/
public Set<Object> zGet(String key){
public Set<Object> zGet(String key) {
try {
return redisTemplate.opsForZSet().reverseRange(key,Long.MIN_VALUE, Long.MAX_VALUE);
return redisTemplate.opsForZSet().reverseRange(key, Long.MIN_VALUE, Long.MAX_VALUE);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils zGet key {} error:{}",key,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils zGet key {} error:{}", key, e.getMessage(), e);
return null;
}
}



// ===============================list=================================

/**
@@ -558,7 +605,7 @@ public class RedisUtils {
try {
return redisTemplate.opsForList().range(key, start, end);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetIndex key {} start {} end {} error:{}",key,start,end,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetIndex key {} start {} end {} error:{}", key, start, end, e.getMessage(), e);
return null;
}
}
@@ -573,7 +620,7 @@ public class RedisUtils {
try {
return redisTemplate.opsForList().size(key);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetListSize key {} error:{}",key,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetListSize key {} error:{}", key, e.getMessage(), e);
return 0;
}
}
@@ -589,7 +636,7 @@ public class RedisUtils {
try {
return redisTemplate.opsForList().index(key, index);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetIndex key {} index {} error:{}",key,index,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetIndex key {} index {} error:{}", key, index, e.getMessage(), e);
return null;
}
}
@@ -606,7 +653,7 @@ public class RedisUtils {
redisTemplate.opsForList().rightPush(key, value);
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} error:{}",key,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} error:{}", key, value, e.getMessage(), e);
return false;
}
}
@@ -617,7 +664,7 @@ public class RedisUtils {
* @param key 键
* @param value 值
* @param time 时间(秒)
* @return
* @return 是否存储成功
*/
public boolean lSet(String key, Object value, long time) {
try {
@@ -627,7 +674,7 @@ public class RedisUtils {
}
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} time {} error:{}",key,value,time,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} time {} error:{}", key, value, time, e.getMessage(), e);
return false;
}
}
@@ -637,14 +684,14 @@ public class RedisUtils {
*
* @param key 键
* @param value 值
* @return
* @return 是否存储成功
*/
public boolean lSet(String key, List<Object> value) {
try {
redisTemplate.opsForList().rightPushAll(key, value);
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} error:{}",key,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} error:{}", key, value, e.getMessage(), e);
return false;
}
}
@@ -655,7 +702,7 @@ public class RedisUtils {
* @param key 键
* @param value 值
* @param time 时间(秒)
* @return
* @return 是否存储成功
*/
public boolean lSet(String key, List<Object> value, long time) {
try {
@@ -665,7 +712,7 @@ public class RedisUtils {
}
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} time {} error:{}",key,value,time,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} time {} error:{}", key, value, time, e.getMessage(), e);
return false;
}
}
@@ -676,14 +723,14 @@ public class RedisUtils {
* @param key 键
* @param index 索引
* @param value 值
* @return /
* @return 更新数据标识
*/
public boolean lUpdateIndex(String key, long index, Object value) {
try {
redisTemplate.opsForList().set(key, index, value);
return true;
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lUpdateIndex key {} index {} value {} error:{}",key,index,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lUpdateIndex key {} index {} value {} error:{}", key, index, value, e.getMessage(), e);
return false;
}
}
@@ -700,8 +747,23 @@ public class RedisUtils {
try {
return redisTemplate.opsForList().remove(key, count, value);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lRemove key {} count {} value {} error:{}",key,count,value,e.getMessage(),e);
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lRemove key {} count {} value {} error:{}", key, count, value, e.getMessage(), e);
return 0;
}
}

/**
* 队列从左弹出数据
*
* @param key
* @return key对应的value值
*/
public Object lpop(String key) {
try {
return redisTemplate.opsForList().leftPop(key);
} catch (Exception e) {
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lRemove key {} error:{}", key, e.getMessage(), e);
return null;
}
}
}

+ 2
- 1
dubhe-server/common/src/main/java/org/dubhe/utils/ReflectionUtils.java View File

@@ -18,7 +18,8 @@
package org.dubhe.utils;

import java.lang.reflect.Field;
import java.util.*;
import java.util.ArrayList;
import java.util.List;

/**
* @description 反射工具类


+ 2
- 2
dubhe-server/common/src/main/java/org/dubhe/utils/RegexUtil.java View File

@@ -23,8 +23,8 @@ import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
* @description: 正则匹配工具类
* @create: 2020/4/23 13:51
* @description 正则匹配工具类
* @date 2020-04-23
*/
@Slf4j
public class RegexUtil {


+ 50
- 1
dubhe-server/common/src/main/java/org/dubhe/utils/SqlUtil.java View File

@@ -17,11 +17,18 @@

package org.dubhe.utils;

import org.dubhe.base.BaseService;
import org.dubhe.base.DataContext;

import java.util.HashSet;
import java.util.Objects;
import java.util.Set;

/**
* @description sql语句转换的工具类
* @date 2020-07-06
*/
public class SqlUtil {

/**
@@ -46,4 +53,46 @@ public class SqlUtil {
return "";
}


/**
* 获取资源拥有着ID
*
* @return 资源拥有者id集合
*/
public static Set<Long> getResourceIds() {
if (!Objects.isNull(DataContext.get())) {
return DataContext.get().getResourceUserIds();
}
Set<Long> ids = new HashSet<>();
Long id = JwtUtils.getCurrentUserDto().getId();
ids.add(id);
return ids;

}


/**
* 构建目标sql语句
*
* @param originSql 原生sql
* @param resourceUserIds 所属资源用户ids
* @return 目标sql
*/
public static String buildTargetSql(String originSql, Set<Long> resourceUserIds) {
if (BaseService.isAdmin()) {
return originSql;
}
String sqlWhereBefore = org.dubhe.utils.StringUtils.substringBefore(originSql.toLowerCase(), "where");
String sqlWhereAfter = org.dubhe.utils.StringUtils.substringAfter(originSql.toLowerCase(), "where");
StringBuffer buffer = new StringBuffer();
//操作的sql拼接
String targetSql = buffer.append(sqlWhereBefore).append(" where ").append(" origin_user_id in (")
.append(org.dubhe.utils.StringUtils.join(resourceUserIds, ",")).append(") and ").append(sqlWhereAfter).toString();

return targetSql;
}




}

+ 37
- 0
dubhe-server/common/src/main/java/org/dubhe/utils/StringUtils.java View File

@@ -273,4 +273,41 @@ public class StringUtils extends org.apache.commons.lang3.StringUtils {
matcher.appendTail(sb);
return sb.toString();
}


/**
* 字符串截取前
* @param str
* @return
*/
public static String substringBefore(String str, String separator){

if (!isEmpty(str) && separator != null) {
if (separator.isEmpty()) {
return "";
} else {
int pos = str.indexOf(separator);
return pos == -1 ? str : str.substring(0, pos);
}
} else {
return str;
}
}

/**
* 字符串截取后
* @param str
* @return
*/
public static String substringAfter(String str, String separator){

if (isEmpty(str)) {
return str;
} else if (separator == null) {
return "";
} else {
int pos = str.indexOf(separator);
return pos == -1 ? "" : str.substring(pos + separator.length());
}
}
}

+ 18
- 27
dubhe-server/common/src/main/java/org/dubhe/utils/TimeTransferUtil.java View File

@@ -17,42 +17,33 @@

package org.dubhe.utils;


import lombok.extern.slf4j.Slf4j;

import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.Calendar;
import java.util.Date;

import static org.dubhe.base.MagicNumConstant.EIGHT;

/**
* @description: UTC时间转换CST时间工具类
* @create: 2020/5/20 12:10
* @description 时间格式转换工具类
* @date 2020-05-20
*/
@Slf4j
public class TimeTransferUtil {

private static final String UTC_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.sss'Z'";

/**
* @param utcTime
* @return cstTime
* Date转换为UTC时间
*
* @param date
* @return utcTime
*/
public static String cstTransfer(String utcTime){
Date utcDate = null;
/**2020-05-20T03:13:22Z 对应的时间格式 yyyy-MM-dd'T'HH:mm:ss'Z'**/
SimpleDateFormat utcSimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'");

try {
utcDate = utcSimpleDateFormat.parse(utcTime);
} catch (ParseException e) {
log.info(e.getMessage());
return null;
}
/**System.out.println("UTC时间:"+date);**/
public static String dateTransferToUtc(Date date){
Calendar calendar = Calendar.getInstance();
calendar.setTime(utcDate);
calendar.set(Calendar.HOUR,calendar.get(Calendar.HOUR)+8);
SimpleDateFormat cstSimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
Date cstDate = calendar.getTime();
String cstTime = cstSimpleDateFormat.format(calendar.getTime());
return cstTime;
calendar.setTime(date);
/**UTC时间与CST时间相差8小时**/
calendar.set(Calendar.HOUR,calendar.get(Calendar.HOUR) - EIGHT);
SimpleDateFormat utcSimpleDateFormat = new SimpleDateFormat(UTC_FORMAT);
Date utcDate = calendar.getTime();
return utcSimpleDateFormat.format(utcDate);
}
}

+ 2
- 3
dubhe-server/common/src/main/java/org/dubhe/utils/UniqueKeyGenerator.java View File

@@ -25,9 +25,8 @@ import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;

/**
* @description: 唯一码生成器 (依赖时间轴)
*
* @date 2020.05.08
* @description 唯一码生成器 (依赖时间轴)
* @date 2020-05-08
*/
public class UniqueKeyGenerator {



+ 1
- 2
dubhe-server/common/src/main/java/org/dubhe/utils/WrapperHelp.java View File

@@ -20,7 +20,6 @@ package org.dubhe.utils;
import cn.hutool.core.util.ObjectUtil;
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
import org.dubhe.annotation.Query;

import java.lang.reflect.Field;
@@ -31,7 +30,7 @@ import java.util.List;

/**
* @description 构建Wrapper
* @date 2020-03-15 13:52:30
* @date 2020-03-15
*/
@Slf4j
public class WrapperHelp {


+ 2
- 2
dubhe-server/common/src/test/java/org/dubhe/HttpUtilsTest.java View File

@@ -23,8 +23,8 @@ import org.junit.Test;
import static org.dubhe.utils.HttpUtils.isSuccess;

/**
* @description: HttpUtil
* @date 2020.04.30
* @description HttpUtil
* @date 2020-04-30
*/
public class HttpUtilsTest {



+ 71
- 0
dubhe-server/deploy.sh View File

@@ -0,0 +1,71 @@
#!/bin/bash

PROG_NAME=$0
ACTION=$1
ENV=$2
APP_HOME=$3

APP_NAME=dubhe-${ENV}

APP_HOME=$APP_HOME/${APP_NAME} # 从package.tgz中解压出来的jar包放到这个目录下
JAR_NAME=${APP_HOME}/dubhe-admin/target/dubhe-admin-1.0-exec.jar # jar包的名字
JAVA_OUT=/dev/null

# 创建出相关目录
mkdir -p ${APP_HOME}
mkdir -p ${APP_HOME}/logs

usage() {
echo "Usage: $PROG_NAME {start|stop|restart} {dev|test|prod}"
exit 2
}

start_application() {
echo "starting java process"
echo "nohup java -jar ${JAR_NAME} > ${JAVA_OUT} --spring.profiles.active=${ENV} 2>&1 &"
nohup java -jar ${JAR_NAME} > ${JAVA_OUT} --spring.profiles.active=${ENV} 2>&1 &
echo "started java process"
}

stop_application() {
checkjavapid=`ps -ef | grep java | grep ${APP_NAME} | grep -v grep |grep -v 'deploy.sh'| awk '{print$2}'`

if [ -z $checkjavapid ];then
echo -e "\rno java process "$checkjavapid
return
fi

echo "stop java process"
times=60
for e in $(seq 60)
do
sleep 1
COSTTIME=$(($times - $e ))
checkjavapid=`ps -ef | grep java | grep ${APP_NAME} | grep -v grep |grep -v 'deploy.sh'| awk '{print$2}'`
if [ "$checkjavapid" != "" ];then
echo "kill "$checkjavapid
kill -9 $checkjavapid
echo -e "\r -- stopping java lasts `expr $COSTTIME` seconds."
else
echo -e "\rjava process has exited"
break;
fi
done
echo ""
}

case "$ACTION" in
start)
start_application
;;
stop)
stop_application
;;
restart)
stop_application
start_application
;;
*)
usage
;;
esac

+ 1
- 0
dubhe-server/dubhe-admin/pom.xml View File

@@ -75,6 +75,7 @@
<configuration>
<skip>false</skip>
<fork>true</fork>
<classifier>exec</classifier>
</configuration>
</plugin>
<!-- 跳过单元测试 -->


dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/HarborImagePushAsync.java → dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/HarborImagePushAsync.java View File

@@ -14,7 +14,7 @@
* limitations under the License.
* =============================================================
*/
package org.dubhe.task;
package org.dubhe.async;

import cn.hutool.core.util.StrUtil;
import org.dubhe.base.ResponseCode;
@@ -25,6 +25,7 @@ import org.dubhe.enums.ImageStateEnum;
import org.dubhe.enums.LogEnum;
import org.dubhe.exception.BusinessException;
import org.dubhe.harbor.api.HarborApi;
import org.dubhe.utils.IOUtil;
import org.dubhe.utils.LogUtil;
import org.dubhe.utils.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
@@ -62,46 +63,21 @@ public class HarborImagePushAsync {
String imageResource = trainHarborConfig.getAddress() + StrUtil.SLASH + trainHarborConfig.getModelName()
+ StrUtil.SLASH + imageNameandTag;
String cmdStr = "docker login --username=" + trainHarborConfig.getUsername() + " " + trainHarborConfig.getAddress() + " --password=" + trainHarborConfig.getPassword() + " ; docker " +
"load < " + imagePath + " |awk '{print $3}' |xargs -I str docker tag str " + imageResource + " ; docker push " + imageResource;
"load < " + imagePath + " |awk '{print $3}' |xargs -I str docker tag str " + imageResource + " ; docker push " + imageResource + "; docker rmi " + imageResource;
String[] cmd = {"/bin/bash", "-c", cmdStr};
LogUtil.info(LogEnum.BIZ_TRAIN, "镜像上传执行脚本参数:{}", cmd);

Process process = Runtime.getRuntime().exec(cmd);
//读取标准输出流
BufferedReader brOut = new BufferedReader(new InputStreamReader(process.getInputStream()));
//读取标准错误流
BufferedReader brErr = new BufferedReader(new InputStreamReader(process.getErrorStream()));
String line;
String outMessage = "";
String errMessage = "";
while ((line = brOut.readLine()) != null) {
outMessage += line;
}
if (StringUtils.isNotEmpty(outMessage)) {
LogUtil.info(LogEnum.BIZ_TRAIN, "shell上传镜像输出信息:" + outMessage);
}
while ((line = brErr.readLine()) != null) {
errMessage += line;
}
if (StringUtils.isNotEmpty(errMessage)) {
LogUtil.error(LogEnum.BIZ_TRAIN, "shell上传镜像异常信息:" + errMessage);
}
Integer status = process.waitFor();
LogUtil.info(LogEnum.BIZ_TRAIN, "上传镜像状态:{}", status);
if (status == null) {
if (harborApi.isExistImage(ptImage.getImageUrl())) {
updateImageStatus(ptImage, ImageStateEnum.SUCCESS.getCode());
} else {
updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode());
}
} else if (status == 0) {
if (checkImagePushIsOk(ptImage, process)) {
updateImageStatus(ptImage, ImageStateEnum.SUCCESS.getCode());
} else {
updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode());
}
} catch (Exception e) {
LogUtil.error(LogEnum.BIZ_TRAIN, "上传镜像异常:{}", e);
updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode());
throw new BusinessException("上传镜像异常!");

}
}

@@ -117,4 +93,52 @@ public class HarborImagePushAsync {
ptImageMapper.updateById(ptImage);
return ResponseCode.SUCCESS;
}


/**
* 校验镜像是否上传成功
*
* @param ptImage 镜像信息
* @param process process对象
* @return 是否上传成功
*/
public boolean checkImagePushIsOk(PtImage ptImage, Process process) {
//读取标准输出流
BufferedReader brOut = new BufferedReader(new InputStreamReader(process.getInputStream()));
//读取标准错误流
BufferedReader brErr = new BufferedReader(new InputStreamReader(process.getErrorStream()));
String line;
StringBuilder outMessage = new StringBuilder();
StringBuilder errMessage = new StringBuilder();
boolean isPushOk = true;
try {
while ((line = brOut.readLine()) != null) {
outMessage.append(line);
}
if (StringUtils.isNotEmpty(outMessage)) {
LogUtil.info(LogEnum.BIZ_TRAIN, "shell上传镜像输出信息:{}", outMessage.toString());
}
while ((line = brErr.readLine()) != null) {
errMessage.append(line);
}
if (StringUtils.isNotEmpty(errMessage)) {
LogUtil.error(LogEnum.BIZ_TRAIN, "shell上传镜像异常信息:{}", errMessage.toString());
}
Integer status = process.waitFor();
LogUtil.info(LogEnum.BIZ_TRAIN, "上传镜像状态:{}", status);
if (status == null) {
if (!harborApi.isExistImage(ptImage.getImageUrl())) {
isPushOk = false;
}
} else if (status != 0) {
isPushOk = false;
}
} catch (Exception e) {
LogUtil.error(LogEnum.BIZ_TRAIN, "上传镜像异常:{}", e);
return false;
} finally {
IOUtil.close(brErr, brOut);
}
return isPushOk;
}
}

+ 144
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/StopTrainJobAsync.java View File

@@ -0,0 +1,144 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.async;

import org.dubhe.config.TrainJobConfig;
import org.dubhe.dao.PtTrainJobMapper;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.domain.entity.PtTrainJob;
import org.dubhe.enums.LogEnum;
import org.dubhe.enums.TrainJobStatusEnum;
import org.dubhe.enums.TrainTypeEnum;
import org.dubhe.k8s.api.DistributeTrainApi;
import org.dubhe.k8s.api.PodApi;
import org.dubhe.k8s.api.TrainJobApi;
import org.dubhe.k8s.domain.resource.BizPod;
import org.dubhe.utils.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;

import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.function.Consumer;

/**
* @description 停止训练任务异步处理
* @date 2020-08-13
*/
@Component
public class StopTrainJobAsync {

@Autowired
private K8sNameTool k8sNameTool;

@Autowired
private PodApi podApi;

@Autowired
private TrainJobApi trainJobApi;

@Autowired
private TrainJobConfig trainJobConfig;

@Autowired
private DistributeTrainApi distributeTrainApi;

@Autowired
private PtTrainJobMapper ptTrainJobMapper;

/**
* 停止任务
*
* @param currentUser 用户
* @param jobList 任务集合
*/
@Async("trainExecutor")
public void stopJobs(UserDTO currentUser, List<PtTrainJob> jobList) {
String namespace = k8sNameTool.generateNamespace(currentUser.getId());
jobList.forEach(job -> {
BizPod bizPod = podApi.getWithResourceName(namespace, job.getJobName());
if (!bizPod.isSuccess()) {
LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job return code:{},message:{}", currentUser.getUsername(), Integer.valueOf(bizPod.getCode()), bizPod.getMessage());
}
boolean bool = TrainTypeEnum.isDistributeTrain(job.getTrainType()) ?
distributeTrainApi.deleteByResourceName(namespace, job.getJobName()).isSuccess() :
trainJobApi.delete(namespace, job.getJobName());
if (!bool) {
LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job and K8S fails in the stop process, namespace为{}, resourceName为{}",
currentUser.getUsername(), namespace, job.getJobName());
}
//更新训练状态
job.setRuntime(calculateRuntime(bizPod))
.setTrainStatus(TrainJobStatusEnum.STOP.getStatus());
ptTrainJobMapper.updateById(job);

});
}


/**
* 计算job训练时长
*
* @param bizPod pod信息
* @return String 训练时长
*/
private String calculateRuntime(BizPod bizPod) {
return calculateRuntime(bizPod, (x) -> {
});
}


/**
* 计算job训练时长
*
* @param bizPod
* @param consumer pod已经完成状态的回调函数
* @return res 返回训练时长
*/
private String calculateRuntime(BizPod bizPod, Consumer<String> consumer) {
Long completedTime;
if (StringUtils.isBlank(bizPod.getStartTime())) {
return TrainUtil.INIT_RUNTIME;
}
Long startTime = transformTime(bizPod.getStartTime());
boolean hasCompleted = StringUtils.isNotBlank(bizPod.getCompletedTime());
completedTime = hasCompleted ? transformTime(bizPod.getCompletedTime()) : LocalDateTime.now().toEpochSecond(ZoneOffset.of(trainJobConfig.getPlusEight()));
Long time = completedTime - startTime;
String res = DubheDateUtil.convert2Str(time);
if (hasCompleted) {
consumer.accept(res);
}
return res;
}


/**
* 时间转换
*
* @param time 时间
* @return Long 时间戳
*/
private Long transformTime(String time) {
LocalDateTime localDateTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_OFFSET_DATE_TIME);
//没有根据时区做处理, 默认当前为东八区
localDateTime = localDateTime.plusHours(Long.valueOf(trainJobConfig.getEight()));
return localDateTime.toEpochSecond(ZoneOffset.of(trainJobConfig.getPlusEight()));
}
}

+ 115
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainAlgorithmUploadAsync.java View File

@@ -0,0 +1,115 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.async;

import org.dubhe.config.NfsConfig;
import org.dubhe.dao.PtTrainAlgorithmMapper;
import org.dubhe.domain.dto.PtTrainAlgorithmCreateDTO;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.domain.entity.PtTrainAlgorithm;
import org.dubhe.enums.AlgorithmStatusEnum;
import org.dubhe.enums.BizNfsEnum;
import org.dubhe.enums.LogEnum;
import org.dubhe.exception.BusinessException;
import org.dubhe.service.NoteBookService;
import org.dubhe.utils.K8sNameTool;
import org.dubhe.utils.LocalFileUtil;
import org.dubhe.utils.LogUtil;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Component;

/**
* @description 异步上传算法
* @date 2020-08-10
*/
@Component
public class TrainAlgorithmUploadAsync {

@Autowired
private LocalFileUtil localFileUtil;

@Autowired
private NfsConfig nfsConfig;

@Autowired
private K8sNameTool k8sNameTool;

@Autowired
private NoteBookService noteBookService;

@Autowired
private PtTrainAlgorithmMapper trainAlgorithmMapper;

/**
* 异步任务创建算法
*
* @param user 当前登录用户信息
* @param ptTrainAlgorithm 算法信息
* @param trainAlgorithmCreateDTO 创建算法条件
*/
@Async("trainExecutor")
public void createTrainAlgorithm(UserDTO user, PtTrainAlgorithm ptTrainAlgorithm, PtTrainAlgorithmCreateDTO trainAlgorithmCreateDTO) {
String path = nfsConfig.getBucket() + trainAlgorithmCreateDTO.getCodeDir();
//校验创建算法来源(true:由fork创建算法,false:其它创建算法方式),若为true则拷贝预置算法文件至新路径
if (trainAlgorithmCreateDTO.getFork()) {
//生成算法相对路径
String algorithmPath = k8sNameTool.getNfsPath(BizNfsEnum.ALGORITHM, user.getId());
//拷贝预置算法文件夹
boolean copyResult = localFileUtil.copyPath(path, nfsConfig.getBucket() + algorithmPath);
if (!copyResult) {
LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} copied the preset algorithm path {} successfully", user.getUsername(), path);
updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, false);
throw new BusinessException("内部错误");
}

ptTrainAlgorithm.setCodeDir(algorithmPath);

//修改算法上传状态
updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, true);

} else {
updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, true);
}
}


/**
* 更新上传算法状态
*
* @param ptTrainAlgorithm 算法信息
* @param trainAlgorithmCreateDTO 创建算法的条件
* @param flag 创建算法是否成功(true:成功,false:失败)
*/
public void updateTrainAlgorithm(PtTrainAlgorithm ptTrainAlgorithm, PtTrainAlgorithmCreateDTO trainAlgorithmCreateDTO, boolean flag) {

LogUtil.info(LogEnum.BIZ_TRAIN, "async update algorithmPath by algorithmId:{} and update noteBook by noteBookId:{}", ptTrainAlgorithm.getId(), trainAlgorithmCreateDTO.getNoteBookId());
if (flag) {
ptTrainAlgorithm.setAlgorithmStatus(AlgorithmStatusEnum.SUCCESS.getCode());
//更新fork算法新路径
trainAlgorithmMapper.updateById(ptTrainAlgorithm);
//保存算法根据notbookId更新算法id
if (trainAlgorithmCreateDTO.getNoteBookId() != null) {
LogUtil.info(LogEnum.BIZ_TRAIN, "Save algorithm Update algorithm ID :{} according to notBookId:{}", trainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId());
noteBookService.updateTrainIdByNoteBookId(trainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId());
}
} else {
ptTrainAlgorithm.setAlgorithmStatus(AlgorithmStatusEnum.FAIL.getCode());
trainAlgorithmMapper.updateById(ptTrainAlgorithm);
}
}
}

+ 417
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TrainJobAsync.java View File

@@ -0,0 +1,417 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.async;

import cn.hutool.core.util.StrUtil;
import com.alibaba.fastjson.JSONObject;
import org.dubhe.base.MagicNumConstant;
import org.dubhe.constant.SymbolConstant;
import org.dubhe.config.TrainJobConfig;
import org.dubhe.dao.PtTrainJobMapper;
import org.dubhe.domain.dto.BaseTrainJobDTO;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.domain.entity.PtTrainJob;
import org.dubhe.domain.vo.PtImageAndAlgorithmVO;
import org.dubhe.enums.BizEnum;
import org.dubhe.enums.LogEnum;
import org.dubhe.enums.ResourcesPoolTypeEnum;
import org.dubhe.enums.TrainJobStatusEnum;
import org.dubhe.exception.BusinessException;
import org.dubhe.k8s.api.DistributeTrainApi;
import org.dubhe.k8s.api.NamespaceApi;
import org.dubhe.k8s.api.TrainJobApi;
import org.dubhe.k8s.domain.bo.DistributeTrainBO;
import org.dubhe.k8s.domain.bo.PtJupyterJobBO;
import org.dubhe.k8s.domain.resource.BizDistributeTrain;
import org.dubhe.k8s.domain.resource.BizNamespace;
import org.dubhe.k8s.domain.vo.PtJupyterJobVO;
import org.dubhe.utils.*;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
import java.util.ArrayList;
import java.util.List;

/**
* @description 提交训练任务
* @date 2020-07-17
*/
@Component
public class TrainJobAsync {

@Autowired
private K8sNameTool k8sNameTool;

@Autowired
private NamespaceApi namespaceApi;

@Autowired
private TrainJobConfig trainJobConfig;

@Autowired
private NfsUtil nfsUtil;

@Autowired
private LocalFileUtil localFileUtil;

@Autowired
private PtTrainJobMapper ptTrainJobMapper;

@Autowired
private TrainJobApi trainJobApi;

@Autowired
private DistributeTrainApi distributeTrainApi;


/**
* 提交分布式训练
*
* @param baseTrainJobDTO 训练任务信息
* @param currentUser 用户
* @param ptImageAndAlgorithmVO 镜像和算法信息
* @param ptTrainJob 训练任务实体信息
*/
public void doDistributedJob(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) {
try {
//判断是否存在相应的namespace,如果没有则创建
String namespace = getNamespace(currentUser);
// 构建DistributeTrainBO
DistributeTrainBO bo = buildDistributeTrainBO(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob, namespace);
if (null == bo) {
LogUtil.error(LogEnum.BIZ_TRAIN, "user{}create TrainJob,Encapsulating ptjupyterjobbo object is empty,the received parameters namespace:{}", currentUser.getId(), namespace);
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, "", false);
return;
}
// 调度K8s
BizDistributeTrain bizDistributeTrain = distributeTrainApi.create(bo);
if (bizDistributeTrain.isSuccess()) {
// 调度成功
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, bizDistributeTrain.getName(), true);
} else {
// 调度失败
LogUtil.error(LogEnum.BIZ_TRAIN, "distributeTrainApi.create FAILED! {}", bizDistributeTrain);
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, bizDistributeTrain.getName(), false);
}
} catch (Exception e) {
LogUtil.error(LogEnum.BIZ_TRAIN, "doDistributedJob ERROR!{} ", e);
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, "", false);
}
}

/**
* 构造分布式训练DistributeTrainBO
*
* @param baseTrainJobDTO 训练任务信息
* @param currentUser 用户
* @param ptImageAndAlgorithmVO 镜像和算法信息
* @param ptTrainJob 训练任务实体信息
* @param namespace 命名空间
* @return DistributeTrainBO
*/
private DistributeTrainBO buildDistributeTrainBO(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob, String namespace) {
//绝对路径
String basePath = nfsUtil.getNfsConfig().getBucket() + trainJobConfig.getManage() + StrUtil.SLASH
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName();
//相对路径
String relativePath = StrUtil.SLASH + trainJobConfig.getManage() + StrUtil.SLASH
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName();
String[] codeDirArray = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH);
String workspaceDir = codeDirArray[codeDirArray.length - 1];
// 算法路径待拷贝的地址
String sourcePath = nfsUtil.getNfsConfig().getBucket() + ptImageAndAlgorithmVO.getCodeDir().substring(1);
String trainDir = basePath.substring(1) + StrUtil.SLASH + workspaceDir;

if (!localFileUtil.copyPath(sourcePath, trainDir)) {
LogUtil.error(LogEnum.BIZ_TRAIN, "buildDistributeTrainBO copyPath failed ! sourcePath:{},basePath:{},trainDir:{}", sourcePath, basePath, trainDir);
return null;
}
// 参数前缀
String paramPrefix = trainJobConfig.getPythonFormat();
// 初始化固定头命令,获取分布式节点IP
StringBuilder sb = new StringBuilder("export NODE_IPS=`cat /home/hostfile.json |jq -r \".[]|.ip\"|paste -d \",\" -s` ");
// 切换到算法路径下
sb.append(" && cd ").append(trainJobConfig.getDockerTrainPath()).append(StrUtil.SLASH).append(workspaceDir).append(" && ");
// 拼接用户自定义python启动命令
sb.append(ptImageAndAlgorithmVO.getRunCommand());
// 拼接python固定参数 节点IP
sb.append(paramPrefix).append(trainJobConfig.getNodeIps()).append("=\"$NODE_IPS\" ");
// 拼接python固定参数 节点数量
sb.append(paramPrefix).append(trainJobConfig.getNodeNum()).append(SymbolConstant.FLAG_EQUAL).append(ptTrainJob.getResourcesPoolNode()).append(StrUtil.SPACE);
if (ptImageAndAlgorithmVO.getIsTrainOut()) {
// 拼接 out
nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getOutPath());
baseTrainJobDTO.setOutPath(relativePath + StrUtil.SLASH + trainJobConfig.getOutPath());
sb.append(paramPrefix).append(trainJobConfig.getDockerOutPath());
}
if (ptImageAndAlgorithmVO.getIsTrainLog()) {
// 拼接 输出日志
nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getLogPath());
baseTrainJobDTO.setLogPath(relativePath + StrUtil.SLASH + trainJobConfig.getLogPath());
sb.append(paramPrefix).append(trainJobConfig.getDockerLogPath());
}
if (ptImageAndAlgorithmVO.getIsVisualizedLog()) {
// 拼接 输出可视化日志
nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath());
baseTrainJobDTO.setVisualizedLogPath(relativePath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath());
sb.append(paramPrefix).append(trainJobConfig.getDockerVisualizedLogPath());
}
// 拼接python固定参数 数据集
sb.append(paramPrefix).append(trainJobConfig.getDockerDataset());
JSONObject runParams = baseTrainJobDTO.getRunParams();
if (null != runParams && !runParams.isEmpty()) {
// 拼接用户自定义参数
runParams.entrySet().forEach(entry ->
sb.append(paramPrefix).append(entry.getKey()).append(SymbolConstant.FLAG_EQUAL).append(entry.getValue()).append(StrUtil.SPACE)
);
}
// 在用户自定以参数拼接晚后拼接固定参数,防止被用户自定义参数覆盖
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) {
// 需要GPU
sb.append(paramPrefix).append(trainJobConfig.getGpuNumPerNode()).append(SymbolConstant.FLAG_EQUAL).append(baseTrainJobDTO.getGpuNumPerNode()).append(StrUtil.SPACE);
}
String mainCommand = sb.toString();
// 拼接辅助日志打印
String wholeCommand = " echo 'Distribute training mission begins... "
+ mainCommand
+ " ' && "
+ mainCommand
+ " && echo 'Distribute training mission is over' ";
DistributeTrainBO distributeTrainBO = new DistributeTrainBO()
.setNamespace(namespace)
.setName(baseTrainJobDTO.getJobName())
.setSize(ptTrainJob.getResourcesPoolNode())
.setImage(ptImageAndAlgorithmVO.getImageName())
.setMasterCmd(wholeCommand)
.setMemNum(baseTrainJobDTO.getMenNum())
.setCpuNum(baseTrainJobDTO.getCpuNum())
.setDatasetStoragePath(k8sNameTool.getAbsoluteNfsPath(baseTrainJobDTO.getDataSourcePath()))
.setWorkspaceStoragePath(localFileUtil.formatPath(nfsUtil.getNfsConfig().getRootDir() + basePath))
.setModelStoragePath(k8sNameTool.getAbsoluteNfsPath(relativePath + StrUtil.SLASH + trainJobConfig.getOutPath()))
.setBusinessLabel(k8sNameTool.getPodLabel(BizEnum.ALGORITHM));
//延时启动,单位为分钟
if (baseTrainJobDTO.getDelayCreateTime() != null && baseTrainJobDTO.getDelayCreateTime() > 0) {
distributeTrainBO.setDelayCreateTime(baseTrainJobDTO.getDelayCreateTime() * MagicNumConstant.SIXTY);
}
//定时停止,单位为分钟
if (baseTrainJobDTO.getDelayDeleteTime() != null && baseTrainJobDTO.getDelayDeleteTime() > 0) {
distributeTrainBO.setDelayDeleteTime(baseTrainJobDTO.getDelayDeleteTime() * MagicNumConstant.SIXTY);
}
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) {
// 需要GPU
distributeTrainBO.setGpuNum(baseTrainJobDTO.getGpuNumPerNode());
}
// 主从一致
distributeTrainBO.setSlaveCmd(distributeTrainBO.getMasterCmd());
return distributeTrainBO;
}


/**
* 提交job
*
* @param baseTrainJobDTO 训练任务信息
* @param currentUser 用户
* @param ptImageAndAlgorithmVO 镜像和算法信息
*/
public void doJob(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) {
PtJupyterJobBO jobBo = null;
String k8sJobName = "";
try {
//判断是否存在相应的namespace,如果没有则创建
String namespace = getNamespace(currentUser);

//封装PtJupyterJobBO对象,调用创建训练任务接口
jobBo = pkgPtJupyterJobBo(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, namespace);
if (null == jobBo) {
LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob,Encapsulating ptjupyterjobbo object is empty,the received parameters namespace:{}", currentUser.getId(), namespace);
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false);
}
PtJupyterJobVO ptJupyterJobResult = trainJobApi.create(jobBo);
if (!ptJupyterJobResult.isSuccess()) {
String message = null == ptJupyterJobResult.getMessage() ? "未知的错误" : ptJupyterJobResult.getMessage();
LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(), jobBo, message);
ptTrainJob.setTrainMsg(message);
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false);
}
k8sJobName = ptJupyterJobResult.getName();
//更新训练任务状态
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, true);
} catch (Exception e) {
LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(),
jobBo, e);
ptTrainJob.setTrainMsg("内部错误");
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false);
}
}


/**
* 获取namespace
*
* @param currentUser 用户
* @return String 命名空间
*/
private String getNamespace(UserDTO currentUser) {
String namespaceStr = k8sNameTool.generateNamespace(currentUser.getId());
BizNamespace bizNamespace = namespaceApi.get(namespaceStr);
if (null == bizNamespace) {
BizNamespace namespace = namespaceApi.create(namespaceStr, null);
if (null == namespace || !namespace.isSuccess()) {
LogUtil.error(LogEnum.BIZ_TRAIN, "User {} failed to create namespace during training job...");
throw new BusinessException("内部错误");
}
}
return namespaceStr;
}

/**
* 封装出创建job所需的BO
*
* @param baseTrainJobDTO 训练任务信息
* @param ptImageAndAlgorithmVO 镜像和算法信息
* @param namespace 命名空间
* @return PtJupyterJobBO jupyter任务BO
*/
private PtJupyterJobBO pkgPtJupyterJobBo(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser,
PtImageAndAlgorithmVO ptImageAndAlgorithmVO, String namespace) {

//绝对路径
String commonPath = nfsUtil.getNfsConfig().getBucket() + trainJobConfig.getManage() + StrUtil.SLASH
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName();
//相对路径
String relativeCommonPath = StrUtil.SLASH + trainJobConfig.getManage() + StrUtil.SLASH
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName();
String[] codeDirArray = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH);
String workspaceDir = codeDirArray[codeDirArray.length - 1];
// 算法路径待拷贝的地址
String sourcePath = nfsUtil.getNfsConfig().getBucket() + ptImageAndAlgorithmVO.getCodeDir().substring(1);
String trainDir = commonPath.substring(1) + StrUtil.SLASH + workspaceDir;
LogUtil.info(LogEnum.BIZ_TRAIN, "Algorithm path copy::sourcePath:{},commonPath:{},trainDir:{}", sourcePath, commonPath, trainDir);
boolean bool = localFileUtil.copyPath(sourcePath.substring(1), trainDir);
if (!bool) {
LogUtil.error(LogEnum.BIZ_TRAIN, "During the process of user {} creating training Job and encapsulating k8s creating job interface parameters, it failed to copy algorithm directory {} to the specified directory {}", currentUser.getUsername(), sourcePath.substring(1),
trainDir);
return null;
}

List<String> list = new ArrayList<>();
JSONObject runParams = baseTrainJobDTO.getRunParams();

StringBuilder sb = new StringBuilder();
sb.append(ptImageAndAlgorithmVO.getRunCommand());
// 拼接out,log和dataset
String pattern = trainJobConfig.getPythonFormat();
if (ptImageAndAlgorithmVO.getIsTrainOut()) {
nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getOutPath());
baseTrainJobDTO.setOutPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getOutPath());
sb.append(pattern).append(trainJobConfig.getDockerOutPath());
}
if (ptImageAndAlgorithmVO.getIsTrainLog()) {
nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getLogPath());
baseTrainJobDTO.setLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getLogPath());
sb.append(pattern).append(trainJobConfig.getDockerLogPath());
}
if (ptImageAndAlgorithmVO.getIsVisualizedLog()) {
nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath());
baseTrainJobDTO.setVisualizedLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath());
sb.append(pattern).append(trainJobConfig.getDockerVisualizedLogPath());
}
sb.append(pattern).append(trainJobConfig.getDockerDataset());

String valDataSourcePath = baseTrainJobDTO.getValDataSourcePath();
if (StringUtils.isNotBlank(valDataSourcePath)) {
sb.append(pattern).append(trainJobConfig.getLoadValDatasetKey()).append(SymbolConstant.FLAG_EQUAL).append(trainJobConfig.getDockerValDatasetPath());
}
//将模型加载路径拼接到
String modelLoadPathDir = baseTrainJobDTO.getModelLoadPathDir();
if (StringUtils.isNotBlank(modelLoadPathDir)) {
//将模型路径model_load_dir路径
sb.append(pattern).append(trainJobConfig.getLoadKey()).append(SymbolConstant.FLAG_EQUAL).append(trainJobConfig.getDockerModelPath());
}

if (null != runParams && !runParams.isEmpty()) {
runParams.forEach((k, v) ->
sb.append(pattern).append(k).append(SymbolConstant.FLAG_EQUAL).append(v).append(StrUtil.SPACE)
);
}
// 在用户自定以参数拼接晚后拼接固定参数,防止被用户自定义参数覆盖
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) {
// 需要GPU
sb.append(pattern).append(trainJobConfig.getGpuNumPerNode()).append(SymbolConstant.FLAG_EQUAL).append(baseTrainJobDTO.getGpuNumPerNode()).append(StrUtil.SPACE);
}
String executeCmd = sb.toString();
list.add("-c");

String workPath = trainJobConfig.getDockerTrainPath() + StrUtil.SLASH + workspaceDir;
String command = "echo 'training mission begins... " + executeCmd + "\r\n '" +
" && cd " + workPath +
" && " + executeCmd +
" && echo 'the training mission is over' ";
list.add(command);

PtJupyterJobBO jobBo = new PtJupyterJobBO();
jobBo.setNamespace(namespace)
.setName(baseTrainJobDTO.getJobName())
.setImage(ptImageAndAlgorithmVO.getImageName())
.putNfsMounts(trainJobConfig.getDockerDatasetPath(), nfsUtil.getNfsConfig().getRootDir() + nfsUtil.getNfsConfig().getBucket().substring(1) + baseTrainJobDTO.getDataSourcePath())
.setCmdLines(list)
.putNfsMounts(trainJobConfig.getDockerTrainPath(), nfsUtil.getNfsConfig().getRootDir() + commonPath.substring(1))
.putNfsMounts(trainJobConfig.getDockerModelPath(), nfsUtil.formatPath(nfsUtil.getAbsolutePath(modelLoadPathDir)))
.putNfsMounts(trainJobConfig.getDockerValDatasetPath(), nfsUtil.formatPath(nfsUtil.getAbsolutePath(valDataSourcePath)))
.setBusinessLabel(k8sNameTool.getPodLabel(BizEnum.ALGORITHM));
//延时启动,单位为分钟
if (baseTrainJobDTO.getDelayCreateTime() != null && baseTrainJobDTO.getDelayCreateTime() > 0) {
jobBo.setDelayCreateTime(baseTrainJobDTO.getDelayCreateTime() * MagicNumConstant.SIXTY);
}
//自动停止,单位为分钟
if (baseTrainJobDTO.getDelayDeleteTime() != null && baseTrainJobDTO.getDelayDeleteTime() > 0) {
jobBo.setDelayDeleteTime(baseTrainJobDTO.getDelayDeleteTime() * MagicNumConstant.SIXTY);
}
jobBo.setCpuNum(baseTrainJobDTO.getCpuNum()).setMemNum(baseTrainJobDTO.getMenNum());
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) {
jobBo.setUseGpu(true).setGpuNum(baseTrainJobDTO.getGpuNumPerNode());
} else {
jobBo.setUseGpu(false);
}
return jobBo;
}

/**
* 训练任务异步处理更新训练状态
*
* @param user 用户
* @param ptTrainJob 训练任务
* @param baseTrainJobDTO 训练任务信息
* @param k8sJobName k8s创建的job名称,或者分布式训练名称
* @param flag 创建训练任务是否异常(true:正常,false:失败)
**/
private void updateTrainStatus(UserDTO user, PtTrainJob ptTrainJob, BaseTrainJobDTO baseTrainJobDTO, String k8sJobName, boolean flag) {

ptTrainJob.setK8sJobName(k8sJobName)
.setOutPath(baseTrainJobDTO.getOutPath())
.setLogPath(baseTrainJobDTO.getLogPath())
.setVisualizedLogPath(baseTrainJobDTO.getVisualizedLogPath());
LogUtil.info(LogEnum.BIZ_TRAIN, "user {} training tasks are processed asynchronously to update training status,receiving parameters:{}", user.getId(), ptTrainJob);
if (flag) {
ptTrainJobMapper.updateById(ptTrainJob);
} else {
ptTrainJob.setTrainStatus(TrainJobStatusEnum.CREATE_FAILED.getStatus());
//训练任务创建失败
ptTrainJobMapper.updateById(ptTrainJob);
}
}
}

dubhe-server/dubhe-admin/src/main/java/org/dubhe/task/TransactionAsyncManager.java → dubhe-server/dubhe-admin/src/main/java/org/dubhe/async/TransactionAsyncManager.java View File

@@ -14,13 +14,16 @@
* limitations under the License.
* =============================================================
*/
package org.dubhe.task;
package org.dubhe.async;

import org.dubhe.aspect.LogAspect;
import org.dubhe.base.DataContext;
import org.dubhe.domain.dto.BaseTrainJobDTO;
import org.dubhe.domain.dto.CommonPermissionDataDTO;
import org.dubhe.domain.dto.UserDTO;
import org.dubhe.domain.entity.PtTrainJob;
import org.dubhe.domain.vo.PtImageAndAlgorithmVO;
import org.dubhe.enums.TrainTypeEnum;
import org.slf4j.MDC;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
@@ -38,26 +41,32 @@ import java.util.concurrent.Executor;
public class TransactionAsyncManager {

@Autowired
private TrainJobAsyncTask trainJobAsyncTask;

@Resource(name = "trainJobAsyncExecutor")
private Executor trainJobAsyncExecutor;
private TrainJobAsync trainJobAsync;

@Resource(name = "trainExecutor")
private Executor trainExecutor;

public void execute(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) {

String traceId = MDC.get(LogAspect.TRACE_ID);
CommonPermissionDataDTO commonPermissionDataDTO = DataContext.get();
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() {

@Override
public void afterCommit() {
trainJobAsyncExecutor.execute(
() -> {
MDC.put(LogAspect.TRACE_ID, traceId);
trainJobAsyncTask.doJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob);
MDC.remove(LogAspect.TRACE_ID);
}
);

trainExecutor.execute(() -> {
MDC.put(LogAspect.TRACE_ID, traceId);
DataContext.set(commonPermissionDataDTO);

if (TrainTypeEnum.isDistributeTrain(ptTrainJob.getTrainType())) {
// 分布式训练
trainJobAsync.doDistributedJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob);
} else {
// 普通训练
trainJobAsync.doJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob);
}
MDC.remove(LogAspect.TRACE_ID);
DataContext.remove();
});
}
});
}

+ 148
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/GlobalFilter.java View File

@@ -0,0 +1,148 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.config;

import com.alibaba.fastjson.JSON;
import org.dubhe.constant.StringConstant;
import org.dubhe.constatnts.UserConstant;
import org.dubhe.dto.GlobalRequestRecordDTO;
import org.dubhe.enums.LogEnum;
import org.dubhe.utils.JwtUtils;
import org.dubhe.utils.LogUtil;
import org.dubhe.utils.StringUtils;
import org.springframework.core.annotation.Order;
import org.springframework.stereotype.Component;

import javax.servlet.*;
import javax.servlet.annotation.WebFilter;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

import static org.dubhe.constant.StringConstant.K8S_CALLBACK_URI;

/**
* @description 全局请求拦截器 用于日志收集
* @date 2020-08-13
*/
@Order(1)
@Component
@WebFilter(filterName = "GlobalFilter", urlPatterns = "/**")
public class GlobalFilter implements Filter {
@Override
public void init(FilterConfig filterConfig) throws ServletException {

}

@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException {

long start = System.currentTimeMillis();
HttpServletRequest request = ((HttpServletRequest) servletRequest);
HttpServletResponse response = ((HttpServletResponse) servletResponse);

GlobalRequestRecordDTO dto = new GlobalRequestRecordDTO();

try {
if (StringUtils.isNotBlank(request.getContentType()) && request.getContentType().contains(StringConstant.MULTIPART)) {
chain.doFilter(request, response);
} else {
checkScheduleRequest(request);
checkK8sCallback(request);

RequestBodyWrapper requestBodyWrapper = new RequestBodyWrapper(request);
ResponseBodyWrapper responseBodyWrapper = new ResponseBodyWrapper(response);
dto.setRequestBody(requestBodyWrapper.getBodyString());
chain.doFilter(requestBodyWrapper, responseBodyWrapper);

if (StringConstant.JSON_REQUEST.equals(responseBodyWrapper.getContentType())) {
final String responseBody = responseBodyWrapper.getResponseBody();
dto.setResponseBody(responseBody);
} else {
responseBodyWrapper.flush();
}
}
} catch (Exception e) {
LogUtil.error(LogEnum.GLOBAL_REQ, "Global request record error : {}", e);
throw e;
} finally {
buildGlobalRequestDTO(dto, request, response);
dto.setTimeCost(System.currentTimeMillis() - start);
LogUtil.info(LogEnum.GLOBAL_REQ, "Global request record: {}", dto);
LogUtil.cleanTrace();
}
}

/**
* 构建全局请求对象
*
* @param dto
* @param request
* @param response
*/
private void buildGlobalRequestDTO(GlobalRequestRecordDTO dto, HttpServletRequest request, HttpServletResponse response) {
dto.setClientHost(request.getRemoteHost());
dto.setParams(JSON.toJSONString(request.getParameterMap()));
dto.setMethod(request.getMethod());
dto.setUri(request.getRequestURI());
//身份认证信息
String token = request.getHeader(UserConstant.USER_TOKEN_KEY);
dto.setAuthorization(token);
if (token != null) {
String userName = JwtUtils.getUserName(token);
dto.setUsername(userName);
}
dto.setContentType(response.getContentType());
dto.setStatus(response.getStatus());
}

/**
* 检查是否是前端的定时请求
*
* @param request 请求信息
* @return 是否是前端的定时请求
*/
private boolean checkScheduleRequest(HttpServletRequest request) {

if (StringConstant.REQUEST_METHOD_GET.equals(request.getMethod())
&& StringUtils.isNotBlank(request.getParameter(LogUtil.SCHEDULE_LEVEL))) {
LogUtil.startScheduleTrace();
return true;
}

return false;
}

/**
* 校验请求是否为k8s回调
* @param request 请求信息
* @return 是否为k8s回调
*/
private boolean checkK8sCallback(HttpServletRequest request) {
if (request.getRequestURI() != null && request.getRequestURI().contains(K8S_CALLBACK_URI)) {
LogUtil.startK8sCallbackTrace();
return true;
}
return false;
}

@Override
public void destroy() {

}
}

+ 101
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/RequestBodyWrapper.java View File

@@ -0,0 +1,101 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.config;

import org.dubhe.enums.LogEnum;
import org.dubhe.utils.LogUtil;

import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStreamReader;

/**
* @description 用于获取请求body参数的包装类
* @date 2020-08-20
*/
public class RequestBodyWrapper extends HttpServletRequestWrapper {

private ByteArrayInputStream byteArrayInputStream;

private String bodyString;

public RequestBodyWrapper(HttpServletRequest request)
throws IOException {
super(request);
try (BufferedReader reader = request.getReader()) {
StringBuilder sb = new StringBuilder();
String line = null;
while ((line = reader.readLine()) != null) {
sb.append(line);
}
if (sb.length() > 0) {
bodyString = sb.toString();
} else {
bodyString = "";
}
byteArrayInputStream = new ByteArrayInputStream(bodyString.getBytes());
} catch (Exception e) {
LogUtil.error(LogEnum.GLOBAL_REQ, "request get reader error : {}", e);
throw e;
}

}

/**
* 获取请求体的json数据
*
* @return
*/
public String getBodyString() {
return bodyString;
}

@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}

@Override
public ServletInputStream getInputStream() throws IOException {
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}

@Override
public boolean isReady() {
return false;
}

@Override
public void setReadListener(ReadListener readListener) {
}

@Override
public int read() throws IOException {
return byteArrayInputStream.read();
}
};
}
}

+ 126
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/ResponseBodyWrapper.java View File

@@ -0,0 +1,126 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/

package org.dubhe.config;

import org.dubhe.enums.LogEnum;
import org.dubhe.utils.LogUtil;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.io.UnsupportedEncodingException;

import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import javax.servlet.http.HttpServletResponse;
import javax.servlet.http.HttpServletResponseWrapper;

/**
* @description 用于获取response 的 json 返回值
* @date 2020-08-19
*/
public class ResponseBodyWrapper extends HttpServletResponseWrapper {

private ByteArrayOutputStream byteArrayOutputStream = null;

private ServletOutputStream servletOutputStream = null;

private PrintWriter printWriter = null;

private HttpServletResponse response;

public ResponseBodyWrapper(HttpServletResponse response) throws IOException {
super(response);
this.response = response;
byteArrayOutputStream = new ByteArrayOutputStream();
printWriter = new PrintWriter(new OutputStreamWriter(byteArrayOutputStream, "UTF-8"));
servletOutputStream = new ServletOutputStream() {

@Override
public void write(int b) throws IOException {
byteArrayOutputStream.write(b);
}

@Override
public boolean isReady() {
return false;
}

@Override
public void setWriteListener(WriteListener writeListener) {
}
};
}

@Override
public ServletOutputStream getOutputStream() throws IOException {
return servletOutputStream;
}

@Override
public PrintWriter getWriter() throws IOException {
return printWriter;
}

@Override
public void flushBuffer() throws IOException {
if (servletOutputStream != null) {
servletOutputStream.flush();
}
if (printWriter != null) {
printWriter.flush();
}
}

@Override
public void reset() {
byteArrayOutputStream.reset();
}

/**
* 获取json返回值
* @return
* @throws IOException
*/
public String getResponseBody() throws IOException {
//清空response的流,之后再添加进去
flushBuffer();
byte[] bytes = byteArrayOutputStream.toByteArray();
try {
return new String(bytes, "UTF-8");
} catch (UnsupportedEncodingException e) {
LogUtil.error(LogEnum.GLOBAL_REQ, e);
} finally {
response.getOutputStream().write(bytes);
}
return "";
}

/**
* 清掉缓冲
* @throws IOException
*/
public void flush() throws IOException {
//清空response的流,之后再添加进去
flushBuffer();
byte[] bytes = byteArrayOutputStream.toByteArray();
response.getOutputStream().write(bytes);
}

}

+ 1
- 1
dubhe-server/dubhe-admin/src/main/java/org/dubhe/config/TimestampConverter.java View File

@@ -24,7 +24,7 @@ import org.springframework.stereotype.Component;
import java.sql.Timestamp;

/**
* @description: 转换时间戳类型
* @description 转换时间戳类型
* @date 2020-05-22
*/
@Component


+ 47
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/ModelQueryMapper.java View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.dao;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.dubhe.domain.dto.ModelQueryDTO;
import org.dubhe.domain.entity.ModelQuery;
import org.dubhe.domain.entity.ModelQueryBrance;

/**
* @description model mapper
* @date 2020-10-09
*/
public interface ModelQueryMapper extends BaseMapper<ModelQueryDTO> {
/**
* 根据modelId查询模型信息
*
* @param modelId 模型id
* @return modelQuery返回查询的模型对象
*/
@Select("select name,url from pt_model_info where id=#{modelId}")
ModelQuery findModelNameById(@Param("modelId") Integer modelId);

/**
* 根据模型路径查询模型版本信息
*
* @param modelLoadPathDir 模型路径
* @return ModelQueryBrance 模型版本信息
*/
@Select("select version from pt_model_branch where url=#{modelLoadPathDir}")
ModelQueryBrance findModelVersionByUrl(@Param("modelLoadPathDir") String modelLoadPathDir);
}

+ 6
- 6
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/NoteBookMapper.java View File

@@ -21,6 +21,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.dubhe.annotation.DataPermission;
import org.dubhe.domain.entity.NoteBook;

import java.util.List;
@@ -29,28 +30,27 @@ import java.util.List;
* @description notebook mapper
* @date 2020-04-28
*/
@DataPermission(ignoresMethod = {"insert","findByNamespaceAndResourceName","selectRunNotUrlList"})
public interface NoteBookMapper extends BaseMapper<NoteBook> {

/**
* 根据名称查询
*
* @param name
* @param userId
* @param status
* @return NoteBook
*/
@Select("select * from notebook where notebook_name = #{name} and user_id = #{userId} and status != #{status} and deleted = 0 limit 1")
NoteBook findByNameAndUserId(@Param("name") String name, @Param("userId") long userId, @Param("status") Integer status);
@Select("select * from notebook where notebook_name = #{name} and status != #{status} and deleted = 0 limit 1")
NoteBook findByNameAndStatus(@Param("name") String name, @Param("status") Integer status);

/**
* 查询正在运行的notebook数量
*
* @param userId
* @param status
* @return int
*/
@Select("select count(1) from notebook where user_id = #{userId} and status = #{status} and deleted = 0")
int selectRunNoteBookNum(@Param("userId") long userId, @Param("status") Integer status);
@Select("select count(1) from notebook where status = #{status} and deleted = 0")
int selectRunNoteBookNum( @Param("status") Integer status);

/**
* 根据namespace + resourceName查询


+ 2
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtImageMapper.java View File

@@ -19,12 +19,14 @@ package org.dubhe.dao;


import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.dubhe.annotation.DataPermission;
import org.dubhe.domain.entity.PtImage;

/**
* @description 镜像 Mapper 接口
* @date 2020-04-27
*/
@DataPermission(ignoresMethod = {"insert"})
public interface PtImageMapper extends BaseMapper<PtImage> {

}

+ 2
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmMapper.java View File

@@ -20,6 +20,7 @@ package org.dubhe.dao;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.dubhe.annotation.DataPermission;
import org.dubhe.domain.entity.PtTrainAlgorithm;

import java.util.List;
@@ -28,6 +29,7 @@ import java.util.List;
* @description 训练算法Mapper
* @date 2020-04-27
*/
@DataPermission(ignoresMethod = {"insert"})
public interface PtTrainAlgorithmMapper extends BaseMapper<PtTrainAlgorithm> {

/**


+ 4
- 4
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainAlgorithmUsageMapper.java View File

@@ -17,17 +17,17 @@

package org.dubhe.dao;

import org.dubhe.annotation.DataPermission;
import org.dubhe.domain.entity.PtTrainAlgorithmUsage;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;

/**
*
* 用户辅助信息Mapper 接口
* </p>
*
* @since 2020-06-23
* @description 用户辅助信息Mapper 接口
* @date 2020-06-23
*/
@DataPermission(ignoresMethod = "insert")
public interface PtTrainAlgorithmUsageMapper extends BaseMapper<PtTrainAlgorithmUsage> {

}

+ 2
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainJobMapper.java View File

@@ -21,6 +21,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import com.baomidou.mybatisplus.extension.plugins.pagination.Page;
import org.apache.ibatis.annotations.Param;
import org.apache.ibatis.annotations.Select;
import org.dubhe.annotation.DataPermission;
import org.dubhe.domain.entity.PtTrainJob;
import org.dubhe.domain.vo.PtTrainVO;

@@ -28,6 +29,7 @@ import org.dubhe.domain.vo.PtTrainVO;
* @description 训练作业job Mapper 接口
* @date 2020-04-27
*/
@DataPermission(ignoresMethod = {"insert","selectCountByStatus","getPageTrain"})
public interface PtTrainJobMapper extends BaseMapper<PtTrainJob> {

/**


+ 2
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainMapper.java View File

@@ -18,12 +18,14 @@
package org.dubhe.dao;

import com.baomidou.mybatisplus.core.mapper.BaseMapper;
import org.dubhe.annotation.DataPermission;
import org.dubhe.domain.entity.PtTrain;

/**
* @description 训练作业主 Mapper 接口
* @date 2020-04-27
*/
@DataPermission(ignoresMethod = {"insert"})
public interface PtTrainMapper extends BaseMapper<PtTrain> {

}

+ 2
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/dao/PtTrainParamMapper.java View File

@@ -17,6 +17,7 @@

package org.dubhe.dao;

import org.dubhe.annotation.DataPermission;
import org.dubhe.domain.entity.PtTrainParam;
import com.baomidou.mybatisplus.core.mapper.BaseMapper;

@@ -24,6 +25,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper;
* @description 任务参数 Mapper 接口
* @date 2020-04-27
*/
@DataPermission(ignoresMethod = "insert")
public interface PtTrainParamMapper extends BaseMapper<PtTrainParam> {

}

+ 36
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/BaseTrainJobDTO.java View File

@@ -40,4 +40,40 @@ public class BaseTrainJobDTO implements Serializable {
private String outPath;
private String logPath;
private String visualizedLogPath;
private Integer delayCreateTime;
private Integer delayDeleteTime;

/**
* @return 每个节点的GPU数量
*/
public Integer getGpuNumPerNode(){
return getPtTrainJobSpecs().getSpecsInfo().getInteger("gpuNum");
}

/**
* @return cpu数量
*/
public Integer getCpuNum(){
return getPtTrainJobSpecs().getSpecsInfo().getInteger("cpuNum");
}

/**
* @return memNum
*/
public Integer getMenNum(){
return getPtTrainJobSpecs().getSpecsInfo().getInteger("memNum");
}
/**
* "验证数据来源名称"
*/
private String valDataSourceName;

/**
* 验证数据来源路径
*/
private String valDataSourcePath;
/**
* 模型路径
*/
private String modelLoadPathDir;
}

+ 39
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/ModelQueryDTO.java View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.domain.dto;
import io.swagger.annotations.ApiModel;
import io.swagger.annotations.ApiModelProperty;
import lombok.Data;

/**
* @description model 查询dto
* @date 2020-10-09
*/
@Data
@ApiModel("模型查询")
public class ModelQueryDTO {
@ApiModelProperty(value = "模型类型")
private Integer modelResource;
@ApiModelProperty(value = "模型名称")
private String name;
@ApiModelProperty(value = "模型版本")
private String version;
@ApiModelProperty(value = "模型id")
private Integer id;
@ApiModelProperty(value = "模型路径")
private String modelPath;
}

+ 1
- 2
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookListQueryDTO.java View File

@@ -30,7 +30,6 @@ import java.io.Serializable;
@Data
public class NoteBookListQueryDTO implements Serializable {

@Query(propName = "status", type = Query.Type.EQ)
@ApiModelProperty("0运行中,1停止, 2删除, 3启动中,4停止中,5删除中,6运行异常(暂未启用)")
private Integer status;

@@ -38,7 +37,7 @@ public class NoteBookListQueryDTO implements Serializable {
@ApiModelProperty("notebook名称")
private String noteBookName;

@Query(propName = "user_id", type = Query.Type.EQ)
@Query(propName = "origin_user_id", type = Query.Type.EQ)
@ApiModelProperty(value = "所属用户ID", hidden = true)
private Long userId;



+ 1
- 1
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/NoteBookQueryDTO.java View File

@@ -42,7 +42,7 @@ public class NoteBookQueryDTO implements Serializable {
@Query(propName = "k8s_pvc_path", type = Query.Type.EQ)
private String k8sPvcPath;

@Query(propName = "user_id", type = Query.Type.EQ)
@Query(propName = "origin_user_id", type = Query.Type.EQ)
private Long userId;

@Query(propName = "last_operation_timeout", type = Query.Type.LT)


+ 2
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDTO.java View File

@@ -66,4 +66,6 @@ public class PtImageDTO implements Serializable {
@ApiModelProperty("删除(0正常,1已删除)")
private Boolean deleted;

@ApiModelProperty("资源拥有者ID")
private Long originUserId;
}

+ 39
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageDeleteDTO.java View File

@@ -0,0 +1,39 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.domain.dto;

import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.experimental.Accessors;

import javax.validation.constraints.NotNull;
import java.io.Serializable;
import java.util.List;

/**
* @description 训练镜像删除DTO
* @date 2020-08-13
*/
@Data
@Accessors(chain = true)
public class PtImageDeleteDTO implements Serializable {
private static final long serialVersionUID = 1L;

@ApiModelProperty(value = "id", required = true)
@NotNull(message = "镜像id不能为空")
private List<Long> ids;
}

+ 3
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageQueryDTO.java View File

@@ -40,4 +40,7 @@ public class PtImageQueryDTO extends PageQueryBase implements Serializable {
@ApiModelProperty(value = "镜像状态,0为制作中,1位制作成功,2位制作失败")
private Integer imageStatus;

@ApiModelProperty(value = "镜像名称或id")
private String imageNameOrId;

}

+ 47
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUpdateDTO.java View File

@@ -0,0 +1,47 @@
/**
* Copyright 2020 Zhejiang Lab. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================
*/
package org.dubhe.domain.dto;

import io.swagger.annotations.ApiModelProperty;
import lombok.Data;
import lombok.experimental.Accessors;
import org.dubhe.utils.TrainUtil;
import org.hibernate.validator.constraints.Length;

import javax.validation.constraints.NotNull;
import java.io.Serializable;
import java.util.List;

/**
* @description 训练镜像信息修改DTO
* @date 2020-08-13
*/
@Data
@Accessors(chain = true)
public class PtImageUpdateDTO implements Serializable {

private static final long serialVersionUID = 1L;

@ApiModelProperty(value = "id", required = true)
@NotNull(message = "镜像id不能为空")
private List<Long> ids;

@ApiModelProperty("镜像描述")
@Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "镜像描述-输入长度不能超过1024个字符")
private String remark;

}

+ 2
- 2
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtImageUploadDTO.java View File

@@ -55,7 +55,7 @@ public class PtImageUploadDTO implements Serializable {
@Pattern(regexp = TrainUtil.REGEXP_TAG, message = "镜像版本号支持字母、数字、英文横杠、英文.号和下划线")
private String imageTag;

@ApiModelProperty("备注")
@Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "备注-输入长度不能超过1024个字符")
@ApiModelProperty("镜像描述")
@Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "镜像描述-输入长度不能超过1024个字符")
private String remark;
}

+ 3
- 0
dubhe-server/dubhe-admin/src/main/java/org/dubhe/domain/dto/PtTrainAlgorithmCreateDTO.java View File

@@ -85,4 +85,7 @@ public class PtTrainAlgorithmCreateDTO implements Serializable {
@ApiModelProperty("noteBookId")
private Long noteBookId;

@ApiModelProperty("资源拥有者ID")
private Long originUserId;

}

Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save