@@ -48,6 +48,5 @@ output/ | |||
.classpath | |||
logs/ | |||
/dubhe-k8s/src/main/resources/kubeconfig | |||
/dubhe-k8s/src/main/resources | |||
*.log | |||
/dubhe-admin/kubeconfig |
@@ -1,13 +1,22 @@ | |||
# 一站式开发平台-服务端 | |||
# 之江天枢-服务端 | |||
## 本地开发 | |||
**之江天枢一站式人工智能开源平台**(简称:**之江天枢**),包括海量数据处理、交互式模型构建(包含Notebook和模型可视化)、AI模型高效训练。多维度产品形态满足从开发者到大型企业的不同需求,将提升人工智能技术的研发效率、扩大算法模型的应用范围,进一步构建人工智能生态“朋友圈”。 | |||
## 源码部署 | |||
### 准备环境 | |||
安装如下软件环境。 | |||
- OpenJDK:1.8+ | |||
- Redis: 3.0+ | |||
- Maven: 3.0+ | |||
- MYSQL: 5.5.0+ | |||
- MYSQL: 5.7.0+ | |||
### 下载源码 | |||
``` bash | |||
git clone https://codeup.teambition.com/zhejianglab/dubhe-server.git | |||
# 进入项目根目录 | |||
cd dubhe-server | |||
``` | |||
### 创建DB | |||
在MySQL中依次执行如下sql文件 | |||
@@ -20,14 +29,35 @@ sql/v1/02-Dubhe-DML.sql | |||
### 配置 | |||
根据实际情况修改如下配置文件。 | |||
``` | |||
dubhe-admin/src/main/resources/config/application-dev.yml | |||
dubhe-admin/src/main/resources/config/application-prod.yml | |||
``` | |||
### 启动: | |||
### 构建 | |||
``` bash | |||
# 构建,生成的 jar 包位于 ./dubhe-admin/target/dubhe-admin-1.0.jar | |||
mvn clean compile package | |||
``` | |||
mvn spring-boot:run | |||
### 启动 | |||
``` bash | |||
# 指定启动环境为 prod | |||
## admin模块 | |||
java -jar ./dubhe-admin/target/dubhe-admin-1.0-exec.jar --spring.profiles.active=prod | |||
## task模块 | |||
java -jar ./dubhe-task/target/dubhe-task-1.0.jar --spring.profiles.active=prod | |||
``` | |||
## 本地开发 | |||
### 必要条件: | |||
导入maven项目,下载所需的依赖包 | |||
mysql下创建数据库dubhe,初始化数据脚本 | |||
安装redis | |||
### 启动: | |||
mvn spring-boot:run | |||
## 代码结构: | |||
``` | |||
├── common 公共模块 | |||
@@ -49,6 +79,7 @@ mvn spring-boot:run | |||
├── dubhe-data 数据处理模块 | |||
├── dubhe-model 模型管理模块 | |||
├── dubhe-system 系统管理 | |||
├── dubhe-task 定时任务模块 | |||
``` | |||
## docker服务器 | |||
@@ -29,11 +29,6 @@ | |||
<artifactId>guava</artifactId> | |||
<version>21.0</version> | |||
</dependency> | |||
<dependency> | |||
<groupId>com.github.penggle</groupId> | |||
<artifactId>kaptcha</artifactId> | |||
<version>${kaptcha.version}</version> | |||
</dependency> | |||
<!-- shiro --> | |||
<dependency> | |||
<groupId>org.apache.shiro</groupId> | |||
@@ -95,6 +90,10 @@ | |||
<artifactId>commons-compress</artifactId> | |||
<version>1.20</version> | |||
</dependency> | |||
<dependency> | |||
<groupId>com.github.whvcse</groupId> | |||
<artifactId>easy-captcha</artifactId> | |||
</dependency> | |||
</dependencies> | |||
<build> | |||
<plugins> | |||
@@ -14,31 +14,28 @@ | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.annotation; | |||
import java.lang.annotation.*; | |||
/** | |||
* 数据权限过滤Mapper拦截 | |||
* | |||
* @date 2020-06-22 | |||
* @description 数据权限注解 | |||
* @date 2020-09-24 | |||
*/ | |||
import java.lang.annotation.*; | |||
@Target({ElementType.METHOD, ElementType.TYPE}) | |||
@Retention(RetentionPolicy.RUNTIME) | |||
@Documented | |||
public @interface DataPermission { | |||
/** | |||
* 不需要数据权限的方法名 | |||
* 只在类的注解上使用,代表方法的数据权限类型 | |||
* @return | |||
*/ | |||
String[] ignores() default {}; | |||
String permission() default ""; | |||
/** | |||
* 只在方法的注解上使用,代表方法的数据权限类型,如果不加注解,只会识别带"select"方法名的方法 | |||
* | |||
* 不需要数据权限的方法名 | |||
* @return | |||
*/ | |||
String[] permission() default {}; | |||
String[] ignoresMethod() default {}; | |||
} |
@@ -0,0 +1,47 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.annotation; | |||
import org.dubhe.enums.DatasetTypeEnum; | |||
import java.lang.annotation.ElementType; | |||
import java.lang.annotation.Retention; | |||
import java.lang.annotation.RetentionPolicy; | |||
import java.lang.annotation.Target; | |||
/** | |||
* @description 数据权限方法注解 | |||
* @date 2020-09-24 | |||
*/ | |||
@Target(ElementType.METHOD) | |||
@Retention(RetentionPolicy.RUNTIME) | |||
public @interface DataPermissionMethod { | |||
/** | |||
* 是否需要拦截标识 true: 不拦截 false: 拦截 | |||
* | |||
* @return 拦截标识 | |||
*/ | |||
boolean interceptFlag() default false; | |||
/** | |||
* 数据类型 | |||
* | |||
* @return 数据集类型 | |||
*/ | |||
DatasetTypeEnum dataType() default DatasetTypeEnum.PRIVATE; | |||
} |
@@ -17,19 +17,20 @@ | |||
package org.dubhe.annotation; | |||
import javax.validation.Constraint; | |||
import javax.validation.ConstraintValidator; | |||
import javax.validation.ConstraintValidatorContext; | |||
import javax.validation.Payload; | |||
import java.lang.annotation.ElementType; | |||
import java.lang.annotation.Retention; | |||
import java.lang.annotation.RetentionPolicy; | |||
import java.lang.annotation.Target; | |||
import java.lang.reflect.InvocationTargetException; | |||
import java.lang.reflect.Method; | |||
import javax.validation.Constraint; | |||
import javax.validation.ConstraintValidator; | |||
import javax.validation.ConstraintValidatorContext; | |||
import javax.validation.Payload; | |||
/** | |||
* @date: 2020-05-21 | |||
* @description 接口枚举类检测标注类 | |||
* @date 2020-05-21 | |||
*/ | |||
@Target({ ElementType.METHOD, ElementType.FIELD, ElementType.ANNOTATION_TYPE }) | |||
@Retention(RetentionPolicy.RUNTIME) | |||
@@ -0,0 +1,66 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.annotation; | |||
import javax.validation.Constraint; | |||
import javax.validation.ConstraintValidator; | |||
import javax.validation.ConstraintValidatorContext; | |||
import javax.validation.Payload; | |||
import java.lang.annotation.*; | |||
import java.util.Arrays; | |||
/** | |||
* @description 自定义状态校验注解(传入值是否在指定状态范围内) | |||
* @date 2020-09-18 | |||
*/ | |||
@Target({ElementType.FIELD, ElementType.PARAMETER}) | |||
@Retention(RetentionPolicy.RUNTIME) | |||
@Constraint(validatedBy = FlagValidator.Validator.class) | |||
@Documented | |||
public @interface FlagValidator { | |||
String[] value() default {}; | |||
String message() default "flag value is invalid"; | |||
Class<?>[] groups() default {}; | |||
Class<? extends Payload>[] payload() default {}; | |||
/** | |||
* @description 校验传入值是否在默认值范围校验逻辑 | |||
* @date 2020-09-18 | |||
*/ | |||
class Validator implements ConstraintValidator<FlagValidator, Integer> { | |||
private String[] values; | |||
@Override | |||
public void initialize(FlagValidator flagValidator) { | |||
this.values = flagValidator.value(); | |||
} | |||
@Override | |||
public boolean isValid(Integer value, ConstraintValidatorContext constraintValidatorContext) { | |||
if (value == null) { | |||
//当状态为空时,使用默认值 | |||
return false; | |||
} | |||
return Arrays.stream(values).anyMatch(value::equals); | |||
} | |||
} | |||
} |
@@ -15,8 +15,7 @@ | |||
*/ | |||
package org.dubhe.aspect; | |||
import java.util.UUID; | |||
import lombok.extern.slf4j.Slf4j; | |||
import org.aspectj.lang.JoinPoint; | |||
import org.aspectj.lang.ProceedingJoinPoint; | |||
import org.aspectj.lang.annotation.Around; | |||
@@ -28,10 +27,11 @@ import org.slf4j.MDC; | |||
import org.springframework.stereotype.Component; | |||
import org.springframework.util.StringUtils; | |||
import lombok.extern.slf4j.Slf4j; | |||
import java.util.UUID; | |||
/** | |||
* @date 2020/04/10 | |||
* @description 日志切面 | |||
* @date 2020-04-10 | |||
*/ | |||
@Component | |||
@Aspect | |||
@@ -54,7 +54,7 @@ public class LogAspect { | |||
public void taskAspect() { | |||
} | |||
@Pointcut(" serviceAspect() || taskAspect() ") | |||
@Pointcut(" serviceAspect() ") | |||
public void aroundAspect() { | |||
} | |||
@@ -0,0 +1,116 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.aspect; | |||
import org.aspectj.lang.ProceedingJoinPoint; | |||
import org.aspectj.lang.annotation.Around; | |||
import org.aspectj.lang.annotation.Aspect; | |||
import org.aspectj.lang.annotation.Pointcut; | |||
import org.aspectj.lang.reflect.MethodSignature; | |||
import org.dubhe.annotation.DataPermissionMethod; | |||
import org.dubhe.base.BaseService; | |||
import org.dubhe.base.DataContext; | |||
import org.dubhe.domain.dto.CommonPermissionDataDTO; | |||
import org.dubhe.enums.DatasetTypeEnum; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.utils.JwtUtils; | |||
import org.dubhe.utils.LogUtil; | |||
import org.springframework.stereotype.Component; | |||
import java.lang.reflect.Method; | |||
import java.util.HashSet; | |||
import java.util.Objects; | |||
import java.util.Set; | |||
/** | |||
* @description 数据权限切面 | |||
* @date 2020-09-24 | |||
*/ | |||
@Aspect | |||
@Component | |||
public class PermissionAspect { | |||
/** | |||
* 公共数据的有用户ID | |||
*/ | |||
public static final Long PUBLIC_DATA_USER_ID = 0L; | |||
/** | |||
* 基于注解的切面方法 | |||
*/ | |||
@Pointcut("@annotation(org.dubhe.annotation.DataPermissionMethod)") | |||
private void cutMethod() { | |||
} | |||
/** | |||
*环绕通知 | |||
* @param joinPoint 切入参数对象 | |||
* @return 返回方法结果集 | |||
* @throws Throwable | |||
*/ | |||
@Around("cutMethod()") | |||
public Object around(ProceedingJoinPoint joinPoint) throws Throwable { | |||
// 获取方法传入参数 | |||
Object[] params = joinPoint.getArgs(); | |||
DataPermissionMethod dataPermissionMethod = getDeclaredAnnotation(joinPoint); | |||
if (!Objects.isNull(JwtUtils.getCurrentUserDto()) && !Objects.isNull(dataPermissionMethod)) { | |||
Set<Long> ids = new HashSet<>(); | |||
ids.add(JwtUtils.getCurrentUserDto().getId()); | |||
CommonPermissionDataDTO commonPermissionDataDTO = CommonPermissionDataDTO.builder().type(dataPermissionMethod.interceptFlag()).resourceUserIds(ids).build(); | |||
if (DatasetTypeEnum.PUBLIC.equals(dataPermissionMethod.dataType())) { | |||
ids.add(PUBLIC_DATA_USER_ID); | |||
commonPermissionDataDTO.setResourceUserIds(ids); | |||
} | |||
DataContext.set(commonPermissionDataDTO); | |||
} | |||
// 执行源方法 | |||
try { | |||
return joinPoint.proceed(params); | |||
} finally { | |||
// 模拟进行验证 | |||
BaseService.removeContext(); | |||
} | |||
} | |||
/** | |||
* 获取方法中声明的注解 | |||
* | |||
* @param joinPoint 切入参数对象 | |||
* @return DataPermissionMethod 方法注解类型 | |||
*/ | |||
public DataPermissionMethod getDeclaredAnnotation(ProceedingJoinPoint joinPoint){ | |||
// 获取方法名 | |||
String methodName = joinPoint.getSignature().getName(); | |||
// 反射获取目标类 | |||
Class<?> targetClass = joinPoint.getTarget().getClass(); | |||
// 拿到方法对应的参数类型 | |||
Class<?>[] parameterTypes = ((MethodSignature) joinPoint.getSignature()).getParameterTypes(); | |||
// 根据类、方法、参数类型(重载)获取到方法的具体信息 | |||
Method objMethod = null; | |||
try { | |||
objMethod = targetClass.getMethod(methodName, parameterTypes); | |||
} catch (NoSuchMethodException e) { | |||
LogUtil.error(LogEnum.BIZ_DATASET,"获取注解方法参数异常 error:{}",e); | |||
} | |||
// 拿到方法定义的注解信息 | |||
DataPermissionMethod annotation = objMethod.getDeclaredAnnotation(DataPermissionMethod.class); | |||
// 返回 | |||
return annotation; | |||
} | |||
} |
@@ -25,7 +25,7 @@ import java.io.Serializable; | |||
/** | |||
* @description 镜像基础类DTO | |||
* @date: 2020-07-14 | |||
* @date 2020-07-14 | |||
*/ | |||
@Data | |||
@Accessors(chain = true) | |||
@@ -0,0 +1,77 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.base; | |||
import org.dubhe.constant.PermissionConstant; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.domain.entity.Role; | |||
import org.dubhe.exception.BaseErrorCode; | |||
import org.dubhe.exception.BusinessException; | |||
import org.dubhe.utils.JwtUtils; | |||
import org.springframework.util.CollectionUtils; | |||
import java.util.List; | |||
import java.util.Objects; | |||
import java.util.stream.Collectors; | |||
/** | |||
* @description 服务层基础数据公共方法类 | |||
* @date 2020-03-27 | |||
*/ | |||
public class BaseService { | |||
private BaseService (){} | |||
/** | |||
* 校验是否具有管理员权限 | |||
*/ | |||
public static void checkAdminPermission() { | |||
if(!isAdmin()){ | |||
throw new BusinessException(BaseErrorCode.DATASET_ADMIN_PERMISSION_ERROR); | |||
} | |||
} | |||
/** | |||
* 校验是否是管理管理员 | |||
* | |||
* @return 校验标识 | |||
*/ | |||
public static Boolean isAdmin() { | |||
UserDTO currentUserDto = JwtUtils.getCurrentUserDto(); | |||
if (currentUserDto != null && !CollectionUtils.isEmpty(currentUserDto.getRoles())) { | |||
List<Role> roles = currentUserDto.getRoles(); | |||
List<Role> roleList = roles.stream(). | |||
filter(a -> a.getId().compareTo(PermissionConstant.ADMIN_USER_ID) == 0) | |||
.collect(Collectors.toList()); | |||
if (!CollectionUtils.isEmpty(roleList)) { | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
/** | |||
* 清除本地线程数据权限数据 | |||
*/ | |||
public static void removeContext(){ | |||
if( !Objects.isNull(DataContext.get())){ | |||
DataContext.remove(); | |||
} | |||
} | |||
} |
@@ -24,8 +24,8 @@ import java.io.Serializable; | |||
import java.sql.Timestamp; | |||
/** | |||
* @description: VO基础类 | |||
* @date: 2020-05-22 | |||
* @description VO基础类 | |||
* @date 2020-05-22 | |||
*/ | |||
@Data | |||
public class BaseVO implements Serializable { | |||
@@ -0,0 +1,62 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.base; | |||
import org.dubhe.domain.dto.CommonPermissionDataDTO; | |||
/** | |||
* @description 共享上下文数据集信息 | |||
* @date 2020-04-10 | |||
*/ | |||
public class DataContext { | |||
/** | |||
* 私有化构造参数 | |||
*/ | |||
private DataContext() { | |||
} | |||
private static final ThreadLocal<CommonPermissionDataDTO> CONTEXT = new ThreadLocal<>(); | |||
/** | |||
* 存放数据集信息 | |||
* | |||
* @param datasetVO | |||
*/ | |||
public static void set(CommonPermissionDataDTO datasetVO) { | |||
CONTEXT.set(datasetVO); | |||
} | |||
/** | |||
* 获取用户信息 | |||
* | |||
* @return | |||
*/ | |||
public static CommonPermissionDataDTO get() { | |||
return CONTEXT.get(); | |||
} | |||
/** | |||
* 清除当前线程内引用,防止内存泄漏 | |||
*/ | |||
public static void remove() { | |||
CONTEXT.remove(); | |||
} | |||
} |
@@ -18,7 +18,8 @@ | |||
package org.dubhe.base; | |||
/** | |||
* @date: 2020-05-14 | |||
* @description 常用常量类 | |||
* @date 2020-05-14 | |||
*/ | |||
public final class MagicNumConstant { | |||
@@ -86,6 +87,7 @@ public final class MagicNumConstant { | |||
public static final long TWELVE_LONG = 12L; | |||
public static final long SIXTY_LONG = 60L; | |||
public static final long THOUSAND_LONG = 1000L; | |||
public static final long TEN_THOUSAND_LONG = 10000L; | |||
public static final long ONE_ZERO_ONE_ZERO_ONE_ZERO_LONG = 101010L; | |||
public static final long NINE_ZERO_NINE_ZERO_NINE_ZERO_LONG = 909090L; | |||
@@ -26,8 +26,8 @@ import org.dubhe.constant.NumberConstant; | |||
import javax.validation.constraints.Min; | |||
/** | |||
* @description: 分页基类 | |||
* @date: 2020-05-8 | |||
* @description 分页基类 | |||
* @date 2020-05-08 | |||
*/ | |||
@Data | |||
@Accessors(chain = true) | |||
@@ -14,32 +14,31 @@ | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.task; | |||
package org.dubhe.base; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.service.PtImageService; | |||
import org.dubhe.utils.LogUtil; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.scheduling.annotation.Scheduled; | |||
import org.springframework.stereotype.Component; | |||
/** | |||
* @description 从harbor同步projectName | |||
* @date 2020-6-23 | |||
**/ | |||
@Component | |||
public class HarborProjectNameSyncTask { | |||
* @description 定时任务处理器, 主要做日志标识 | |||
* @date 2020-08-13 | |||
*/ | |||
public class ScheduleTaskHandler { | |||
public static void process(Handler handler) { | |||
LogUtil.startScheduleTrace(); | |||
try { | |||
handler.run(); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.BIZ_SYS, "There is something wrong in schedule task handler :{}", e); | |||
} finally { | |||
LogUtil.cleanTrace(); | |||
} | |||
} | |||
@Autowired | |||
private PtImageService ptImageService; | |||
/** | |||
* 每天晚上11点开始同步 | |||
**/ | |||
@Scheduled(cron = "0 0 23 * * ?") | |||
public void syncProjectName() { | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "开始到harbor同步projectName到harbor_project表。。。。。"); | |||
ptImageService.harborImageNameSync(); | |||
public interface Handler { | |||
void run(); | |||
} | |||
} |
@@ -1,65 +0,0 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.config; | |||
import com.google.code.kaptcha.impl.DefaultKaptcha; | |||
import com.google.code.kaptcha.util.Config; | |||
import org.springframework.context.annotation.Bean; | |||
import org.springframework.stereotype.Component; | |||
import java.util.Properties; | |||
/** | |||
* @description 验证码配置 | |||
* @date 2020-02-23 | |||
*/ | |||
@Component | |||
public class KaptchaConfig { | |||
private final static String CODE_LENGTH = "4"; | |||
private final static String SESSION_KEY = "verification_session_key"; | |||
@Bean | |||
public DefaultKaptcha defaultKaptcha() { | |||
DefaultKaptcha defaultKaptcha = new DefaultKaptcha(); | |||
Properties properties = new Properties(); | |||
// 设置边框 | |||
properties.setProperty("kaptcha.border", "yes"); | |||
// 设置边框颜色 | |||
properties.setProperty("kaptcha.border.color", "105,179,90"); | |||
// 设置字体颜色 | |||
properties.setProperty("kaptcha.textproducer.font.color", "blue"); | |||
// 设置图片宽度 | |||
properties.setProperty("kaptcha.image.width", "108"); | |||
// 设置图片高度 | |||
properties.setProperty("kaptcha.image.height", "28"); | |||
// 设置字体尺寸 | |||
properties.setProperty("kaptcha.textproducer.font.size", "26"); | |||
// 设置session key | |||
properties.setProperty("kaptcha.session.key", SESSION_KEY); | |||
// 设置验证码长度 | |||
properties.setProperty("kaptcha.textproducer.char.length", CODE_LENGTH); | |||
// 设置字体 | |||
properties.setProperty("kaptcha.textproducer.font.names", "宋体,楷体,黑体"); | |||
//去噪点 | |||
properties.setProperty("kaptcha.noise.impl", "com.google.code.kaptcha.impl.NoNoise"); | |||
Config config = new Config(properties); | |||
defaultKaptcha.setConfig(config); | |||
return defaultKaptcha; | |||
} | |||
} |
@@ -1,12 +1,12 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
@@ -30,21 +30,17 @@ import java.util.Objects; | |||
/** | |||
* @description 处理新增和更新的基础数据填充,配合BaseEntity和MyBatisPlusConfig使用 | |||
* @date 2020-6-10 | |||
* @date 2020-06-10 | |||
*/ | |||
@Component | |||
public class MetaHandlerConfig implements MetaObjectHandler { | |||
private final String LOCK_USER_ID = "LOCK_USER_ID"; | |||
/** | |||
* 新增数据执行 | |||
* | |||
* @param metaObject | |||
* @param metaObject 基础数据 | |||
*/ | |||
@Override | |||
public void insertFill(MetaObject metaObject) { | |||
if (Objects.isNull(getFieldValByName(StringConstant.CREATE_TIME, metaObject))) { | |||
@@ -53,13 +49,14 @@ public class MetaHandlerConfig implements MetaObjectHandler { | |||
if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_TIME, metaObject))) { | |||
this.setFieldValByName(StringConstant.UPDATE_TIME, DateUtil.getCurrentTimestamp(), metaObject); | |||
} | |||
synchronized (LOCK_USER_ID){ | |||
if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_USER_ID, metaObject))) { | |||
this.setFieldValByName(StringConstant.UPDATE_USER_ID, getUserId(), metaObject); | |||
} | |||
if (Objects.isNull(getFieldValByName(StringConstant.CREATE_USER_ID, metaObject))) { | |||
this.setFieldValByName(StringConstant.CREATE_USER_ID, getUserId(), metaObject); | |||
} | |||
if (Objects.isNull(getFieldValByName(StringConstant.UPDATE_USER_ID, metaObject))) { | |||
this.setFieldValByName(StringConstant.UPDATE_USER_ID, getUserId(), metaObject); | |||
} | |||
if (Objects.isNull(getFieldValByName(StringConstant.CREATE_USER_ID, metaObject))) { | |||
this.setFieldValByName(StringConstant.CREATE_USER_ID, getUserId(), metaObject); | |||
} | |||
if (Objects.isNull(getFieldValByName(StringConstant.ORIGIN_USER_ID, metaObject))) { | |||
this.setFieldValByName(StringConstant.ORIGIN_USER_ID, getUserId(), metaObject); | |||
} | |||
if (Objects.isNull(getFieldValByName(StringConstant.DELETED, metaObject))) { | |||
this.setFieldValByName(StringConstant.DELETED, SwitchEnum.getBooleanValue(SwitchEnum.OFF.getValue()), metaObject); | |||
@@ -69,7 +66,7 @@ public class MetaHandlerConfig implements MetaObjectHandler { | |||
/** | |||
* 更新数据执行 | |||
* | |||
* @param metaObject | |||
* @param metaObject 基础数据 | |||
*/ | |||
@Override | |||
public void updateFill(MetaObject metaObject) { | |||
@@ -17,294 +17,27 @@ | |||
package org.dubhe.config; | |||
import com.baomidou.mybatisplus.core.override.MybatisMapperProxy; | |||
import com.baomidou.mybatisplus.core.parser.ISqlParser; | |||
import com.baomidou.mybatisplus.core.parser.SqlParserHelper; | |||
import com.baomidou.mybatisplus.extension.plugins.PaginationInterceptor; | |||
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantHandler; | |||
import com.baomidou.mybatisplus.extension.plugins.tenant.TenantSqlParser; | |||
import com.google.common.collect.Sets; | |||
import net.sf.jsqlparser.expression.Expression; | |||
import net.sf.jsqlparser.expression.LongValue; | |||
import net.sf.jsqlparser.expression.operators.relational.ExpressionList; | |||
import net.sf.jsqlparser.expression.operators.relational.InExpression; | |||
import net.sf.jsqlparser.schema.Column; | |||
import org.apache.ibatis.mapping.MappedStatement; | |||
import org.apache.shiro.UnavailableSecurityManagerException; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.base.MagicNumConstant; | |||
import org.dubhe.constant.PermissionConstant; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.domain.entity.Role; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.utils.JwtUtils; | |||
import org.dubhe.utils.LogUtil; | |||
import org.springframework.beans.BeansException; | |||
import org.springframework.context.ApplicationContext; | |||
import org.springframework.context.ApplicationContextAware; | |||
import org.springframework.context.ApplicationListener; | |||
import org.dubhe.interceptor.PaginationInterceptor; | |||
import org.springframework.context.annotation.Bean; | |||
import org.springframework.context.annotation.Configuration; | |||
import org.springframework.context.event.ContextRefreshedEvent; | |||
import org.springframework.core.annotation.AnnotationUtils; | |||
import org.springframework.transaction.annotation.EnableTransactionManagement; | |||
import org.springframework.util.CollectionUtils; | |||
import java.lang.annotation.Annotation; | |||
import java.lang.reflect.Field; | |||
import java.lang.reflect.Method; | |||
import java.lang.reflect.Proxy; | |||
import java.util.*; | |||
import java.util.stream.Collectors; | |||
/** | |||
* @description MybatisPlus配置类 | |||
* @date 2020-06-24 | |||
* @description mybatis plus拦截器 | |||
* @date 2020-06-10 | |||
*/ | |||
@EnableTransactionManagement | |||
@Configuration | |||
public class MybatisPlusConfig implements ApplicationListener<ContextRefreshedEvent>, ApplicationContextAware { | |||
/** | |||
* 以此字段作为租户实现数据隔离 | |||
*/ | |||
private static final String TENANT_ID_COLUMN = "create_user_id"; | |||
/** | |||
* 以0作为公共数据的标识 | |||
*/ | |||
private static final long PUBLIC_TENANT_ID = MagicNumConstant.ZERO; | |||
private static final Set<Long> PUBLIC_TENANT_ID_SET = new HashSet<Long>() {{ | |||
add(PUBLIC_TENANT_ID); | |||
}}; | |||
private static final String PACKAGE_SEPARATOR = "."; | |||
private static final Set<String> SELECT_PERMISSION = new HashSet<String>() {{ | |||
add(PermissionConstant.SELECT); | |||
}}; | |||
private static final Set<String> UPDATE_DELETE_PERMISSION = new HashSet<String>() {{ | |||
add(PermissionConstant.UPDATE); | |||
add(PermissionConstant.DELETE); | |||
}}; | |||
public class MybatisPlusConfig { | |||
private static final String SELECT_STR = "select"; | |||
/** | |||
* 优先级高于dataFilters,如果ignore,则不进行sql注入 | |||
*/ | |||
private Map<String, Set<String>> dataFilters = new HashMap<>(); | |||
private ApplicationContext applicationContext; | |||
public Set<Long> tenantId; | |||
/** | |||
* mybatis plus 分页插件 | |||
* 其中增加了通过多租户实现了数据权限功能 | |||
* 注入 MybatisPlus 分页拦截器 | |||
* | |||
* @return | |||
* @return 自定义MybatisPlus分页拦截器 | |||
*/ | |||
@Bean | |||
public PaginationInterceptor paginationInterceptor() { | |||
PaginationInterceptor paginationInterceptor = new PaginationInterceptor(); | |||
List<ISqlParser> sqlParserList = new ArrayList<>(); | |||
TenantSqlParser tenantSqlParser = new TenantSqlParser(); | |||
tenantSqlParser.setTenantHandler(new TenantHandler() { | |||
@Override | |||
public Expression getTenantId(boolean where) { | |||
Set<Long> tenants = tenantId; | |||
final boolean multipleTenantIds = tenants.size() > MagicNumConstant.ONE; | |||
if (multipleTenantIds) { | |||
return multipleTenantIdCondition(tenants); | |||
} else { | |||
return singleTenantIdCondition(tenants); | |||
} | |||
} | |||
private Expression singleTenantIdCondition(Set<Long> tenants) { | |||
return new LongValue((Long) tenants.toArray()[0]); | |||
} | |||
private Expression multipleTenantIdCondition(Set<Long> tenants) { | |||
final InExpression inExpression = new InExpression(); | |||
inExpression.setLeftExpression(new Column(getTenantIdColumn())); | |||
final ExpressionList itemsList = new ExpressionList(); | |||
final List<Expression> inValues = new ArrayList<>(tenants.size()); | |||
tenants.forEach(i -> | |||
inValues.add(new LongValue(i)) | |||
); | |||
itemsList.setExpressions(inValues); | |||
inExpression.setRightItemsList(itemsList); | |||
return inExpression; | |||
} | |||
@Override | |||
public String getTenantIdColumn() { | |||
return TENANT_ID_COLUMN; | |||
} | |||
@Override | |||
public boolean doTableFilter(String tableName) { | |||
return false; | |||
} | |||
}); | |||
sqlParserList.add(tenantSqlParser); | |||
paginationInterceptor.setSqlParserList(sqlParserList); | |||
paginationInterceptor.setSqlParserFilter(metaObject -> { | |||
MappedStatement ms = SqlParserHelper.getMappedStatement(metaObject); | |||
String method = ms.getId(); | |||
if (!dataFilters.containsKey(method) || isAdmin()) { | |||
return true; | |||
} | |||
Set<String> permission = dataFilters.get(method); | |||
tenantId = getTenantId(permission); | |||
return false; | |||
}); | |||
return paginationInterceptor; | |||
} | |||
/** | |||
* 判断用户是否是管理员 | |||
* 如果未登录,无法请求任何接口,所以不会到该层,因此匿名认为是定时任务,给予admin权限。 | |||
* | |||
* @return 判断用户是否是管理员 | |||
*/ | |||
private boolean isAdmin() { | |||
UserDTO user; | |||
try { | |||
user = JwtUtils.getCurrentUserDto(); | |||
} catch (UnavailableSecurityManagerException e) { | |||
return true; | |||
} | |||
if (Objects.isNull(user)) { | |||
return true; | |||
} | |||
List<Role> roles; | |||
if ((roles = user.getRoles()) == null) { | |||
return false; | |||
} | |||
Set<String> permissions = roles.stream().map(Role::getPermission).collect(Collectors.toSet()); | |||
if (CollectionUtils.isEmpty(permissions)) { | |||
return false; | |||
} | |||
return user.getId() == PermissionConstant.ANONYMOUS_USER || user.getId() == PermissionConstant.ADMIN_USER_ID; | |||
} | |||
/** | |||
* 如果是管理员,在前一步isAdmin已过滤; | |||
* 如果是匿名用户,在shiro层被过滤; | |||
* 因此只会是无角色、权限用户或普通用户 | |||
* | |||
* @return Set<Long> 租户ID集合 | |||
*/ | |||
private Set<Long> getTenantId(Set<String> permission) { | |||
UserDTO user = JwtUtils.getCurrentUserDto(); | |||
List<Role> roles; | |||
if (Objects.isNull(user) || (roles = user.getRoles()) == null) { | |||
if (permission.contains(PermissionConstant.SELECT)) { | |||
return PUBLIC_TENANT_ID_SET; | |||
} | |||
return Collections.EMPTY_SET; | |||
} | |||
Set<String> permissions = roles.stream().map(Role::getPermission).collect(Collectors.toSet()); | |||
if (CollectionUtils.isEmpty(permissions)) { | |||
if (permission.contains(PermissionConstant.SELECT)) { | |||
return PUBLIC_TENANT_ID_SET; | |||
} | |||
return Collections.EMPTY_SET; | |||
} | |||
if (permission.contains(PermissionConstant.SELECT)) { | |||
return new HashSet<Long>() {{ | |||
add(PUBLIC_TENANT_ID); | |||
add(user.getId()); | |||
}}; | |||
} | |||
return new HashSet<Long>() {{ | |||
add(user.getId()); | |||
}}; | |||
} | |||
/** | |||
* 设置上下文 | |||
* #需要通过上下文 获取SpringBean | |||
* | |||
* @param applicationContext spring上下文 | |||
* @throws BeansException 找不到bean异常 | |||
*/ | |||
@Override | |||
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException { | |||
this.applicationContext = applicationContext; | |||
} | |||
@Override | |||
public void onApplicationEvent(ContextRefreshedEvent contextRefreshedEvent) { | |||
Class<? extends Annotation> annotationClass = DataPermission.class; | |||
Map<String, Object> beanWithAnnotation = applicationContext.getBeansWithAnnotation(annotationClass); | |||
Set<Map.Entry<String, Object>> entitySet = beanWithAnnotation.entrySet(); | |||
for (Map.Entry<String, Object> entry : entitySet) { | |||
Proxy proxy = (Proxy) entry.getValue(); | |||
Class clazz = getMapperClass(proxy); | |||
populateDataFilters(clazz); | |||
} | |||
} | |||
/** | |||
* 根据mapper对应代理对象获取Class | |||
* | |||
* @param proxy mapper对应代理对象 | |||
* @return | |||
*/ | |||
private Class getMapperClass(Proxy proxy) { | |||
try { | |||
Field field = proxy.getClass().getSuperclass().getDeclaredField("h"); | |||
field.setAccessible(true); | |||
MybatisMapperProxy mapperProxy = (MybatisMapperProxy) field.get(proxy); | |||
field = mapperProxy.getClass().getDeclaredField("mapperInterface"); | |||
field.setAccessible(true); | |||
return (Class) field.get(mapperProxy); | |||
} catch (NoSuchFieldException | IllegalAccessException e) { | |||
LogUtil.error(LogEnum.BIZ_DATASET, "reflect error", e); | |||
} | |||
return null; | |||
} | |||
/** | |||
* 填充数据权限过滤,处理那些需要排除的方法 | |||
* | |||
* @param clazz 需要处理的类(mapper) | |||
*/ | |||
private void populateDataFilters(Class clazz) { | |||
if (clazz == null) { | |||
return; | |||
} | |||
Method[] methods = clazz.getMethods(); | |||
DataPermission dataPermission = AnnotationUtils.findAnnotation((Class<?>) clazz, DataPermission.class); | |||
Set<String> ignores = Sets.newHashSet(dataPermission.ignores()); | |||
for (Method method : methods) { | |||
if (ignores.contains(method.getName())) { | |||
continue; | |||
} | |||
Set<String> permission = getDataPermission(method); | |||
dataFilters.put(clazz.getName() + PACKAGE_SEPARATOR + method.getName(), permission); | |||
} | |||
} | |||
/** | |||
* 获取方法上权限注解 | |||
* 权限注解包含 | |||
* 1.用户拥有指定权限才可以执行该方法:比如 PermissionConstant.SELECT 表示用户必须拥有select权限,才可以使用该方法 | |||
* 2.方法权限校验排除:比如 ignores = {"insert"} 表示insert方法不做权限处理 | |||
* | |||
* @param method 方法对象 | |||
* @return | |||
*/ | |||
private Set<String> getDataPermission(Method method) { | |||
DataPermission dataPermission = AnnotationUtils.findAnnotation(method, DataPermission.class); | |||
// 无注解时以方法名判断 | |||
if (dataPermission == null) { | |||
if (method.getName().contains(SELECT_STR)) { | |||
return SELECT_PERMISSION; | |||
} | |||
return UPDATE_DELETE_PERMISSION; | |||
} | |||
return Sets.newHashSet(dataPermission.permission()); | |||
} | |||
} | |||
@@ -0,0 +1,62 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.config; | |||
import lombok.Data; | |||
import org.springframework.boot.context.properties.ConfigurationProperties; | |||
import org.springframework.stereotype.Component; | |||
/** | |||
* @description 垃圾回收机制配置常量 | |||
* @date 2020-09-21 | |||
*/ | |||
@Data | |||
@Component | |||
@ConfigurationProperties(prefix = "recycle.timeout") | |||
public class RecycleConfig { | |||
/** | |||
* 回收无效文件的默认有效时长 | |||
*/ | |||
private Integer date; | |||
/** | |||
* 用户上传文件至临时路径下后文件最大有效时长,以小时为单位 | |||
*/ | |||
private Integer fileValid; | |||
/** | |||
* 用户删除某一算法后,其算法文件最大有效时长,以天为单位 | |||
*/ | |||
private Integer algorithmValid; | |||
/** | |||
* 用户删除某一模型后,其模型文件最大有效时长,以天为单位 | |||
*/ | |||
private Integer modelValid; | |||
/** | |||
* 用户删除训练任务后,其训练管理文件最大有效时长,以天为单位 | |||
*/ | |||
private Integer trainValid; | |||
/** | |||
* 删除服务器无效文件(大文件) | |||
* 示例:rsync --delete-before -d /空目录 /需要回收的源目录 | |||
*/ | |||
public static final String DEL_COMMAND = "ssh %s@%s \"mkdir -p %s; rsync --delete-before -d %s %s; rmdir %s %s\""; | |||
} |
@@ -15,77 +15,71 @@ | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.constant; | |||
package org.dubhe.config; | |||
import lombok.Data; | |||
import org.springframework.beans.factory.annotation.Value; | |||
import org.springframework.boot.context.properties.ConfigurationProperties; | |||
import org.springframework.stereotype.Component; | |||
/** | |||
* @description 训练常量 | |||
* @create: 2020-05-12 | |||
* @date 2020-05-12 | |||
*/ | |||
@Component | |||
@Data | |||
public class TrainJobConstant { | |||
@ConfigurationProperties(prefix = "train-job") | |||
public class TrainJobConfig { | |||
@Value("${train-job.namespace}") | |||
private String namespace; | |||
@Value("${train-job.version-label}") | |||
private String versionLabel; | |||
@Value("${train-job.separator}") | |||
private String separator; | |||
@Value("${train-job.pod-name}") | |||
private String podName; | |||
@Value("${train-job.python-format}") | |||
private String pythonFormat; | |||
@Value("${train-job.manage}") | |||
private String manage; | |||
@Value("${train-job.out-path}") | |||
private String outPath; | |||
@Value("${train-job.log-path}") | |||
private String logPath; | |||
@Value("${train-job.visualized-log-path}") | |||
private String visualizedLogPath; | |||
@Value("${train-job.docker-dataset-path}") | |||
private String dockerDatasetPath; | |||
@Value("${train-job.docker-train-path}") | |||
private String dockerTrainPath; | |||
@Value("${train-job.docker-out-path}") | |||
private String dockerOutPath; | |||
@Value("${train-job.docker-log-path}") | |||
private String dockerLogPath; | |||
@Value("${train-job.docker-dataset}") | |||
private String dockerDataset; | |||
@Value("${train-job.docker-visualized-log-path}") | |||
private String dockerModelPath; | |||
private String dockerValDatasetPath; | |||
private String loadValDatasetKey; | |||
private String dockerVisualizedLogPath; | |||
@Value("${train-job.load-path}") | |||
private String loadPath; | |||
@Value("${train-job.load-key}") | |||
private String loadKey; | |||
@Value("${train-job.eight}") | |||
private String eight; | |||
@Value("${train-job.plus-eight}") | |||
private String plusEight; | |||
private String nodeIps; | |||
private String nodeNum; | |||
private String gpuNumPerNode; | |||
public static final String TRAIN_ID = "trainId"; | |||
public static final String TRAIN_VERSION = "trainVersion"; | |||
@@ -97,4 +91,5 @@ public class TrainJobConstant { | |||
public static final String CREATE_TIME = "createTime"; | |||
public static final String ALGORITHM_NAME = "algorithmName"; | |||
} |
@@ -16,9 +16,13 @@ | |||
*/ | |||
package org.dubhe.config; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.utils.LogUtil; | |||
import org.springframework.aop.interceptor.AsyncUncaughtExceptionHandler; | |||
import org.springframework.beans.factory.annotation.Value; | |||
import org.springframework.context.annotation.Bean; | |||
import org.springframework.context.annotation.Configuration; | |||
import org.springframework.scheduling.annotation.AsyncConfigurer; | |||
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; | |||
import java.util.concurrent.Executor; | |||
@@ -29,7 +33,7 @@ import java.util.concurrent.ThreadPoolExecutor; | |||
* @date 2020-07-17 | |||
*/ | |||
@Configuration | |||
public class TrainPoolConfig { | |||
public class TrainPoolConfig implements AsyncConfigurer { | |||
@Value("${basepool.corePoolSize:40}") | |||
private Integer corePoolSize; | |||
@@ -44,8 +48,9 @@ public class TrainPoolConfig { | |||
* 训练任务异步处理线程池 | |||
* @return Executor 线程实例 | |||
*/ | |||
@Bean | |||
public Executor trainJobAsyncExecutor() { | |||
@Bean("trainExecutor") | |||
@Override | |||
public Executor getAsyncExecutor() { | |||
ThreadPoolTaskExecutor taskExecutor = new ThreadPoolTaskExecutor(); | |||
//核心线程数 | |||
taskExecutor.setCorePoolSize(corePoolSize); | |||
@@ -57,10 +62,18 @@ public class TrainPoolConfig { | |||
//配置队列大小 | |||
taskExecutor.setQueueCapacity(blockQueueSize); | |||
//配置线程池前缀 | |||
taskExecutor.setThreadNamePrefix("async-train-job-"); | |||
taskExecutor.setThreadNamePrefix("async-train-"); | |||
//拒绝策略 | |||
taskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.AbortPolicy()); | |||
taskExecutor.initialize(); | |||
return taskExecutor; | |||
} | |||
@Override | |||
public AsyncUncaughtExceptionHandler getAsyncUncaughtExceptionHandler() { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "开始捕获训练管理异步任务异常信息-----》》》"); | |||
return (ex, method, params) -> { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "训练管理方法名{}的异步任务执行失败,参数信息:{},异常信息:{}", method.getName(), params, ex); | |||
}; | |||
} | |||
} |
@@ -18,8 +18,8 @@ | |||
package org.dubhe.constant; | |||
/** | |||
* @Description 数字常量 | |||
* @Date 2020-6-9 | |||
* @description 数字常量 | |||
* @date 2020-06-09 | |||
*/ | |||
public class NumberConstant { | |||
@@ -32,6 +32,8 @@ public class NumberConstant { | |||
public final static int NUMBER_30 = 30; | |||
public final static int NUMBER_50 = 50; | |||
public final static int NUMBER_60 = 60; | |||
public final static int NUMBER_1024 = 1024; | |||
public final static int NUMBER_1000 = 1000; | |||
public final static int HOUR_SECOND = 60 * 60; | |||
public final static int DAY_SECOND = 60 * 60 * 24; | |||
public final static int WEEK_SECOND = 60 * 60 * 24 * 7; | |||
@@ -21,8 +21,8 @@ import lombok.Data; | |||
import org.springframework.stereotype.Component; | |||
/** | |||
* @description: 权限常量 | |||
* @since: 2020-05-25 14:39 | |||
* @description 权限常量 | |||
* @date 2020-05-25 | |||
*/ | |||
@Component | |||
@Data | |||
@@ -32,9 +32,10 @@ public class PermissionConstant { | |||
* 超级用户 | |||
*/ | |||
public static final long ADMIN_USER_ID = 1L; | |||
public static final long ANONYMOUS_USER = -1L; | |||
public static final String SELECT = "select"; | |||
public static final String UPDATE = "update"; | |||
public static final String DELETE = "delete"; | |||
/** | |||
* 数据集模块类型 | |||
*/ | |||
public static final Integer RESOURCE_DATA_MODEL = 1; | |||
} |
@@ -23,24 +23,29 @@ package org.dubhe.constant; | |||
*/ | |||
public final class StringConstant { | |||
public static final String MSIE = "MSIE"; | |||
public static final String MOZILLA = "Mozilla"; | |||
public static final String MSIE = "MSIE"; | |||
public static final String MOZILLA = "Mozilla"; | |||
public static final String REQUEST_METHOD_GET = "GET"; | |||
/** | |||
* 公共字段 | |||
*/ | |||
public static final String CREATE_TIME = "createTime"; | |||
public static final String UPDATE_TIME = "updateTime"; | |||
public static final String UPDATE_USER_ID = "updateUserId"; | |||
public static final String CREATE_USER_ID = "createUserId"; | |||
public static final String DELETED = "deleted"; | |||
public static final String UTF8 = "utf-8"; | |||
/** | |||
* 公共字段 | |||
*/ | |||
public static final String CREATE_TIME = "createTime"; | |||
public static final String UPDATE_TIME = "updateTime"; | |||
public static final String UPDATE_USER_ID = "updateUserId"; | |||
public static final String CREATE_USER_ID = "createUserId"; | |||
public static final String ORIGIN_USER_ID = "originUserId"; | |||
public static final String DELETED = "deleted"; | |||
public static final String UTF8 = "utf-8"; | |||
public static final String JSON_REQUEST = "application/json"; | |||
public static final String K8S_CALLBACK_URI = "/api/k8s/callback/pod"; | |||
public static final String MULTIPART = "multipart/form-data"; | |||
/** | |||
* 测试环境 | |||
*/ | |||
public static final String PROFILE_ACTIVE_TEST = "test"; | |||
/** | |||
* 测试环境 | |||
*/ | |||
public static final String PROFILE_ACTIVE_TEST = "test"; | |||
private StringConstant() { | |||
} | |||
private StringConstant() { | |||
} | |||
} |
@@ -19,7 +19,7 @@ package org.dubhe.constant; | |||
/** | |||
* @description 符号常量 | |||
* @Date 2020-5-29 | |||
* @date 2020-5-29 | |||
*/ | |||
public class SymbolConstant { | |||
public static final String SLASH = "/"; | |||
@@ -40,6 +40,8 @@ public class SymbolConstant { | |||
public static final String DOUBLE_MARK= "\"\""; | |||
public static final String MARK= "\""; | |||
public static final String FLAG_EQUAL = "="; | |||
private SymbolConstant() { | |||
} | |||
@@ -17,15 +17,12 @@ | |||
package org.dubhe.constant; | |||
import org.springframework.stereotype.Component; | |||
import lombok.Data; | |||
/** | |||
* @description 算法用途 | |||
* @date: 2020-06-23 | |||
* @date 2020-06-23 | |||
*/ | |||
@Component | |||
@Data | |||
public class UserAuxiliaryInfoConstant { | |||
@@ -0,0 +1,52 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.domain.dto; | |||
import lombok.AllArgsConstructor; | |||
import lombok.Builder; | |||
import lombok.Data; | |||
import lombok.NoArgsConstructor; | |||
import java.io.Serializable; | |||
import java.util.Set; | |||
/** | |||
* @description 公共权限信息DTO | |||
* @date 2020-09-24 | |||
*/ | |||
@AllArgsConstructor | |||
@NoArgsConstructor | |||
@Builder | |||
@Data | |||
public class CommonPermissionDataDTO implements Serializable { | |||
/** | |||
* 资源拥有者ID | |||
*/ | |||
private Long id; | |||
/** | |||
* 公共类型 | |||
*/ | |||
private Boolean type; | |||
/** | |||
* 资源所属用户ids | |||
*/ | |||
private Set<Long> resourceUserIds; | |||
} |
@@ -16,15 +16,14 @@ | |||
*/ | |||
package org.dubhe.domain.entity; | |||
import java.io.Serializable; | |||
import org.dubhe.base.MagicNumConstant; | |||
import com.alibaba.fastjson.annotation.JSONField; | |||
import cn.hutool.core.date.DateUtil; | |||
import com.alibaba.fastjson.annotation.JSONField; | |||
import lombok.Data; | |||
import lombok.experimental.Accessors; | |||
import org.dubhe.base.MagicNumConstant; | |||
import java.io.Serializable; | |||
/** | |||
* @description 日志对象封装类 | |||
@@ -34,31 +33,27 @@ import lombok.experimental.Accessors; | |||
@Accessors(chain = true) | |||
public class LogInfo implements Serializable { | |||
@JSONField(ordinal = MagicNumConstant.ONE) | |||
private String traceId; | |||
private static final long serialVersionUID = 5250395474667395607L; | |||
@JSONField(ordinal = MagicNumConstant.ONE) | |||
private String traceId; | |||
@JSONField(ordinal = MagicNumConstant.TWO) | |||
private String type; | |||
@JSONField(ordinal = MagicNumConstant.TWO) | |||
private String type; | |||
@JSONField(ordinal = MagicNumConstant.THREE) | |||
private String level; | |||
@JSONField(ordinal = MagicNumConstant.THREE) | |||
private String level; | |||
@JSONField(ordinal = MagicNumConstant.FOUR) | |||
private String cName; | |||
@JSONField(ordinal = MagicNumConstant.FOUR) | |||
private String location; | |||
@JSONField(ordinal = MagicNumConstant.FIVE) | |||
private String mName; | |||
@JSONField(ordinal = MagicNumConstant.SIX) | |||
private String line; | |||
@JSONField(ordinal = MagicNumConstant.SEVEN) | |||
private String time = DateUtil.now(); | |||
@JSONField(ordinal = MagicNumConstant.FIVE) | |||
private String time = DateUtil.now(); | |||
@JSONField(ordinal = MagicNumConstant.EIGHT) | |||
private Object info; | |||
@JSONField(ordinal = MagicNumConstant.SIX) | |||
private Object info; | |||
public void setInfo(Object info) { | |||
this.info = info; | |||
} | |||
public void setInfo(Object info) { | |||
this.info = info; | |||
} | |||
} |
@@ -22,12 +22,9 @@ import lombok.Builder; | |||
import lombok.Data; | |||
import lombok.NoArgsConstructor; | |||
import org.dubhe.base.BaseEntity; | |||
import org.hibernate.validator.constraints.Length; | |||
import javax.validation.constraints.NotBlank; | |||
import javax.validation.constraints.NotNull; | |||
import java.io.Serializable; | |||
import java.sql.Timestamp; | |||
import java.util.Objects; | |||
/** | |||
@@ -22,12 +22,7 @@ import lombok.Builder; | |||
import lombok.Data; | |||
import lombok.NoArgsConstructor; | |||
import org.dubhe.base.BaseEntity; | |||
import org.hibernate.validator.constraints.Length; | |||
import javax.validation.constraints.NotBlank; | |||
import javax.validation.constraints.NotEmpty; | |||
import java.io.Serializable; | |||
import java.sql.Timestamp; | |||
import java.util.Objects; | |||
import java.util.Set; | |||
@@ -0,0 +1,73 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.dto; | |||
import lombok.Data; | |||
/** | |||
* @description 全局请求日志信息 | |||
* @date 2020-08-13 | |||
*/ | |||
@Data | |||
public class GlobalRequestRecordDTO { | |||
/** | |||
* 客户主机地址 | |||
*/ | |||
private String clientHost; | |||
/** | |||
* 请求地址 | |||
*/ | |||
private String uri; | |||
/** | |||
* 授权信息 | |||
*/ | |||
private String authorization; | |||
/** | |||
* 用户名 | |||
*/ | |||
private String username; | |||
/** | |||
* form参数 | |||
*/ | |||
private String params; | |||
/** | |||
* 返回值类型 | |||
*/ | |||
private String contentType; | |||
/** | |||
* 返回状态 | |||
*/ | |||
private Integer status; | |||
/** | |||
* 时间耗费 | |||
*/ | |||
private Long timeCost; | |||
/** | |||
* 请求方式 | |||
*/ | |||
private String method; | |||
/** | |||
* 请求体body参数 | |||
*/ | |||
private String requestBody; | |||
/** | |||
* 返回值json数据 | |||
*/ | |||
private String responseBody; | |||
} |
@@ -44,6 +44,14 @@ public class BaseK8sPodCallbackCreateDTO { | |||
@NotEmpty(message = "podName 不能为空!") | |||
private String podName; | |||
@ApiModelProperty(required = true,value = "k8s pod parent type") | |||
@NotEmpty(message = "podParentType 不能为空!") | |||
private String podParentType; | |||
@ApiModelProperty(required = true,value = "k8s pod parent name") | |||
@NotEmpty(message = "podParentName 不能为空!") | |||
private String podParentName; | |||
@ApiModelProperty(value = "k8s pod phase",notes = "对应PodPhaseEnum") | |||
@NotEmpty(message = "phase 不能为空!") | |||
private String phase; | |||
@@ -55,10 +63,12 @@ public class BaseK8sPodCallbackCreateDTO { | |||
} | |||
public BaseK8sPodCallbackCreateDTO(String namespace,String resourceName,String podName,String phase,String messages){ | |||
public BaseK8sPodCallbackCreateDTO(String namespace,String resourceName,String podName,String podParentType,String podParentName,String phase,String messages){ | |||
this.namespace = namespace; | |||
this.resourceName = resourceName; | |||
this.podName = podName; | |||
this.podParentType = podParentType; | |||
this.podParentName = podParentName; | |||
this.phase = phase; | |||
this.messages = messages; | |||
} | |||
@@ -69,6 +79,8 @@ public class BaseK8sPodCallbackCreateDTO { | |||
"namespace='" + namespace + '\'' + | |||
", resourceName='" + resourceName + '\'' + | |||
", podName='" + podName + '\'' + | |||
", podParentType='" + podParentType + '\'' + | |||
", podParentName='" + podParentName + '\'' + | |||
", phase='" + phase + '\'' + | |||
", messages=" + messages + | |||
'}'; | |||
@@ -23,9 +23,8 @@ import java.util.HashMap; | |||
import java.util.Map; | |||
/** | |||
* @desc: 业务模块 | |||
* | |||
* @date 2020.05.25 | |||
* @description 业务模块 | |||
* @date 2020-05-25 | |||
*/ | |||
@Getter | |||
public enum BizEnum { | |||
@@ -38,6 +37,10 @@ public enum BizEnum { | |||
* 算法管理 | |||
*/ | |||
ALGORITHM("算法管理","algorithm",1), | |||
/** | |||
* 模型管理 | |||
*/ | |||
MODEL("模型管理","model",2), | |||
; | |||
/** | |||
@@ -21,8 +21,8 @@ import java.util.HashMap; | |||
import java.util.Map; | |||
/** | |||
* @desc: 业务NFS路径枚举 | |||
* @date 2020.05.13 | |||
* @description 业务NFS路径枚举 | |||
* @date 2020-05-13 | |||
*/ | |||
public enum BizNfsEnum { | |||
/** | |||
@@ -33,6 +33,10 @@ public enum BizNfsEnum { | |||
* 算法管理 NFS 路径命名 | |||
*/ | |||
ALGORITHM(BizEnum.ALGORITHM, "algorithm-manage"), | |||
/** | |||
* 模型管理 NFS 路径命名 | |||
*/ | |||
MODEL(BizEnum.MODEL, "model"), | |||
; | |||
BizNfsEnum(BizEnum bizEnum, String bizNfsPath) { | |||
@@ -81,11 +85,11 @@ public enum BizNfsEnum { | |||
return bizNfsPath; | |||
} | |||
public BizEnum getBizEnum(){ | |||
public BizEnum getBizEnum() { | |||
return bizEnum; | |||
} | |||
public String getBizCode() { | |||
return bizEnum == null ? null :bizEnum.getBizCode(); | |||
return bizEnum == null ? null : bizEnum.getBizCode(); | |||
} | |||
} |
@@ -15,7 +15,7 @@ | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.data.constant; | |||
package org.dubhe.enums; | |||
import lombok.Getter; | |||
@@ -26,36 +26,52 @@ import lombok.Getter; | |||
@Getter | |||
public enum LogEnum { | |||
// 系统报错日志 | |||
SYS_ERR, | |||
// 用户请求日志 | |||
REST_REQ, | |||
// 训练模块 | |||
BIZ_TRAIN, | |||
// 系统模块 | |||
BIZ_SYS, | |||
// 模型模块 | |||
BIZ_MODEL, | |||
// 数据集模块 | |||
BIZ_DATASET, | |||
// k8s模块 | |||
BIZ_K8S, | |||
//note book | |||
NOTE_BOOK, | |||
//NFS UTILS | |||
NFS_UTIL; | |||
// 系统报错日志 | |||
SYS_ERR, | |||
// 用户请求日志 | |||
REST_REQ, | |||
//全局请求日志 | |||
GLOBAL_REQ, | |||
// 训练模块 | |||
BIZ_TRAIN, | |||
// 系统模块 | |||
BIZ_SYS, | |||
// 模型模块 | |||
BIZ_MODEL, | |||
// 数据集模块 | |||
BIZ_DATASET, | |||
// k8s模块 | |||
BIZ_K8S, | |||
//note book | |||
NOTE_BOOK, | |||
//NFS UTILS | |||
NFS_UTIL, | |||
//localFileUtil | |||
LOCAL_FILE_UTIL, | |||
//FILE UTILS | |||
FILE_UTIL, | |||
//FILE UTILS | |||
UPLOAD_TEMP, | |||
//STATE MACHINE | |||
STATE_MACHINE, | |||
//全局垃圾回收 | |||
GARBAGE_RECYCLE, | |||
//DATA_SEQUENCE | |||
DATA_SEQUENCE, | |||
//IO UTIL | |||
IO_UTIL; | |||
/** | |||
* 判断日志类型不能为空 | |||
* | |||
* @param logType 日志类型 | |||
* @return boolean 返回类型 | |||
*/ | |||
public static boolean isLogType(LogEnum logType) { | |||
/** | |||
* 判断日志类型不能为空 | |||
* | |||
* @param logType 日志类型 | |||
* @return boolean 返回类型 | |||
*/ | |||
public static boolean isLogType(LogEnum logType) { | |||
if (logType != null) { | |||
return true; | |||
} | |||
return false; | |||
} | |||
if (logType != null) { | |||
return true; | |||
} | |||
return false; | |||
} | |||
} |
@@ -0,0 +1,72 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.enums; | |||
import lombok.Getter; | |||
import lombok.ToString; | |||
/** | |||
* @Description 操作类型枚举 | |||
* @Date 2020-08-24 | |||
*/ | |||
@ToString | |||
@Getter | |||
public enum OperationTypeEnum { | |||
/** | |||
* SELECT 查询类型 | |||
*/ | |||
SELECT("select", "查询"), | |||
/** | |||
* UPDATE 修改类型 | |||
*/ | |||
UPDATE("update", "修改"), | |||
/** | |||
* DELETE 删除类型 | |||
*/ | |||
DELETE("delete", "删除"), | |||
/** | |||
* LIMIT 禁止操作类型 | |||
*/ | |||
LIMIT("limit", "禁止操作"), | |||
/** | |||
* INSERT 新增类型 | |||
*/ | |||
INSERT("insert", "新增类型"), | |||
; | |||
/** | |||
* 操作类型值 | |||
*/ | |||
private String type; | |||
/** | |||
* 操作类型备注 | |||
*/ | |||
private String desc; | |||
OperationTypeEnum(String type, String desc) { | |||
this.type = type; | |||
this.desc = desc; | |||
} | |||
} |
@@ -0,0 +1,54 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.enums; | |||
import lombok.Getter; | |||
import java.util.HashSet; | |||
import java.util.Set; | |||
/** | |||
* @description 资源回收枚举类 | |||
* @date 2020-10-10 | |||
*/ | |||
@Getter | |||
public enum RecycleResourceEnum { | |||
/** | |||
* 数据集文件回收 | |||
*/ | |||
DATASET_RECYCLE_FILE("datasetRecycleFile", "数据集文件回收"), | |||
/** | |||
* 数据集版本文件回收 | |||
*/ | |||
DATASET_RECYCLE_VERSION_FILE("datasetRecycleVersionFile", "数据集版本文件回收"), | |||
; | |||
private String className; | |||
private String message; | |||
RecycleResourceEnum(String className, String message) { | |||
this.className = className; | |||
this.message = message; | |||
} | |||
} |
@@ -19,8 +19,8 @@ package org.dubhe.enums; | |||
import lombok.Getter; | |||
import java.util.Arrays; | |||
import java.util.List; | |||
import java.util.HashSet; | |||
import java.util.Set; | |||
/** | |||
* @description 训练任务枚举类 | |||
@@ -72,7 +72,13 @@ public enum TrainJobStatusEnum { | |||
this.message = message; | |||
} | |||
public static TrainJobStatusEnum get(String msg) { | |||
/** | |||
* 根据信息获取枚举类对象 | |||
* | |||
* @param msg 信息 | |||
* @return 枚举类对象 | |||
*/ | |||
public static TrainJobStatusEnum getByMessage(String msg) { | |||
for (TrainJobStatusEnum statusEnum : values()) { | |||
if (statusEnum.message.equalsIgnoreCase(msg)) { | |||
return statusEnum; | |||
@@ -81,21 +87,63 @@ public enum TrainJobStatusEnum { | |||
return UNKNOWN; | |||
} | |||
/** | |||
* 回调状态转换 若是DELETED则转换为STOP,避免状态不统一 | |||
* @param phase k8s pod phase | |||
* @return | |||
*/ | |||
public static TrainJobStatusEnum transferStatus(String phase) { | |||
TrainJobStatusEnum enums = getByMessage(phase); | |||
if (enums != DELETED) { | |||
return enums; | |||
} | |||
return STOP; | |||
} | |||
/** | |||
* 根据状态获取枚举类对象 | |||
* | |||
* @param status 状态 | |||
* @return 枚举类对象 | |||
*/ | |||
public static TrainJobStatusEnum getByStatus(Integer status) { | |||
for (TrainJobStatusEnum statusEnum : values()) { | |||
if (statusEnum.status.equals(status)) { | |||
return statusEnum; | |||
} | |||
} | |||
return UNKNOWN; | |||
} | |||
/** | |||
* 结束状态枚举集合 | |||
*/ | |||
public static final Set<TrainJobStatusEnum> END_TRAIN_JOB_STATUS; | |||
static { | |||
END_TRAIN_JOB_STATUS = new HashSet<>(); | |||
END_TRAIN_JOB_STATUS.add(SUCCEEDED); | |||
END_TRAIN_JOB_STATUS.add(FAILED); | |||
END_TRAIN_JOB_STATUS.add(STOP); | |||
END_TRAIN_JOB_STATUS.add(CREATE_FAILED); | |||
END_TRAIN_JOB_STATUS.add(DELETED); | |||
} | |||
public static boolean isEnd(String msg) { | |||
List<String> endList = Arrays.asList("SUCCEEDED", "FAILED", "STOP", "CREATE_FAILED"); | |||
return endList.stream().anyMatch(s -> s.equalsIgnoreCase(msg)); | |||
return END_TRAIN_JOB_STATUS.contains(getByMessage(msg)); | |||
} | |||
public static boolean isEnd(Integer num) { | |||
List<Integer> endList = Arrays.asList(2, 3, 4, 7); | |||
return endList.stream().anyMatch(s -> s.equals(num)); | |||
public static boolean isEnd(Integer status) { | |||
return END_TRAIN_JOB_STATUS.contains(getByStatus(status)); | |||
} | |||
public static boolean checkStopStatus(Integer num) { | |||
return SUCCEEDED.getStatus().equals(num) || | |||
FAILED.getStatus().equals(num) || | |||
STOP.getStatus().equals(num) || | |||
CREATE_FAILED.getStatus().equals(num) || | |||
DELETED.getStatus().equals(num); | |||
return isEnd(num); | |||
} | |||
public static boolean checkRunStatus(Integer num) { | |||
return PENDING.getStatus().equals(num) || | |||
RUNNING.getStatus().equals(num); | |||
} | |||
} |
@@ -54,6 +54,7 @@ public enum BaseErrorCode implements ErrorCode { | |||
SYSTEM_USER_CANNOT_DELETE(20014, "系统默认用户不可删除!"), | |||
SYSTEM_ROLE_CANNOT_DELETE(20015, "系统默认角色不可删除!"), | |||
DATASET_ADMIN_PERMISSION_ERROR(1310,"无此权限,请联系管理员"), | |||
; | |||
@@ -0,0 +1,42 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.exception; | |||
import lombok.Getter; | |||
/** | |||
* @description 获取序列异常 | |||
* @date 2020-09-23 | |||
*/ | |||
@Getter | |||
public class DataSequenceException extends BusinessException { | |||
private static final long serialVersionUID = 1L; | |||
public DataSequenceException(String msg) { | |||
super(msg); | |||
} | |||
public DataSequenceException(String msg, Throwable cause) { | |||
super(msg,cause); | |||
} | |||
public DataSequenceException(Throwable cause) { | |||
super(cause); | |||
} | |||
} |
@@ -20,8 +20,7 @@ package org.dubhe.exception; | |||
import lombok.Getter; | |||
/** | |||
* @description: Notebook 业务处理异常 | |||
* | |||
* @description Notebook 业务处理异常 | |||
* @date 2020.04.27 | |||
*/ | |||
@Getter | |||
@@ -17,8 +17,8 @@ | |||
package org.dubhe.exception.handler; | |||
import java.util.Objects; | |||
import lombok.extern.slf4j.Slf4j; | |||
import org.apache.ibatis.exceptions.IbatisException; | |||
import org.apache.shiro.ShiroException; | |||
import org.apache.shiro.authc.AuthenticationException; | |||
import org.apache.shiro.authc.IncorrectCredentialsException; | |||
@@ -27,11 +27,7 @@ import org.apache.shiro.authc.UnknownAccountException; | |||
import org.dubhe.base.DataResponseBody; | |||
import org.dubhe.base.ResponseCode; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.exception.BusinessException; | |||
import org.dubhe.exception.CaptchaException; | |||
import org.dubhe.exception.LoginException; | |||
import org.dubhe.exception.NotebookBizException; | |||
import org.dubhe.exception.UnauthorizedException; | |||
import org.dubhe.exception.*; | |||
import org.dubhe.utils.LogUtil; | |||
import org.springframework.http.HttpStatus; | |||
import org.springframework.http.ResponseEntity; | |||
@@ -41,7 +37,7 @@ import org.springframework.web.bind.MethodArgumentNotValidException; | |||
import org.springframework.web.bind.annotation.ExceptionHandler; | |||
import org.springframework.web.bind.annotation.RestControllerAdvice; | |||
import lombok.extern.slf4j.Slf4j; | |||
import java.util.Objects; | |||
/** | |||
* @description 处理异常 | |||
@@ -51,140 +47,144 @@ import lombok.extern.slf4j.Slf4j; | |||
@RestControllerAdvice | |||
public class GlobalExceptionHandler { | |||
/** | |||
* 处理所有不可知的异常 | |||
*/ | |||
@ExceptionHandler(Throwable.class) | |||
public ResponseEntity<DataResponseBody> handleException(Throwable e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
return buildResponseEntity(HttpStatus.INTERNAL_SERVER_ERROR, | |||
new DataResponseBody(ResponseCode.ERROR, e.getMessage())); | |||
} | |||
/** | |||
* UnauthorizedException | |||
*/ | |||
@ExceptionHandler(UnauthorizedException.class) | |||
public ResponseEntity<DataResponseBody> badCredentialsException(UnauthorizedException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
String message = "坏的凭证".equals(e.getMessage()) ? "用户名或密码不正确" : e.getMessage(); | |||
return buildResponseEntity(HttpStatus.UNAUTHORIZED, new DataResponseBody(ResponseCode.ERROR, message)); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = BusinessException.class) | |||
public ResponseEntity<DataResponseBody> badRequestException(BusinessException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = AuthenticationException.class) | |||
public ResponseEntity<DataResponseBody> badRequestException(AuthenticationException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); | |||
} | |||
/** | |||
* shiro 异常捕捉 | |||
*/ | |||
@ExceptionHandler(value = ShiroException.class) | |||
public ResponseEntity<DataResponseBody> accountException(ShiroException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
ResponseEntity<DataResponseBody> responseEntity; | |||
if (e instanceof IncorrectCredentialsException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "密码不正确")); | |||
} else if (e instanceof UnknownAccountException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "此账户不存在")); | |||
} | |||
else if (e instanceof LockedAccountException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "未知的账号")); | |||
} | |||
else if (e instanceof UnknownAccountException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "账户已被禁用")); | |||
} | |||
else { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, | |||
new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); | |||
} | |||
return responseEntity; | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = LoginException.class) | |||
public ResponseEntity<DataResponseBody> loginException(LoginException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
return buildResponseEntity(HttpStatus.UNAUTHORIZED, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = CaptchaException.class) | |||
public ResponseEntity<DataResponseBody> captchaException(CaptchaException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = NotebookBizException.class) | |||
public ResponseEntity<DataResponseBody> captchaException(NotebookBizException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理所有接口数据验证异常 | |||
*/ | |||
@ExceptionHandler(MethodArgumentNotValidException.class) | |||
public ResponseEntity<DataResponseBody> handleMethodArgumentNotValidException(MethodArgumentNotValidException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
String[] str = Objects.requireNonNull(e.getBindingResult().getAllErrors().get(0).getCodes())[1].split("\\."); | |||
String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage(); | |||
String msg = "不能为空"; | |||
if (msg.equals(message)) { | |||
message = str[1] + ":" + message; | |||
} | |||
return buildResponseEntity(HttpStatus.BAD_REQUEST, new DataResponseBody(ResponseCode.ERROR, message)); | |||
} | |||
@ExceptionHandler(BindException.class) | |||
public ResponseEntity<DataResponseBody> bindException(BindException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, e); | |||
ObjectError error = e.getAllErrors().get(0); | |||
return buildResponseEntity(HttpStatus.BAD_REQUEST, | |||
new DataResponseBody(ResponseCode.ERROR, error.getDefaultMessage())); | |||
} | |||
/** | |||
* 统一返回 | |||
* | |||
* @param httpStatus | |||
* @param responseBody | |||
* @return | |||
*/ | |||
private ResponseEntity<DataResponseBody> buildResponseEntity(HttpStatus httpStatus, DataResponseBody responseBody) { | |||
return new ResponseEntity<>(responseBody, httpStatus); | |||
} | |||
/** | |||
* 处理所有不可知的异常 | |||
*/ | |||
@ExceptionHandler(Throwable.class) | |||
public ResponseEntity<DataResponseBody> handleException(Throwable e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
return buildResponseEntity(HttpStatus.INTERNAL_SERVER_ERROR, | |||
new DataResponseBody(ResponseCode.ERROR, e.getMessage())); | |||
} | |||
/** | |||
* UnauthorizedException | |||
*/ | |||
@ExceptionHandler(UnauthorizedException.class) | |||
public ResponseEntity<DataResponseBody> badCredentialsException(UnauthorizedException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
String message = "坏的凭证".equals(e.getMessage()) ? "用户名或密码不正确" : e.getMessage(); | |||
return buildResponseEntity(HttpStatus.UNAUTHORIZED, new DataResponseBody(ResponseCode.ERROR, message)); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = BusinessException.class) | |||
public ResponseEntity<DataResponseBody> badRequestException(BusinessException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = IbatisException.class) | |||
public ResponseEntity<DataResponseBody> persistenceException(IbatisException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, e.getMessage())); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = AuthenticationException.class) | |||
public ResponseEntity<DataResponseBody> badRequestException(AuthenticationException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
return buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); | |||
} | |||
/** | |||
* shiro 异常捕捉 | |||
*/ | |||
@ExceptionHandler(value = ShiroException.class) | |||
public ResponseEntity<DataResponseBody> accountException(ShiroException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
ResponseEntity<DataResponseBody> responseEntity; | |||
if (e instanceof IncorrectCredentialsException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "密码不正确")); | |||
} else if (e instanceof UnknownAccountException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "此账户不存在")); | |||
} else if (e instanceof LockedAccountException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "未知的账号")); | |||
} else if (e instanceof UnknownAccountException) { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, new DataResponseBody(ResponseCode.ERROR, "账户已被禁用")); | |||
} else { | |||
responseEntity = buildResponseEntity(HttpStatus.OK, | |||
new DataResponseBody(ResponseCode.UNAUTHORIZED, "无权访问")); | |||
} | |||
return responseEntity; | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = LoginException.class) | |||
public ResponseEntity<DataResponseBody> loginException(LoginException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
return buildResponseEntity(HttpStatus.UNAUTHORIZED, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = CaptchaException.class) | |||
public ResponseEntity<DataResponseBody> captchaException(CaptchaException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理自定义异常 | |||
*/ | |||
@ExceptionHandler(value = NotebookBizException.class) | |||
public ResponseEntity<DataResponseBody> captchaException(NotebookBizException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
return buildResponseEntity(HttpStatus.OK, e.getResponseBody()); | |||
} | |||
/** | |||
* 处理所有接口数据验证异常 | |||
*/ | |||
@ExceptionHandler(MethodArgumentNotValidException.class) | |||
public ResponseEntity<DataResponseBody> handleMethodArgumentNotValidException(MethodArgumentNotValidException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
String[] str = Objects.requireNonNull(e.getBindingResult().getAllErrors().get(0).getCodes())[1].split("\\."); | |||
String message = e.getBindingResult().getAllErrors().get(0).getDefaultMessage(); | |||
String msg = "不能为空"; | |||
if (msg.equals(message)) { | |||
message = str[1] + ":" + message; | |||
} | |||
return buildResponseEntity(HttpStatus.BAD_REQUEST, new DataResponseBody(ResponseCode.ERROR, message)); | |||
} | |||
@ExceptionHandler(BindException.class) | |||
public ResponseEntity<DataResponseBody> bindException(BindException e) { | |||
// 打印堆栈信息 | |||
LogUtil.error(LogEnum.SYS_ERR, "引起异常的堆栈信息:{}", e); | |||
ObjectError error = e.getAllErrors().get(0); | |||
return buildResponseEntity(HttpStatus.BAD_REQUEST, | |||
new DataResponseBody(ResponseCode.ERROR, error.getDefaultMessage())); | |||
} | |||
/** | |||
* 统一返回 | |||
* | |||
* @param httpStatus | |||
* @param responseBody | |||
* @return | |||
*/ | |||
private ResponseEntity<DataResponseBody> buildResponseEntity(HttpStatus httpStatus, DataResponseBody responseBody) { | |||
return new ResponseEntity<>(responseBody, httpStatus); | |||
} | |||
} |
@@ -0,0 +1,74 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.filter; | |||
import ch.qos.logback.classic.Level; | |||
import ch.qos.logback.classic.spi.ILoggingEvent; | |||
import ch.qos.logback.core.filter.AbstractMatcherFilter; | |||
import ch.qos.logback.core.spi.FilterReply; | |||
import cn.hutool.core.util.StrUtil; | |||
import org.slf4j.Marker; | |||
/** | |||
* @description 自定义日志过滤器 | |||
* @date 2020-07-21 | |||
*/ | |||
public class BaseLogFilter extends AbstractMatcherFilter<ILoggingEvent> { | |||
Level level; | |||
/** | |||
* 重写decide方法 | |||
* | |||
* @param iLoggingEvent event to decide upon. | |||
* @return FilterReply | |||
*/ | |||
@Override | |||
public FilterReply decide(ILoggingEvent iLoggingEvent) { | |||
if (!isStarted()) { | |||
return FilterReply.NEUTRAL; | |||
} | |||
final String msg = iLoggingEvent.getMessage(); | |||
//自定义级别 | |||
if (checkLevel(iLoggingEvent) && msg != null && msg.startsWith(StrUtil.DELIM_START) && msg.endsWith(StrUtil.DELIM_END)) { | |||
final Marker marker = iLoggingEvent.getMarker(); | |||
if (marker != null && this.getName() != null && this.getName().contains(marker.getName())) { | |||
return onMatch; | |||
} | |||
} | |||
return onMismatch; | |||
} | |||
protected boolean checkLevel(ILoggingEvent iLoggingEvent) { | |||
return this.level != null | |||
&& iLoggingEvent.getLevel() != null | |||
&& iLoggingEvent.getLevel().toInt() == this.level.toInt(); | |||
} | |||
public void setLevel(Level level) { | |||
this.level = level; | |||
} | |||
@Override | |||
public void start() { | |||
if (this.level != null) { | |||
super.start(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,49 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.filter; | |||
import ch.qos.logback.classic.spi.ILoggingEvent; | |||
import ch.qos.logback.core.spi.FilterReply; | |||
import org.dubhe.utils.LogUtil; | |||
import org.slf4j.MarkerFactory; | |||
/** | |||
* @description 自定义日志过滤器 | |||
* @date 2020-07-21 | |||
*/ | |||
public class ConsoleLogFilter extends BaseLogFilter { | |||
@Override | |||
public FilterReply decide(ILoggingEvent iLoggingEvent) { | |||
if (!isStarted()) { | |||
return FilterReply.NEUTRAL; | |||
} | |||
return checkLevel(iLoggingEvent) ? onMatch : onMismatch; | |||
} | |||
protected boolean checkLevel(ILoggingEvent iLoggingEvent) { | |||
return this.level != null | |||
&& iLoggingEvent.getLevel() != null | |||
&& iLoggingEvent.getLevel().toInt() >= this.level.toInt() | |||
&& !MarkerFactory.getMarker(LogUtil.K8S_CALLBACK_LEVEL).equals(iLoggingEvent.getMarker()) | |||
&& !MarkerFactory.getMarker(LogUtil.SCHEDULE_LEVEL).equals(iLoggingEvent.getMarker()) | |||
&& !"log4jdbc.log4j2".equals(iLoggingEvent.getLoggerName()); | |||
} | |||
} |
@@ -1,60 +0,0 @@ | |||
/** | |||
* Copyright 2019-2020 Zheng Jie | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
package org.dubhe.filter; | |||
import ch.qos.logback.classic.Level; | |||
import ch.qos.logback.classic.spi.ILoggingEvent; | |||
import ch.qos.logback.core.filter.AbstractMatcherFilter; | |||
import ch.qos.logback.core.spi.FilterReply; | |||
/** | |||
* @description 自定义日志过滤器 | |||
* @date 2020-07-21 | |||
*/ | |||
public class FileLogFilter extends AbstractMatcherFilter<ILoggingEvent> { | |||
Level level; | |||
/** | |||
* 重写decide方法 | |||
* | |||
* @param iLoggingEvent event to decide upon. | |||
* @return FilterReply | |||
*/ | |||
@Override | |||
public FilterReply decide(ILoggingEvent iLoggingEvent) { | |||
if (!isStarted()) { | |||
return FilterReply.NEUTRAL; | |||
} | |||
if (iLoggingEvent.getLevel().equals(level) && iLoggingEvent.getMessage() != null | |||
&& iLoggingEvent.getMessage().startsWith("{") && iLoggingEvent.getMessage().endsWith("}")) { | |||
return onMatch; | |||
} | |||
return onMismatch; | |||
} | |||
public void setLevel(Level level) { | |||
this.level = level; | |||
} | |||
@Override | |||
public void start() { | |||
if (this.level != null) { | |||
super.start(); | |||
} | |||
} | |||
} |
@@ -0,0 +1,34 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.filter; | |||
import ch.qos.logback.classic.spi.ILoggingEvent; | |||
/** | |||
* @description 全局请求 日志过滤器 | |||
* @date 2020-08-13 | |||
*/ | |||
public class GlobalRequestLogFilter extends BaseLogFilter { | |||
@Override | |||
public boolean checkLevel(ILoggingEvent iLoggingEvent) { | |||
return this.level != null; | |||
} | |||
} |
@@ -0,0 +1,113 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.interceptor; | |||
import org.apache.ibatis.executor.statement.StatementHandler; | |||
import org.apache.ibatis.mapping.BoundSql; | |||
import org.apache.ibatis.mapping.MappedStatement; | |||
import org.apache.ibatis.plugin.*; | |||
import org.apache.ibatis.reflection.DefaultReflectorFactory; | |||
import org.apache.ibatis.reflection.MetaObject; | |||
import org.apache.ibatis.reflection.SystemMetaObject; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.base.DataContext; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.enums.OperationTypeEnum; | |||
import org.dubhe.utils.JwtUtils; | |||
import org.dubhe.utils.SqlUtil; | |||
import org.springframework.stereotype.Component; | |||
import java.lang.reflect.Field; | |||
import java.sql.Connection; | |||
import java.util.Arrays; | |||
import java.util.Objects; | |||
import java.util.Properties; | |||
/** | |||
* @description mybatis拦截器 | |||
* @date 2020-06-10 | |||
*/ | |||
@Component | |||
@Intercepts({@Signature(type = StatementHandler.class, method = "prepare", args = {Connection.class, Integer.class})}) | |||
public class MySqlInterceptor implements Interceptor { | |||
@Override | |||
public Object intercept(Invocation invocation) throws Throwable { | |||
StatementHandler statementHandler = (StatementHandler) invocation.getTarget(); | |||
MetaObject metaObject = MetaObject.forObject(statementHandler, SystemMetaObject.DEFAULT_OBJECT_FACTORY, | |||
SystemMetaObject.DEFAULT_OBJECT_WRAPPER_FACTORY, new DefaultReflectorFactory()); | |||
/* | |||
* 先拦截到RoutingStatementHandler,里面有个StatementHandler类型的delegate变量,其实现类是BaseStatementHandler, | |||
* 然后就到BaseStatementHandler的成员变量mappedStatement | |||
*/ | |||
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); | |||
//id为执行的mapper方法的全路径名,如com.uv.dao.UserDao.selectPageVo | |||
String id = mappedStatement.getId(); | |||
//sql语句类型 select、delete、insert、update | |||
String sqlCommandType = mappedStatement.getSqlCommandType().toString(); | |||
BoundSql boundSql = statementHandler.getBoundSql(); | |||
//获取到原始sql语句 | |||
String sql = boundSql.getSql(); | |||
String mSql = sql; | |||
//注解逻辑判断 添加注解了才拦截 | |||
Class<?> classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf("."))); | |||
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, mappedStatement.getId().length()); | |||
UserDTO currentUserDto = JwtUtils.getCurrentUserDto(); | |||
//获取类注解 获取需要忽略拦截的方法名称 | |||
DataPermission dataAnnotation = classType.getAnnotation(DataPermission.class); | |||
if (!Objects.isNull(dataAnnotation)) { | |||
String[] ignores = dataAnnotation.ignoresMethod(); | |||
//校验拦截忽略方法名 忽略新增方法 忽略回调/定时方法 | |||
if ((!Objects.isNull(ignores) && Arrays.asList(ignores).contains(mName)) | |||
|| OperationTypeEnum.INSERT.getType().equals(sqlCommandType.toLowerCase()) | |||
|| Objects.isNull(currentUserDto) | |||
|| (!Objects.isNull(DataContext.get()) && DataContext.get().getType()) | |||
) { | |||
return invocation.proceed(); | |||
} else { | |||
//拦截所有sql操作类型 | |||
mSql = SqlUtil.buildTargetSql(sql, SqlUtil.getResourceIds()); | |||
} | |||
} | |||
//通过反射修改sql语句 | |||
Field field = boundSql.getClass().getDeclaredField("sql"); | |||
field.setAccessible(true); | |||
field.set(boundSql, mSql); | |||
return invocation.proceed(); | |||
} | |||
@Override | |||
public Object plugin(Object target) { | |||
if (target instanceof StatementHandler) { | |||
return Plugin.wrap(target, this); | |||
} else { | |||
return target; | |||
} | |||
} | |||
@Override | |||
public void setProperties(Properties properties) { | |||
} | |||
} |
@@ -0,0 +1,457 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.interceptor; | |||
import com.baomidou.mybatisplus.annotation.DbType; | |||
import com.baomidou.mybatisplus.core.MybatisDefaultParameterHandler; | |||
import com.baomidou.mybatisplus.core.metadata.IPage; | |||
import com.baomidou.mybatisplus.core.metadata.OrderItem; | |||
import com.baomidou.mybatisplus.core.parser.ISqlParser; | |||
import com.baomidou.mybatisplus.core.parser.SqlInfo; | |||
import com.baomidou.mybatisplus.core.toolkit.CollectionUtils; | |||
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils; | |||
import com.baomidou.mybatisplus.core.toolkit.PluginUtils; | |||
import com.baomidou.mybatisplus.core.toolkit.StringUtils; | |||
import com.baomidou.mybatisplus.extension.handlers.AbstractSqlParserHandler; | |||
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectFactory; | |||
import com.baomidou.mybatisplus.extension.plugins.pagination.DialectModel; | |||
import com.baomidou.mybatisplus.extension.plugins.pagination.dialects.IDialect; | |||
import com.baomidou.mybatisplus.extension.toolkit.JdbcUtils; | |||
import com.baomidou.mybatisplus.extension.toolkit.SqlParserUtils; | |||
import net.sf.jsqlparser.JSQLParserException; | |||
import net.sf.jsqlparser.parser.CCJSqlParserUtil; | |||
import net.sf.jsqlparser.schema.Column; | |||
import net.sf.jsqlparser.statement.select.*; | |||
import org.apache.ibatis.executor.statement.StatementHandler; | |||
import org.apache.ibatis.logging.Log; | |||
import org.apache.ibatis.logging.LogFactory; | |||
import org.apache.ibatis.mapping.*; | |||
import org.apache.ibatis.plugin.*; | |||
import org.apache.ibatis.reflection.MetaObject; | |||
import org.apache.ibatis.reflection.SystemMetaObject; | |||
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler; | |||
import org.apache.ibatis.session.Configuration; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.base.DataContext; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.enums.OperationTypeEnum; | |||
import org.dubhe.utils.JwtUtils; | |||
import org.dubhe.utils.SqlUtil; | |||
import java.sql.Connection; | |||
import java.sql.PreparedStatement; | |||
import java.sql.ResultSet; | |||
import java.util.*; | |||
import java.util.stream.Collectors; | |||
/** | |||
* @description MybatisPlus 分页拦截器 | |||
* @date 2020-10-09 | |||
*/ | |||
@Intercepts({@Signature( | |||
type = StatementHandler.class, | |||
method = "prepare", | |||
args = {Connection.class, Integer.class} | |||
)}) | |||
public class PaginationInterceptor extends AbstractSqlParserHandler implements Interceptor { | |||
protected static final Log logger = LogFactory.getLog(PaginationInterceptor.class); | |||
/** | |||
* COUNT SQL 解析 | |||
*/ | |||
protected ISqlParser countSqlParser; | |||
/** | |||
* 溢出总页数,设置第一页 | |||
*/ | |||
protected boolean overflow = false; | |||
/** | |||
* 单页限制 500 条,小于 0 如 -1 不受限制 | |||
*/ | |||
protected long limit = 500L; | |||
/** | |||
* 数据类型 | |||
*/ | |||
private DbType dbType; | |||
/** | |||
* 方言 | |||
*/ | |||
private IDialect dialect; | |||
/** | |||
* 方言类型 | |||
*/ | |||
@Deprecated | |||
protected String dialectType; | |||
/** | |||
* 方言实现类 | |||
*/ | |||
@Deprecated | |||
protected String dialectClazz; | |||
public PaginationInterceptor() { | |||
} | |||
/** | |||
* 构建分页sql | |||
* | |||
* @param originalSql 原生sql | |||
* @param page 分页参数 | |||
* @return 构建后 sql | |||
*/ | |||
public static String concatOrderBy(String originalSql, IPage<?> page) { | |||
if (CollectionUtils.isNotEmpty(page.orders())) { | |||
try { | |||
List<OrderItem> orderList = page.orders(); | |||
Select selectStatement = (Select) CCJSqlParserUtil.parse(originalSql); | |||
List orderByElements; | |||
List orderByElementsReturn; | |||
if (selectStatement.getSelectBody() instanceof PlainSelect) { | |||
PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody(); | |||
orderByElements = plainSelect.getOrderByElements(); | |||
orderByElementsReturn = addOrderByElements(orderList, orderByElements); | |||
plainSelect.setOrderByElements(orderByElementsReturn); | |||
return plainSelect.toString(); | |||
} | |||
if (selectStatement.getSelectBody() instanceof SetOperationList) { | |||
SetOperationList setOperationList = (SetOperationList) selectStatement.getSelectBody(); | |||
orderByElements = setOperationList.getOrderByElements(); | |||
orderByElementsReturn = addOrderByElements(orderList, orderByElements); | |||
setOperationList.setOrderByElements(orderByElementsReturn); | |||
return setOperationList.toString(); | |||
} | |||
if (selectStatement.getSelectBody() instanceof WithItem) { | |||
return originalSql; | |||
} | |||
return originalSql; | |||
} catch (JSQLParserException var7) { | |||
logger.error("failed to concat orderBy from IPage, exception=", var7); | |||
} | |||
} | |||
return originalSql; | |||
} | |||
/** | |||
* 添加分页排序规则 | |||
* | |||
* @param orderList 分页规则 | |||
* @param orderByElements 分页排序元素 | |||
* @return 分页规则 | |||
*/ | |||
private static List<OrderByElement> addOrderByElements(List<OrderItem> orderList, List<OrderByElement> orderByElements) { | |||
orderByElements = CollectionUtils.isEmpty(orderByElements) ? new ArrayList(orderList.size()) : orderByElements; | |||
List<OrderByElement> orderByElementList = (List) orderList.stream().filter((item) -> { | |||
return StringUtils.isNotBlank(item.getColumn()); | |||
}).map((item) -> { | |||
OrderByElement element = new OrderByElement(); | |||
element.setExpression(new Column(item.getColumn())); | |||
element.setAsc(item.isAsc()); | |||
element.setAscDescPresent(true); | |||
return element; | |||
}).collect(Collectors.toList()); | |||
((List) orderByElements).addAll(orderByElementList); | |||
return (List) orderByElements; | |||
} | |||
/** | |||
* 执行sql查询逻辑 | |||
* | |||
* @param invocation mybatis 调用类 | |||
* @return | |||
* @throws Throwable | |||
*/ | |||
@Override | |||
public Object intercept(Invocation invocation) throws Throwable { | |||
StatementHandler statementHandler = (StatementHandler) PluginUtils.realTarget(invocation.getTarget()); | |||
MetaObject metaObject = SystemMetaObject.forObject(statementHandler); | |||
this.sqlParser(metaObject); | |||
MappedStatement mappedStatement = (MappedStatement) metaObject.getValue("delegate.mappedStatement"); | |||
if (SqlCommandType.SELECT == mappedStatement.getSqlCommandType() && StatementType.CALLABLE != mappedStatement.getStatementType()) { | |||
BoundSql boundSql = (BoundSql) metaObject.getValue("delegate.boundSql"); | |||
Object paramObj = boundSql.getParameterObject(); | |||
IPage<?> page = null; | |||
if (paramObj instanceof IPage) { | |||
page = (IPage) paramObj; | |||
} else if (paramObj instanceof Map) { | |||
Iterator var8 = ((Map) paramObj).values().iterator(); | |||
while (var8.hasNext()) { | |||
Object arg = var8.next(); | |||
if (arg instanceof IPage) { | |||
page = (IPage) arg; | |||
break; | |||
} | |||
} | |||
} | |||
if (null != page && page.getSize() >= 0L) { | |||
if (this.limit > 0L && this.limit <= page.getSize()) { | |||
this.handlerLimit(page); | |||
} | |||
String originalSql = boundSql.getSql(); | |||
//注解逻辑判断 添加注解了才拦截 | |||
Class<?> classType = Class.forName(mappedStatement.getId().substring(0, mappedStatement.getId().lastIndexOf("."))); | |||
String mName = mappedStatement.getId().substring(mappedStatement.getId().lastIndexOf(".") + 1, mappedStatement.getId().length()); | |||
UserDTO currentUserDto = JwtUtils.getCurrentUserDto(); | |||
String sqlCommandType = mappedStatement.getSqlCommandType().toString(); | |||
//获取类注解 获取需要忽略拦截的方法名称 | |||
DataPermission dataAnnotation = classType.getAnnotation(DataPermission.class); | |||
if (!Objects.isNull(dataAnnotation)) { | |||
String[] ignores = dataAnnotation.ignoresMethod(); | |||
//校验拦截忽略方法名 忽略新增方法 忽略回调/定时方法 | |||
if (!((!Objects.isNull(ignores) && Arrays.asList(ignores).contains(mName)) | |||
|| OperationTypeEnum.INSERT.getType().equals(sqlCommandType.toLowerCase()) | |||
|| Objects.isNull(currentUserDto) | |||
|| (!Objects.isNull(DataContext.get()) && DataContext.get().getType())) | |||
) { | |||
originalSql = SqlUtil.buildTargetSql(originalSql, SqlUtil.getResourceIds()); | |||
} | |||
} | |||
Connection connection = (Connection) invocation.getArgs()[0]; | |||
if (page.isSearchCount() && !page.isHitCount()) { | |||
SqlInfo sqlInfo = SqlParserUtils.getOptimizeCountSql(page.optimizeCountSql(), this.countSqlParser, originalSql); | |||
this.queryTotal(sqlInfo.getSql(), mappedStatement, boundSql, page, connection); | |||
if (page.getTotal() <= 0L) { | |||
return null; | |||
} | |||
} | |||
DbType dbType = Optional.ofNullable(this.dbType).orElse(JdbcUtils.getDbType(connection.getMetaData().getURL())); | |||
IDialect dialect = Optional.ofNullable(this.dialect).orElse(DialectFactory.getDialect(dbType)); | |||
String buildSql = concatOrderBy(originalSql, page); | |||
DialectModel model = dialect.buildPaginationSql(buildSql, page.offset(), page.getSize()); | |||
Configuration configuration = mappedStatement.getConfiguration(); | |||
List<ParameterMapping> mappings = new ArrayList(boundSql.getParameterMappings()); | |||
Map<String, Object> additionalParameters = (Map) metaObject.getValue("delegate.boundSql.additionalParameters"); | |||
model.consumers(mappings, configuration, additionalParameters); | |||
metaObject.setValue("delegate.boundSql.sql", model.getDialectSql()); | |||
metaObject.setValue("delegate.boundSql.parameterMappings", mappings); | |||
return invocation.proceed(); | |||
} else { | |||
return invocation.proceed(); | |||
} | |||
} else { | |||
return invocation.proceed(); | |||
} | |||
} | |||
/** | |||
* 处理分页数量 | |||
* | |||
* @param page 分页参数 | |||
*/ | |||
protected void handlerLimit(IPage<?> page) { | |||
page.setSize(this.limit); | |||
} | |||
/** | |||
* 查询总数量 | |||
* | |||
* @param sql sql语句 | |||
* @param mappedStatement 映射语句包装类 | |||
* @param boundSql sql包装类 | |||
* @param page 分页参数 | |||
* @param connection JDBC连接包装类 | |||
*/ | |||
protected void queryTotal(String sql, MappedStatement mappedStatement, BoundSql boundSql, IPage<?> page, Connection connection) { | |||
try { | |||
PreparedStatement statement = connection.prepareStatement(sql); | |||
Throwable var7 = null; | |||
try { | |||
DefaultParameterHandler parameterHandler = new MybatisDefaultParameterHandler(mappedStatement, boundSql.getParameterObject(), boundSql); | |||
parameterHandler.setParameters(statement); | |||
long total = 0L; | |||
ResultSet resultSet = statement.executeQuery(); | |||
Throwable var12 = null; | |||
try { | |||
if (resultSet.next()) { | |||
total = resultSet.getLong(1); | |||
} | |||
} catch (Throwable var37) { | |||
var12 = var37; | |||
throw var37; | |||
} finally { | |||
if (resultSet != null) { | |||
if (var12 != null) { | |||
try { | |||
resultSet.close(); | |||
} catch (Throwable var36) { | |||
var12.addSuppressed(var36); | |||
} | |||
} else { | |||
resultSet.close(); | |||
} | |||
} | |||
} | |||
page.setTotal(total); | |||
if (this.overflow && page.getCurrent() > page.getPages()) { | |||
this.handlerOverflow(page); | |||
} | |||
} catch (Throwable var39) { | |||
var7 = var39; | |||
throw var39; | |||
} finally { | |||
if (statement != null) { | |||
if (var7 != null) { | |||
try { | |||
statement.close(); | |||
} catch (Throwable var35) { | |||
var7.addSuppressed(var35); | |||
} | |||
} else { | |||
statement.close(); | |||
} | |||
} | |||
} | |||
} catch (Exception var41) { | |||
throw ExceptionUtils.mpe("Error: Method queryTotal execution error of sql : \n %s \n", var41, new Object[]{sql}); | |||
} | |||
} | |||
/** | |||
* 设置默认当前页 | |||
* | |||
* @param page 分页参数 | |||
*/ | |||
protected void handlerOverflow(IPage<?> page) { | |||
page.setCurrent(1L); | |||
} | |||
/** | |||
* MybatisPlus拦截器实现自定义插件 | |||
* | |||
* @param target 拦截目标对象 | |||
* @return | |||
*/ | |||
@Override | |||
public Object plugin(Object target) { | |||
return target instanceof StatementHandler ? Plugin.wrap(target, this) : target; | |||
} | |||
/** | |||
* MybatisPlus拦截器实现自定义属性设置 | |||
* | |||
* @param prop 属性参数 | |||
*/ | |||
@Override | |||
public void setProperties(Properties prop) { | |||
String dialectType = prop.getProperty("dialectType"); | |||
String dialectClazz = prop.getProperty("dialectClazz"); | |||
if (StringUtils.isNotBlank(dialectType)) { | |||
this.setDialectType(dialectType); | |||
} | |||
if (StringUtils.isNotBlank(dialectClazz)) { | |||
this.setDialectClazz(dialectClazz); | |||
} | |||
} | |||
/** | |||
* 设置数据源类型 | |||
* | |||
* @param dialectType 数据源类型 | |||
*/ | |||
@Deprecated | |||
public void setDialectType(String dialectType) { | |||
this.setDbType(DbType.getDbType(dialectType)); | |||
} | |||
/** | |||
* 设置方言实现类配置 | |||
* | |||
* @param dialectClazz 方言实现类 | |||
*/ | |||
@Deprecated | |||
public void setDialectClazz(String dialectClazz) { | |||
this.setDialect(DialectFactory.getDialect(dialectClazz)); | |||
} | |||
/** | |||
* 设置获取总数的sql解析器 | |||
* | |||
* @param countSqlParser 总数的sql解析器 | |||
* @return 自定义MybatisPlus拦截器 | |||
*/ | |||
public PaginationInterceptor setCountSqlParser(final ISqlParser countSqlParser) { | |||
this.countSqlParser = countSqlParser; | |||
return this; | |||
} | |||
/** | |||
* 溢出总页数,设置第一页 | |||
* | |||
* @param overflow 溢出总页数 | |||
* @return 自定义MybatisPlus拦截器 | |||
*/ | |||
public PaginationInterceptor setOverflow(final boolean overflow) { | |||
this.overflow = overflow; | |||
return this; | |||
} | |||
/** | |||
* 设置分页规则 | |||
* | |||
* @param limit 分页数量 | |||
* @return 自定义MybatisPlus拦截器 | |||
*/ | |||
public PaginationInterceptor setLimit(final long limit) { | |||
this.limit = limit; | |||
return this; | |||
} | |||
/** | |||
* 设置数据类型 | |||
* | |||
* @param dbType 数据类型 | |||
* @return 自定义MybatisPlus拦截器 | |||
*/ | |||
public PaginationInterceptor setDbType(final DbType dbType) { | |||
this.dbType = dbType; | |||
return this; | |||
} | |||
/** | |||
* 设置方言 | |||
* | |||
* @param dialect 方言 | |||
* @return 自定义MybatisPlus拦截器 | |||
*/ | |||
public PaginationInterceptor setDialect(final IDialect dialect) { | |||
this.dialect = dialect; | |||
return this; | |||
} | |||
} | |||
@@ -18,16 +18,25 @@ | |||
package org.dubhe.utils; | |||
import java.sql.Timestamp; | |||
import java.text.DateFormat; | |||
import java.text.SimpleDateFormat; | |||
import java.time.Instant; | |||
import java.time.LocalDate; | |||
import java.time.LocalDateTime; | |||
import java.time.ZoneId; | |||
import java.util.Date; | |||
/** | |||
* @description 日期工具类 | |||
* @date 2020-6-10 | |||
* @date 2020-06-10 | |||
*/ | |||
public class DateUtil { | |||
private DateUtil(){ | |||
} | |||
/** | |||
* 获取当前时间戳 | |||
* | |||
@@ -77,4 +86,13 @@ public class DateUtil { | |||
return (milli-l1); | |||
} | |||
/** | |||
* @return 当前字符串时间yyyy-MM-dd HH:mm:ss SSS | |||
*/ | |||
public static String getCurrentTimeStr(){ | |||
Date date = new Date(); | |||
DateFormat dateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss SSS"); | |||
return dateFormat.format(date); | |||
} | |||
} |
@@ -19,10 +19,12 @@ package org.dubhe.utils; | |||
import cn.hutool.core.codec.Base64; | |||
import cn.hutool.core.io.IoUtil; | |||
import cn.hutool.core.util.CharsetUtil; | |||
import cn.hutool.core.util.IdUtil; | |||
import cn.hutool.poi.excel.BigExcelWriter; | |||
import cn.hutool.poi.excel.ExcelUtil; | |||
import org.apache.poi.util.IOUtils; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.exception.BusinessException; | |||
import org.springframework.web.multipart.MultipartFile; | |||
@@ -352,4 +354,56 @@ public class FileUtil extends cn.hutool.core.io.FileUtil { | |||
return getMd5(getByte(file)); | |||
} | |||
/** | |||
* 生成文件 | |||
* @param filePath 文件绝对路径 | |||
* @param content 文件内容 | |||
* @param append 文件是否是追加 | |||
* @return | |||
*/ | |||
public static boolean generateFile(String filePath,String content,boolean append){ | |||
File file = new File(filePath); | |||
FileOutputStream outputStream = null; | |||
try { | |||
if (!file.exists()){ | |||
file.createNewFile(); | |||
} | |||
outputStream = new FileOutputStream(file,append); | |||
outputStream.write(content.getBytes(CharsetUtil.defaultCharset())); | |||
outputStream.flush(); | |||
}catch (IOException e) { | |||
LogUtil.error(LogEnum.FILE_UTIL,e); | |||
return false; | |||
}finally { | |||
if (outputStream != null){ | |||
try { | |||
outputStream.close(); | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.FILE_UTIL,e); | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
/** | |||
* 压缩文件目录 | |||
* | |||
* @param zipDir 待压缩文件夹路径 | |||
* @param zipFile 压缩完成zip文件绝对路径 | |||
* @return | |||
*/ | |||
public static boolean zipPath(String zipDir,String zipFile) { | |||
if (zipDir == null) { | |||
return false; | |||
} | |||
File zip = new File(zipFile); | |||
cn.hutool.core.util.ZipUtil.zip(zip, CharsetUtil.defaultCharset(), true, | |||
(f) -> !f.isDirectory(), | |||
new File(zipDir).listFiles()); | |||
return true; | |||
} | |||
} |
@@ -20,6 +20,7 @@ import org.apache.commons.io.IOUtils; | |||
import org.dubhe.enums.LogEnum; | |||
import javax.net.ssl.HttpsURLConnection; | |||
import javax.net.ssl.SSLContext; | |||
import javax.net.ssl.SSLSocketFactory; | |||
@@ -31,13 +32,13 @@ import java.io.InputStreamReader; | |||
import java.net.URL; | |||
import java.security.cert.CertificateException; | |||
import java.security.cert.X509Certificate; | |||
import org.apache.commons.codec.binary.Base64; | |||
import static org.dubhe.constant.StringConstant.UTF8; | |||
import static org.dubhe.constant.SymbolConstant.BLANK; | |||
/** | |||
* @description: httpClient工具类,不校验SSL证书 | |||
* @date: 2020-5-21 | |||
* @description httpClient工具类,不校验SSL证书 | |||
* @date 2020-05-21 | |||
*/ | |||
public class HttpClientUtils { | |||
@@ -45,7 +46,7 @@ public class HttpClientUtils { | |||
InputStream inputStream = null; | |||
BufferedReader bufferedReader = null; | |||
InputStreamReader inputStreamReader = null; | |||
StringBuilder stringBuider = new StringBuilder(); | |||
StringBuilder stringBuilder = new StringBuilder(); | |||
String result = BLANK; | |||
HttpsURLConnection con = null; | |||
try { | |||
@@ -59,10 +60,10 @@ public class HttpClientUtils { | |||
String str = null; | |||
while ((str = bufferedReader.readLine()) != null) { | |||
stringBuider.append(str); | |||
stringBuilder.append(str); | |||
} | |||
result = stringBuider.toString(); | |||
result = stringBuilder.toString(); | |||
LogUtil.info(LogEnum.BIZ_SYS,"Request path:{}, SUCCESS, result:{}", path, result); | |||
} catch (Exception e) { | |||
@@ -74,17 +75,54 @@ public class HttpClientUtils { | |||
return result; | |||
} | |||
public static String sendHttpsDelete(String path,String username,String password) { | |||
InputStream inputStream = null; | |||
BufferedReader bufferedReader = null; | |||
InputStreamReader inputStreamReader = null; | |||
StringBuilder stringBuilder = new StringBuilder(); | |||
String result = BLANK; | |||
HttpsURLConnection con = null; | |||
try { | |||
con = getConnection(path); | |||
String input =username+ ":" +password; | |||
String encoding=Base64.encodeBase64String(input.getBytes()); | |||
con.setRequestProperty(JwtUtils.AUTH_HEADER, "Basic " + encoding); | |||
con.setRequestMethod("DELETE"); | |||
con.connect(); | |||
/**将返回的输入流转换成字符串**/ | |||
inputStream = con.getInputStream(); | |||
inputStreamReader = new InputStreamReader(inputStream, UTF8); | |||
bufferedReader = new BufferedReader(inputStreamReader); | |||
private static void closeResource(BufferedReader bufferedReader,InputStreamReader inputStreamReader,InputStream inputStream,HttpsURLConnection con) { | |||
String str = null; | |||
while ((str = bufferedReader.readLine()) != null) { | |||
stringBuilder.append(str); | |||
} | |||
IOUtils.closeQuietly(bufferedReader); | |||
result = stringBuilder.toString(); | |||
LogUtil.info(LogEnum.BIZ_SYS,"Request path:{}, SUCCESS, result:{}", path, result); | |||
if (inputStreamReader != null) { | |||
IOUtils.closeQuietly(inputStreamReader); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.BIZ_SYS,"Request path:{}, ERROR, exception:{}", path, e); | |||
return result; | |||
} finally { | |||
closeResource(bufferedReader,inputStreamReader,inputStream,con); | |||
} | |||
return result; | |||
} | |||
private static void closeResource(BufferedReader bufferedReader,InputStreamReader inputStreamReader,InputStream inputStream,HttpsURLConnection con) { | |||
if (inputStream != null) { | |||
IOUtils.closeQuietly(inputStream); | |||
} | |||
if (inputStreamReader != null) { | |||
IOUtils.closeQuietly(inputStreamReader); | |||
} | |||
if (bufferedReader != null) { | |||
IOUtils.closeQuietly(bufferedReader); | |||
} | |||
if (con != null) { | |||
con.disconnect(); | |||
} | |||
@@ -20,8 +20,8 @@ package org.dubhe.utils; | |||
import lombok.extern.slf4j.Slf4j; | |||
/** | |||
* @description: HttpUtil | |||
* @date 2020.04.30 | |||
* @description HttpUtil | |||
* @date 2020-04-30 | |||
*/ | |||
@Slf4j | |||
public class HttpUtils { | |||
@@ -0,0 +1,46 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.utils; | |||
import org.dubhe.enums.LogEnum; | |||
import java.io.Closeable; | |||
import java.io.IOException; | |||
/** | |||
* @description IO流操作工具类 | |||
* @date 2020-10-14 | |||
*/ | |||
public class IOUtil { | |||
/** | |||
* 循环的依次关闭流 | |||
* | |||
* @param closeableList 要被关闭的流集合 | |||
*/ | |||
public static void close(Closeable... closeableList) { | |||
for (Closeable closeable : closeableList) { | |||
try { | |||
if (closeable != null) { | |||
closeable.close(); | |||
} | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.IO_UTIL, "关闭流异常,异常信息:{}", e); | |||
} | |||
} | |||
} | |||
} |
@@ -139,7 +139,7 @@ public class JwtUtils { | |||
public static boolean isTokenExpired(String token) { | |||
Date now = Calendar.getInstance().getTime(); | |||
DecodedJWT jwt = JWT.decode(token); | |||
return jwt.getExpiresAt().before(now); | |||
return jwt.getExpiresAt() == null || jwt.getExpiresAt().before(now); | |||
} | |||
/** | |||
@@ -67,12 +67,12 @@ public class K8sNameTool { | |||
} | |||
/** | |||
* 生成 Notebook的NameSpace | |||
* 生成 Notebook的Namespace | |||
* | |||
* @param userId | |||
* @return namespace | |||
*/ | |||
public String generateNameSpace(long userId) { | |||
public String generateNamespace(long userId) { | |||
return this.k8sNameConfig.getNamespace() + SEPARATOR + userId; | |||
} | |||
@@ -96,7 +96,7 @@ public class K8sNameTool { | |||
* @param namespace | |||
* @return Long | |||
*/ | |||
public Long getUserIdFromNameSpace(String namespace) { | |||
public Long getUserIdFromNamespace(String namespace) { | |||
if (StringUtils.isEmpty(namespace) || !namespace.contains(this.k8sNameConfig.getNamespace() + SEPARATOR)) { | |||
return null; | |||
} | |||
@@ -0,0 +1,307 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.utils; | |||
import cn.hutool.core.util.StrUtil; | |||
import lombok.Getter; | |||
import org.apache.commons.compress.archivers.zip.ZipArchiveEntry; | |||
import org.apache.commons.compress.archivers.zip.ZipFile; | |||
import org.apache.commons.io.IOUtils; | |||
import org.dubhe.base.MagicNumConstant; | |||
import org.dubhe.config.NfsConfig; | |||
import org.dubhe.enums.LogEnum; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.beans.factory.annotation.Value; | |||
import org.springframework.stereotype.Component; | |||
import org.springframework.util.FileCopyUtils; | |||
import java.io.*; | |||
import java.util.Enumeration; | |||
import java.util.zip.ZipEntry; | |||
/** | |||
* @description 本地文件操作工具类 | |||
* @date 2020-08-19 | |||
*/ | |||
@Component | |||
@Getter | |||
public class LocalFileUtil { | |||
@Autowired | |||
private NfsConfig nfsConfig; | |||
private static final String FILE_SEPARATOR = File.separator; | |||
private static final String ZIP = ".zip"; | |||
private static final String CHARACTER_GBK = "GBK"; | |||
private static final String OS_NAME = "os.name"; | |||
private static final String WINDOWS = "Windows"; | |||
@Value("${k8s.nfs-root-path}") | |||
private String nfsRootPath; | |||
@Value("${k8s.nfs-root-windows-path}") | |||
private String nfsRootWindowsPath; | |||
/** | |||
* windows 与 linux 的路径兼容 | |||
* | |||
* @param path linux下的路径 | |||
* @return path 兼容windows后的路径 | |||
*/ | |||
private String compatiblePath(String path) { | |||
if (path == null) { | |||
return null; | |||
} | |||
if (System.getProperties().getProperty(OS_NAME).contains(WINDOWS)) { | |||
path = path.replace(nfsRootPath, StrUtil.SLASH); | |||
path = path.replace(StrUtil.SLASH, FILE_SEPARATOR); | |||
path = nfsRootWindowsPath + path; | |||
} | |||
return path; | |||
} | |||
/** | |||
* 本地解压zip包并删除压缩文件 | |||
* | |||
* @param sourcePath zip源文件 例如:/abc/z.zip | |||
* @param targetPath 解压后的目标文件夹 例如:/abc/ | |||
* @return boolean | |||
*/ | |||
public boolean unzipLocalPath(String sourcePath, String targetPath) { | |||
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { | |||
return false; | |||
} | |||
if (!sourcePath.toLowerCase().endsWith(ZIP)) { | |||
return false; | |||
} | |||
//绝对路径 | |||
String sourceAbsolutePath = nfsConfig.getRootDir() + sourcePath; | |||
String targetPathAbsolutePath = nfsConfig.getRootDir() + targetPath; | |||
ZipFile zipFile = null; | |||
InputStream in = null; | |||
OutputStream out = null; | |||
File sourceFile = new File(compatiblePath(sourceAbsolutePath)); | |||
File targetFileDir = new File(compatiblePath(targetPathAbsolutePath)); | |||
if (!targetFileDir.exists()) { | |||
boolean targetMkdir = targetFileDir.mkdirs(); | |||
if (!targetMkdir) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}failed to create target folder before decompression", sourceAbsolutePath); | |||
} | |||
} | |||
try { | |||
zipFile = new ZipFile(sourceFile); | |||
//判断压缩文件编码方式,并重新获取文件对象 | |||
try { | |||
zipFile.close(); | |||
zipFile = new ZipFile(sourceFile, CHARACTER_GBK); | |||
} catch (Exception e) { | |||
zipFile.close(); | |||
zipFile = new ZipFile(sourceFile); | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}the encoding mode of decompressed compressed file is changed to UTF-8:{}", sourceAbsolutePath, e); | |||
} | |||
ZipEntry entry; | |||
Enumeration enumeration = zipFile.getEntries(); | |||
while (enumeration.hasMoreElements()) { | |||
entry = (ZipEntry) enumeration.nextElement(); | |||
String entryName = entry.getName(); | |||
File fileDir; | |||
if (entry.isDirectory()) { | |||
fileDir = new File(targetPathAbsolutePath + entry.getName()); | |||
if (!fileDir.exists()) { | |||
boolean fileMkdir = fileDir.mkdirs(); | |||
if (!fileMkdir) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to create folder {} while decompressing {}", fileDir, sourceAbsolutePath); | |||
} | |||
} | |||
} else { | |||
//若文件夹未创建则创建文件夹 | |||
if (entryName.contains(FILE_SEPARATOR)) { | |||
String zipDirName = entryName.substring(MagicNumConstant.ZERO, entryName.lastIndexOf(FILE_SEPARATOR)); | |||
fileDir = new File(targetPathAbsolutePath + zipDirName); | |||
if (!fileDir.exists()) { | |||
boolean fileMkdir = fileDir.mkdirs(); | |||
if (!fileMkdir) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to create folder {} while decompressing {}", fileDir, sourceAbsolutePath); | |||
} | |||
} | |||
} | |||
in = zipFile.getInputStream((ZipArchiveEntry) entry); | |||
out = new FileOutputStream(new File(targetPathAbsolutePath, entryName)); | |||
IOUtils.copyLarge(in, out); | |||
in.close(); | |||
out.close(); | |||
} | |||
} | |||
boolean deleteZipFile = sourceFile.delete(); | |||
if (!deleteZipFile) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}compressed file deletion failed after decompression", sourceAbsolutePath); | |||
} | |||
return true; | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}decompression failed: {}", sourceAbsolutePath, e); | |||
return false; | |||
} finally { | |||
//关闭未关闭的io流 | |||
closeIoFlow(sourceAbsolutePath, zipFile, in, out); | |||
} | |||
} | |||
/** | |||
* 关闭未关闭的io流 | |||
* | |||
* @param sourceAbsolutePath 源路径 | |||
* @param zipFile 压缩文件对象 | |||
* @param in 输入流 | |||
* @param out 输出流 | |||
*/ | |||
private void closeIoFlow(String sourceAbsolutePath, ZipFile zipFile, InputStream in, OutputStream out) { | |||
if (in != null) { | |||
try { | |||
in.close(); | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}input stream shutdown failed: {}", sourceAbsolutePath, e); | |||
} | |||
} | |||
if (out != null) { | |||
try { | |||
out.close(); | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}output stream shutdown failed: {}", sourceAbsolutePath, e); | |||
} | |||
} | |||
if (zipFile != null) { | |||
try { | |||
zipFile.close(); | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "{}input stream shutdown failed: {}", sourceAbsolutePath, e); | |||
} | |||
} | |||
} | |||
/** | |||
* NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况 | |||
* | |||
* 通过本地文件复制方式 | |||
* | |||
* @param sourcePath 需要复制的文件目录 例如:/abc/def | |||
* @param targetPath 需要放置的目标目录 例如:/abc/dd | |||
* @return boolean | |||
*/ | |||
public boolean copyPath(String sourcePath, String targetPath) { | |||
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { | |||
return false; | |||
} | |||
sourcePath = formatPath(sourcePath); | |||
targetPath = formatPath(targetPath); | |||
try { | |||
return copyLocalPath(nfsConfig.getRootDir() + sourcePath, nfsConfig.getRootDir() + targetPath); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, " failed to Copy file original path: {} ,target path: {} ,copyPath: {}", sourcePath, targetPath, e); | |||
return false; | |||
} | |||
} | |||
/** | |||
* 复制文件到指定目录下 单个文件 | |||
* | |||
* @param sourcePath 需要复制的文件 例如:/abc/def/cc.txt | |||
* @param targetPath 需要放置的目标目录 例如:/abc/dd | |||
* @return boolean | |||
*/ | |||
private boolean copyLocalFile(String sourcePath, String targetPath) { | |||
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { | |||
return false; | |||
} | |||
sourcePath = formatPath(sourcePath); | |||
targetPath = formatPath(targetPath); | |||
try (InputStream input = new FileInputStream(sourcePath); | |||
FileOutputStream output = new FileOutputStream(targetPath)) { | |||
FileCopyUtils.copy(input, output); | |||
return true; | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, " failed to copy file original path: {} ,target path: {} ,copyLocalFile:{} ", sourcePath, targetPath, e); | |||
return false; | |||
} | |||
} | |||
/** | |||
* 复制文件 到指定目录下 多个文件 包含目录与文件并存情况 | |||
* | |||
* @param sourcePath 需要复制的文件目录 例如:/abc/def | |||
* @param targetPath 需要放置的目标目录 例如:/abc/dd | |||
* @return boolean | |||
*/ | |||
private boolean copyLocalPath(String sourcePath, String targetPath) { | |||
if (!StringUtils.isEmpty(sourcePath) && !StringUtils.isEmpty(targetPath)) { | |||
sourcePath = formatPath(sourcePath); | |||
if (sourcePath.endsWith(FILE_SEPARATOR)) { | |||
sourcePath = sourcePath.substring(MagicNumConstant.ZERO, sourcePath.lastIndexOf(FILE_SEPARATOR)); | |||
} | |||
targetPath = formatPath(targetPath); | |||
File sourceFile = new File(sourcePath); | |||
if (sourceFile.exists()) { | |||
File[] files = sourceFile.listFiles(); | |||
if (files != null && files.length != 0) { | |||
for (File file : files) { | |||
try { | |||
if (file.isDirectory()) { | |||
File fileDir = new File(targetPath + FILE_SEPARATOR + file.getName()); | |||
if (!fileDir.exists()) { | |||
fileDir.mkdirs(); | |||
} | |||
copyLocalPath(sourcePath + FILE_SEPARATOR + file.getName(), targetPath + FILE_SEPARATOR + file.getName()); | |||
} | |||
if (file.isFile()) { | |||
File fileTargetPath = new File(targetPath); | |||
if (!fileTargetPath.exists()) { | |||
fileTargetPath.mkdirs(); | |||
} | |||
copyLocalFile(file.getAbsolutePath(), targetPath + FILE_SEPARATOR + file.getName()); | |||
} | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.LOCAL_FILE_UTIL, "failed to copy folder original path: {} , target path : {} ,copyLocalPath: {}", sourcePath, targetPath, e); | |||
return false; | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
} | |||
return false; | |||
} | |||
/** | |||
* 替换路径中多余的 "/" | |||
* | |||
* @param path | |||
* @return String | |||
*/ | |||
public String formatPath(String path) { | |||
if (!StringUtils.isEmpty(path)) { | |||
return path.replaceAll("///*", FILE_SEPARATOR); | |||
} | |||
return path; | |||
} | |||
} |
@@ -23,10 +23,10 @@ import lombok.extern.slf4j.Slf4j; | |||
import org.apache.commons.lang3.exception.ExceptionUtils; | |||
import org.dubhe.aspect.LogAspect; | |||
import org.dubhe.base.MagicNumConstant; | |||
import org.dubhe.constant.SymbolConstant; | |||
import org.dubhe.domain.entity.LogInfo; | |||
import org.dubhe.enums.LogEnum; | |||
import org.slf4j.MDC; | |||
import org.slf4j.MarkerFactory; | |||
import org.slf4j.helpers.MessageFormatter; | |||
import java.util.Arrays; | |||
@@ -38,218 +38,283 @@ import java.util.UUID; | |||
*/ | |||
@Slf4j | |||
public class LogUtil { | |||
/** | |||
* info级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
*/ | |||
public static void info(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.INFO, object); | |||
} | |||
/** | |||
* debug级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
*/ | |||
public static void debug(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.DEBUG, object); | |||
} | |||
/** | |||
* error级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
*/ | |||
public static void error(LogEnum logType, Object... object) { | |||
errorObjectHandle(object); | |||
logHandle(logType, Level.ERROR, object); | |||
} | |||
/** | |||
* warn级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
*/ | |||
public static void warn(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.WARN, object); | |||
} | |||
/** | |||
* trace级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
*/ | |||
public static void trace(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.TRACE, object); | |||
} | |||
/** | |||
* 日志处理 | |||
* | |||
* @param logType 日志类型 | |||
* @param level 日志级别 | |||
* @param object 打印的日志参数 | |||
*/ | |||
private static void logHandle(LogEnum logType, Level level, Object[] object) { | |||
LogInfo logInfo = generateLogInfo(logType, level, object); | |||
String logInfoJsonStr = logJsonStringLengthLimit(logInfo); | |||
switch (Level.toLevel(logInfo.getLevel()).levelInt) { | |||
case Level.TRACE_INT: | |||
log.trace(logInfoJsonStr); | |||
break; | |||
case Level.DEBUG_INT: | |||
log.debug(logInfoJsonStr); | |||
break; | |||
case Level.INFO_INT: | |||
log.info(logInfoJsonStr); | |||
break; | |||
case Level.WARN_INT: | |||
log.warn(logInfoJsonStr); | |||
break; | |||
case Level.ERROR_INT: | |||
log.error(logInfoJsonStr); | |||
break; | |||
default: | |||
} | |||
} | |||
/** | |||
* 日志信息组装的内部方法 | |||
* | |||
* @param logType 日志类型 | |||
* @param level 日志级别 | |||
* @param object 打印的日志参数 | |||
* @return LogInfo日志对象信息 | |||
*/ | |||
private static LogInfo generateLogInfo(LogEnum logType, Level level, Object[] object) { | |||
LogInfo logInfo = new LogInfo(); | |||
// 日志类型检测 | |||
if (!LogEnum.isLogType(logType)) { | |||
level = Level.ERROR; | |||
object = new Object[MagicNumConstant.ONE]; | |||
object[MagicNumConstant.ZERO] = String.valueOf("logType【").concat(String.valueOf(logType)) | |||
.concat("】is error !"); | |||
logType = LogEnum.SYS_ERR; | |||
} | |||
// 获取trace_id | |||
if (StringUtils.isEmpty(MDC.get(LogAspect.TRACE_ID))) { | |||
MDC.put(LogAspect.TRACE_ID, UUID.randomUUID().toString()); | |||
} | |||
// 设置logInfo的level,type,traceId属性 | |||
logInfo.setLevel(level.levelStr).setType(logType.toString()).setTraceId(MDC.get(LogAspect.TRACE_ID)); | |||
// 设置logInfo的堆栈信息 | |||
setLogStackInfo(logInfo); | |||
// 设置logInfo的info信息 | |||
setLogInfo(logInfo, object); | |||
// 截取loginfo的长度并转换成json字符串 | |||
return logInfo; | |||
} | |||
/** | |||
* 设置loginfo的堆栈信息 | |||
* | |||
* @param logInfo 日志对象 | |||
*/ | |||
private static void setLogStackInfo(LogInfo logInfo) { | |||
StackTraceElement[] elements = Thread.currentThread().getStackTrace(); | |||
if (elements.length >= MagicNumConstant.SIX) { | |||
logInfo.setCName(elements[MagicNumConstant.FIVE].getClassName()) | |||
.setMName(elements[MagicNumConstant.FIVE].getMethodName()) | |||
.setLine(String.valueOf(elements[MagicNumConstant.FIVE].getLineNumber())); | |||
} | |||
} | |||
/** | |||
* 限制log日志的长度并转换成json | |||
* | |||
* @param logInfo 日志对象 | |||
* @return String 日志对象Json字符串 | |||
*/ | |||
private static String logJsonStringLengthLimit(LogInfo logInfo) { | |||
try { | |||
String jsonString = JSON.toJSONString(logInfo.getInfo()); | |||
if (jsonString.length() > MagicNumConstant.TEN_THOUSAND) { | |||
jsonString = jsonString.substring(MagicNumConstant.ZERO, MagicNumConstant.TEN_THOUSAND); | |||
} | |||
logInfo.setInfo(jsonString); | |||
jsonString = JSON.toJSONString(logInfo); | |||
jsonString = jsonString.replace(SymbolConstant.BACKSLASH_MARK, SymbolConstant.MARK) | |||
.replace(SymbolConstant.DOUBLE_MARK, SymbolConstant.MARK) | |||
.replace(SymbolConstant.BRACKETS, SymbolConstant.BLANK); | |||
return jsonString; | |||
} catch (Exception e) { | |||
logInfo.setLevel(Level.ERROR.levelStr).setType(LogEnum.SYS_ERR.toString()) | |||
.setInfo("cannot serialize exception: " + ExceptionUtils.getStackTrace(e)); | |||
return JSON.toJSONString(logInfo); | |||
} | |||
} | |||
/** | |||
* 设置日志对象的info信息 | |||
* | |||
* @param logInfo 日志对象 | |||
* @param object 打印的日志参数 | |||
*/ | |||
private static void setLogInfo(LogInfo logInfo, Object[] object) { | |||
for (Object obj : object) { | |||
if (obj instanceof Exception) { | |||
log.error((ExceptionUtils.getStackTrace((Throwable) obj))); | |||
} | |||
} | |||
if (object.length > MagicNumConstant.ONE) { | |||
logInfo.setInfo(MessageFormatter.arrayFormat(object[MagicNumConstant.ZERO].toString(), | |||
Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)).getMessage()); | |||
} else if (object.length == MagicNumConstant.ONE && object[MagicNumConstant.ZERO] instanceof Exception) { | |||
logInfo.setInfo((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO]))); | |||
} else if (object.length == MagicNumConstant.ONE) { | |||
logInfo.setInfo( | |||
object[MagicNumConstant.ZERO] == null ? SymbolConstant.BLANK : object[MagicNumConstant.ZERO]); | |||
} else { | |||
logInfo.setInfo(SymbolConstant.BLANK); | |||
} | |||
} | |||
/** | |||
* 处理Exception的情况 | |||
* | |||
* @param object 打印的日志参数 | |||
*/ | |||
private static void errorObjectHandle(Object[] object) { | |||
if (object.length >= MagicNumConstant.TWO) { | |||
object[MagicNumConstant.ZERO] = String.valueOf(object[MagicNumConstant.ZERO]) | |||
.concat(SymbolConstant.BRACKETS); | |||
} | |||
if (object.length == MagicNumConstant.TWO && object[MagicNumConstant.ONE] instanceof Exception) { | |||
log.error((ExceptionUtils.getStackTrace((Throwable) object[MagicNumConstant.ONE]))); | |||
object[MagicNumConstant.ONE] = ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ONE]); | |||
} else if (object.length >= MagicNumConstant.THREE) { | |||
for (int i = 0; i < object.length; i++) { | |||
if (object[i] instanceof Exception) { | |||
log.error((ExceptionUtils.getStackTrace((Throwable) object[i]))); | |||
object[i] = ExceptionUtils.getStackTrace((Exception) object[i]); | |||
} | |||
} | |||
} | |||
} | |||
private static final String TRACE_TYPE = "TRACE_TYPE"; | |||
public static final String SCHEDULE_LEVEL = "SCHEDULE"; | |||
public static final String K8S_CALLBACK_LEVEL = "K8S_CALLBACK"; | |||
private static final String GLOBAL_REQUEST_LEVEL = "GLOBAL_REQUEST"; | |||
private static final String TRACE_LEVEL = "TRACE"; | |||
private static final String DEBUG_LEVEL = "DEBUG"; | |||
private static final String INFO_LEVEL = "INFO"; | |||
private static final String WARN_LEVEL = "WARN"; | |||
private static final String ERROR_LEVEL = "ERROR"; | |||
public static void startScheduleTrace() { | |||
MDC.put(TRACE_TYPE, SCHEDULE_LEVEL); | |||
} | |||
public static void startK8sCallbackTrace() { | |||
MDC.put(TRACE_TYPE, K8S_CALLBACK_LEVEL); | |||
} | |||
public static void cleanTrace() { | |||
MDC.clear(); | |||
} | |||
/** | |||
* info级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
public static void info(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.INFO, object); | |||
} | |||
/** | |||
* debug级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
public static void debug(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.DEBUG, object); | |||
} | |||
/** | |||
* error级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
public static void error(LogEnum logType, Object... object) { | |||
errorObjectHandle(object); | |||
logHandle(logType, Level.ERROR, object); | |||
} | |||
/** | |||
* warn级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
public static void warn(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.WARN, object); | |||
} | |||
/** | |||
* trace级别的日志 | |||
* | |||
* @param logType 日志类型 | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
public static void trace(LogEnum logType, Object... object) { | |||
logHandle(logType, Level.TRACE, object); | |||
} | |||
/** | |||
* 日志处理 | |||
* | |||
* @param logType 日志类型 | |||
* @param level 日志级别 | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
private static void logHandle(LogEnum logType, Level level, Object[] object) { | |||
LogInfo logInfo = generateLogInfo(logType, level, object); | |||
switch (logInfo.getLevel()) { | |||
case TRACE_LEVEL: | |||
log.trace(MarkerFactory.getMarker(TRACE_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
case DEBUG_LEVEL: | |||
log.debug(MarkerFactory.getMarker(DEBUG_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
case GLOBAL_REQUEST_LEVEL: | |||
logInfo.setLevel(null); | |||
logInfo.setType(null); | |||
logInfo.setLocation(null); | |||
log.info(MarkerFactory.getMarker(GLOBAL_REQUEST_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
case SCHEDULE_LEVEL: | |||
log.info(MarkerFactory.getMarker(SCHEDULE_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
case K8S_CALLBACK_LEVEL: | |||
log.info(MarkerFactory.getMarker(K8S_CALLBACK_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
case INFO_LEVEL: | |||
log.info(MarkerFactory.getMarker(INFO_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
case WARN_LEVEL: | |||
log.warn(MarkerFactory.getMarker(WARN_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
case ERROR_LEVEL: | |||
log.error(MarkerFactory.getMarker(ERROR_LEVEL), logJsonStringLengthLimit(logInfo)); | |||
break; | |||
default: | |||
} | |||
} | |||
/** | |||
* 日志信息组装的内部方法 | |||
* | |||
* @param logType 日志类型 | |||
* @param level 日志级别 | |||
* @param object 打印的日志参数 | |||
* @return LogInfo | |||
*/ | |||
private static LogInfo generateLogInfo(LogEnum logType, Level level, Object[] object) { | |||
LogInfo logInfo = new LogInfo(); | |||
// 日志类型检测 | |||
if (!LogEnum.isLogType(logType)) { | |||
level = Level.ERROR; | |||
object = new Object[MagicNumConstant.ONE]; | |||
object[MagicNumConstant.ZERO] = "日志类型【".concat(String.valueOf(logType)).concat("】不正确!"); | |||
logType = LogEnum.SYS_ERR; | |||
} | |||
// 获取trace_id | |||
if (StringUtils.isEmpty(MDC.get(LogAspect.TRACE_ID))) { | |||
MDC.put(LogAspect.TRACE_ID, UUID.randomUUID().toString()); | |||
} | |||
// 设置logInfo的level,type,traceId属性 | |||
logInfo.setLevel(level.levelStr) | |||
.setType(logType.toString()) | |||
.setTraceId(MDC.get(LogAspect.TRACE_ID)); | |||
//自定义日志级别 | |||
//LogEnum、 MDC中的 TRACE_TYPE 做日志分流标识 | |||
if (Level.INFO.toInt() == level.toInt()) { | |||
if (LogEnum.GLOBAL_REQ.equals(logType)) { | |||
//info全局请求 | |||
logInfo.setLevel(GLOBAL_REQUEST_LEVEL); | |||
} else if (LogEnum.BIZ_K8S.equals(logType)) { | |||
logInfo.setLevel(K8S_CALLBACK_LEVEL); | |||
} else { | |||
//schedule定时等 链路记录 | |||
String traceType = MDC.get(TRACE_TYPE); | |||
if (StringUtils.isNotBlank(traceType)) { | |||
logInfo.setLevel(traceType); | |||
} | |||
} | |||
} | |||
// 设置logInfo的堆栈信息 | |||
setLogStackInfo(logInfo); | |||
// 设置logInfo的info信息 | |||
setLogInfo(logInfo, object); | |||
// 截取logInfo的长度并转换成json字符串 | |||
return logInfo; | |||
} | |||
/** | |||
* 设置loginfo的堆栈信息 | |||
* | |||
* @param logInfo 日志对象 | |||
* @return void | |||
*/ | |||
private static void setLogStackInfo(LogInfo logInfo) { | |||
StackTraceElement[] elements = Thread.currentThread().getStackTrace(); | |||
if (elements.length >= MagicNumConstant.SIX) { | |||
StackTraceElement element = elements[MagicNumConstant.FIVE]; | |||
logInfo.setLocation(String.format("%s#%s:%s", element.getClassName(), element.getMethodName(), element.getLineNumber())); | |||
} | |||
} | |||
/** | |||
* 限制log日志的长度并转换成json | |||
* | |||
* @param logInfo 日志对象 | |||
* @return String | |||
*/ | |||
private static String logJsonStringLengthLimit(LogInfo logInfo) { | |||
try { | |||
String jsonString = JSON.toJSONString(logInfo); | |||
if (StringUtils.isBlank(jsonString)) { | |||
return ""; | |||
} | |||
if (jsonString.length() > MagicNumConstant.TEN_THOUSAND) { | |||
String trunk = logInfo.getInfo().toString().substring(MagicNumConstant.ZERO, MagicNumConstant.NINE_THOUSAND); | |||
logInfo.setInfo(trunk); | |||
jsonString = JSON.toJSONString(logInfo); | |||
} | |||
return jsonString; | |||
} catch (Exception e) { | |||
logInfo.setLevel(Level.ERROR.levelStr).setType(LogEnum.SYS_ERR.toString()) | |||
.setInfo("cannot serialize exception: " + ExceptionUtils.getStackTrace(e)); | |||
return JSON.toJSONString(logInfo); | |||
} | |||
} | |||
/** | |||
* 设置日志对象的info信息 | |||
* | |||
* @param logInfo 日志对象 | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
private static void setLogInfo(LogInfo logInfo, Object[] object) { | |||
if (object.length > MagicNumConstant.ONE) { | |||
logInfo.setInfo(MessageFormatter.arrayFormat(object[MagicNumConstant.ZERO].toString(), | |||
Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)).getMessage()); | |||
} else if (object.length == MagicNumConstant.ONE && object[MagicNumConstant.ZERO] instanceof Exception) { | |||
logInfo.setInfo((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO]))); | |||
log.error((ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ZERO]))); | |||
} else if (object.length == MagicNumConstant.ONE) { | |||
logInfo.setInfo(object[MagicNumConstant.ZERO] == null ? "" : object[MagicNumConstant.ZERO]); | |||
} else { | |||
logInfo.setInfo(""); | |||
} | |||
} | |||
/** | |||
* 处理Exception的情况 | |||
* | |||
* @param object 打印的日志参数 | |||
* @return void | |||
*/ | |||
private static void errorObjectHandle(Object[] object) { | |||
if (object.length == MagicNumConstant.TWO && object[MagicNumConstant.ONE] instanceof Exception) { | |||
log.error(String.valueOf(object[MagicNumConstant.ZERO]), (Exception) object[MagicNumConstant.ONE]); | |||
object[MagicNumConstant.ONE] = ExceptionUtils.getStackTrace((Exception) object[MagicNumConstant.ONE]); | |||
} else if (object.length >= MagicNumConstant.THREE) { | |||
log.error(String.valueOf(object[MagicNumConstant.ZERO]), | |||
Arrays.copyOfRange(object, MagicNumConstant.ONE, object.length)); | |||
for (int i = 0; i < object.length; i++) { | |||
if (object[i] instanceof Exception) { | |||
object[i] = ExceptionUtils.getStackTrace((Exception) object[i]); | |||
} | |||
} | |||
} | |||
} | |||
} |
@@ -20,8 +20,8 @@ package org.dubhe.utils; | |||
import org.dubhe.base.MagicNumConstant; | |||
/** | |||
* @description: 计算工具类 | |||
* @create: 2020/6/4 14:53 | |||
* @description 计算工具类 | |||
* @date 2020-06-04 | |||
*/ | |||
public class MathUtils { | |||
@@ -68,9 +68,9 @@ public class MinioUtil { | |||
try { | |||
client = new MinioClient(url, accessKey, secretKey); | |||
} catch (InvalidEndpointException e) { | |||
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint invalid. e:", e); | |||
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint invalid. e, {}", e); | |||
} catch (InvalidPortException e) { | |||
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint port invalid. e:", e); | |||
LogUtil.warn(LogEnum.BIZ_DATASET, "MinIO endpoint port invalid. e, {}", e); | |||
} | |||
} | |||
@@ -94,7 +94,7 @@ public class MinioUtil { | |||
/** | |||
* 读取文件 | |||
* | |||
* @param bucket | |||
* @param bucket 桶 | |||
* @param fullFilePath 文件存储的全路径,包括文件名,非'/'开头. e.g. dataset/12/annotation/test.txt | |||
* @return String | |||
*/ | |||
@@ -107,7 +107,7 @@ public class MinioUtil { | |||
/** | |||
* 文件删除 | |||
* | |||
* @param bucket | |||
* @param bucket 桶 | |||
* @param fullFilePath 文件存储的全路径,包括文件名,非'/'开头. e.g. dataset/12/annotation/test.txt | |||
*/ | |||
public void del(String bucket, String fullFilePath) throws Exception { | |||
@@ -125,8 +125,8 @@ public class MinioUtil { | |||
/** | |||
* 批量删除文件 | |||
* | |||
* @param bucket | |||
* @param objectNames | |||
* @param bucket 桶 | |||
* @param objectNames 对象名称 | |||
*/ | |||
public void delFiles(String bucket,List<String> objectNames) throws Exception{ | |||
Iterable<Result<DeleteError>> results = client.removeObjects(bucket, objectNames); | |||
@@ -138,6 +138,10 @@ public class MinioUtil { | |||
/** | |||
* 获取对象名称 | |||
* | |||
* @param bucketName 桶名称 | |||
* @param prefix 前缀 | |||
* @return | |||
* @throws Exception | |||
*/ | |||
public List<String> getObjects(String bucketName, String prefix)throws Exception{ | |||
List<String> fileNames = new ArrayList<>(); | |||
@@ -152,6 +156,10 @@ public class MinioUtil { | |||
/** | |||
* 获取文件流 | |||
* | |||
* @param bucket 桶 | |||
* @param objectName 对象名称 | |||
* @return | |||
* @throws Exception | |||
*/ | |||
public InputStream getObjectInputStream(String bucket,String objectName)throws Exception{ | |||
return client.getObject(bucket, objectName); | |||
@@ -160,9 +168,9 @@ public class MinioUtil { | |||
/** | |||
* 文件夹复制 | |||
* | |||
* @param bucket | |||
* @param bucket 桶 | |||
* @param sourceFiles 源文件 | |||
* @param targetDir 目标文件夹 | |||
* @param targetDir 目标文件夹 | |||
*/ | |||
public void copyDir(String bucket, List<String> sourceFiles, String targetDir) { | |||
sourceFiles.forEach(sourceFile -> { | |||
@@ -211,10 +219,10 @@ public class MinioUtil { | |||
/** | |||
* 生成文件下载请求参数方法 | |||
* | |||
* @param bucketName | |||
* @param prefix | |||
* @param objects | |||
* @return MinioDownloadDto | |||
* @param bucketName 桶名称 | |||
* @param prefix 前缀 | |||
* @param objects 对象名称 | |||
* @return MinioDownloadDto 下载请求参数 | |||
*/ | |||
public MinioDownloadDto getDownloadParam(String bucketName, String prefix, List<String> objects, String zipName) { | |||
String paramTemplate = "{\"id\":%d,\"jsonrpc\":\"%s\",\"params\":{\"username\":\"%s\",\"password\":\"%s\"},\"method\":\"%s\"}"; | |||
@@ -17,6 +17,7 @@ | |||
package org.dubhe.utils; | |||
import cn.hutool.core.util.StrUtil; | |||
import com.emc.ecs.nfsclient.nfs.io.Nfs3File; | |||
import com.emc.ecs.nfsclient.nfs.io.NfsFileInputStream; | |||
import com.emc.ecs.nfsclient.nfs.io.NfsFileOutputStream; | |||
@@ -31,7 +32,6 @@ import org.dubhe.exception.NfsBizException; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.stereotype.Component; | |||
import org.springframework.util.CollectionUtils; | |||
import org.springframework.util.FileCopyUtils; | |||
import java.io.*; | |||
import java.util.ArrayList; | |||
@@ -130,14 +130,14 @@ public class NfsUtil { | |||
} | |||
/** | |||
* 校验文件或文件夹是否存在 | |||
* 校验文件或文件夹是否不存在 | |||
* | |||
* @param path 文件路径 | |||
* @return boolean | |||
*/ | |||
public boolean fileOrDirIsEmpty(String path) { | |||
if (!StringUtils.isEmpty(path)) { | |||
path = formatPath(path); | |||
path = formatPath(path.startsWith(nfsConfig.getRootDir()) ? path.replaceFirst(nfsConfig.getRootDir(), StrUtil.SLASH) : path); | |||
Nfs3File nfs3File = getNfs3File(path); | |||
try { | |||
if (nfs3File.exists()) { | |||
@@ -398,30 +398,6 @@ public class NfsUtil { | |||
} | |||
/** | |||
* NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况 | |||
* | |||
* 通过本地文件复制方式 | |||
* | |||
* @param sourcePath 需要复制的文件目录 例如:/abc/def | |||
* @param targetPath 需要放置的目标目录 例如:/abc/dd | |||
* @return boolean | |||
*/ | |||
public boolean copyPath(String sourcePath, String targetPath) { | |||
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { | |||
return false; | |||
} | |||
sourcePath = formatPath(sourcePath); | |||
targetPath = formatPath(targetPath); | |||
try { | |||
return copyLocalPath(nfsConfig.getRootDir() + sourcePath, nfsConfig.getRootDir() + targetPath); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.NFS_UTIL, " copyPath 复制失败: ", e); | |||
return false; | |||
} | |||
} | |||
/** | |||
* NFS 复制目录到指定目录下 多个文件 包含目录与文件并存情况 | |||
* | |||
@@ -468,126 +444,15 @@ public class NfsUtil { | |||
} | |||
} | |||
/** | |||
* 复制文件到指定目录下 单个文件 | |||
* | |||
* @param sourcePath 需要复制的文件 例如:/abc/def/cc.txt | |||
* @param targetPath 需要放置的目标目录 例如:/abc/dd | |||
* @return boolean | |||
*/ | |||
public boolean copyLocalFile(String sourcePath, String targetPath) { | |||
LogUtil.info(LogEnum.NFS_UTIL, "复制文件原路径: {} ,目标路径: {}", sourcePath, targetPath); | |||
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { | |||
return false; | |||
} | |||
sourcePath = formatPath(sourcePath); | |||
targetPath = formatPath(targetPath); | |||
LogUtil.info(LogEnum.NFS_UTIL, "过滤后文件原路径: {} ,目标路径:{}", sourcePath, targetPath); | |||
try (InputStream input = new FileInputStream(sourcePath); | |||
FileOutputStream output = new FileOutputStream(targetPath)) { | |||
FileCopyUtils.copy(input, output); | |||
LogUtil.info(LogEnum.NFS_UTIL, "复制文件成功"); | |||
return true; | |||
} catch (IOException e) { | |||
LogUtil.error(LogEnum.NFS_UTIL, " copyLocalFile 复制失败: ", e); | |||
return false; | |||
} | |||
} | |||
/** | |||
* 复制文件 到指定目录下 多个文件 包含目录与文件并存情况 | |||
* | |||
* @param sourcePath 需要复制的文件目录 例如:/abc/def | |||
* @param targetPath 需要放置的目标目录 例如:/abc/dd | |||
* @return boolean | |||
*/ | |||
public boolean copyLocalPath(String sourcePath, String targetPath) { | |||
if (!StringUtils.isEmpty(sourcePath) && !StringUtils.isEmpty(targetPath)) { | |||
sourcePath = formatPath(sourcePath); | |||
if (sourcePath.endsWith(FILE_SEPARATOR)) { | |||
sourcePath = sourcePath.substring(MagicNumConstant.ZERO, sourcePath.lastIndexOf(FILE_SEPARATOR)); | |||
} | |||
targetPath = formatPath(targetPath); | |||
LogUtil.info(LogEnum.NFS_UTIL, "复制文件夹 原路径: {} , 目标路径 : {} ", sourcePath, targetPath); | |||
File[] files = new File(sourcePath).listFiles(); | |||
LogUtil.info(LogEnum.NFS_UTIL, "需要复制的文件数量为: {}", files.length); | |||
if (files.length != 0) { | |||
for (File file : files) { | |||
try { | |||
if (file.isDirectory()) { | |||
LogUtil.info(LogEnum.NFS_UTIL, "需要复制夹: {}", file.getAbsolutePath()); | |||
LogUtil.info(LogEnum.NFS_UTIL, "目标文件夹: {}", targetPath + FILE_SEPARATOR + file.getName()); | |||
File fileDir = new File(targetPath + FILE_SEPARATOR + file.getName()); | |||
if(!fileDir.exists()){ | |||
fileDir.mkdirs(); | |||
} | |||
copyLocalPath(sourcePath + FILE_SEPARATOR + file.getName(), targetPath + FILE_SEPARATOR + file.getName()); | |||
} | |||
if (file.isFile()) { | |||
File fileTargetPath = new File(targetPath); | |||
if(!fileTargetPath.exists()){ | |||
fileTargetPath.mkdirs(); | |||
} | |||
LogUtil.info(LogEnum.NFS_UTIL, "需要复制文件: {}", file.getAbsolutePath()); | |||
LogUtil.info(LogEnum.NFS_UTIL, "需要复制文件名称: {}", file.getName()); | |||
copyLocalFile(file.getAbsolutePath() , targetPath + FILE_SEPARATOR + file.getName()); | |||
} | |||
}catch (Exception e){ | |||
LogUtil.error(LogEnum.NFS_UTIL, "复制文件夹失败: {}", e); | |||
return false; | |||
} | |||
} | |||
} | |||
return true; | |||
} | |||
return false; | |||
} | |||
/** | |||
* 解压前清理同路径下其他文件(目前只支持路径下无文件夹,文件均为zip文件) | |||
* 上传路径垃圾文件清理 | |||
* | |||
* @param zipFilePath zip源文件 例如:/abc/z.zip | |||
* @param path 文件夹 例如:/abc/ | |||
* @return boolean | |||
*/ | |||
public boolean cleanPath(String zipFilePath, String path) { | |||
if (!StringUtils.isEmpty(zipFilePath) && !StringUtils.isEmpty(path) && zipFilePath.toLowerCase().endsWith(ZIP)) { | |||
zipFilePath = formatPath(zipFilePath); | |||
path = formatPath(path); | |||
Nfs3File nfs3Files = getNfs3File(path); | |||
try { | |||
String zipName = zipFilePath.substring(zipFilePath.lastIndexOf(FILE_SEPARATOR) + MagicNumConstant.ONE); | |||
if (!StringUtils.isEmpty(zipName)) { | |||
List<Nfs3File> nfs3FilesList = nfs3Files.listFiles(); | |||
if (!CollectionUtils.isEmpty(nfs3FilesList)) { | |||
for (Nfs3File nfs3File : nfs3FilesList) { | |||
if (!zipName.equals(nfs3File.getName())) { | |||
nfs3File.delete(); | |||
} | |||
} | |||
return true; | |||
} | |||
} | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.NFS_UTIL, "路径{}清理失败,错误原因为:{} ", path, e); | |||
return false; | |||
} finally { | |||
nfsPool.revertNfs(nfs3Files.getNfs()); | |||
} | |||
} | |||
return false; | |||
} | |||
/** | |||
* zip解压并删除压缩文件 | |||
* | |||
* 当压缩包文件较多时,可能会因为RPC超时而解压失败 | |||
* 该方法已废弃,请使用org.dubhe.utils.LocalFileUtil类的unzipLocalPath方法来替代 | |||
* @param sourcePath zip源文件 例如:/abc/z.zip | |||
* @param targetPath 解压后的目标文件夹 例如:/abc/ | |||
* @return boolean | |||
*/ | |||
@Deprecated | |||
public boolean unzip(String sourcePath, String targetPath) { | |||
if (StringUtils.isEmpty(sourcePath) || StringUtils.isEmpty(targetPath)) { | |||
return false; | |||
@@ -894,11 +759,15 @@ public class NfsUtil { | |||
* @param path | |||
* @return String | |||
*/ | |||
private String formatPath(String path) { | |||
public String formatPath(String path) { | |||
if (!StringUtils.isEmpty(path)) { | |||
return path.replaceAll("///*", FILE_SEPARATOR); | |||
} | |||
return path; | |||
} | |||
public String getAbsolutePath(String relativePath) { | |||
return nfsConfig.getRootDir() + nfsConfig.getBucket() + relativePath; | |||
} | |||
} |
@@ -26,6 +26,9 @@ import org.springframework.data.redis.core.Cursor; | |||
import org.springframework.data.redis.core.RedisConnectionUtils; | |||
import org.springframework.data.redis.core.RedisTemplate; | |||
import org.springframework.data.redis.core.ScanOptions; | |||
import org.springframework.data.redis.core.script.DefaultRedisScript; | |||
import org.springframework.data.redis.core.script.RedisScript; | |||
import org.springframework.data.redis.serializer.Jackson2JsonRedisSerializer; | |||
import org.springframework.stereotype.Component; | |||
import org.springframework.util.CollectionUtils; | |||
@@ -33,7 +36,7 @@ import java.util.*; | |||
import java.util.concurrent.TimeUnit; | |||
/** | |||
* @description redis工具类 | |||
* @description redis工具类 | |||
* @date 2020-03-13 | |||
*/ | |||
@Component | |||
@@ -62,7 +65,7 @@ public class RedisUtils { | |||
redisTemplate.expire(key, time, TimeUnit.SECONDS); | |||
} | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils expire key {} time {} error:{}",key,time,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils expire key {} time {} error:{}", key, time, e.getMessage(), e); | |||
return false; | |||
} | |||
return true; | |||
@@ -82,7 +85,7 @@ public class RedisUtils { | |||
* 查找匹配key | |||
* | |||
* @param pattern key | |||
* @return / | |||
* @return List<String> 匹配的key集合 | |||
*/ | |||
public List<String> scan(String pattern) { | |||
ScanOptions options = ScanOptions.scanOptions().match(pattern).build(); | |||
@@ -94,9 +97,9 @@ public class RedisUtils { | |||
result.add(new String(cursor.next())); | |||
} | |||
try { | |||
RedisConnectionUtils.releaseConnection(rc, factory,true); | |||
RedisConnectionUtils.releaseConnection(rc, factory, true); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils scan pattern {} error:{}",pattern,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils scan pattern {} error:{}", pattern, e.getMessage(), e); | |||
} | |||
return result; | |||
} | |||
@@ -107,7 +110,7 @@ public class RedisUtils { | |||
* @param patternKey key | |||
* @param page 页码 | |||
* @param size 每页数目 | |||
* @return / | |||
* @return 匹配到的key集合 | |||
*/ | |||
public List<String> findKeysForPage(String patternKey, int page, int size) { | |||
ScanOptions options = ScanOptions.scanOptions().match(patternKey).build(); | |||
@@ -132,9 +135,9 @@ public class RedisUtils { | |||
cursor.next(); | |||
} | |||
try { | |||
RedisConnectionUtils.releaseConnection(rc, factory,true); | |||
RedisConnectionUtils.releaseConnection(rc, factory, true); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils findKeysForPage patternKey {} page {} size {} error:{}",patternKey,page,size,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils findKeysForPage patternKey {} page {} size {} error:{}", patternKey, page, size, e.getMessage(), e); | |||
} | |||
return result; | |||
} | |||
@@ -149,7 +152,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.hasKey(key); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hasKey key {} error:{}",key,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hasKey key {} error:{}", key, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -175,7 +178,7 @@ public class RedisUtils { | |||
* 普通缓存获取 | |||
* | |||
* @param key 键 | |||
* @return 值 | |||
* @return key对应的value值 | |||
*/ | |||
public Object get(String key) { | |||
@@ -185,8 +188,8 @@ public class RedisUtils { | |||
/** | |||
* 批量获取 | |||
* | |||
* @param keys | |||
* @return | |||
* @param keys key集合 | |||
* @return key集合对应的value集合 | |||
*/ | |||
public List<Object> multiGet(List<String> keys) { | |||
Object obj = redisTemplate.opsForValue().multiGet(Collections.singleton(keys)); | |||
@@ -205,7 +208,7 @@ public class RedisUtils { | |||
redisTemplate.opsForValue().set(key, value); | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} error:{}",key,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} error:{}", key, value, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -227,7 +230,7 @@ public class RedisUtils { | |||
} | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} time {} error:{}",key,value,time,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} time {} error:{}", key, value, time, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -250,11 +253,56 @@ public class RedisUtils { | |||
} | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils set key {} value {} time {} timeUnit {} error:{}",key,value,time,timeUnit,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils set key {} value {} time {} timeUnit {} error:{}", key, value, time, timeUnit, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
//===============================Lock================================= | |||
/** | |||
* 加锁 | |||
* @param key 键 | |||
* @param requestId 请求id用以释放锁 | |||
* @param expireTime 超时时间(秒) | |||
* @return | |||
*/ | |||
public boolean getDistributedLock(String key, String requestId, long expireTime) { | |||
String script = "if redis.call('setNx',KEYS[1],ARGV[1]) == 1 then if redis.call('get',KEYS[1]) == ARGV[1] then return redis.call('expire',KEYS[1],ARGV[2]) else return 0 end else return 0 end"; | |||
Object result = executeRedisScript(script, key, requestId, expireTime); | |||
return result != null && result.equals(MagicNumConstant.ONE_LONG); | |||
} | |||
/** | |||
* 释放锁 | |||
* @param key 键 | |||
* @param requestId 请求id用以释放锁 | |||
* @return | |||
*/ | |||
public boolean releaseDistributedLock(String key, String requestId) { | |||
String script = "if redis.call('get', KEYS[1]) == ARGV[1] then return redis.call('del', KEYS[1]) else return 0 end"; | |||
Object result = executeRedisScript(script, key, requestId); | |||
return result != null && result.equals(MagicNumConstant.ONE_LONG); | |||
} | |||
/** | |||
* | |||
* @param script 脚本字符串 | |||
* @param key 键 | |||
* @param args 脚本其他参数 | |||
* @return | |||
*/ | |||
public Object executeRedisScript(String script, String key, Object... args) { | |||
try { | |||
RedisScript<Long> redisScript = new DefaultRedisScript<>(script, Long.class); | |||
redisTemplate.setValueSerializer(new Jackson2JsonRedisSerializer<>(Object.class)); | |||
return redisTemplate.execute(redisScript, Collections.singletonList(key), args); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR, "executeRedisScript script {} key {} expireTime {} args {} error:{}", script, key, args, e); | |||
return MagicNumConstant.ZERO_LONG; | |||
} | |||
} | |||
// ================================Map================================= | |||
/** | |||
@@ -291,7 +339,7 @@ public class RedisUtils { | |||
redisTemplate.opsForHash().putAll(key, map); | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hmset key {} map {} error:{}",key,map,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hmset key {} map {} error:{}", key, map, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -312,7 +360,7 @@ public class RedisUtils { | |||
} | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hmset key {} map {} time {} error:{}",key,map,time,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hmset key {} map {} time {} error:{}", key, map, time, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -330,7 +378,7 @@ public class RedisUtils { | |||
redisTemplate.opsForHash().put(key, item, value); | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hset key {} item {} value {} error:{}",key,item,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hset key {} item {} value {} error:{}", key, item, value, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -352,7 +400,7 @@ public class RedisUtils { | |||
} | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils hset key {} item {} value {} time {} error:{}",key,item,value,time,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils hset key {} item {} value {} time {} error:{}", key, item, value, time, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -414,7 +462,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForSet().members(key); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sGet key {} error:{}",key,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sGet key {} error:{}", key, e.getMessage(), e); | |||
return null; | |||
} | |||
} | |||
@@ -430,7 +478,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForSet().isMember(key, value); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sHasKey key {} value {} error:{}",key,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sHasKey key {} value {} error:{}", key, value, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -446,7 +494,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForSet().add(key, values); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sSet key {} values {} error:{}",key,values,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sSet key {} values {} error:{}", key, values, e.getMessage(), e); | |||
return 0; | |||
} | |||
} | |||
@@ -467,7 +515,7 @@ public class RedisUtils { | |||
} | |||
return count; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sSetAndTime key {} time {} values {} error:{}",key,time,values,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sSetAndTime key {} time {} values {} error:{}", key, time, values, e.getMessage(), e); | |||
return 0; | |||
} | |||
} | |||
@@ -482,7 +530,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForSet().size(key); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils sGetSetSize key {} error:{}",key,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils sGetSetSize key {} error:{}", key, e.getMessage(), e); | |||
return 0; | |||
} | |||
} | |||
@@ -499,7 +547,7 @@ public class RedisUtils { | |||
Long count = redisTemplate.opsForSet().remove(key, values); | |||
return count; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils setRemove key {} values {} error:{}",key,values,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils setRemove key {} values {} error:{}", key, values, e.getMessage(), e); | |||
return 0; | |||
} | |||
} | |||
@@ -507,22 +555,22 @@ public class RedisUtils { | |||
// ===============================sorted set================================= | |||
/** | |||
*将zSet数据放入缓存 | |||
* 将zSet数据放入缓存 | |||
* | |||
* @param key | |||
* @param time | |||
* @param values | |||
* @return Boolean | |||
*/ | |||
public Boolean zSet(String key, long time, Object value){ | |||
public Boolean zSet(String key, long time, Object value) { | |||
try { | |||
Boolean success = redisTemplate.opsForZSet().add(key, value,System.currentTimeMillis()); | |||
Boolean success = redisTemplate.opsForZSet().add(key, value, System.currentTimeMillis()); | |||
if (success) { | |||
expire(key, time); | |||
} | |||
return success; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils zSet key {} time {} value {} error:{}",key,time,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils zSet key {} time {} value {} error:{}", key, time, value, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -533,17 +581,16 @@ public class RedisUtils { | |||
* @param key | |||
* @return Set<Object> | |||
*/ | |||
public Set<Object> zGet(String key){ | |||
public Set<Object> zGet(String key) { | |||
try { | |||
return redisTemplate.opsForZSet().reverseRange(key,Long.MIN_VALUE, Long.MAX_VALUE); | |||
return redisTemplate.opsForZSet().reverseRange(key, Long.MIN_VALUE, Long.MAX_VALUE); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils zGet key {} error:{}",key,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils zGet key {} error:{}", key, e.getMessage(), e); | |||
return null; | |||
} | |||
} | |||
// ===============================list================================= | |||
/** | |||
@@ -558,7 +605,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForList().range(key, start, end); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetIndex key {} start {} end {} error:{}",key,start,end,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetIndex key {} start {} end {} error:{}", key, start, end, e.getMessage(), e); | |||
return null; | |||
} | |||
} | |||
@@ -573,7 +620,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForList().size(key); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetListSize key {} error:{}",key,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetListSize key {} error:{}", key, e.getMessage(), e); | |||
return 0; | |||
} | |||
} | |||
@@ -589,7 +636,7 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForList().index(key, index); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lGetIndex key {} index {} error:{}",key,index,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lGetIndex key {} index {} error:{}", key, index, e.getMessage(), e); | |||
return null; | |||
} | |||
} | |||
@@ -606,7 +653,7 @@ public class RedisUtils { | |||
redisTemplate.opsForList().rightPush(key, value); | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} error:{}",key,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} error:{}", key, value, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -617,7 +664,7 @@ public class RedisUtils { | |||
* @param key 键 | |||
* @param value 值 | |||
* @param time 时间(秒) | |||
* @return | |||
* @return 是否存储成功 | |||
*/ | |||
public boolean lSet(String key, Object value, long time) { | |||
try { | |||
@@ -627,7 +674,7 @@ public class RedisUtils { | |||
} | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} time {} error:{}",key,value,time,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} time {} error:{}", key, value, time, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -637,14 +684,14 @@ public class RedisUtils { | |||
* | |||
* @param key 键 | |||
* @param value 值 | |||
* @return | |||
* @return 是否存储成功 | |||
*/ | |||
public boolean lSet(String key, List<Object> value) { | |||
try { | |||
redisTemplate.opsForList().rightPushAll(key, value); | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} error:{}",key,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} error:{}", key, value, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -655,7 +702,7 @@ public class RedisUtils { | |||
* @param key 键 | |||
* @param value 值 | |||
* @param time 时间(秒) | |||
* @return | |||
* @return 是否存储成功 | |||
*/ | |||
public boolean lSet(String key, List<Object> value, long time) { | |||
try { | |||
@@ -665,7 +712,7 @@ public class RedisUtils { | |||
} | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lSet key {} value {} time {} error:{}",key,value,time,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lSet key {} value {} time {} error:{}", key, value, time, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -676,14 +723,14 @@ public class RedisUtils { | |||
* @param key 键 | |||
* @param index 索引 | |||
* @param value 值 | |||
* @return / | |||
* @return 更新数据标识 | |||
*/ | |||
public boolean lUpdateIndex(String key, long index, Object value) { | |||
try { | |||
redisTemplate.opsForList().set(key, index, value); | |||
return true; | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lUpdateIndex key {} index {} value {} error:{}",key,index,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lUpdateIndex key {} index {} value {} error:{}", key, index, value, e.getMessage(), e); | |||
return false; | |||
} | |||
} | |||
@@ -700,8 +747,23 @@ public class RedisUtils { | |||
try { | |||
return redisTemplate.opsForList().remove(key, count, value); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR,"RedisUtils lRemove key {} count {} value {} error:{}",key,count,value,e.getMessage(),e); | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lRemove key {} count {} value {} error:{}", key, count, value, e.getMessage(), e); | |||
return 0; | |||
} | |||
} | |||
/** | |||
* 队列从左弹出数据 | |||
* | |||
* @param key | |||
* @return key对应的value值 | |||
*/ | |||
public Object lpop(String key) { | |||
try { | |||
return redisTemplate.opsForList().leftPop(key); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.SYS_ERR, "RedisUtils lRemove key {} error:{}", key, e.getMessage(), e); | |||
return null; | |||
} | |||
} | |||
} |
@@ -18,7 +18,8 @@ | |||
package org.dubhe.utils; | |||
import java.lang.reflect.Field; | |||
import java.util.*; | |||
import java.util.ArrayList; | |||
import java.util.List; | |||
/** | |||
* @description 反射工具类 | |||
@@ -23,8 +23,8 @@ import java.util.regex.Matcher; | |||
import java.util.regex.Pattern; | |||
/** | |||
* @description: 正则匹配工具类 | |||
* @create: 2020/4/23 13:51 | |||
* @description 正则匹配工具类 | |||
* @date 2020-04-23 | |||
*/ | |||
@Slf4j | |||
public class RegexUtil { | |||
@@ -17,11 +17,18 @@ | |||
package org.dubhe.utils; | |||
import org.dubhe.base.BaseService; | |||
import org.dubhe.base.DataContext; | |||
import java.util.HashSet; | |||
import java.util.Objects; | |||
import java.util.Set; | |||
/** | |||
* @description sql语句转换的工具类 | |||
* @date 2020-07-06 | |||
*/ | |||
public class SqlUtil { | |||
/** | |||
@@ -46,4 +53,46 @@ public class SqlUtil { | |||
return ""; | |||
} | |||
/** | |||
* 获取资源拥有着ID | |||
* | |||
* @return 资源拥有者id集合 | |||
*/ | |||
public static Set<Long> getResourceIds() { | |||
if (!Objects.isNull(DataContext.get())) { | |||
return DataContext.get().getResourceUserIds(); | |||
} | |||
Set<Long> ids = new HashSet<>(); | |||
Long id = JwtUtils.getCurrentUserDto().getId(); | |||
ids.add(id); | |||
return ids; | |||
} | |||
/** | |||
* 构建目标sql语句 | |||
* | |||
* @param originSql 原生sql | |||
* @param resourceUserIds 所属资源用户ids | |||
* @return 目标sql | |||
*/ | |||
public static String buildTargetSql(String originSql, Set<Long> resourceUserIds) { | |||
if (BaseService.isAdmin()) { | |||
return originSql; | |||
} | |||
String sqlWhereBefore = org.dubhe.utils.StringUtils.substringBefore(originSql.toLowerCase(), "where"); | |||
String sqlWhereAfter = org.dubhe.utils.StringUtils.substringAfter(originSql.toLowerCase(), "where"); | |||
StringBuffer buffer = new StringBuffer(); | |||
//操作的sql拼接 | |||
String targetSql = buffer.append(sqlWhereBefore).append(" where ").append(" origin_user_id in (") | |||
.append(org.dubhe.utils.StringUtils.join(resourceUserIds, ",")).append(") and ").append(sqlWhereAfter).toString(); | |||
return targetSql; | |||
} | |||
} |
@@ -273,4 +273,41 @@ public class StringUtils extends org.apache.commons.lang3.StringUtils { | |||
matcher.appendTail(sb); | |||
return sb.toString(); | |||
} | |||
/** | |||
* 字符串截取前 | |||
* @param str | |||
* @return | |||
*/ | |||
public static String substringBefore(String str, String separator){ | |||
if (!isEmpty(str) && separator != null) { | |||
if (separator.isEmpty()) { | |||
return ""; | |||
} else { | |||
int pos = str.indexOf(separator); | |||
return pos == -1 ? str : str.substring(0, pos); | |||
} | |||
} else { | |||
return str; | |||
} | |||
} | |||
/** | |||
* 字符串截取后 | |||
* @param str | |||
* @return | |||
*/ | |||
public static String substringAfter(String str, String separator){ | |||
if (isEmpty(str)) { | |||
return str; | |||
} else if (separator == null) { | |||
return ""; | |||
} else { | |||
int pos = str.indexOf(separator); | |||
return pos == -1 ? "" : str.substring(pos + separator.length()); | |||
} | |||
} | |||
} |
@@ -17,42 +17,33 @@ | |||
package org.dubhe.utils; | |||
import lombok.extern.slf4j.Slf4j; | |||
import java.text.ParseException; | |||
import java.text.SimpleDateFormat; | |||
import java.util.Calendar; | |||
import java.util.Date; | |||
import static org.dubhe.base.MagicNumConstant.EIGHT; | |||
/** | |||
* @description: UTC时间转换CST时间工具类 | |||
* @create: 2020/5/20 12:10 | |||
* @description 时间格式转换工具类 | |||
* @date 2020-05-20 | |||
*/ | |||
@Slf4j | |||
public class TimeTransferUtil { | |||
private static final String UTC_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.sss'Z'"; | |||
/** | |||
* @param utcTime | |||
* @return cstTime | |||
* Date转换为UTC时间 | |||
* | |||
* @param date | |||
* @return utcTime | |||
*/ | |||
public static String cstTransfer(String utcTime){ | |||
Date utcDate = null; | |||
/**2020-05-20T03:13:22Z 对应的时间格式 yyyy-MM-dd'T'HH:mm:ss'Z'**/ | |||
SimpleDateFormat utcSimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'"); | |||
try { | |||
utcDate = utcSimpleDateFormat.parse(utcTime); | |||
} catch (ParseException e) { | |||
log.info(e.getMessage()); | |||
return null; | |||
} | |||
/**System.out.println("UTC时间:"+date);**/ | |||
public static String dateTransferToUtc(Date date){ | |||
Calendar calendar = Calendar.getInstance(); | |||
calendar.setTime(utcDate); | |||
calendar.set(Calendar.HOUR,calendar.get(Calendar.HOUR)+8); | |||
SimpleDateFormat cstSimpleDateFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"); | |||
Date cstDate = calendar.getTime(); | |||
String cstTime = cstSimpleDateFormat.format(calendar.getTime()); | |||
return cstTime; | |||
calendar.setTime(date); | |||
/**UTC时间与CST时间相差8小时**/ | |||
calendar.set(Calendar.HOUR,calendar.get(Calendar.HOUR) - EIGHT); | |||
SimpleDateFormat utcSimpleDateFormat = new SimpleDateFormat(UTC_FORMAT); | |||
Date utcDate = calendar.getTime(); | |||
return utcSimpleDateFormat.format(utcDate); | |||
} | |||
} |
@@ -25,9 +25,8 @@ import java.util.concurrent.ConcurrentHashMap; | |||
import java.util.concurrent.atomic.AtomicInteger; | |||
/** | |||
* @description: 唯一码生成器 (依赖时间轴) | |||
* | |||
* @date 2020.05.08 | |||
* @description 唯一码生成器 (依赖时间轴) | |||
* @date 2020-05-08 | |||
*/ | |||
public class UniqueKeyGenerator { | |||
@@ -20,7 +20,6 @@ package org.dubhe.utils; | |||
import cn.hutool.core.util.ObjectUtil; | |||
import com.baomidou.mybatisplus.core.conditions.query.QueryWrapper; | |||
import lombok.extern.slf4j.Slf4j; | |||
import lombok.val; | |||
import org.dubhe.annotation.Query; | |||
import java.lang.reflect.Field; | |||
@@ -31,7 +30,7 @@ import java.util.List; | |||
/** | |||
* @description 构建Wrapper | |||
* @date 2020-03-15 13:52:30 | |||
* @date 2020-03-15 | |||
*/ | |||
@Slf4j | |||
public class WrapperHelp { | |||
@@ -23,8 +23,8 @@ import org.junit.Test; | |||
import static org.dubhe.utils.HttpUtils.isSuccess; | |||
/** | |||
* @description: HttpUtil | |||
* @date 2020.04.30 | |||
* @description HttpUtil | |||
* @date 2020-04-30 | |||
*/ | |||
public class HttpUtilsTest { | |||
@@ -0,0 +1,71 @@ | |||
#!/bin/bash | |||
PROG_NAME=$0 | |||
ACTION=$1 | |||
ENV=$2 | |||
APP_HOME=$3 | |||
APP_NAME=dubhe-${ENV} | |||
APP_HOME=$APP_HOME/${APP_NAME} # 从package.tgz中解压出来的jar包放到这个目录下 | |||
JAR_NAME=${APP_HOME}/dubhe-admin/target/dubhe-admin-1.0-exec.jar # jar包的名字 | |||
JAVA_OUT=/dev/null | |||
# 创建出相关目录 | |||
mkdir -p ${APP_HOME} | |||
mkdir -p ${APP_HOME}/logs | |||
usage() { | |||
echo "Usage: $PROG_NAME {start|stop|restart} {dev|test|prod}" | |||
exit 2 | |||
} | |||
start_application() { | |||
echo "starting java process" | |||
echo "nohup java -jar ${JAR_NAME} > ${JAVA_OUT} --spring.profiles.active=${ENV} 2>&1 &" | |||
nohup java -jar ${JAR_NAME} > ${JAVA_OUT} --spring.profiles.active=${ENV} 2>&1 & | |||
echo "started java process" | |||
} | |||
stop_application() { | |||
checkjavapid=`ps -ef | grep java | grep ${APP_NAME} | grep -v grep |grep -v 'deploy.sh'| awk '{print$2}'` | |||
if [ -z $checkjavapid ];then | |||
echo -e "\rno java process "$checkjavapid | |||
return | |||
fi | |||
echo "stop java process" | |||
times=60 | |||
for e in $(seq 60) | |||
do | |||
sleep 1 | |||
COSTTIME=$(($times - $e )) | |||
checkjavapid=`ps -ef | grep java | grep ${APP_NAME} | grep -v grep |grep -v 'deploy.sh'| awk '{print$2}'` | |||
if [ "$checkjavapid" != "" ];then | |||
echo "kill "$checkjavapid | |||
kill -9 $checkjavapid | |||
echo -e "\r -- stopping java lasts `expr $COSTTIME` seconds." | |||
else | |||
echo -e "\rjava process has exited" | |||
break; | |||
fi | |||
done | |||
echo "" | |||
} | |||
case "$ACTION" in | |||
start) | |||
start_application | |||
;; | |||
stop) | |||
stop_application | |||
;; | |||
restart) | |||
stop_application | |||
start_application | |||
;; | |||
*) | |||
usage | |||
;; | |||
esac |
@@ -75,6 +75,7 @@ | |||
<configuration> | |||
<skip>false</skip> | |||
<fork>true</fork> | |||
<classifier>exec</classifier> | |||
</configuration> | |||
</plugin> | |||
<!-- 跳过单元测试 --> | |||
@@ -14,7 +14,7 @@ | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.task; | |||
package org.dubhe.async; | |||
import cn.hutool.core.util.StrUtil; | |||
import org.dubhe.base.ResponseCode; | |||
@@ -25,6 +25,7 @@ import org.dubhe.enums.ImageStateEnum; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.exception.BusinessException; | |||
import org.dubhe.harbor.api.HarborApi; | |||
import org.dubhe.utils.IOUtil; | |||
import org.dubhe.utils.LogUtil; | |||
import org.dubhe.utils.StringUtils; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
@@ -62,46 +63,21 @@ public class HarborImagePushAsync { | |||
String imageResource = trainHarborConfig.getAddress() + StrUtil.SLASH + trainHarborConfig.getModelName() | |||
+ StrUtil.SLASH + imageNameandTag; | |||
String cmdStr = "docker login --username=" + trainHarborConfig.getUsername() + " " + trainHarborConfig.getAddress() + " --password=" + trainHarborConfig.getPassword() + " ; docker " + | |||
"load < " + imagePath + " |awk '{print $3}' |xargs -I str docker tag str " + imageResource + " ; docker push " + imageResource; | |||
"load < " + imagePath + " |awk '{print $3}' |xargs -I str docker tag str " + imageResource + " ; docker push " + imageResource + "; docker rmi " + imageResource; | |||
String[] cmd = {"/bin/bash", "-c", cmdStr}; | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "镜像上传执行脚本参数:{}", cmd); | |||
Process process = Runtime.getRuntime().exec(cmd); | |||
//读取标准输出流 | |||
BufferedReader brOut = new BufferedReader(new InputStreamReader(process.getInputStream())); | |||
//读取标准错误流 | |||
BufferedReader brErr = new BufferedReader(new InputStreamReader(process.getErrorStream())); | |||
String line; | |||
String outMessage = ""; | |||
String errMessage = ""; | |||
while ((line = brOut.readLine()) != null) { | |||
outMessage += line; | |||
} | |||
if (StringUtils.isNotEmpty(outMessage)) { | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "shell上传镜像输出信息:" + outMessage); | |||
} | |||
while ((line = brErr.readLine()) != null) { | |||
errMessage += line; | |||
} | |||
if (StringUtils.isNotEmpty(errMessage)) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "shell上传镜像异常信息:" + errMessage); | |||
} | |||
Integer status = process.waitFor(); | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "上传镜像状态:{}", status); | |||
if (status == null) { | |||
if (harborApi.isExistImage(ptImage.getImageUrl())) { | |||
updateImageStatus(ptImage, ImageStateEnum.SUCCESS.getCode()); | |||
} else { | |||
updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); | |||
} | |||
} else if (status == 0) { | |||
if (checkImagePushIsOk(ptImage, process)) { | |||
updateImageStatus(ptImage, ImageStateEnum.SUCCESS.getCode()); | |||
} else { | |||
updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); | |||
} | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "上传镜像异常:{}", e); | |||
updateImageStatus(ptImage, ImageStateEnum.FAIL.getCode()); | |||
throw new BusinessException("上传镜像异常!"); | |||
} | |||
} | |||
@@ -117,4 +93,52 @@ public class HarborImagePushAsync { | |||
ptImageMapper.updateById(ptImage); | |||
return ResponseCode.SUCCESS; | |||
} | |||
/** | |||
* 校验镜像是否上传成功 | |||
* | |||
* @param ptImage 镜像信息 | |||
* @param process process对象 | |||
* @return 是否上传成功 | |||
*/ | |||
public boolean checkImagePushIsOk(PtImage ptImage, Process process) { | |||
//读取标准输出流 | |||
BufferedReader brOut = new BufferedReader(new InputStreamReader(process.getInputStream())); | |||
//读取标准错误流 | |||
BufferedReader brErr = new BufferedReader(new InputStreamReader(process.getErrorStream())); | |||
String line; | |||
StringBuilder outMessage = new StringBuilder(); | |||
StringBuilder errMessage = new StringBuilder(); | |||
boolean isPushOk = true; | |||
try { | |||
while ((line = brOut.readLine()) != null) { | |||
outMessage.append(line); | |||
} | |||
if (StringUtils.isNotEmpty(outMessage)) { | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "shell上传镜像输出信息:{}", outMessage.toString()); | |||
} | |||
while ((line = brErr.readLine()) != null) { | |||
errMessage.append(line); | |||
} | |||
if (StringUtils.isNotEmpty(errMessage)) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "shell上传镜像异常信息:{}", errMessage.toString()); | |||
} | |||
Integer status = process.waitFor(); | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "上传镜像状态:{}", status); | |||
if (status == null) { | |||
if (!harborApi.isExistImage(ptImage.getImageUrl())) { | |||
isPushOk = false; | |||
} | |||
} else if (status != 0) { | |||
isPushOk = false; | |||
} | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "上传镜像异常:{}", e); | |||
return false; | |||
} finally { | |||
IOUtil.close(brErr, brOut); | |||
} | |||
return isPushOk; | |||
} | |||
} |
@@ -0,0 +1,144 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.async; | |||
import org.dubhe.config.TrainJobConfig; | |||
import org.dubhe.dao.PtTrainJobMapper; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.domain.entity.PtTrainJob; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.enums.TrainJobStatusEnum; | |||
import org.dubhe.enums.TrainTypeEnum; | |||
import org.dubhe.k8s.api.DistributeTrainApi; | |||
import org.dubhe.k8s.api.PodApi; | |||
import org.dubhe.k8s.api.TrainJobApi; | |||
import org.dubhe.k8s.domain.resource.BizPod; | |||
import org.dubhe.utils.*; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.scheduling.annotation.Async; | |||
import org.springframework.stereotype.Component; | |||
import java.time.LocalDateTime; | |||
import java.time.ZoneOffset; | |||
import java.time.format.DateTimeFormatter; | |||
import java.util.List; | |||
import java.util.function.Consumer; | |||
/** | |||
* @description 停止训练任务异步处理 | |||
* @date 2020-08-13 | |||
*/ | |||
@Component | |||
public class StopTrainJobAsync { | |||
@Autowired | |||
private K8sNameTool k8sNameTool; | |||
@Autowired | |||
private PodApi podApi; | |||
@Autowired | |||
private TrainJobApi trainJobApi; | |||
@Autowired | |||
private TrainJobConfig trainJobConfig; | |||
@Autowired | |||
private DistributeTrainApi distributeTrainApi; | |||
@Autowired | |||
private PtTrainJobMapper ptTrainJobMapper; | |||
/** | |||
* 停止任务 | |||
* | |||
* @param currentUser 用户 | |||
* @param jobList 任务集合 | |||
*/ | |||
@Async("trainExecutor") | |||
public void stopJobs(UserDTO currentUser, List<PtTrainJob> jobList) { | |||
String namespace = k8sNameTool.generateNamespace(currentUser.getId()); | |||
jobList.forEach(job -> { | |||
BizPod bizPod = podApi.getWithResourceName(namespace, job.getJobName()); | |||
if (!bizPod.isSuccess()) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job return code:{},message:{}", currentUser.getUsername(), Integer.valueOf(bizPod.getCode()), bizPod.getMessage()); | |||
} | |||
boolean bool = TrainTypeEnum.isDistributeTrain(job.getTrainType()) ? | |||
distributeTrainApi.deleteByResourceName(namespace, job.getJobName()).isSuccess() : | |||
trainJobApi.delete(namespace, job.getJobName()); | |||
if (!bool) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "User {} stops training Job and K8S fails in the stop process, namespace为{}, resourceName为{}", | |||
currentUser.getUsername(), namespace, job.getJobName()); | |||
} | |||
//更新训练状态 | |||
job.setRuntime(calculateRuntime(bizPod)) | |||
.setTrainStatus(TrainJobStatusEnum.STOP.getStatus()); | |||
ptTrainJobMapper.updateById(job); | |||
}); | |||
} | |||
/** | |||
* 计算job训练时长 | |||
* | |||
* @param bizPod pod信息 | |||
* @return String 训练时长 | |||
*/ | |||
private String calculateRuntime(BizPod bizPod) { | |||
return calculateRuntime(bizPod, (x) -> { | |||
}); | |||
} | |||
/** | |||
* 计算job训练时长 | |||
* | |||
* @param bizPod | |||
* @param consumer pod已经完成状态的回调函数 | |||
* @return res 返回训练时长 | |||
*/ | |||
private String calculateRuntime(BizPod bizPod, Consumer<String> consumer) { | |||
Long completedTime; | |||
if (StringUtils.isBlank(bizPod.getStartTime())) { | |||
return TrainUtil.INIT_RUNTIME; | |||
} | |||
Long startTime = transformTime(bizPod.getStartTime()); | |||
boolean hasCompleted = StringUtils.isNotBlank(bizPod.getCompletedTime()); | |||
completedTime = hasCompleted ? transformTime(bizPod.getCompletedTime()) : LocalDateTime.now().toEpochSecond(ZoneOffset.of(trainJobConfig.getPlusEight())); | |||
Long time = completedTime - startTime; | |||
String res = DubheDateUtil.convert2Str(time); | |||
if (hasCompleted) { | |||
consumer.accept(res); | |||
} | |||
return res; | |||
} | |||
/** | |||
* 时间转换 | |||
* | |||
* @param time 时间 | |||
* @return Long 时间戳 | |||
*/ | |||
private Long transformTime(String time) { | |||
LocalDateTime localDateTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_OFFSET_DATE_TIME); | |||
//没有根据时区做处理, 默认当前为东八区 | |||
localDateTime = localDateTime.plusHours(Long.valueOf(trainJobConfig.getEight())); | |||
return localDateTime.toEpochSecond(ZoneOffset.of(trainJobConfig.getPlusEight())); | |||
} | |||
} |
@@ -0,0 +1,115 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.async; | |||
import org.dubhe.config.NfsConfig; | |||
import org.dubhe.dao.PtTrainAlgorithmMapper; | |||
import org.dubhe.domain.dto.PtTrainAlgorithmCreateDTO; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.domain.entity.PtTrainAlgorithm; | |||
import org.dubhe.enums.AlgorithmStatusEnum; | |||
import org.dubhe.enums.BizNfsEnum; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.exception.BusinessException; | |||
import org.dubhe.service.NoteBookService; | |||
import org.dubhe.utils.K8sNameTool; | |||
import org.dubhe.utils.LocalFileUtil; | |||
import org.dubhe.utils.LogUtil; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.scheduling.annotation.Async; | |||
import org.springframework.stereotype.Component; | |||
/** | |||
* @description 异步上传算法 | |||
* @date 2020-08-10 | |||
*/ | |||
@Component | |||
public class TrainAlgorithmUploadAsync { | |||
@Autowired | |||
private LocalFileUtil localFileUtil; | |||
@Autowired | |||
private NfsConfig nfsConfig; | |||
@Autowired | |||
private K8sNameTool k8sNameTool; | |||
@Autowired | |||
private NoteBookService noteBookService; | |||
@Autowired | |||
private PtTrainAlgorithmMapper trainAlgorithmMapper; | |||
/** | |||
* 异步任务创建算法 | |||
* | |||
* @param user 当前登录用户信息 | |||
* @param ptTrainAlgorithm 算法信息 | |||
* @param trainAlgorithmCreateDTO 创建算法条件 | |||
*/ | |||
@Async("trainExecutor") | |||
public void createTrainAlgorithm(UserDTO user, PtTrainAlgorithm ptTrainAlgorithm, PtTrainAlgorithmCreateDTO trainAlgorithmCreateDTO) { | |||
String path = nfsConfig.getBucket() + trainAlgorithmCreateDTO.getCodeDir(); | |||
//校验创建算法来源(true:由fork创建算法,false:其它创建算法方式),若为true则拷贝预置算法文件至新路径 | |||
if (trainAlgorithmCreateDTO.getFork()) { | |||
//生成算法相对路径 | |||
String algorithmPath = k8sNameTool.getNfsPath(BizNfsEnum.ALGORITHM, user.getId()); | |||
//拷贝预置算法文件夹 | |||
boolean copyResult = localFileUtil.copyPath(path, nfsConfig.getBucket() + algorithmPath); | |||
if (!copyResult) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "The user {} copied the preset algorithm path {} successfully", user.getUsername(), path); | |||
updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, false); | |||
throw new BusinessException("内部错误"); | |||
} | |||
ptTrainAlgorithm.setCodeDir(algorithmPath); | |||
//修改算法上传状态 | |||
updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, true); | |||
} else { | |||
updateTrainAlgorithm(ptTrainAlgorithm, trainAlgorithmCreateDTO, true); | |||
} | |||
} | |||
/** | |||
* 更新上传算法状态 | |||
* | |||
* @param ptTrainAlgorithm 算法信息 | |||
* @param trainAlgorithmCreateDTO 创建算法的条件 | |||
* @param flag 创建算法是否成功(true:成功,false:失败) | |||
*/ | |||
public void updateTrainAlgorithm(PtTrainAlgorithm ptTrainAlgorithm, PtTrainAlgorithmCreateDTO trainAlgorithmCreateDTO, boolean flag) { | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "async update algorithmPath by algorithmId:{} and update noteBook by noteBookId:{}", ptTrainAlgorithm.getId(), trainAlgorithmCreateDTO.getNoteBookId()); | |||
if (flag) { | |||
ptTrainAlgorithm.setAlgorithmStatus(AlgorithmStatusEnum.SUCCESS.getCode()); | |||
//更新fork算法新路径 | |||
trainAlgorithmMapper.updateById(ptTrainAlgorithm); | |||
//保存算法根据notbookId更新算法id | |||
if (trainAlgorithmCreateDTO.getNoteBookId() != null) { | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "Save algorithm Update algorithm ID :{} according to notBookId:{}", trainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId()); | |||
noteBookService.updateTrainIdByNoteBookId(trainAlgorithmCreateDTO.getNoteBookId(), ptTrainAlgorithm.getId()); | |||
} | |||
} else { | |||
ptTrainAlgorithm.setAlgorithmStatus(AlgorithmStatusEnum.FAIL.getCode()); | |||
trainAlgorithmMapper.updateById(ptTrainAlgorithm); | |||
} | |||
} | |||
} |
@@ -0,0 +1,417 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.async; | |||
import cn.hutool.core.util.StrUtil; | |||
import com.alibaba.fastjson.JSONObject; | |||
import org.dubhe.base.MagicNumConstant; | |||
import org.dubhe.constant.SymbolConstant; | |||
import org.dubhe.config.TrainJobConfig; | |||
import org.dubhe.dao.PtTrainJobMapper; | |||
import org.dubhe.domain.dto.BaseTrainJobDTO; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.domain.entity.PtTrainJob; | |||
import org.dubhe.domain.vo.PtImageAndAlgorithmVO; | |||
import org.dubhe.enums.BizEnum; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.enums.ResourcesPoolTypeEnum; | |||
import org.dubhe.enums.TrainJobStatusEnum; | |||
import org.dubhe.exception.BusinessException; | |||
import org.dubhe.k8s.api.DistributeTrainApi; | |||
import org.dubhe.k8s.api.NamespaceApi; | |||
import org.dubhe.k8s.api.TrainJobApi; | |||
import org.dubhe.k8s.domain.bo.DistributeTrainBO; | |||
import org.dubhe.k8s.domain.bo.PtJupyterJobBO; | |||
import org.dubhe.k8s.domain.resource.BizDistributeTrain; | |||
import org.dubhe.k8s.domain.resource.BizNamespace; | |||
import org.dubhe.k8s.domain.vo.PtJupyterJobVO; | |||
import org.dubhe.utils.*; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.stereotype.Component; | |||
import java.util.ArrayList; | |||
import java.util.List; | |||
/** | |||
* @description 提交训练任务 | |||
* @date 2020-07-17 | |||
*/ | |||
@Component | |||
public class TrainJobAsync { | |||
@Autowired | |||
private K8sNameTool k8sNameTool; | |||
@Autowired | |||
private NamespaceApi namespaceApi; | |||
@Autowired | |||
private TrainJobConfig trainJobConfig; | |||
@Autowired | |||
private NfsUtil nfsUtil; | |||
@Autowired | |||
private LocalFileUtil localFileUtil; | |||
@Autowired | |||
private PtTrainJobMapper ptTrainJobMapper; | |||
@Autowired | |||
private TrainJobApi trainJobApi; | |||
@Autowired | |||
private DistributeTrainApi distributeTrainApi; | |||
/** | |||
* 提交分布式训练 | |||
* | |||
* @param baseTrainJobDTO 训练任务信息 | |||
* @param currentUser 用户 | |||
* @param ptImageAndAlgorithmVO 镜像和算法信息 | |||
* @param ptTrainJob 训练任务实体信息 | |||
*/ | |||
public void doDistributedJob(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) { | |||
try { | |||
//判断是否存在相应的namespace,如果没有则创建 | |||
String namespace = getNamespace(currentUser); | |||
// 构建DistributeTrainBO | |||
DistributeTrainBO bo = buildDistributeTrainBO(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob, namespace); | |||
if (null == bo) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "user{}create TrainJob,Encapsulating ptjupyterjobbo object is empty,the received parameters namespace:{}", currentUser.getId(), namespace); | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, "", false); | |||
return; | |||
} | |||
// 调度K8s | |||
BizDistributeTrain bizDistributeTrain = distributeTrainApi.create(bo); | |||
if (bizDistributeTrain.isSuccess()) { | |||
// 调度成功 | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, bizDistributeTrain.getName(), true); | |||
} else { | |||
// 调度失败 | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "distributeTrainApi.create FAILED! {}", bizDistributeTrain); | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, bizDistributeTrain.getName(), false); | |||
} | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "doDistributedJob ERROR!{} ", e); | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, "", false); | |||
} | |||
} | |||
/** | |||
* 构造分布式训练DistributeTrainBO | |||
* | |||
* @param baseTrainJobDTO 训练任务信息 | |||
* @param currentUser 用户 | |||
* @param ptImageAndAlgorithmVO 镜像和算法信息 | |||
* @param ptTrainJob 训练任务实体信息 | |||
* @param namespace 命名空间 | |||
* @return DistributeTrainBO | |||
*/ | |||
private DistributeTrainBO buildDistributeTrainBO(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob, String namespace) { | |||
//绝对路径 | |||
String basePath = nfsUtil.getNfsConfig().getBucket() + trainJobConfig.getManage() + StrUtil.SLASH | |||
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); | |||
//相对路径 | |||
String relativePath = StrUtil.SLASH + trainJobConfig.getManage() + StrUtil.SLASH | |||
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); | |||
String[] codeDirArray = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH); | |||
String workspaceDir = codeDirArray[codeDirArray.length - 1]; | |||
// 算法路径待拷贝的地址 | |||
String sourcePath = nfsUtil.getNfsConfig().getBucket() + ptImageAndAlgorithmVO.getCodeDir().substring(1); | |||
String trainDir = basePath.substring(1) + StrUtil.SLASH + workspaceDir; | |||
if (!localFileUtil.copyPath(sourcePath, trainDir)) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "buildDistributeTrainBO copyPath failed ! sourcePath:{},basePath:{},trainDir:{}", sourcePath, basePath, trainDir); | |||
return null; | |||
} | |||
// 参数前缀 | |||
String paramPrefix = trainJobConfig.getPythonFormat(); | |||
// 初始化固定头命令,获取分布式节点IP | |||
StringBuilder sb = new StringBuilder("export NODE_IPS=`cat /home/hostfile.json |jq -r \".[]|.ip\"|paste -d \",\" -s` "); | |||
// 切换到算法路径下 | |||
sb.append(" && cd ").append(trainJobConfig.getDockerTrainPath()).append(StrUtil.SLASH).append(workspaceDir).append(" && "); | |||
// 拼接用户自定义python启动命令 | |||
sb.append(ptImageAndAlgorithmVO.getRunCommand()); | |||
// 拼接python固定参数 节点IP | |||
sb.append(paramPrefix).append(trainJobConfig.getNodeIps()).append("=\"$NODE_IPS\" "); | |||
// 拼接python固定参数 节点数量 | |||
sb.append(paramPrefix).append(trainJobConfig.getNodeNum()).append(SymbolConstant.FLAG_EQUAL).append(ptTrainJob.getResourcesPoolNode()).append(StrUtil.SPACE); | |||
if (ptImageAndAlgorithmVO.getIsTrainOut()) { | |||
// 拼接 out | |||
nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getOutPath()); | |||
baseTrainJobDTO.setOutPath(relativePath + StrUtil.SLASH + trainJobConfig.getOutPath()); | |||
sb.append(paramPrefix).append(trainJobConfig.getDockerOutPath()); | |||
} | |||
if (ptImageAndAlgorithmVO.getIsTrainLog()) { | |||
// 拼接 输出日志 | |||
nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getLogPath()); | |||
baseTrainJobDTO.setLogPath(relativePath + StrUtil.SLASH + trainJobConfig.getLogPath()); | |||
sb.append(paramPrefix).append(trainJobConfig.getDockerLogPath()); | |||
} | |||
if (ptImageAndAlgorithmVO.getIsVisualizedLog()) { | |||
// 拼接 输出可视化日志 | |||
nfsUtil.createDir(basePath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); | |||
baseTrainJobDTO.setVisualizedLogPath(relativePath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); | |||
sb.append(paramPrefix).append(trainJobConfig.getDockerVisualizedLogPath()); | |||
} | |||
// 拼接python固定参数 数据集 | |||
sb.append(paramPrefix).append(trainJobConfig.getDockerDataset()); | |||
JSONObject runParams = baseTrainJobDTO.getRunParams(); | |||
if (null != runParams && !runParams.isEmpty()) { | |||
// 拼接用户自定义参数 | |||
runParams.entrySet().forEach(entry -> | |||
sb.append(paramPrefix).append(entry.getKey()).append(SymbolConstant.FLAG_EQUAL).append(entry.getValue()).append(StrUtil.SPACE) | |||
); | |||
} | |||
// 在用户自定以参数拼接晚后拼接固定参数,防止被用户自定义参数覆盖 | |||
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { | |||
// 需要GPU | |||
sb.append(paramPrefix).append(trainJobConfig.getGpuNumPerNode()).append(SymbolConstant.FLAG_EQUAL).append(baseTrainJobDTO.getGpuNumPerNode()).append(StrUtil.SPACE); | |||
} | |||
String mainCommand = sb.toString(); | |||
// 拼接辅助日志打印 | |||
String wholeCommand = " echo 'Distribute training mission begins... " | |||
+ mainCommand | |||
+ " ' && " | |||
+ mainCommand | |||
+ " && echo 'Distribute training mission is over' "; | |||
DistributeTrainBO distributeTrainBO = new DistributeTrainBO() | |||
.setNamespace(namespace) | |||
.setName(baseTrainJobDTO.getJobName()) | |||
.setSize(ptTrainJob.getResourcesPoolNode()) | |||
.setImage(ptImageAndAlgorithmVO.getImageName()) | |||
.setMasterCmd(wholeCommand) | |||
.setMemNum(baseTrainJobDTO.getMenNum()) | |||
.setCpuNum(baseTrainJobDTO.getCpuNum()) | |||
.setDatasetStoragePath(k8sNameTool.getAbsoluteNfsPath(baseTrainJobDTO.getDataSourcePath())) | |||
.setWorkspaceStoragePath(localFileUtil.formatPath(nfsUtil.getNfsConfig().getRootDir() + basePath)) | |||
.setModelStoragePath(k8sNameTool.getAbsoluteNfsPath(relativePath + StrUtil.SLASH + trainJobConfig.getOutPath())) | |||
.setBusinessLabel(k8sNameTool.getPodLabel(BizEnum.ALGORITHM)); | |||
//延时启动,单位为分钟 | |||
if (baseTrainJobDTO.getDelayCreateTime() != null && baseTrainJobDTO.getDelayCreateTime() > 0) { | |||
distributeTrainBO.setDelayCreateTime(baseTrainJobDTO.getDelayCreateTime() * MagicNumConstant.SIXTY); | |||
} | |||
//定时停止,单位为分钟 | |||
if (baseTrainJobDTO.getDelayDeleteTime() != null && baseTrainJobDTO.getDelayDeleteTime() > 0) { | |||
distributeTrainBO.setDelayDeleteTime(baseTrainJobDTO.getDelayDeleteTime() * MagicNumConstant.SIXTY); | |||
} | |||
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { | |||
// 需要GPU | |||
distributeTrainBO.setGpuNum(baseTrainJobDTO.getGpuNumPerNode()); | |||
} | |||
// 主从一致 | |||
distributeTrainBO.setSlaveCmd(distributeTrainBO.getMasterCmd()); | |||
return distributeTrainBO; | |||
} | |||
/** | |||
* 提交job | |||
* | |||
* @param baseTrainJobDTO 训练任务信息 | |||
* @param currentUser 用户 | |||
* @param ptImageAndAlgorithmVO 镜像和算法信息 | |||
*/ | |||
public void doJob(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) { | |||
PtJupyterJobBO jobBo = null; | |||
String k8sJobName = ""; | |||
try { | |||
//判断是否存在相应的namespace,如果没有则创建 | |||
String namespace = getNamespace(currentUser); | |||
//封装PtJupyterJobBO对象,调用创建训练任务接口 | |||
jobBo = pkgPtJupyterJobBo(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, namespace); | |||
if (null == jobBo) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob,Encapsulating ptjupyterjobbo object is empty,the received parameters namespace:{}", currentUser.getId(), namespace); | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false); | |||
} | |||
PtJupyterJobVO ptJupyterJobResult = trainJobApi.create(jobBo); | |||
if (!ptJupyterJobResult.isSuccess()) { | |||
String message = null == ptJupyterJobResult.getMessage() ? "未知的错误" : ptJupyterJobResult.getMessage(); | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(), jobBo, message); | |||
ptTrainJob.setTrainMsg(message); | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false); | |||
} | |||
k8sJobName = ptJupyterJobResult.getName(); | |||
//更新训练任务状态 | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, true); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "user {} create TrainJob, K8s creation failed, the received parameters are {}, the wrong information is{}", currentUser.getUsername(), | |||
jobBo, e); | |||
ptTrainJob.setTrainMsg("内部错误"); | |||
updateTrainStatus(currentUser, ptTrainJob, baseTrainJobDTO, k8sJobName, false); | |||
} | |||
} | |||
/** | |||
* 获取namespace | |||
* | |||
* @param currentUser 用户 | |||
* @return String 命名空间 | |||
*/ | |||
private String getNamespace(UserDTO currentUser) { | |||
String namespaceStr = k8sNameTool.generateNamespace(currentUser.getId()); | |||
BizNamespace bizNamespace = namespaceApi.get(namespaceStr); | |||
if (null == bizNamespace) { | |||
BizNamespace namespace = namespaceApi.create(namespaceStr, null); | |||
if (null == namespace || !namespace.isSuccess()) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "User {} failed to create namespace during training job..."); | |||
throw new BusinessException("内部错误"); | |||
} | |||
} | |||
return namespaceStr; | |||
} | |||
/** | |||
* 封装出创建job所需的BO | |||
* | |||
* @param baseTrainJobDTO 训练任务信息 | |||
* @param ptImageAndAlgorithmVO 镜像和算法信息 | |||
* @param namespace 命名空间 | |||
* @return PtJupyterJobBO jupyter任务BO | |||
*/ | |||
private PtJupyterJobBO pkgPtJupyterJobBo(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, | |||
PtImageAndAlgorithmVO ptImageAndAlgorithmVO, String namespace) { | |||
//绝对路径 | |||
String commonPath = nfsUtil.getNfsConfig().getBucket() + trainJobConfig.getManage() + StrUtil.SLASH | |||
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); | |||
//相对路径 | |||
String relativeCommonPath = StrUtil.SLASH + trainJobConfig.getManage() + StrUtil.SLASH | |||
+ currentUser.getId() + StrUtil.SLASH + baseTrainJobDTO.getJobName(); | |||
String[] codeDirArray = ptImageAndAlgorithmVO.getCodeDir().split(StrUtil.SLASH); | |||
String workspaceDir = codeDirArray[codeDirArray.length - 1]; | |||
// 算法路径待拷贝的地址 | |||
String sourcePath = nfsUtil.getNfsConfig().getBucket() + ptImageAndAlgorithmVO.getCodeDir().substring(1); | |||
String trainDir = commonPath.substring(1) + StrUtil.SLASH + workspaceDir; | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "Algorithm path copy::sourcePath:{},commonPath:{},trainDir:{}", sourcePath, commonPath, trainDir); | |||
boolean bool = localFileUtil.copyPath(sourcePath.substring(1), trainDir); | |||
if (!bool) { | |||
LogUtil.error(LogEnum.BIZ_TRAIN, "During the process of user {} creating training Job and encapsulating k8s creating job interface parameters, it failed to copy algorithm directory {} to the specified directory {}", currentUser.getUsername(), sourcePath.substring(1), | |||
trainDir); | |||
return null; | |||
} | |||
List<String> list = new ArrayList<>(); | |||
JSONObject runParams = baseTrainJobDTO.getRunParams(); | |||
StringBuilder sb = new StringBuilder(); | |||
sb.append(ptImageAndAlgorithmVO.getRunCommand()); | |||
// 拼接out,log和dataset | |||
String pattern = trainJobConfig.getPythonFormat(); | |||
if (ptImageAndAlgorithmVO.getIsTrainOut()) { | |||
nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getOutPath()); | |||
baseTrainJobDTO.setOutPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getOutPath()); | |||
sb.append(pattern).append(trainJobConfig.getDockerOutPath()); | |||
} | |||
if (ptImageAndAlgorithmVO.getIsTrainLog()) { | |||
nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getLogPath()); | |||
baseTrainJobDTO.setLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getLogPath()); | |||
sb.append(pattern).append(trainJobConfig.getDockerLogPath()); | |||
} | |||
if (ptImageAndAlgorithmVO.getIsVisualizedLog()) { | |||
nfsUtil.createDir(commonPath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); | |||
baseTrainJobDTO.setVisualizedLogPath(relativeCommonPath + StrUtil.SLASH + trainJobConfig.getVisualizedLogPath()); | |||
sb.append(pattern).append(trainJobConfig.getDockerVisualizedLogPath()); | |||
} | |||
sb.append(pattern).append(trainJobConfig.getDockerDataset()); | |||
String valDataSourcePath = baseTrainJobDTO.getValDataSourcePath(); | |||
if (StringUtils.isNotBlank(valDataSourcePath)) { | |||
sb.append(pattern).append(trainJobConfig.getLoadValDatasetKey()).append(SymbolConstant.FLAG_EQUAL).append(trainJobConfig.getDockerValDatasetPath()); | |||
} | |||
//将模型加载路径拼接到 | |||
String modelLoadPathDir = baseTrainJobDTO.getModelLoadPathDir(); | |||
if (StringUtils.isNotBlank(modelLoadPathDir)) { | |||
//将模型路径model_load_dir路径 | |||
sb.append(pattern).append(trainJobConfig.getLoadKey()).append(SymbolConstant.FLAG_EQUAL).append(trainJobConfig.getDockerModelPath()); | |||
} | |||
if (null != runParams && !runParams.isEmpty()) { | |||
runParams.forEach((k, v) -> | |||
sb.append(pattern).append(k).append(SymbolConstant.FLAG_EQUAL).append(v).append(StrUtil.SPACE) | |||
); | |||
} | |||
// 在用户自定以参数拼接晚后拼接固定参数,防止被用户自定义参数覆盖 | |||
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { | |||
// 需要GPU | |||
sb.append(pattern).append(trainJobConfig.getGpuNumPerNode()).append(SymbolConstant.FLAG_EQUAL).append(baseTrainJobDTO.getGpuNumPerNode()).append(StrUtil.SPACE); | |||
} | |||
String executeCmd = sb.toString(); | |||
list.add("-c"); | |||
String workPath = trainJobConfig.getDockerTrainPath() + StrUtil.SLASH + workspaceDir; | |||
String command = "echo 'training mission begins... " + executeCmd + "\r\n '" + | |||
" && cd " + workPath + | |||
" && " + executeCmd + | |||
" && echo 'the training mission is over' "; | |||
list.add(command); | |||
PtJupyterJobBO jobBo = new PtJupyterJobBO(); | |||
jobBo.setNamespace(namespace) | |||
.setName(baseTrainJobDTO.getJobName()) | |||
.setImage(ptImageAndAlgorithmVO.getImageName()) | |||
.putNfsMounts(trainJobConfig.getDockerDatasetPath(), nfsUtil.getNfsConfig().getRootDir() + nfsUtil.getNfsConfig().getBucket().substring(1) + baseTrainJobDTO.getDataSourcePath()) | |||
.setCmdLines(list) | |||
.putNfsMounts(trainJobConfig.getDockerTrainPath(), nfsUtil.getNfsConfig().getRootDir() + commonPath.substring(1)) | |||
.putNfsMounts(trainJobConfig.getDockerModelPath(), nfsUtil.formatPath(nfsUtil.getAbsolutePath(modelLoadPathDir))) | |||
.putNfsMounts(trainJobConfig.getDockerValDatasetPath(), nfsUtil.formatPath(nfsUtil.getAbsolutePath(valDataSourcePath))) | |||
.setBusinessLabel(k8sNameTool.getPodLabel(BizEnum.ALGORITHM)); | |||
//延时启动,单位为分钟 | |||
if (baseTrainJobDTO.getDelayCreateTime() != null && baseTrainJobDTO.getDelayCreateTime() > 0) { | |||
jobBo.setDelayCreateTime(baseTrainJobDTO.getDelayCreateTime() * MagicNumConstant.SIXTY); | |||
} | |||
//自动停止,单位为分钟 | |||
if (baseTrainJobDTO.getDelayDeleteTime() != null && baseTrainJobDTO.getDelayDeleteTime() > 0) { | |||
jobBo.setDelayDeleteTime(baseTrainJobDTO.getDelayDeleteTime() * MagicNumConstant.SIXTY); | |||
} | |||
jobBo.setCpuNum(baseTrainJobDTO.getCpuNum()).setMemNum(baseTrainJobDTO.getMenNum()); | |||
if (ResourcesPoolTypeEnum.isGpuCode(baseTrainJobDTO.getPtTrainJobSpecs().getResourcesPoolType())) { | |||
jobBo.setUseGpu(true).setGpuNum(baseTrainJobDTO.getGpuNumPerNode()); | |||
} else { | |||
jobBo.setUseGpu(false); | |||
} | |||
return jobBo; | |||
} | |||
/** | |||
* 训练任务异步处理更新训练状态 | |||
* | |||
* @param user 用户 | |||
* @param ptTrainJob 训练任务 | |||
* @param baseTrainJobDTO 训练任务信息 | |||
* @param k8sJobName k8s创建的job名称,或者分布式训练名称 | |||
* @param flag 创建训练任务是否异常(true:正常,false:失败) | |||
**/ | |||
private void updateTrainStatus(UserDTO user, PtTrainJob ptTrainJob, BaseTrainJobDTO baseTrainJobDTO, String k8sJobName, boolean flag) { | |||
ptTrainJob.setK8sJobName(k8sJobName) | |||
.setOutPath(baseTrainJobDTO.getOutPath()) | |||
.setLogPath(baseTrainJobDTO.getLogPath()) | |||
.setVisualizedLogPath(baseTrainJobDTO.getVisualizedLogPath()); | |||
LogUtil.info(LogEnum.BIZ_TRAIN, "user {} training tasks are processed asynchronously to update training status,receiving parameters:{}", user.getId(), ptTrainJob); | |||
if (flag) { | |||
ptTrainJobMapper.updateById(ptTrainJob); | |||
} else { | |||
ptTrainJob.setTrainStatus(TrainJobStatusEnum.CREATE_FAILED.getStatus()); | |||
//训练任务创建失败 | |||
ptTrainJobMapper.updateById(ptTrainJob); | |||
} | |||
} | |||
} |
@@ -14,13 +14,16 @@ | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.task; | |||
package org.dubhe.async; | |||
import org.dubhe.aspect.LogAspect; | |||
import org.dubhe.base.DataContext; | |||
import org.dubhe.domain.dto.BaseTrainJobDTO; | |||
import org.dubhe.domain.dto.CommonPermissionDataDTO; | |||
import org.dubhe.domain.dto.UserDTO; | |||
import org.dubhe.domain.entity.PtTrainJob; | |||
import org.dubhe.domain.vo.PtImageAndAlgorithmVO; | |||
import org.dubhe.enums.TrainTypeEnum; | |||
import org.slf4j.MDC; | |||
import org.springframework.beans.factory.annotation.Autowired; | |||
import org.springframework.stereotype.Component; | |||
@@ -38,26 +41,32 @@ import java.util.concurrent.Executor; | |||
public class TransactionAsyncManager { | |||
@Autowired | |||
private TrainJobAsyncTask trainJobAsyncTask; | |||
@Resource(name = "trainJobAsyncExecutor") | |||
private Executor trainJobAsyncExecutor; | |||
private TrainJobAsync trainJobAsync; | |||
@Resource(name = "trainExecutor") | |||
private Executor trainExecutor; | |||
public void execute(BaseTrainJobDTO baseTrainJobDTO, UserDTO currentUser, PtImageAndAlgorithmVO ptImageAndAlgorithmVO, PtTrainJob ptTrainJob) { | |||
String traceId = MDC.get(LogAspect.TRACE_ID); | |||
CommonPermissionDataDTO commonPermissionDataDTO = DataContext.get(); | |||
TransactionSynchronizationManager.registerSynchronization(new TransactionSynchronizationAdapter() { | |||
@Override | |||
public void afterCommit() { | |||
trainJobAsyncExecutor.execute( | |||
() -> { | |||
MDC.put(LogAspect.TRACE_ID, traceId); | |||
trainJobAsyncTask.doJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob); | |||
MDC.remove(LogAspect.TRACE_ID); | |||
} | |||
); | |||
trainExecutor.execute(() -> { | |||
MDC.put(LogAspect.TRACE_ID, traceId); | |||
DataContext.set(commonPermissionDataDTO); | |||
if (TrainTypeEnum.isDistributeTrain(ptTrainJob.getTrainType())) { | |||
// 分布式训练 | |||
trainJobAsync.doDistributedJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob); | |||
} else { | |||
// 普通训练 | |||
trainJobAsync.doJob(baseTrainJobDTO, currentUser, ptImageAndAlgorithmVO, ptTrainJob); | |||
} | |||
MDC.remove(LogAspect.TRACE_ID); | |||
DataContext.remove(); | |||
}); | |||
} | |||
}); | |||
} |
@@ -0,0 +1,148 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.config; | |||
import com.alibaba.fastjson.JSON; | |||
import org.dubhe.constant.StringConstant; | |||
import org.dubhe.constatnts.UserConstant; | |||
import org.dubhe.dto.GlobalRequestRecordDTO; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.utils.JwtUtils; | |||
import org.dubhe.utils.LogUtil; | |||
import org.dubhe.utils.StringUtils; | |||
import org.springframework.core.annotation.Order; | |||
import org.springframework.stereotype.Component; | |||
import javax.servlet.*; | |||
import javax.servlet.annotation.WebFilter; | |||
import javax.servlet.http.HttpServletRequest; | |||
import javax.servlet.http.HttpServletResponse; | |||
import java.io.IOException; | |||
import static org.dubhe.constant.StringConstant.K8S_CALLBACK_URI; | |||
/** | |||
* @description 全局请求拦截器 用于日志收集 | |||
* @date 2020-08-13 | |||
*/ | |||
@Order(1) | |||
@Component | |||
@WebFilter(filterName = "GlobalFilter", urlPatterns = "/**") | |||
public class GlobalFilter implements Filter { | |||
@Override | |||
public void init(FilterConfig filterConfig) throws ServletException { | |||
} | |||
@Override | |||
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain chain) throws IOException, ServletException { | |||
long start = System.currentTimeMillis(); | |||
HttpServletRequest request = ((HttpServletRequest) servletRequest); | |||
HttpServletResponse response = ((HttpServletResponse) servletResponse); | |||
GlobalRequestRecordDTO dto = new GlobalRequestRecordDTO(); | |||
try { | |||
if (StringUtils.isNotBlank(request.getContentType()) && request.getContentType().contains(StringConstant.MULTIPART)) { | |||
chain.doFilter(request, response); | |||
} else { | |||
checkScheduleRequest(request); | |||
checkK8sCallback(request); | |||
RequestBodyWrapper requestBodyWrapper = new RequestBodyWrapper(request); | |||
ResponseBodyWrapper responseBodyWrapper = new ResponseBodyWrapper(response); | |||
dto.setRequestBody(requestBodyWrapper.getBodyString()); | |||
chain.doFilter(requestBodyWrapper, responseBodyWrapper); | |||
if (StringConstant.JSON_REQUEST.equals(responseBodyWrapper.getContentType())) { | |||
final String responseBody = responseBodyWrapper.getResponseBody(); | |||
dto.setResponseBody(responseBody); | |||
} else { | |||
responseBodyWrapper.flush(); | |||
} | |||
} | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.GLOBAL_REQ, "Global request record error : {}", e); | |||
throw e; | |||
} finally { | |||
buildGlobalRequestDTO(dto, request, response); | |||
dto.setTimeCost(System.currentTimeMillis() - start); | |||
LogUtil.info(LogEnum.GLOBAL_REQ, "Global request record: {}", dto); | |||
LogUtil.cleanTrace(); | |||
} | |||
} | |||
/** | |||
* 构建全局请求对象 | |||
* | |||
* @param dto | |||
* @param request | |||
* @param response | |||
*/ | |||
private void buildGlobalRequestDTO(GlobalRequestRecordDTO dto, HttpServletRequest request, HttpServletResponse response) { | |||
dto.setClientHost(request.getRemoteHost()); | |||
dto.setParams(JSON.toJSONString(request.getParameterMap())); | |||
dto.setMethod(request.getMethod()); | |||
dto.setUri(request.getRequestURI()); | |||
//身份认证信息 | |||
String token = request.getHeader(UserConstant.USER_TOKEN_KEY); | |||
dto.setAuthorization(token); | |||
if (token != null) { | |||
String userName = JwtUtils.getUserName(token); | |||
dto.setUsername(userName); | |||
} | |||
dto.setContentType(response.getContentType()); | |||
dto.setStatus(response.getStatus()); | |||
} | |||
/** | |||
* 检查是否是前端的定时请求 | |||
* | |||
* @param request 请求信息 | |||
* @return 是否是前端的定时请求 | |||
*/ | |||
private boolean checkScheduleRequest(HttpServletRequest request) { | |||
if (StringConstant.REQUEST_METHOD_GET.equals(request.getMethod()) | |||
&& StringUtils.isNotBlank(request.getParameter(LogUtil.SCHEDULE_LEVEL))) { | |||
LogUtil.startScheduleTrace(); | |||
return true; | |||
} | |||
return false; | |||
} | |||
/** | |||
* 校验请求是否为k8s回调 | |||
* @param request 请求信息 | |||
* @return 是否为k8s回调 | |||
*/ | |||
private boolean checkK8sCallback(HttpServletRequest request) { | |||
if (request.getRequestURI() != null && request.getRequestURI().contains(K8S_CALLBACK_URI)) { | |||
LogUtil.startK8sCallbackTrace(); | |||
return true; | |||
} | |||
return false; | |||
} | |||
@Override | |||
public void destroy() { | |||
} | |||
} |
@@ -0,0 +1,101 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.config; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.utils.LogUtil; | |||
import javax.servlet.ReadListener; | |||
import javax.servlet.ServletInputStream; | |||
import javax.servlet.http.HttpServletRequest; | |||
import javax.servlet.http.HttpServletRequestWrapper; | |||
import java.io.BufferedReader; | |||
import java.io.ByteArrayInputStream; | |||
import java.io.IOException; | |||
import java.io.InputStreamReader; | |||
/** | |||
* @description 用于获取请求body参数的包装类 | |||
* @date 2020-08-20 | |||
*/ | |||
public class RequestBodyWrapper extends HttpServletRequestWrapper { | |||
private ByteArrayInputStream byteArrayInputStream; | |||
private String bodyString; | |||
public RequestBodyWrapper(HttpServletRequest request) | |||
throws IOException { | |||
super(request); | |||
try (BufferedReader reader = request.getReader()) { | |||
StringBuilder sb = new StringBuilder(); | |||
String line = null; | |||
while ((line = reader.readLine()) != null) { | |||
sb.append(line); | |||
} | |||
if (sb.length() > 0) { | |||
bodyString = sb.toString(); | |||
} else { | |||
bodyString = ""; | |||
} | |||
byteArrayInputStream = new ByteArrayInputStream(bodyString.getBytes()); | |||
} catch (Exception e) { | |||
LogUtil.error(LogEnum.GLOBAL_REQ, "request get reader error : {}", e); | |||
throw e; | |||
} | |||
} | |||
/** | |||
* 获取请求体的json数据 | |||
* | |||
* @return | |||
*/ | |||
public String getBodyString() { | |||
return bodyString; | |||
} | |||
@Override | |||
public BufferedReader getReader() throws IOException { | |||
return new BufferedReader(new InputStreamReader(getInputStream())); | |||
} | |||
@Override | |||
public ServletInputStream getInputStream() throws IOException { | |||
return new ServletInputStream() { | |||
@Override | |||
public boolean isFinished() { | |||
return false; | |||
} | |||
@Override | |||
public boolean isReady() { | |||
return false; | |||
} | |||
@Override | |||
public void setReadListener(ReadListener readListener) { | |||
} | |||
@Override | |||
public int read() throws IOException { | |||
return byteArrayInputStream.read(); | |||
} | |||
}; | |||
} | |||
} |
@@ -0,0 +1,126 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.config; | |||
import org.dubhe.enums.LogEnum; | |||
import org.dubhe.utils.LogUtil; | |||
import java.io.ByteArrayOutputStream; | |||
import java.io.IOException; | |||
import java.io.OutputStreamWriter; | |||
import java.io.PrintWriter; | |||
import java.io.UnsupportedEncodingException; | |||
import javax.servlet.ServletOutputStream; | |||
import javax.servlet.WriteListener; | |||
import javax.servlet.http.HttpServletResponse; | |||
import javax.servlet.http.HttpServletResponseWrapper; | |||
/** | |||
* @description 用于获取response 的 json 返回值 | |||
* @date 2020-08-19 | |||
*/ | |||
public class ResponseBodyWrapper extends HttpServletResponseWrapper { | |||
private ByteArrayOutputStream byteArrayOutputStream = null; | |||
private ServletOutputStream servletOutputStream = null; | |||
private PrintWriter printWriter = null; | |||
private HttpServletResponse response; | |||
public ResponseBodyWrapper(HttpServletResponse response) throws IOException { | |||
super(response); | |||
this.response = response; | |||
byteArrayOutputStream = new ByteArrayOutputStream(); | |||
printWriter = new PrintWriter(new OutputStreamWriter(byteArrayOutputStream, "UTF-8")); | |||
servletOutputStream = new ServletOutputStream() { | |||
@Override | |||
public void write(int b) throws IOException { | |||
byteArrayOutputStream.write(b); | |||
} | |||
@Override | |||
public boolean isReady() { | |||
return false; | |||
} | |||
@Override | |||
public void setWriteListener(WriteListener writeListener) { | |||
} | |||
}; | |||
} | |||
@Override | |||
public ServletOutputStream getOutputStream() throws IOException { | |||
return servletOutputStream; | |||
} | |||
@Override | |||
public PrintWriter getWriter() throws IOException { | |||
return printWriter; | |||
} | |||
@Override | |||
public void flushBuffer() throws IOException { | |||
if (servletOutputStream != null) { | |||
servletOutputStream.flush(); | |||
} | |||
if (printWriter != null) { | |||
printWriter.flush(); | |||
} | |||
} | |||
@Override | |||
public void reset() { | |||
byteArrayOutputStream.reset(); | |||
} | |||
/** | |||
* 获取json返回值 | |||
* @return | |||
* @throws IOException | |||
*/ | |||
public String getResponseBody() throws IOException { | |||
//清空response的流,之后再添加进去 | |||
flushBuffer(); | |||
byte[] bytes = byteArrayOutputStream.toByteArray(); | |||
try { | |||
return new String(bytes, "UTF-8"); | |||
} catch (UnsupportedEncodingException e) { | |||
LogUtil.error(LogEnum.GLOBAL_REQ, e); | |||
} finally { | |||
response.getOutputStream().write(bytes); | |||
} | |||
return ""; | |||
} | |||
/** | |||
* 清掉缓冲 | |||
* @throws IOException | |||
*/ | |||
public void flush() throws IOException { | |||
//清空response的流,之后再添加进去 | |||
flushBuffer(); | |||
byte[] bytes = byteArrayOutputStream.toByteArray(); | |||
response.getOutputStream().write(bytes); | |||
} | |||
} |
@@ -24,7 +24,7 @@ import org.springframework.stereotype.Component; | |||
import java.sql.Timestamp; | |||
/** | |||
* @description: 转换时间戳类型 | |||
* @description 转换时间戳类型 | |||
* @date 2020-05-22 | |||
*/ | |||
@Component | |||
@@ -0,0 +1,47 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.dao; | |||
import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
import org.apache.ibatis.annotations.Param; | |||
import org.apache.ibatis.annotations.Select; | |||
import org.dubhe.domain.dto.ModelQueryDTO; | |||
import org.dubhe.domain.entity.ModelQuery; | |||
import org.dubhe.domain.entity.ModelQueryBrance; | |||
/** | |||
* @description model mapper | |||
* @date 2020-10-09 | |||
*/ | |||
public interface ModelQueryMapper extends BaseMapper<ModelQueryDTO> { | |||
/** | |||
* 根据modelId查询模型信息 | |||
* | |||
* @param modelId 模型id | |||
* @return modelQuery返回查询的模型对象 | |||
*/ | |||
@Select("select name,url from pt_model_info where id=#{modelId}") | |||
ModelQuery findModelNameById(@Param("modelId") Integer modelId); | |||
/** | |||
* 根据模型路径查询模型版本信息 | |||
* | |||
* @param modelLoadPathDir 模型路径 | |||
* @return ModelQueryBrance 模型版本信息 | |||
*/ | |||
@Select("select version from pt_model_branch where url=#{modelLoadPathDir}") | |||
ModelQueryBrance findModelVersionByUrl(@Param("modelLoadPathDir") String modelLoadPathDir); | |||
} |
@@ -21,6 +21,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; | |||
import org.apache.ibatis.annotations.Param; | |||
import org.apache.ibatis.annotations.Select; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.domain.entity.NoteBook; | |||
import java.util.List; | |||
@@ -29,28 +30,27 @@ import java.util.List; | |||
* @description notebook mapper | |||
* @date 2020-04-28 | |||
*/ | |||
@DataPermission(ignoresMethod = {"insert","findByNamespaceAndResourceName","selectRunNotUrlList"}) | |||
public interface NoteBookMapper extends BaseMapper<NoteBook> { | |||
/** | |||
* 根据名称查询 | |||
* | |||
* @param name | |||
* @param userId | |||
* @param status | |||
* @return NoteBook | |||
*/ | |||
@Select("select * from notebook where notebook_name = #{name} and user_id = #{userId} and status != #{status} and deleted = 0 limit 1") | |||
NoteBook findByNameAndUserId(@Param("name") String name, @Param("userId") long userId, @Param("status") Integer status); | |||
@Select("select * from notebook where notebook_name = #{name} and status != #{status} and deleted = 0 limit 1") | |||
NoteBook findByNameAndStatus(@Param("name") String name, @Param("status") Integer status); | |||
/** | |||
* 查询正在运行的notebook数量 | |||
* | |||
* @param userId | |||
* @param status | |||
* @return int | |||
*/ | |||
@Select("select count(1) from notebook where user_id = #{userId} and status = #{status} and deleted = 0") | |||
int selectRunNoteBookNum(@Param("userId") long userId, @Param("status") Integer status); | |||
@Select("select count(1) from notebook where status = #{status} and deleted = 0") | |||
int selectRunNoteBookNum( @Param("status") Integer status); | |||
/** | |||
* 根据namespace + resourceName查询 | |||
@@ -19,12 +19,14 @@ package org.dubhe.dao; | |||
import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.domain.entity.PtImage; | |||
/** | |||
* @description 镜像 Mapper 接口 | |||
* @date 2020-04-27 | |||
*/ | |||
@DataPermission(ignoresMethod = {"insert"}) | |||
public interface PtImageMapper extends BaseMapper<PtImage> { | |||
} |
@@ -20,6 +20,7 @@ package org.dubhe.dao; | |||
import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
import org.apache.ibatis.annotations.Param; | |||
import org.apache.ibatis.annotations.Select; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.domain.entity.PtTrainAlgorithm; | |||
import java.util.List; | |||
@@ -28,6 +29,7 @@ import java.util.List; | |||
* @description 训练算法Mapper | |||
* @date 2020-04-27 | |||
*/ | |||
@DataPermission(ignoresMethod = {"insert"}) | |||
public interface PtTrainAlgorithmMapper extends BaseMapper<PtTrainAlgorithm> { | |||
/** | |||
@@ -17,17 +17,17 @@ | |||
package org.dubhe.dao; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.domain.entity.PtTrainAlgorithmUsage; | |||
import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
/** | |||
* | |||
* 用户辅助信息Mapper 接口 | |||
* </p> | |||
* | |||
* @since 2020-06-23 | |||
* @description 用户辅助信息Mapper 接口 | |||
* @date 2020-06-23 | |||
*/ | |||
@DataPermission(ignoresMethod = "insert") | |||
public interface PtTrainAlgorithmUsageMapper extends BaseMapper<PtTrainAlgorithmUsage> { | |||
} |
@@ -21,6 +21,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
import com.baomidou.mybatisplus.extension.plugins.pagination.Page; | |||
import org.apache.ibatis.annotations.Param; | |||
import org.apache.ibatis.annotations.Select; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.domain.entity.PtTrainJob; | |||
import org.dubhe.domain.vo.PtTrainVO; | |||
@@ -28,6 +29,7 @@ import org.dubhe.domain.vo.PtTrainVO; | |||
* @description 训练作业job Mapper 接口 | |||
* @date 2020-04-27 | |||
*/ | |||
@DataPermission(ignoresMethod = {"insert","selectCountByStatus","getPageTrain"}) | |||
public interface PtTrainJobMapper extends BaseMapper<PtTrainJob> { | |||
/** | |||
@@ -18,12 +18,14 @@ | |||
package org.dubhe.dao; | |||
import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.domain.entity.PtTrain; | |||
/** | |||
* @description 训练作业主 Mapper 接口 | |||
* @date 2020-04-27 | |||
*/ | |||
@DataPermission(ignoresMethod = {"insert"}) | |||
public interface PtTrainMapper extends BaseMapper<PtTrain> { | |||
} |
@@ -17,6 +17,7 @@ | |||
package org.dubhe.dao; | |||
import org.dubhe.annotation.DataPermission; | |||
import org.dubhe.domain.entity.PtTrainParam; | |||
import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
@@ -24,6 +25,7 @@ import com.baomidou.mybatisplus.core.mapper.BaseMapper; | |||
* @description 任务参数 Mapper 接口 | |||
* @date 2020-04-27 | |||
*/ | |||
@DataPermission(ignoresMethod = "insert") | |||
public interface PtTrainParamMapper extends BaseMapper<PtTrainParam> { | |||
} |
@@ -40,4 +40,40 @@ public class BaseTrainJobDTO implements Serializable { | |||
private String outPath; | |||
private String logPath; | |||
private String visualizedLogPath; | |||
private Integer delayCreateTime; | |||
private Integer delayDeleteTime; | |||
/** | |||
* @return 每个节点的GPU数量 | |||
*/ | |||
public Integer getGpuNumPerNode(){ | |||
return getPtTrainJobSpecs().getSpecsInfo().getInteger("gpuNum"); | |||
} | |||
/** | |||
* @return cpu数量 | |||
*/ | |||
public Integer getCpuNum(){ | |||
return getPtTrainJobSpecs().getSpecsInfo().getInteger("cpuNum"); | |||
} | |||
/** | |||
* @return memNum | |||
*/ | |||
public Integer getMenNum(){ | |||
return getPtTrainJobSpecs().getSpecsInfo().getInteger("memNum"); | |||
} | |||
/** | |||
* "验证数据来源名称" | |||
*/ | |||
private String valDataSourceName; | |||
/** | |||
* 验证数据来源路径 | |||
*/ | |||
private String valDataSourcePath; | |||
/** | |||
* 模型路径 | |||
*/ | |||
private String modelLoadPathDir; | |||
} |
@@ -0,0 +1,39 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.domain.dto; | |||
import io.swagger.annotations.ApiModel; | |||
import io.swagger.annotations.ApiModelProperty; | |||
import lombok.Data; | |||
/** | |||
* @description model 查询dto | |||
* @date 2020-10-09 | |||
*/ | |||
@Data | |||
@ApiModel("模型查询") | |||
public class ModelQueryDTO { | |||
@ApiModelProperty(value = "模型类型") | |||
private Integer modelResource; | |||
@ApiModelProperty(value = "模型名称") | |||
private String name; | |||
@ApiModelProperty(value = "模型版本") | |||
private String version; | |||
@ApiModelProperty(value = "模型id") | |||
private Integer id; | |||
@ApiModelProperty(value = "模型路径") | |||
private String modelPath; | |||
} |
@@ -30,7 +30,6 @@ import java.io.Serializable; | |||
@Data | |||
public class NoteBookListQueryDTO implements Serializable { | |||
@Query(propName = "status", type = Query.Type.EQ) | |||
@ApiModelProperty("0运行中,1停止, 2删除, 3启动中,4停止中,5删除中,6运行异常(暂未启用)") | |||
private Integer status; | |||
@@ -38,7 +37,7 @@ public class NoteBookListQueryDTO implements Serializable { | |||
@ApiModelProperty("notebook名称") | |||
private String noteBookName; | |||
@Query(propName = "user_id", type = Query.Type.EQ) | |||
@Query(propName = "origin_user_id", type = Query.Type.EQ) | |||
@ApiModelProperty(value = "所属用户ID", hidden = true) | |||
private Long userId; | |||
@@ -42,7 +42,7 @@ public class NoteBookQueryDTO implements Serializable { | |||
@Query(propName = "k8s_pvc_path", type = Query.Type.EQ) | |||
private String k8sPvcPath; | |||
@Query(propName = "user_id", type = Query.Type.EQ) | |||
@Query(propName = "origin_user_id", type = Query.Type.EQ) | |||
private Long userId; | |||
@Query(propName = "last_operation_timeout", type = Query.Type.LT) | |||
@@ -66,4 +66,6 @@ public class PtImageDTO implements Serializable { | |||
@ApiModelProperty("删除(0正常,1已删除)") | |||
private Boolean deleted; | |||
@ApiModelProperty("资源拥有者ID") | |||
private Long originUserId; | |||
} |
@@ -0,0 +1,39 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.domain.dto; | |||
import io.swagger.annotations.ApiModelProperty; | |||
import lombok.Data; | |||
import lombok.experimental.Accessors; | |||
import javax.validation.constraints.NotNull; | |||
import java.io.Serializable; | |||
import java.util.List; | |||
/** | |||
* @description 训练镜像删除DTO | |||
* @date 2020-08-13 | |||
*/ | |||
@Data | |||
@Accessors(chain = true) | |||
public class PtImageDeleteDTO implements Serializable { | |||
private static final long serialVersionUID = 1L; | |||
@ApiModelProperty(value = "id", required = true) | |||
@NotNull(message = "镜像id不能为空") | |||
private List<Long> ids; | |||
} |
@@ -40,4 +40,7 @@ public class PtImageQueryDTO extends PageQueryBase implements Serializable { | |||
@ApiModelProperty(value = "镜像状态,0为制作中,1位制作成功,2位制作失败") | |||
private Integer imageStatus; | |||
@ApiModelProperty(value = "镜像名称或id") | |||
private String imageNameOrId; | |||
} |
@@ -0,0 +1,47 @@ | |||
/** | |||
* Copyright 2020 Zhejiang Lab. All Rights Reserved. | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
* ============================================================= | |||
*/ | |||
package org.dubhe.domain.dto; | |||
import io.swagger.annotations.ApiModelProperty; | |||
import lombok.Data; | |||
import lombok.experimental.Accessors; | |||
import org.dubhe.utils.TrainUtil; | |||
import org.hibernate.validator.constraints.Length; | |||
import javax.validation.constraints.NotNull; | |||
import java.io.Serializable; | |||
import java.util.List; | |||
/** | |||
* @description 训练镜像信息修改DTO | |||
* @date 2020-08-13 | |||
*/ | |||
@Data | |||
@Accessors(chain = true) | |||
public class PtImageUpdateDTO implements Serializable { | |||
private static final long serialVersionUID = 1L; | |||
@ApiModelProperty(value = "id", required = true) | |||
@NotNull(message = "镜像id不能为空") | |||
private List<Long> ids; | |||
@ApiModelProperty("镜像描述") | |||
@Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "镜像描述-输入长度不能超过1024个字符") | |||
private String remark; | |||
} |
@@ -55,7 +55,7 @@ public class PtImageUploadDTO implements Serializable { | |||
@Pattern(regexp = TrainUtil.REGEXP_TAG, message = "镜像版本号支持字母、数字、英文横杠、英文.号和下划线") | |||
private String imageTag; | |||
@ApiModelProperty("备注") | |||
@Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "备注-输入长度不能超过1024个字符") | |||
@ApiModelProperty("镜像描述") | |||
@Length(max = TrainUtil.NUMBER_ONE_THOUSAND_AND_TWENTY_FOUR, message = "镜像描述-输入长度不能超过1024个字符") | |||
private String remark; | |||
} |
@@ -85,4 +85,7 @@ public class PtTrainAlgorithmCreateDTO implements Serializable { | |||
@ApiModelProperty("noteBookId") | |||
private Long noteBookId; | |||
@ApiModelProperty("资源拥有者ID") | |||
private Long originUserId; | |||
} |